Aella-Nemotron-12B / _fsdp_init.py
amarinference's picture
Upload folder using huggingface_hub
65fef67 verified
import itertools
from typing import Optional, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.device_mesh import _get_device_handle
from torch.distributed.tensor import DeviceMesh, DTensor, init_device_mesh
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from ._fsdp_common import _is_composable_with_fsdp, FSDPMeshInfo, HSDPMeshInfo
from ._fsdp_state import _get_module_fsdp_state
def _get_post_forward_mesh_info(
reshard_after_forward: Union[bool, int], mesh_info: FSDPMeshInfo
) -> Optional[FSDPMeshInfo]:
shard_mesh_size = mesh_info.shard_mesh_size
if not isinstance(reshard_after_forward, (bool, int)):
raise ValueError(
"reshard_after_forward should be a bool or an int representing the "
f"group size to reshard to, not {reshard_after_forward}"
)
# NOTE: `isinstance(False, int)` returns `True`.
if not isinstance(reshard_after_forward, bool) and isinstance(
reshard_after_forward, int
):
if (
reshard_after_forward < 1
or reshard_after_forward > shard_mesh_size
or shard_mesh_size % reshard_after_forward != 0
):
raise ValueError(
"If passing reshard_after_forward as an int, it should be a "
f"factor of {shard_mesh_size}, not {reshard_after_forward}"
)
elif reshard_after_forward == 1:
reshard_after_forward = False
elif reshard_after_forward == shard_mesh_size:
reshard_after_forward = True
post_forward_mesh_info = None
if reshard_after_forward is True:
post_forward_mesh_info = mesh_info
elif reshard_after_forward is not False: # int case
# For HSDP, we can flatten the two replicate dims into the 0th dim
post_forward_mesh_tensor = mesh_info.mesh.mesh.view(-1, reshard_after_forward)
post_forward_mesh = DeviceMesh(
mesh_info.mesh.device_type, post_forward_mesh_tensor
)
post_forward_mesh_info = HSDPMeshInfo(
post_forward_mesh, shard_mesh_dim=1, replicate_mesh_dim=0
)
return post_forward_mesh_info
def _init_default_fully_shard_mesh() -> DeviceMesh:
"""Default to global CUDA mesh if possible else global CPU mesh."""
if not dist.distributed_c10d.is_initialized():
dist.distributed_c10d.init_process_group()
default_pg = dist.distributed_c10d._get_default_group()
device = torch._C._get_accelerator()
mesh = init_device_mesh(device.type, mesh_shape=(default_pg.size(),))
return mesh
def _get_device_from_mesh(mesh: DeviceMesh) -> torch.device:
if mesh.device_type == "cpu":
return torch.device("cpu")
device_handle = _get_device_handle(mesh.device_type)
return torch.device(mesh.device_type, device_handle.current_device())
def _ignore_module(
module: nn.Module,
ignored_params: set[nn.Parameter],
ignore_decision: dict[nn.Module, bool],
) -> bool:
"""
Decide if it is safe to ignore a module for applying fully_shard.
"""
if module in ignore_decision:
return ignore_decision[module]
if len(list(module.buffers(recurse=False))) > 0:
# Cannot ignore a module with any buffer
ignore_decision[module] = False
return False
for _, param in module.named_parameters(recurse=False):
if param not in ignored_params:
# at least one param is not ignored. So this module shouldn't be.
ignore_decision[module] = False
return False
# Need to consider descendants of module
for child in list(module.children()):
ignore_child = _ignore_module(child, ignored_params, ignore_decision)
if not ignore_child:
# Cannot ignore module if one of its children is not ignored
ignore_decision[module] = False
return False
# Safe to ignore module
ignore_decision[module] = True
return True
def _adjust_managed_modules(
modules: list[nn.Module], ignored_params: set[nn.Parameter]
) -> list[nn.Module]:
"""
Adjust the given list of managed modules by removing those with all parameters ignored.
"""
ignore_decision: dict[nn.Module, bool] = {}
new_modules = []
for module in modules:
ignored = _ignore_module(module, ignored_params, ignore_decision)
if not ignored:
new_modules.append(module)
return new_modules
def _get_managed_modules(
root_modules: tuple[nn.Module, ...],
ignored_params: Optional[set[nn.Parameter]] = None,
) -> list[nn.Module]:
modules: list[nn.Module] = []
root_modules_set = set(root_modules)
# Track visisted modules to avoid visiting shared modules multiple times
visited_modules: set[nn.Module] = set()
def dfs(module: nn.Module) -> None:
"""
Runs a DFS to collect managed modules, not recursing into modules with
a non-composable API or ``fully_shard`` already applied.
"""
if not _is_composable_with_fsdp(module):
return
elif (
module not in root_modules_set
and _get_module_fsdp_state(module) is not None
):
return # nested `fully_shard` module
visited_modules.add(module)
for submodule in module.children():
if submodule not in visited_modules:
dfs(submodule)
modules.append(module)
for root_module in root_modules:
dfs(root_module)
if ignored_params is None:
return modules
adjusted_modules = _adjust_managed_modules(modules, ignored_params)
return adjusted_modules
def _verify_managed_param(name: str, param: nn.Parameter) -> None:
"""
Verify if the parameter is accepted by fully_shard. The only restriction now
is that the parameter cannot be a scalar tensor (param.numel == 0) since we
need at least one dim to shard.
"""
if len(param.shape) == 0:
raise ValueError(
"fully_shard doesn't support scalar parameters. "
f"Change {name} to a 1D tensor with numel equal to 1."
)
def _get_managed_states(
modules: list[nn.Module], ignored_params: Optional[set[nn.Parameter]] = None
) -> tuple[list[nn.Parameter], list[torch.Tensor]]:
params: list[nn.Parameter] = []
buffers: list[torch.Tensor] = []
# Track visited parameters/buffers to avoid visiting shared parameters and
# buffers multiple times
visited_params: set[nn.Parameter] = set()
visited_buffers: set[torch.Tensor] = set()
if ignored_params is None:
ignored_params = set()
for module in modules:
for name, param in module.named_parameters(recurse=False):
if param in ignored_params:
# do not include an ignored parameters
continue
if param not in visited_params:
_verify_managed_param(name, param)
params.append(param)
visited_params.add(param)
for buffer in module.buffers(recurse=False):
if buffer not in visited_buffers:
buffers.append(buffer)
visited_buffers.add(buffer)
return params, buffers
def _move_states_to_device(
params: list[nn.Parameter],
buffers: list[torch.Tensor],
device: torch.device,
) -> None:
"""
We have FSDP move states to device for simpler and faster initialization
since FSDP almost always uses CUDA for training. We move parameters/buffers
rather than modules since modules to support ignoring parameters/buffers in
the future.
"""
# Follow the logic in `nn.Module._apply`
for tensor in itertools.chain(params, buffers):
if tensor.device == device or tensor.device.type == "meta":
# Keep meta-device tensors on meta device for deferred init
continue
if isinstance(tensor, DTensor):
if (dtensor_mesh_type := tensor.device_mesh.device_type) != device.type:
raise ValueError(
"Requires DTensor to have mesh of the same type as the FSDP mesh "
f"but got {dtensor_mesh_type} for DTensor and {device.type} for FSDP"
)
raise AssertionError(
f"Expects DTensor to be moved to {dtensor_mesh_type} but got {tensor.device}"
)
tensor_ = tensor
if is_traceable_wrapper_subclass(tensor_):
with torch.no_grad(): # avoid autograd increasing C++ refcount by 1
tensor_on_device = nn.Parameter(tensor.to(device))
torch.utils.swap_tensors(tensor, tensor_on_device)
else:
tensor.data = tensor.to(device)