|
|
import itertools |
|
|
from typing import Optional, Union |
|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
import torch.nn as nn |
|
|
from torch.distributed.device_mesh import _get_device_handle |
|
|
from torch.distributed.tensor import DeviceMesh, DTensor, init_device_mesh |
|
|
from torch.utils._python_dispatch import is_traceable_wrapper_subclass |
|
|
|
|
|
from ._fsdp_common import _is_composable_with_fsdp, FSDPMeshInfo, HSDPMeshInfo |
|
|
from ._fsdp_state import _get_module_fsdp_state |
|
|
|
|
|
|
|
|
def _get_post_forward_mesh_info( |
|
|
reshard_after_forward: Union[bool, int], mesh_info: FSDPMeshInfo |
|
|
) -> Optional[FSDPMeshInfo]: |
|
|
shard_mesh_size = mesh_info.shard_mesh_size |
|
|
if not isinstance(reshard_after_forward, (bool, int)): |
|
|
raise ValueError( |
|
|
"reshard_after_forward should be a bool or an int representing the " |
|
|
f"group size to reshard to, not {reshard_after_forward}" |
|
|
) |
|
|
|
|
|
if not isinstance(reshard_after_forward, bool) and isinstance( |
|
|
reshard_after_forward, int |
|
|
): |
|
|
if ( |
|
|
reshard_after_forward < 1 |
|
|
or reshard_after_forward > shard_mesh_size |
|
|
or shard_mesh_size % reshard_after_forward != 0 |
|
|
): |
|
|
raise ValueError( |
|
|
"If passing reshard_after_forward as an int, it should be a " |
|
|
f"factor of {shard_mesh_size}, not {reshard_after_forward}" |
|
|
) |
|
|
elif reshard_after_forward == 1: |
|
|
reshard_after_forward = False |
|
|
elif reshard_after_forward == shard_mesh_size: |
|
|
reshard_after_forward = True |
|
|
post_forward_mesh_info = None |
|
|
if reshard_after_forward is True: |
|
|
post_forward_mesh_info = mesh_info |
|
|
elif reshard_after_forward is not False: |
|
|
|
|
|
post_forward_mesh_tensor = mesh_info.mesh.mesh.view(-1, reshard_after_forward) |
|
|
post_forward_mesh = DeviceMesh( |
|
|
mesh_info.mesh.device_type, post_forward_mesh_tensor |
|
|
) |
|
|
post_forward_mesh_info = HSDPMeshInfo( |
|
|
post_forward_mesh, shard_mesh_dim=1, replicate_mesh_dim=0 |
|
|
) |
|
|
return post_forward_mesh_info |
|
|
|
|
|
|
|
|
def _init_default_fully_shard_mesh() -> DeviceMesh: |
|
|
"""Default to global CUDA mesh if possible else global CPU mesh.""" |
|
|
if not dist.distributed_c10d.is_initialized(): |
|
|
dist.distributed_c10d.init_process_group() |
|
|
default_pg = dist.distributed_c10d._get_default_group() |
|
|
device = torch._C._get_accelerator() |
|
|
mesh = init_device_mesh(device.type, mesh_shape=(default_pg.size(),)) |
|
|
return mesh |
|
|
|
|
|
|
|
|
def _get_device_from_mesh(mesh: DeviceMesh) -> torch.device: |
|
|
if mesh.device_type == "cpu": |
|
|
return torch.device("cpu") |
|
|
device_handle = _get_device_handle(mesh.device_type) |
|
|
return torch.device(mesh.device_type, device_handle.current_device()) |
|
|
|
|
|
|
|
|
def _ignore_module( |
|
|
module: nn.Module, |
|
|
ignored_params: set[nn.Parameter], |
|
|
ignore_decision: dict[nn.Module, bool], |
|
|
) -> bool: |
|
|
""" |
|
|
Decide if it is safe to ignore a module for applying fully_shard. |
|
|
""" |
|
|
if module in ignore_decision: |
|
|
return ignore_decision[module] |
|
|
|
|
|
if len(list(module.buffers(recurse=False))) > 0: |
|
|
|
|
|
ignore_decision[module] = False |
|
|
return False |
|
|
|
|
|
for _, param in module.named_parameters(recurse=False): |
|
|
if param not in ignored_params: |
|
|
|
|
|
ignore_decision[module] = False |
|
|
return False |
|
|
|
|
|
|
|
|
for child in list(module.children()): |
|
|
ignore_child = _ignore_module(child, ignored_params, ignore_decision) |
|
|
if not ignore_child: |
|
|
|
|
|
ignore_decision[module] = False |
|
|
return False |
|
|
|
|
|
|
|
|
ignore_decision[module] = True |
|
|
return True |
|
|
|
|
|
|
|
|
def _adjust_managed_modules( |
|
|
modules: list[nn.Module], ignored_params: set[nn.Parameter] |
|
|
) -> list[nn.Module]: |
|
|
""" |
|
|
Adjust the given list of managed modules by removing those with all parameters ignored. |
|
|
""" |
|
|
ignore_decision: dict[nn.Module, bool] = {} |
|
|
new_modules = [] |
|
|
for module in modules: |
|
|
ignored = _ignore_module(module, ignored_params, ignore_decision) |
|
|
if not ignored: |
|
|
new_modules.append(module) |
|
|
return new_modules |
|
|
|
|
|
|
|
|
def _get_managed_modules( |
|
|
root_modules: tuple[nn.Module, ...], |
|
|
ignored_params: Optional[set[nn.Parameter]] = None, |
|
|
) -> list[nn.Module]: |
|
|
modules: list[nn.Module] = [] |
|
|
root_modules_set = set(root_modules) |
|
|
|
|
|
visited_modules: set[nn.Module] = set() |
|
|
|
|
|
def dfs(module: nn.Module) -> None: |
|
|
""" |
|
|
Runs a DFS to collect managed modules, not recursing into modules with |
|
|
a non-composable API or ``fully_shard`` already applied. |
|
|
""" |
|
|
if not _is_composable_with_fsdp(module): |
|
|
return |
|
|
elif ( |
|
|
module not in root_modules_set |
|
|
and _get_module_fsdp_state(module) is not None |
|
|
): |
|
|
return |
|
|
visited_modules.add(module) |
|
|
for submodule in module.children(): |
|
|
if submodule not in visited_modules: |
|
|
dfs(submodule) |
|
|
modules.append(module) |
|
|
|
|
|
for root_module in root_modules: |
|
|
dfs(root_module) |
|
|
|
|
|
if ignored_params is None: |
|
|
return modules |
|
|
|
|
|
adjusted_modules = _adjust_managed_modules(modules, ignored_params) |
|
|
return adjusted_modules |
|
|
|
|
|
|
|
|
def _verify_managed_param(name: str, param: nn.Parameter) -> None: |
|
|
""" |
|
|
Verify if the parameter is accepted by fully_shard. The only restriction now |
|
|
is that the parameter cannot be a scalar tensor (param.numel == 0) since we |
|
|
need at least one dim to shard. |
|
|
""" |
|
|
if len(param.shape) == 0: |
|
|
raise ValueError( |
|
|
"fully_shard doesn't support scalar parameters. " |
|
|
f"Change {name} to a 1D tensor with numel equal to 1." |
|
|
) |
|
|
|
|
|
|
|
|
def _get_managed_states( |
|
|
modules: list[nn.Module], ignored_params: Optional[set[nn.Parameter]] = None |
|
|
) -> tuple[list[nn.Parameter], list[torch.Tensor]]: |
|
|
params: list[nn.Parameter] = [] |
|
|
buffers: list[torch.Tensor] = [] |
|
|
|
|
|
|
|
|
visited_params: set[nn.Parameter] = set() |
|
|
visited_buffers: set[torch.Tensor] = set() |
|
|
if ignored_params is None: |
|
|
ignored_params = set() |
|
|
|
|
|
for module in modules: |
|
|
for name, param in module.named_parameters(recurse=False): |
|
|
if param in ignored_params: |
|
|
|
|
|
continue |
|
|
if param not in visited_params: |
|
|
_verify_managed_param(name, param) |
|
|
params.append(param) |
|
|
visited_params.add(param) |
|
|
for buffer in module.buffers(recurse=False): |
|
|
if buffer not in visited_buffers: |
|
|
buffers.append(buffer) |
|
|
visited_buffers.add(buffer) |
|
|
return params, buffers |
|
|
|
|
|
|
|
|
def _move_states_to_device( |
|
|
params: list[nn.Parameter], |
|
|
buffers: list[torch.Tensor], |
|
|
device: torch.device, |
|
|
) -> None: |
|
|
""" |
|
|
We have FSDP move states to device for simpler and faster initialization |
|
|
since FSDP almost always uses CUDA for training. We move parameters/buffers |
|
|
rather than modules since modules to support ignoring parameters/buffers in |
|
|
the future. |
|
|
""" |
|
|
|
|
|
for tensor in itertools.chain(params, buffers): |
|
|
if tensor.device == device or tensor.device.type == "meta": |
|
|
|
|
|
continue |
|
|
if isinstance(tensor, DTensor): |
|
|
if (dtensor_mesh_type := tensor.device_mesh.device_type) != device.type: |
|
|
raise ValueError( |
|
|
"Requires DTensor to have mesh of the same type as the FSDP mesh " |
|
|
f"but got {dtensor_mesh_type} for DTensor and {device.type} for FSDP" |
|
|
) |
|
|
raise AssertionError( |
|
|
f"Expects DTensor to be moved to {dtensor_mesh_type} but got {tensor.device}" |
|
|
) |
|
|
tensor_ = tensor |
|
|
if is_traceable_wrapper_subclass(tensor_): |
|
|
with torch.no_grad(): |
|
|
tensor_on_device = nn.Parameter(tensor.to(device)) |
|
|
torch.utils.swap_tensors(tensor, tensor_on_device) |
|
|
else: |
|
|
tensor.data = tensor.to(device) |
|
|
|