Spaces:
Running
on
Zero
Running
on
Zero
| # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
| # // | |
| # // Licensed under the Apache License, Version 2.0 (the "License"); | |
| # // you may not use this file except in compliance with the License. | |
| # // You may obtain a copy of the License at | |
| # // | |
| # // http://www.apache.org/licenses/LICENSE-2.0 | |
| # // | |
| # // Unless required by applicable law or agreed to in writing, software | |
| # // distributed under the License is distributed on an "AS IS" BASIS, | |
| # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # // See the License for the specific language governing permissions and | |
| # // limitations under the License. | |
| """ | |
| Distributed ops for supporting sequence parallel. | |
| """ | |
| from collections import defaultdict | |
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union | |
| import torch | |
| import torch.distributed as dist | |
| from torch import Tensor | |
| from common.cache import Cache | |
| from common.distributed.advanced import ( | |
| get_sequence_parallel_group, | |
| get_sequence_parallel_rank, | |
| get_sequence_parallel_world_size, | |
| ) | |
| from .basic import get_device | |
| _SEQ_DATA_BUF = defaultdict(lambda: [None, None, None]) | |
| _SEQ_DATA_META_SHAPES = defaultdict() | |
| _SEQ_DATA_META_DTYPES = defaultdict() | |
| _SEQ_DATA_ASYNC_COMMS = defaultdict(list) | |
| _SYNC_BUFFER = defaultdict(dict) | |
| def single_all_to_all( | |
| local_input: Tensor, | |
| scatter_dim: int, | |
| gather_dim: int, | |
| group: dist.ProcessGroup, | |
| async_op: bool = False, | |
| ): | |
| """ | |
| A function to do all-to-all on a tensor | |
| """ | |
| seq_world_size = dist.get_world_size(group) | |
| prev_scatter_dim = scatter_dim | |
| if scatter_dim != 0: | |
| local_input = local_input.transpose(0, scatter_dim) | |
| if gather_dim == 0: | |
| gather_dim = scatter_dim | |
| scatter_dim = 0 | |
| inp_shape = list(local_input.shape) | |
| inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size | |
| input_t = local_input.reshape( | |
| [seq_world_size, inp_shape[scatter_dim]] + inp_shape[scatter_dim + 1 :] | |
| ).contiguous() | |
| output = torch.empty_like(input_t) | |
| comm = dist.all_to_all_single(output, input_t, group=group, async_op=async_op) | |
| if async_op: | |
| # let user's code transpose & reshape | |
| return output, comm, prev_scatter_dim | |
| # first dim is seq_world_size, so we can split it directly | |
| output = torch.cat(output.split(1), dim=gather_dim + 1).squeeze(0) | |
| if prev_scatter_dim: | |
| output = output.transpose(0, prev_scatter_dim).contiguous() | |
| return output | |
| def _all_to_all( | |
| local_input: Tensor, | |
| scatter_dim: int, | |
| gather_dim: int, | |
| group: dist.ProcessGroup, | |
| ): | |
| seq_world_size = dist.get_world_size(group) | |
| input_list = [ | |
| t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim) | |
| ] | |
| output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)] | |
| dist.all_to_all(output_list, input_list, group=group) | |
| return torch.cat(output_list, dim=gather_dim).contiguous() | |
| class SeqAllToAll(torch.autograd.Function): | |
| def forward( | |
| ctx: Any, | |
| group: dist.ProcessGroup, | |
| local_input: Tensor, | |
| scatter_dim: int, | |
| gather_dim: int, | |
| async_op: bool, | |
| ) -> Tensor: | |
| ctx.group = group | |
| ctx.scatter_dim = scatter_dim | |
| ctx.gather_dim = gather_dim | |
| ctx.async_op = async_op | |
| if async_op: | |
| output, comm, prev_scatter_dim = single_all_to_all( | |
| local_input, scatter_dim, gather_dim, group, async_op=async_op | |
| ) | |
| ctx.prev_scatter_dim = prev_scatter_dim | |
| return output, comm | |
| return _all_to_all(local_input, scatter_dim, gather_dim, group) | |
| def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: | |
| if ctx.async_op: | |
| input_t = torch.cat(grad_output[0].split(1), dim=ctx.gather_dim + 1).squeeze(0) | |
| if ctx.prev_scatter_dim: | |
| input_t = input_t.transpose(0, ctx.prev_scatter_dim) | |
| else: | |
| input_t = grad_output[0] | |
| return ( | |
| None, | |
| _all_to_all(input_t, ctx.gather_dim, ctx.scatter_dim, ctx.group), | |
| None, | |
| None, | |
| None, | |
| ) | |
| class Slice(torch.autograd.Function): | |
| def forward(ctx: Any, group: dist.ProcessGroup, local_input: Tensor, dim: int) -> Tensor: | |
| ctx.group = group | |
| ctx.rank = dist.get_rank(group) | |
| seq_world_size = dist.get_world_size(group) | |
| ctx.seq_world_size = seq_world_size | |
| ctx.dim = dim | |
| dim_size = local_input.shape[dim] | |
| return local_input.split(dim_size // seq_world_size, dim=dim)[ctx.rank].contiguous() | |
| def backward(ctx: Any, grad_output: Tensor) -> Tuple[None, Tensor, None]: | |
| dim_size = list(grad_output.size()) | |
| split_size = dim_size[0] | |
| dim_size[0] = dim_size[0] * ctx.seq_world_size | |
| output = torch.empty(dim_size, dtype=grad_output.dtype, device=torch.cuda.current_device()) | |
| dist._all_gather_base(output, grad_output, group=ctx.group) | |
| return (None, torch.cat(output.split(split_size), dim=ctx.dim), None) | |
| class Gather(torch.autograd.Function): | |
| def forward( | |
| ctx: Any, | |
| group: dist.ProcessGroup, | |
| local_input: Tensor, | |
| dim: int, | |
| grad_scale: Optional[bool] = False, | |
| ) -> Tensor: | |
| ctx.group = group | |
| ctx.rank = dist.get_rank(group) | |
| ctx.dim = dim | |
| ctx.grad_scale = grad_scale | |
| seq_world_size = dist.get_world_size(group) | |
| ctx.seq_world_size = seq_world_size | |
| dim_size = list(local_input.size()) | |
| split_size = dim_size[0] | |
| ctx.part_size = dim_size[dim] | |
| dim_size[0] = dim_size[0] * seq_world_size | |
| output = torch.empty(dim_size, dtype=local_input.dtype, device=torch.cuda.current_device()) | |
| dist._all_gather_base(output, local_input.contiguous(), group=ctx.group) | |
| return torch.cat(output.split(split_size), dim=dim) | |
| def backward(ctx: Any, grad_output: Tensor) -> Tuple[None, Tensor]: | |
| if ctx.grad_scale: | |
| grad_output = grad_output * ctx.seq_world_size | |
| return ( | |
| None, | |
| grad_output.split(ctx.part_size, dim=ctx.dim)[ctx.rank].contiguous(), | |
| None, | |
| None, | |
| ) | |
| def gather_seq_scatter_heads_qkv( | |
| qkv_tensor: Tensor, | |
| *, | |
| seq_dim: int, | |
| qkv_shape: Optional[Tensor] = None, | |
| cache: Cache = Cache(disable=True), | |
| restore_shape: bool = True, | |
| ): | |
| """ | |
| A func to sync splited qkv tensor | |
| qkv_tensor: the tensor we want to do alltoall with. The last dim must | |
| be the projection_idx, which we will split into 3 part. After | |
| spliting, the gather idx will be projecttion_idx + 1 | |
| seq_dim: gather_dim for all2all comm | |
| restore_shape: if True, output will has the same shape length as input | |
| """ | |
| group = get_sequence_parallel_group() | |
| if not group: | |
| return qkv_tensor | |
| world = get_sequence_parallel_world_size() | |
| orig_shape = qkv_tensor.shape | |
| scatter_dim = qkv_tensor.dim() | |
| bef_all2all_shape = list(orig_shape) | |
| qkv_proj_dim = bef_all2all_shape[-1] | |
| bef_all2all_shape = bef_all2all_shape[:-1] + [3, qkv_proj_dim // 3] | |
| qkv_tensor = qkv_tensor.view(bef_all2all_shape) | |
| qkv_tensor = SeqAllToAll.apply(group, qkv_tensor, scatter_dim, seq_dim, False) | |
| if restore_shape: | |
| out_shape = list(orig_shape) | |
| out_shape[seq_dim] *= world | |
| out_shape[-1] = qkv_proj_dim // world | |
| qkv_tensor = qkv_tensor.view(out_shape) | |
| # remove padding | |
| if qkv_shape is not None: | |
| unpad_dim_size = cache( | |
| "unpad_dim_size", lambda: torch.sum(torch.prod(qkv_shape, dim=-1)).item() | |
| ) | |
| if unpad_dim_size % world != 0: | |
| padding_size = qkv_tensor.size(seq_dim) - unpad_dim_size | |
| qkv_tensor = _unpad_tensor(qkv_tensor, seq_dim, padding_size) | |
| return qkv_tensor | |
| def slice_inputs(x: Tensor, dim: int, padding: bool = True): | |
| """ | |
| A func to slice the input sequence in sequence parallel | |
| """ | |
| group = get_sequence_parallel_group() | |
| if group is None: | |
| return x | |
| sp_rank = get_sequence_parallel_rank() | |
| sp_world = get_sequence_parallel_world_size() | |
| dim_size = x.shape[dim] | |
| unit = (dim_size + sp_world - 1) // sp_world | |
| if padding and dim_size % sp_world: | |
| padding_size = sp_world - (dim_size % sp_world) | |
| x = _pad_tensor(x, dim, padding_size) | |
| slc = [slice(None)] * len(x.shape) | |
| slc[dim] = slice(unit * sp_rank, unit * (sp_rank + 1)) | |
| return x[slc] | |
| def remove_seqeunce_parallel_padding(x: Tensor, dim: int, unpad_dim_size: int): | |
| """ | |
| A func to remove the padding part of the tensor based on its original shape | |
| """ | |
| group = get_sequence_parallel_group() | |
| if group is None: | |
| return x | |
| sp_world = get_sequence_parallel_world_size() | |
| if unpad_dim_size % sp_world == 0: | |
| return x | |
| padding_size = sp_world - (unpad_dim_size % sp_world) | |
| assert (padding_size + unpad_dim_size) % sp_world == 0 | |
| return _unpad_tensor(x, dim=dim, padding_size=padding_size) | |
| def gather_heads_scatter_seq(x: Tensor, head_dim: int, seq_dim: int) -> Tensor: | |
| """ | |
| A func to sync attention result with alltoall in sequence parallel | |
| """ | |
| group = get_sequence_parallel_group() | |
| if not group: | |
| return x | |
| dim_size = x.size(seq_dim) | |
| sp_world = get_sequence_parallel_world_size() | |
| if dim_size % sp_world != 0: | |
| padding_size = sp_world - (dim_size % sp_world) | |
| x = _pad_tensor(x, seq_dim, padding_size) | |
| return SeqAllToAll.apply(group, x, seq_dim, head_dim, False) | |
| def gather_seq_scatter_heads(x: Tensor, seq_dim: int, head_dim: int) -> Tensor: | |
| """ | |
| A func to sync embedding input with alltoall in sequence parallel | |
| """ | |
| group = get_sequence_parallel_group() | |
| if not group: | |
| return x | |
| return SeqAllToAll.apply(group, x, head_dim, seq_dim, False) | |
| def scatter_heads(x: Tensor, dim: int) -> Tensor: | |
| """ | |
| A func to split heads before attention in sequence parallel | |
| """ | |
| group = get_sequence_parallel_group() | |
| if not group: | |
| return x | |
| return Slice.apply(group, x, dim) | |
| def gather_heads(x: Tensor, dim: int, grad_scale: Optional[bool] = False) -> Tensor: | |
| """ | |
| A func to gather heads for the attention result in sequence parallel | |
| """ | |
| group = get_sequence_parallel_group() | |
| if not group: | |
| return x | |
| return Gather.apply(group, x, dim, grad_scale) | |
| def gather_outputs( | |
| x: Tensor, | |
| *, | |
| gather_dim: int, | |
| padding_dim: Optional[int] = None, | |
| unpad_shape: Optional[Tensor] = None, | |
| cache: Cache = Cache(disable=True), | |
| scale_grad=True, | |
| ): | |
| """ | |
| A func to gather the outputs for the model result in sequence parallel | |
| """ | |
| group = get_sequence_parallel_group() | |
| if not group: | |
| return x | |
| x = Gather.apply(group, x, gather_dim, scale_grad) | |
| if padding_dim is not None: | |
| unpad_dim_size = cache( | |
| "unpad_dim_size", lambda: torch.sum(torch.prod(unpad_shape, dim=1)).item() | |
| ) | |
| x = remove_seqeunce_parallel_padding(x, padding_dim, unpad_dim_size) | |
| return x | |
| def _pad_tensor(x: Tensor, dim: int, padding_size: int): | |
| shape = list(x.shape) | |
| shape[dim] = padding_size | |
| pad = torch.zeros(shape, dtype=x.dtype, device=x.device) | |
| return torch.cat([x, pad], dim=dim) | |
| def _unpad_tensor(x: Tensor, dim: int, padding_size): | |
| slc = [slice(None)] * len(x.shape) | |
| slc[dim] = slice(0, -padding_size) | |
| return x[slc] | |
| def _broadcast_data(data, shape, dtype, src, group, async_op): | |
| comms = [] | |
| if isinstance(data, (list, tuple)): | |
| for i, sub_shape in enumerate(shape): | |
| comms += _broadcast_data(data[i], sub_shape, dtype[i], src, group, async_op) | |
| elif isinstance(data, dict): | |
| for key, sub_data in data.items(): | |
| comms += _broadcast_data(sub_data, shape[key], dtype[key], src, group, async_op) | |
| elif isinstance(data, Tensor): | |
| comms.append(dist.broadcast(data, src=src, group=group, async_op=async_op)) | |
| return comms | |
| def _traverse(data: Any, op: Callable) -> Union[None, List, Dict, Any]: | |
| if isinstance(data, (list, tuple)): | |
| return [_traverse(sub_data, op) for sub_data in data] | |
| elif isinstance(data, dict): | |
| return {key: _traverse(sub_data, op) for key, sub_data in data.items()} | |
| elif isinstance(data, Tensor): | |
| return op(data) | |
| else: | |
| return None | |
| def _get_shapes(data): | |
| return _traverse(data, op=lambda x: x.shape) | |
| def _get_dtypes(data): | |
| return _traverse(data, op=lambda x: x.dtype) | |
| def _construct_broadcast_buffer(shapes, dtypes, device): | |
| if isinstance(shapes, torch.Size): | |
| return torch.empty(shapes, dtype=dtypes, device=device) | |
| if isinstance(shapes, (list, tuple)): | |
| buffer = [] | |
| for i, sub_shape in enumerate(shapes): | |
| buffer.append(_construct_broadcast_buffer(sub_shape, dtypes[i], device)) | |
| elif isinstance(shapes, dict): | |
| buffer = {} | |
| for key, sub_shape in shapes.items(): | |
| buffer[key] = _construct_broadcast_buffer(sub_shape, dtypes[key], device) | |
| else: | |
| return None | |
| return buffer | |
| class SPDistForward: | |
| """A forward tool to sync different result across sp group | |
| Args: | |
| module: a function or module to process users input | |
| sp_step: current training step to judge which rank to broadcast its result to all | |
| name: a distinct str to save meta and async comm | |
| comm_shape: if different ranks have different shape, mark this arg to True | |
| device: the device for current rank, can be empty | |
| """ | |
| def __init__( | |
| self, | |
| name: str, | |
| comm_shape: bool, | |
| device: torch.device = None, | |
| ): | |
| self.name = name | |
| self.comm_shape = comm_shape | |
| if device: | |
| self.device = device | |
| else: | |
| self.device = get_device() | |
| def __call__(self, inputs) -> Any: | |
| group = get_sequence_parallel_group() | |
| if not group: | |
| yield inputs | |
| else: | |
| device = self.device | |
| sp_world = get_sequence_parallel_world_size() | |
| sp_rank = get_sequence_parallel_rank() | |
| for local_step in range(sp_world): | |
| src_rank = dist.get_global_rank(group, local_step) | |
| is_src = sp_rank == local_step | |
| local_shapes = [] | |
| local_dtypes = [] | |
| if local_step == 0: | |
| local_result = inputs | |
| _SEQ_DATA_BUF[self.name][-1] = local_result | |
| local_shapes = _get_shapes(local_result) | |
| local_dtypes = _get_dtypes(local_result) | |
| if self.comm_shape: | |
| group_shapes_lists = [None] * sp_world | |
| dist.all_gather_object(group_shapes_lists, local_shapes, group=group) | |
| _SEQ_DATA_META_SHAPES[self.name] = group_shapes_lists | |
| else: | |
| _SEQ_DATA_META_SHAPES[self.name] = [local_shapes] * sp_world | |
| _SEQ_DATA_META_DTYPES[self.name] = local_dtypes | |
| shapes = _SEQ_DATA_META_SHAPES[self.name][local_step] | |
| dtypes = _SEQ_DATA_META_DTYPES[self.name] | |
| buf_id = local_step % 2 | |
| if local_step == 0: | |
| sync_data = ( | |
| local_result | |
| if is_src | |
| else _construct_broadcast_buffer(shapes, dtypes, device) | |
| ) | |
| _broadcast_data(sync_data, shapes, dtypes, src_rank, group, False) | |
| _SEQ_DATA_BUF[self.name][buf_id] = sync_data | |
| # wait for async comm ops | |
| if _SEQ_DATA_ASYNC_COMMS[self.name]: | |
| for comm in _SEQ_DATA_ASYNC_COMMS[self.name]: | |
| comm.wait() | |
| # before return the sync result, do async broadcast for next batch | |
| if local_step < sp_world - 1: | |
| next_buf_id = 1 - buf_id | |
| shapes = _SEQ_DATA_META_SHAPES[self.name][local_step + 1] | |
| src_rank = dist.get_global_rank(group, local_step + 1) | |
| is_src = sp_rank == local_step + 1 | |
| next_sync_data = ( | |
| _SEQ_DATA_BUF[self.name][-1] | |
| if is_src | |
| else _construct_broadcast_buffer(shapes, dtypes, device) | |
| ) | |
| _SEQ_DATA_ASYNC_COMMS[self.name] = _broadcast_data( | |
| next_sync_data, shapes, dtypes, src_rank, group, True | |
| ) | |
| _SEQ_DATA_BUF[self.name][next_buf_id] = next_sync_data | |
| yield _SEQ_DATA_BUF[self.name][buf_id] | |
| sync_inputs = SPDistForward(name="bef_fwd", comm_shape=True) | |
| def sync_data(data, sp_idx, name="tmp"): | |
| group = get_sequence_parallel_group() | |
| if group is None: | |
| return data | |
| # if sp_idx in _SYNC_BUFFER[name]: | |
| # return _SYNC_BUFFER[name][sp_idx] | |
| sp_rank = get_sequence_parallel_rank() | |
| src_rank = dist.get_global_rank(group, sp_idx) | |
| objects = [data] if sp_rank == sp_idx else [None] | |
| dist.broadcast_object_list(objects, src=src_rank, group=group) | |
| # _SYNC_BUFFER[name] = {sp_idx: objects[0]} | |
| return objects[0] | |