|
|
|
|
|
|
|
|
import functools |
|
|
import logging |
|
|
from collections.abc import Sequence |
|
|
from typing import Any, Callable, Optional, TYPE_CHECKING |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch._logging import warning_once |
|
|
from torch.autograd import Variable |
|
|
from torch.autograd.graph import _MultiHandle |
|
|
from torch.distributed._composable_state import ( |
|
|
_get_module_state, |
|
|
_insert_module_state, |
|
|
_State, |
|
|
) |
|
|
from torch.distributed.device_mesh import _get_device_handle |
|
|
from torch.distributed.utils import _to_kwargs |
|
|
from torch.utils._pytree import tree_flatten, tree_map |
|
|
|
|
|
from ._fsdp_api import MixedPrecisionPolicy |
|
|
from ._fsdp_common import ( |
|
|
_cast_fp_tensor, |
|
|
compiled_autograd_enabled, |
|
|
detect_compiled_autograd, |
|
|
TrainingState, |
|
|
) |
|
|
from ._fsdp_param_group import FSDPCommContext, FSDPParamGroup |
|
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from ._fsdp_param import FSDPParam |
|
|
|
|
|
|
|
|
logger = logging.getLogger("torch.distributed.fsdp.fully_shard") |
|
|
|
|
|
|
|
|
class FSDPStateContext: |
|
|
"""This has state shared across FSDP states.""" |
|
|
|
|
|
def __init__(self) -> None: |
|
|
|
|
|
self.all_states: list[FSDPState] = [] |
|
|
|
|
|
|
|
|
|
|
|
self.iter_forward_root: Optional[FSDPState] = None |
|
|
|
|
|
self.post_backward_final_callback_queued: bool = False |
|
|
|
|
|
self.is_last_backward: bool = True |
|
|
|
|
|
|
|
|
self.post_optim_event: Optional[torch.Event] = None |
|
|
|
|
|
|
|
|
def disable_if_config_true(func): |
|
|
@functools.wraps(func) |
|
|
def fsdp_hook_wrapper(*args, **kwargs): |
|
|
if torch._dynamo.config.skip_fsdp_hooks: |
|
|
return torch._dynamo.disable(func, recursive=True)(*args, **kwargs) |
|
|
else: |
|
|
return func(*args, **kwargs) |
|
|
|
|
|
return fsdp_hook_wrapper |
|
|
|
|
|
|
|
|
class FSDPState(_State): |
|
|
def __init__(self) -> None: |
|
|
super().__init__() |
|
|
self._fsdp_param_group: Optional[FSDPParamGroup] = None |
|
|
self._is_root: Optional[bool] = None |
|
|
self._state_ctx = FSDPStateContext() |
|
|
self._comm_ctx = FSDPCommContext() |
|
|
self._training_state: TrainingState = TrainingState.IDLE |
|
|
self._states_to_forward_prefetch: list[FSDPState] = [] |
|
|
self._states_to_backward_prefetch: list[FSDPState] = [] |
|
|
self._modules_to_run_forward: set[nn.Module] = set() |
|
|
|
|
|
|
|
|
def init( |
|
|
self, |
|
|
modules: tuple[nn.Module, ...], |
|
|
device: torch.device, |
|
|
mp_policy: MixedPrecisionPolicy, |
|
|
) -> None: |
|
|
for module in modules: |
|
|
_insert_module_state(module, self) |
|
|
self._modules = modules |
|
|
self._device = device |
|
|
self._device_handle = _get_device_handle(device.type) |
|
|
self._mp_policy = mp_policy |
|
|
if len(modules) == 1: |
|
|
self._pre_forward_hook_handle = modules[0].register_forward_pre_hook( |
|
|
self._pre_forward, prepend=True, with_kwargs=True |
|
|
) |
|
|
self._post_forward_hook_handle = modules[0].register_forward_hook( |
|
|
self._post_forward, prepend=False |
|
|
) |
|
|
else: |
|
|
hook_handle = _register_group_forward_hooks( |
|
|
modules, |
|
|
self._pre_forward, |
|
|
self._post_forward, |
|
|
self._modules_to_run_forward, |
|
|
) |
|
|
self._pre_forward_hook_handle = hook_handle |
|
|
self._post_forward_hook_handle = hook_handle |
|
|
|
|
|
def _root_pre_forward( |
|
|
self, module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any] |
|
|
) -> tuple[tuple[Any, ...], dict[str, Any]]: |
|
|
self._lazy_init() |
|
|
if self._state_ctx.iter_forward_root is not None: |
|
|
return args, kwargs |
|
|
if not compiled_autograd_enabled(): |
|
|
logger.debug("FSDP::root_pre_forward") |
|
|
self._state_ctx.iter_forward_root = self |
|
|
with torch.profiler.record_function("FSDP::root_pre_forward"): |
|
|
|
|
|
if (event := self._state_ctx.post_optim_event) is not None: |
|
|
self._comm_ctx.all_gather_copy_in_stream.wait_event(event) |
|
|
self._comm_ctx.all_gather_stream.wait_event(event) |
|
|
self._state_ctx.post_optim_event = None |
|
|
else: |
|
|
current_stream = self._device_handle.current_stream() |
|
|
self._comm_ctx.all_gather_copy_in_stream.wait_stream(current_stream) |
|
|
self._comm_ctx.all_gather_stream.wait_stream(current_stream) |
|
|
if self._device.type in ["cuda", "hpu", "xpu", "mtia"]: |
|
|
with torch.profiler.record_function("FSDP::inputs_to_device"): |
|
|
args_tuple, kwargs_tuple = _to_kwargs( |
|
|
args, kwargs, self._device, False |
|
|
) |
|
|
args, kwargs = args_tuple[0], kwargs_tuple[0] |
|
|
return args, kwargs |
|
|
|
|
|
def _lazy_init(self) -> None: |
|
|
""" |
|
|
Lazy initialization represents when all modules' parallelisms have |
|
|
finalized (e.g. FSDP has been applied to all desired modules). This |
|
|
means that we can determine which state is the root, and we do so by |
|
|
the 1st state to run forward. |
|
|
""" |
|
|
if self._is_root is not None: |
|
|
return |
|
|
self._is_root = True |
|
|
if len(self._modules) > 1: |
|
|
raise RuntimeError( |
|
|
f"FSDP requires a single root module but got {self._modules}" |
|
|
) |
|
|
detect_compiled_autograd() |
|
|
root_module = self._modules[0] |
|
|
visited_states: set[FSDPState] = set() |
|
|
for module_name, module in root_module.named_modules(): |
|
|
if (state := _get_module_fsdp_state(module)) is None: |
|
|
continue |
|
|
if module is not root_module: |
|
|
if state not in visited_states and state._is_root is not None: |
|
|
raise RuntimeError( |
|
|
"FSDP state has already been lazily initialized for " |
|
|
f"{module_name}\nFSDP requires running forward through " |
|
|
"the root module first" |
|
|
) |
|
|
state._is_root = False |
|
|
self._state_ctx.all_states.append(state) |
|
|
visited_states.add(state) |
|
|
if self._fsdp_param_group: |
|
|
|
|
|
|
|
|
self._fsdp_param_group.post_forward_mesh_info = None |
|
|
self._init_fqns() |
|
|
self._init_shared_state() |
|
|
|
|
|
|
|
|
for state in self._state_ctx.all_states: |
|
|
if state._fsdp_param_group: |
|
|
state._fsdp_param_group.lazy_init() |
|
|
|
|
|
def _init_shared_state(self) -> None: |
|
|
self._comm_ctx.lazy_init(self._device) |
|
|
for state in self._state_ctx.all_states: |
|
|
state._state_ctx = self._state_ctx |
|
|
state._comm_ctx = self._comm_ctx |
|
|
if fsdp_param_group := state._fsdp_param_group: |
|
|
fsdp_param_group.comm_ctx = self._comm_ctx |
|
|
|
|
|
def _init_fqns(self) -> None: |
|
|
"""Sets module and parameter FQN attributes for debugging.""" |
|
|
assert self._is_root |
|
|
root_module = self._modules[0] |
|
|
param_to_fsdp_param: dict[nn.Parameter, FSDPParam] = {} |
|
|
module_to_fsdp_param_group: dict[nn.Module, FSDPParamGroup] = {} |
|
|
for state in self._state_ctx.all_states: |
|
|
if fsdp_param_group := state._fsdp_param_group: |
|
|
for fsdp_param in fsdp_param_group.fsdp_params: |
|
|
param_to_fsdp_param[fsdp_param.sharded_param] = fsdp_param |
|
|
for module in fsdp_param_group.modules: |
|
|
module_to_fsdp_param_group[module] = fsdp_param_group |
|
|
for param_name, param in root_module.named_parameters(): |
|
|
if param in param_to_fsdp_param: |
|
|
param_to_fsdp_param[param]._param_fqn = param_name |
|
|
for module_name, module in root_module.named_modules(): |
|
|
if module in module_to_fsdp_param_group: |
|
|
module_fqn = module_to_fsdp_param_group[module]._module_fqn |
|
|
if module_fqn is None: |
|
|
module_to_fsdp_param_group[module]._module_fqn = module_name |
|
|
else: |
|
|
assert isinstance(module_fqn, str), f"{module_fqn}" |
|
|
module_fqn += f", {module_name}" |
|
|
module_to_fsdp_param_group[module]._module_fqn = module_fqn |
|
|
|
|
|
@disable_if_config_true |
|
|
def _pre_forward( |
|
|
self, module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any] |
|
|
) -> tuple[tuple[Any, ...], dict[str, Any]]: |
|
|
|
|
|
|
|
|
if self._training_state == TrainingState.PRE_BACKWARD: |
|
|
return args, kwargs |
|
|
self._training_state = TrainingState.FORWARD |
|
|
args, kwargs = self._root_pre_forward(module, args, kwargs) |
|
|
if self._mp_policy.cast_forward_inputs and self._mp_policy.param_dtype: |
|
|
with torch.profiler.record_function("FSDP::cast_forward_inputs"): |
|
|
cast_fn = functools.partial( |
|
|
_cast_fp_tensor, self._mp_policy.param_dtype |
|
|
) |
|
|
args, kwargs = tree_map(cast_fn, args), tree_map(cast_fn, kwargs) |
|
|
if self._fsdp_param_group: |
|
|
args, kwargs = self._fsdp_param_group.pre_forward(module, args, kwargs) |
|
|
for fsdp_state in self._states_to_forward_prefetch: |
|
|
if (target_param_group := fsdp_state._fsdp_param_group) is not None: |
|
|
FSDPParamGroup._prefetch_unshard(target_param_group, "forward") |
|
|
return args, kwargs |
|
|
|
|
|
@disable_if_config_true |
|
|
def _post_forward(self, module: nn.Module, input: Any, output: Any) -> Any: |
|
|
|
|
|
|
|
|
if self._training_state == TrainingState.PRE_BACKWARD: |
|
|
return output |
|
|
if self._fsdp_param_group: |
|
|
output = self._fsdp_param_group.post_forward(module, input, output) |
|
|
output = self._register_pre_backward_hook(output) |
|
|
self._training_state = TrainingState.IDLE |
|
|
if self._state_ctx.iter_forward_root is self: |
|
|
if all_gather_state := self._comm_ctx.all_gather_state: |
|
|
|
|
|
|
|
|
self._comm_ctx.all_gather_copy_in_stream.wait_event( |
|
|
all_gather_state.event |
|
|
) |
|
|
self._comm_ctx.all_gather_stream.wait_event(all_gather_state.event) |
|
|
self._comm_ctx.all_gather_state = None |
|
|
self._state_ctx.iter_forward_root = None |
|
|
if self._mp_policy.output_dtype is not None: |
|
|
with torch.profiler.record_function("FSDP::cast_forward_outputs"): |
|
|
output = tree_map( |
|
|
functools.partial(_cast_fp_tensor, self._mp_policy.output_dtype), |
|
|
output, |
|
|
) |
|
|
return output |
|
|
|
|
|
def _pre_backward(self, grad: torch.Tensor) -> torch.Tensor: |
|
|
self._training_state = TrainingState.PRE_BACKWARD |
|
|
self._register_root_post_backward_final_callback() |
|
|
if self._fsdp_param_group: |
|
|
default_prefetch = len(self._states_to_backward_prefetch) == 0 |
|
|
self._fsdp_param_group.pre_backward(default_prefetch) |
|
|
for fsdp_state in self._states_to_backward_prefetch: |
|
|
if (target_param_group := fsdp_state._fsdp_param_group) is not None: |
|
|
FSDPParamGroup._prefetch_unshard(target_param_group, "backward") |
|
|
return grad |
|
|
|
|
|
def _root_post_backward_final_callback(self) -> None: |
|
|
if not compiled_autograd_enabled(): |
|
|
logger.debug("FSDP::root_post_backward") |
|
|
with torch.profiler.record_function("FSDP::root_post_backward_callback"): |
|
|
for state in self._state_ctx.all_states: |
|
|
fsdp_param_group = state._fsdp_param_group |
|
|
if ( |
|
|
fsdp_param_group |
|
|
and fsdp_param_group._training_state != TrainingState.POST_BACKWARD |
|
|
): |
|
|
|
|
|
|
|
|
fsdp_param_group.post_backward() |
|
|
state._training_state = TrainingState.IDLE |
|
|
if fsdp_param_group: |
|
|
fsdp_param_group._training_state = TrainingState.IDLE |
|
|
if self._state_ctx.is_last_backward: |
|
|
state._finalize_backward() |
|
|
if self._state_ctx.is_last_backward: |
|
|
self._comm_ctx.post_forward_order.clear() |
|
|
if self._comm_ctx.reduce_scatter_state is not None: |
|
|
self._device_handle.current_stream().wait_event( |
|
|
self._comm_ctx.reduce_scatter_state.event |
|
|
) |
|
|
self._comm_ctx.reduce_scatter_state = None |
|
|
self._state_ctx.post_backward_final_callback_queued = False |
|
|
|
|
|
def _finalize_backward(self) -> None: |
|
|
if self._modules_to_run_forward: |
|
|
msg = ( |
|
|
f"{len(self._modules_to_run_forward)} of the {len(self._modules)} " |
|
|
f"modules passed to fully_shard did not run forward before backward, " |
|
|
"which is error-prone since FSDP post-forward/pre-backward logic " |
|
|
"will not run for these modules. We recommend passing only modules " |
|
|
"that run forward together. Modules that did not run forward: " |
|
|
f"{list(self._modules_to_run_forward)}" |
|
|
) |
|
|
warning_once(logger, msg, stacklevel=2) |
|
|
|
|
|
self._modules_to_run_forward.clear() |
|
|
if self._fsdp_param_group: |
|
|
self._fsdp_param_group.finalize_backward() |
|
|
|
|
|
def _register_pre_backward_hook(self, output: Any) -> Any: |
|
|
if not torch.is_grad_enabled(): |
|
|
return output |
|
|
flat_outputs, _ = tree_flatten(output) |
|
|
for t in flat_outputs: |
|
|
if torch.is_tensor(t) and t.requires_grad: |
|
|
t.register_hook(self._pre_backward) |
|
|
return output |
|
|
|
|
|
def _register_root_post_backward_final_callback(self): |
|
|
if self._state_ctx.post_backward_final_callback_queued: |
|
|
return |
|
|
self._state_ctx.post_backward_final_callback_queued = True |
|
|
Variable._execution_engine.queue_callback( |
|
|
self._root_post_backward_final_callback |
|
|
) |
|
|
|
|
|
|
|
|
def _get_module_fsdp_state(module: nn.Module) -> Optional[FSDPState]: |
|
|
state = _get_module_state(module) |
|
|
if isinstance(state, FSDPState): |
|
|
return state |
|
|
return None |
|
|
|
|
|
|
|
|
def _register_group_forward_hooks( |
|
|
modules: Sequence[nn.Module], |
|
|
pre_hook: Callable, |
|
|
post_hook: Callable, |
|
|
modules_to_run: set[nn.Module], |
|
|
): |
|
|
""" |
|
|
Registers group forward pre and post-hooks. The pre-hook runs upon the |
|
|
first module pre-forward, and the post-hook runs upon the last. If at least |
|
|
one module does not run forward, then the post-hook does not run. |
|
|
""" |
|
|
modules_set = set(modules) |
|
|
|
|
|
@disable_if_config_true |
|
|
@functools.wraps(pre_hook) |
|
|
def wrapped_pre_hook(*args: Any, **kwargs: Any): |
|
|
if len(modules_to_run) == 0: |
|
|
modules_to_run.update(modules_set) |
|
|
return pre_hook(*args, **kwargs) |
|
|
|
|
|
@disable_if_config_true |
|
|
def get_wrapped_post_hook(module: nn.Module): |
|
|
@functools.wraps(post_hook) |
|
|
def wrapped_post_hook(*args: Any, **kwargs: Any): |
|
|
modules_to_run.discard(module) |
|
|
if len(modules_to_run) == 0: |
|
|
return post_hook(*args, **kwargs) |
|
|
|
|
|
return wrapped_post_hook |
|
|
|
|
|
pre_handles = [ |
|
|
module.register_forward_pre_hook( |
|
|
wrapped_pre_hook, prepend=True, with_kwargs=True |
|
|
) |
|
|
for module in modules |
|
|
] |
|
|
post_handles = [ |
|
|
module.register_forward_hook( |
|
|
get_wrapped_post_hook(module), prepend=False, always_call=True |
|
|
) |
|
|
for module in modules |
|
|
] |
|
|
return _MultiHandle(tuple(pre_handles + post_handles)) |
|
|
|