From b1872e1ed8542d9aee8e659f2336289a89588062 Mon Sep 17 00:00:00 2001 From: Angel Gonzalez Date: Tue, 7 May 2024 11:35:59 +0200 Subject: [PATCH 01/24] Adding checkpoint after traning ends --- src/nanotron/config/config.py | 1 + src/nanotron/trainer.py | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index d9946f26..e26fac75 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -129,6 +129,7 @@ class CheckpointsArgs: checkpoints_path: Path checkpoint_interval: int save_initial_state: Optional[bool] = False + save_final_state: Optional[bool] = False resume_checkpoint_path: Optional[Path] = None checkpoints_path_is_shared_file_system: Optional[bool] = False diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 0eda00dc..70d023fb 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -442,7 +442,10 @@ def train( self.save_checkpoint() dist.barrier() # let's wait for everyone before leaving - + + if self.config.checkpoints.save_final_state: + self.save_checkpoint() + self.post_training() def training_step( From bcf405d9af2028773d6d76cd4ff658540b87a3f1 Mon Sep 17 00:00:00 2001 From: AleHD Date: Thu, 27 Jun 2024 11:56:53 +0000 Subject: [PATCH 02/24] Implemented global memory buffer to reduce activation memory of differentiable distributed operations --- src/nanotron/config/config.py | 2 +- .../distributed_differentiable_primitives.py | 17 +++-------------- src/nanotron/parallel/utils.py | 19 +++++++++++++++++++ src/nanotron/utils.py | 17 +++++++++++++++++ 4 files changed, 40 insertions(+), 15 deletions(-) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index d5b9976f..d72ea97f 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -111,7 +111,7 @@ def __post_init__(self): class DataArgs: """Arguments related to the data and data files processing""" - dataset: Union[PretrainDatasetsArgs, NanosetDatasetsArgs] + dataset: Optional[Union[PretrainDatasetsArgs, NanosetDatasetsArgs]] seed: Optional[int] num_loading_workers: Optional[int] = 1 diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index 873d77df..57a67c42 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -19,6 +19,7 @@ from nanotron import distributed as dist from nanotron.distributed import ProcessGroup +from nanotron.parallel.utils import MemoryBuffer class DifferentiableIdentity(torch.autograd.Function): @@ -67,13 +68,7 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): group = torch_dist.distributed_c10d._get_default_group() unsharded_batch_size = sharded_batch_size * group.size() - unsharded_tensor = torch.empty( - unsharded_batch_size, - *rest_size, - device=tensor.device, - dtype=tensor.dtype, - requires_grad=tensor.requires_grad, - ) + unsharded_tensor = MemoryBuffer().get("dist", (unsharded_batch_size, *rest_size), dtype=tensor.dtype) # `tensor` can sometimes not be contiguous # https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L317 @@ -108,13 +103,7 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): # https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L305 tensor = tensor.contiguous() - sharded_tensor = torch.empty( - unsharded_batch_size // group.size(), - *rest_size, - device=tensor.device, - dtype=tensor.dtype, - requires_grad=tensor.requires_grad, - ) + sharded_tensor = MemoryBuffer().get("dist", (unsharded_batch_size//group.size(), *rest_size), dtype=tensor.dtype) dist.reduce_scatter_tensor(sharded_tensor, tensor, group=group, op=dist.ReduceOp.SUM) return sharded_tensor diff --git a/src/nanotron/parallel/utils.py b/src/nanotron/parallel/utils.py index b9ac12ae..eb4e441d 100644 --- a/src/nanotron/parallel/utils.py +++ b/src/nanotron/parallel/utils.py @@ -1,13 +1,32 @@ import functools +import operator import os +import torch from torch import nn from nanotron import distributed as dist +from nanotron.utils import Singleton from nanotron.parallel import ParallelContext from nanotron.parallel.tied_parameters import get_tied_id_to_param +class MemoryBuffer(metaclass=Singleton): + """ + Global memory buffer to store intermediate activations that need not to be cached for the backward pass. + """ + + def __init__(self): + self.buffer = {} + + def get(self, name: str, shape: tuple[int], dtype: torch.dtype = torch.bfloat16) -> torch.Tensor: + required_numel = functools.reduce(operator.mul, shape, 1) + if (name, dtype) not in self.buffer or self.buffer[name, dtype].numel() < required_numel: + self.buffer[name, dtype] = torch.empty(required_numel, dtype=dtype, device=torch.cuda.current_device(), + requires_grad=False) + return self.buffer[name, dtype][:required_numel].view(shape) + + def assert_cuda_max_connections_set_to_1(func): flag_is_set_to_1 = None diff --git a/src/nanotron/utils.py b/src/nanotron/utils.py index 14fe1ca8..8065962b 100644 --- a/src/nanotron/utils.py +++ b/src/nanotron/utils.py @@ -15,6 +15,23 @@ from nanotron import distributed as dist +class Singleton(type): + """ + Singleton metaclass. + Create objects using this class as the metaclass to enable singleton behaviour. + For instance: + ``` + class Logger(metaclass=Singleton): + ... + ``` + """ + _instances = {} + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] + + class ContextManagers: """ Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers` From ed1ca7d0b55b07696d1c622d713c793f3ca53e28 Mon Sep 17 00:00:00 2001 From: AleHD Date: Thu, 27 Jun 2024 14:29:14 +0000 Subject: [PATCH 03/24] GLU fusion --- src/nanotron/models/llama.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index ca8894b9..d310fe2a 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -163,8 +163,7 @@ def __init__( bias=False, async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, ) - # TODO @nouamane: why can't we torch.jit.script GLUActivation? - self.split_silu_mul = GLUActivation(config.hidden_act) + self.split_silu_mul = torch.compile(GLUActivation(config.hidden_act)) def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim] merged_states = self.gate_up_proj(hidden_states) From 9b0de5be04afb9cac631399593aef8de6aa852a6 Mon Sep 17 00:00:00 2001 From: AleHD Date: Thu, 27 Jun 2024 14:42:42 +0000 Subject: [PATCH 04/24] precommit --- src/nanotron/models/llama.py | 2 +- .../distributed_differentiable_primitives.py | 4 +++- src/nanotron/parallel/utils.py | 7 ++++--- src/nanotron/utils.py | 8 +++++--- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index d310fe2a..3319b0ef 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -14,7 +14,7 @@ # limitations under the License. """PyTorch LLaMa model.""" -from typing import Dict, Optional, Union, List +from typing import Dict, Optional, Union import torch from torch import nn diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index 57a67c42..aa460cc6 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -103,7 +103,9 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): # https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L305 tensor = tensor.contiguous() - sharded_tensor = MemoryBuffer().get("dist", (unsharded_batch_size//group.size(), *rest_size), dtype=tensor.dtype) + sharded_tensor = MemoryBuffer().get( + "dist", (unsharded_batch_size // group.size(), *rest_size), dtype=tensor.dtype + ) dist.reduce_scatter_tensor(sharded_tensor, tensor, group=group, op=dist.ReduceOp.SUM) return sharded_tensor diff --git a/src/nanotron/parallel/utils.py b/src/nanotron/parallel/utils.py index eb4e441d..f694b0e6 100644 --- a/src/nanotron/parallel/utils.py +++ b/src/nanotron/parallel/utils.py @@ -6,9 +6,9 @@ from torch import nn from nanotron import distributed as dist -from nanotron.utils import Singleton from nanotron.parallel import ParallelContext from nanotron.parallel.tied_parameters import get_tied_id_to_param +from nanotron.utils import Singleton class MemoryBuffer(metaclass=Singleton): @@ -22,8 +22,9 @@ def __init__(self): def get(self, name: str, shape: tuple[int], dtype: torch.dtype = torch.bfloat16) -> torch.Tensor: required_numel = functools.reduce(operator.mul, shape, 1) if (name, dtype) not in self.buffer or self.buffer[name, dtype].numel() < required_numel: - self.buffer[name, dtype] = torch.empty(required_numel, dtype=dtype, device=torch.cuda.current_device(), - requires_grad=False) + self.buffer[name, dtype] = torch.empty( + required_numel, dtype=dtype, device=torch.cuda.current_device(), requires_grad=False + ) return self.buffer[name, dtype][:required_numel].view(shape) diff --git a/src/nanotron/utils.py b/src/nanotron/utils.py index 8065962b..b3831801 100644 --- a/src/nanotron/utils.py +++ b/src/nanotron/utils.py @@ -1,11 +1,10 @@ import functools import inspect -import math import os import random import socket from contextlib import ExitStack, contextmanager -from typing import Callable, ContextManager, List, Optional +from typing import ContextManager, List, Optional import torch from packaging import version @@ -25,7 +24,9 @@ class Logger(metaclass=Singleton): ... ``` """ + _instances = {} + def __call__(cls, *args, **kwargs): if cls not in cls._instances: cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) @@ -69,7 +70,7 @@ def main_rank_first(group: dist.ProcessGroup): @contextmanager def local_ranks_zero_first(group: Optional[dist.ProcessGroup] = None): """Context manager that executes the code in the context with all the local rank zero of the group going first. - Usefull to run only once per node first (e.g. to create local files, etc) + Useful to run only once per node first (e.g. to create local files, etc) """ is_main = int(os.environ.get("LOCAL_RANK", 0)) == 0 if is_main: @@ -140,6 +141,7 @@ def get_untyped_storage(tensor: torch.Tensor) -> torch.UntypedStorage: else: return tensor.storage().untyped() + def tensor_from_untyped_storage(untyped_storage: torch.UntypedStorage, dtype: torch.dtype): # TODO @thomasw21: Figure out what's the best Pytorch way of building a tensor from a storage. device = untyped_storage.device From 803b6da3233a642a0ba7a62484310d1496db81dc Mon Sep 17 00:00:00 2001 From: AleHD Date: Tue, 16 Jul 2024 11:39:32 +0200 Subject: [PATCH 05/24] Wrong backward fixed --- .../parallel/tensor_parallel/column_linear.py | 62 +++++++++++++++++++ .../distributed_differentiable_primitives.py | 27 +++++--- .../parallel/tensor_parallel/functional.py | 4 +- 3 files changed, 85 insertions(+), 8 deletions(-) create mode 100644 src/nanotron/parallel/tensor_parallel/column_linear.py diff --git a/src/nanotron/parallel/tensor_parallel/column_linear.py b/src/nanotron/parallel/tensor_parallel/column_linear.py new file mode 100644 index 00000000..eaab5abe --- /dev/null +++ b/src/nanotron/parallel/tensor_parallel/column_linear.py @@ -0,0 +1,62 @@ +from typing import Optional + +import torch +from torch.nn import functional as F + +import nanotron.distributed as dist +from nanotron.parallel.utils import MemoryBuffer + + +class ColumnLinearContextParallel(torch.autograd.Function): + """ + Column linear with memory_buffer for the allgather, context parallel + enabled (i.e. tp_mode = TensorParallelLinearMode.REDUCE_SCATTER) and + async communication disabled. + """ + @staticmethod + def forward(ctx, input: torch.Tensor, weight: torch.Tensor, + bias: Optional[torch.Tensor], group: dist.ProcessGroup): + + # Prepare context. + ctx.save_for_backward(input, weight, bias) + ctx.group = group + + # Do allgather. + sharded_batch_size, *rest_size = input.shape + unsharded_batch_size = sharded_batch_size * group.size() + total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) + dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) + + # Get linear output. + out = F.linear(total_input, weight, bias) + return out + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + # Allgather the inputs again. + input, weight, bias = ctx.saved_tensors + group = ctx.group + sharded_batch_size, *rest_size = input.shape + total_input = sharded_batch_size * group.size() + unsharded_batch_size = sharded_batch_size * group.size() + total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) + dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) + + # Get the grad_output and total_input on the correct views to be able to transpose them below. + grad_output = grad_output.contiguous() + assert grad_output.dim() == 3 + grad_output = grad_output.view(grad_output.size(0) * grad_output.size(1), grad_output.size(2)) + total_input = total_input.view(total_input.size(0) * total_input.size(1), total_input.size(2)) + + # Compute gradients. + grad_input = grad_output @ weight + sub_grad_input = torch.empty(input.size(), dtype=input.dtype, device=input.device, requires_grad=False) + dist.reduce_scatter_tensor(sub_grad_input, grad_input, group=group, op=dist.ReduceOp.SUM) + grad_weight = grad_output.T @ total_input + grad_bias = torch.sum(grad_output, dim=0) if bias is not None else None + + return sub_grad_input, grad_weight, grad_bias, None + +def column_linear_context_parallel(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], + group: dist.ProcessGroup): + return ColumnLinearContextParallel.apply(input, weight, bias, group) diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index aa460cc6..d66826e3 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -19,7 +19,6 @@ from nanotron import distributed as dist from nanotron.distributed import ProcessGroup -from nanotron.parallel.utils import MemoryBuffer class DifferentiableIdentity(torch.autograd.Function): @@ -68,7 +67,13 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): group = torch_dist.distributed_c10d._get_default_group() unsharded_batch_size = sharded_batch_size * group.size() - unsharded_tensor = MemoryBuffer().get("dist", (unsharded_batch_size, *rest_size), dtype=tensor.dtype) + unsharded_tensor = torch.empty( + unsharded_batch_size, + *rest_size, + device=tensor.device, + dtype=tensor.dtype, + requires_grad=tensor.requires_grad, + ) # `tensor` can sometimes not be contiguous # https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L317 @@ -79,8 +84,11 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): @staticmethod def backward(ctx, grad_output): + #print(f"{torch.distributed.get_rank()} grad_output: {grad_output}") group = ctx.group - return DifferentiableReduceScatterSum.apply(grad_output, group), None + out = DifferentiableReduceScatterSum.apply(grad_output, group) + #print(f"{torch.distributed.get_rank()} grad_grad: {out}") + return out, None, None class DifferentiableReduceScatterSum(torch.autograd.Function): @@ -103,8 +111,12 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): # https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L305 tensor = tensor.contiguous() - sharded_tensor = MemoryBuffer().get( - "dist", (unsharded_batch_size // group.size(), *rest_size), dtype=tensor.dtype + sharded_tensor = torch.empty( + unsharded_batch_size // group.size(), + *rest_size, + device=tensor.device, + dtype=tensor.dtype, + requires_grad=False, ) dist.reduce_scatter_tensor(sharded_tensor, tensor, group=group, op=dist.ReduceOp.SUM) return sharded_tensor @@ -112,7 +124,8 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): @staticmethod def backward(ctx, grad_output): group = ctx.group - return DifferentiableAllGather.apply(grad_output, group), None + #print(f"{torch.distributed.get_rank()} Calling AllGather because of backward of reducescatter") + return DifferentiableAllGather.apply(grad_output, group, False), None # ----------------- @@ -128,7 +141,7 @@ def differentiable_all_reduce_sum(tensor, group: Optional[ProcessGroup] = None): return DifferentiableAllReduceSum.apply(tensor, group) -def differentiable_all_gather(tensor, group: Optional[ProcessGroup] = None): +def differentiable_all_gather(tensor, group: Optional[ProcessGroup] = None) return DifferentiableAllGather.apply(tensor, group) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index fdef48ac..b3602707 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -26,6 +26,7 @@ differentiable_reduce_scatter_sum, ) from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode +from nanotron.parallel.tensor_parallel.column_linear import column_linear_context_parallel from nanotron.parallel.utils import assert_cuda_max_connections_set_to_1 @@ -352,7 +353,7 @@ def column_linear( if tp_mode is TensorParallelLinearMode.ALL_REDUCE: input = differentiable_identity(input, group=group) elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: - input = differentiable_all_gather(input, group=group) + return column_linear_context_parallel(input, weight, bias, group) else: raise ValueError(f"Got unexpected mode: {tp_mode}.") @@ -473,6 +474,7 @@ def row_linear( out = F.linear(input, weight, bias) + #print("Calling row linear") if tp_mode is TensorParallelLinearMode.ALL_REDUCE: out = differentiable_all_reduce_sum(out, group=group) elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: From 59bfb6b38c4487b04e425d8540b6e44b2a7fbcf9 Mon Sep 17 00:00:00 2001 From: AleHD Date: Tue, 16 Jul 2024 11:42:27 +0200 Subject: [PATCH 06/24] Removed useless prints --- .../tensor_parallel/distributed_differentiable_primitives.py | 3 --- src/nanotron/parallel/tensor_parallel/functional.py | 1 - 2 files changed, 4 deletions(-) diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index d66826e3..f1102908 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -84,10 +84,8 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): @staticmethod def backward(ctx, grad_output): - #print(f"{torch.distributed.get_rank()} grad_output: {grad_output}") group = ctx.group out = DifferentiableReduceScatterSum.apply(grad_output, group) - #print(f"{torch.distributed.get_rank()} grad_grad: {out}") return out, None, None @@ -124,7 +122,6 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): @staticmethod def backward(ctx, grad_output): group = ctx.group - #print(f"{torch.distributed.get_rank()} Calling AllGather because of backward of reducescatter") return DifferentiableAllGather.apply(grad_output, group, False), None diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index b3602707..cedbb219 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -474,7 +474,6 @@ def row_linear( out = F.linear(input, weight, bias) - #print("Calling row linear") if tp_mode is TensorParallelLinearMode.ALL_REDUCE: out = differentiable_all_reduce_sum(out, group=group) elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: From 2c69e9ad2887b7e78c88c2db3209713542dad7e2 Mon Sep 17 00:00:00 2001 From: AleHD Date: Wed, 17 Jul 2024 10:01:44 +0000 Subject: [PATCH 07/24] Minor fixes --- .../distributed_differentiable_primitives.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index f1102908..bd41347a 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -86,7 +86,7 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): def backward(ctx, grad_output): group = ctx.group out = DifferentiableReduceScatterSum.apply(grad_output, group) - return out, None, None + return out, None class DifferentiableReduceScatterSum(torch.autograd.Function): @@ -122,7 +122,7 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): @staticmethod def backward(ctx, grad_output): group = ctx.group - return DifferentiableAllGather.apply(grad_output, group, False), None + return DifferentiableAllGather.apply(grad_output, group), None # ----------------- @@ -138,7 +138,7 @@ def differentiable_all_reduce_sum(tensor, group: Optional[ProcessGroup] = None): return DifferentiableAllReduceSum.apply(tensor, group) -def differentiable_all_gather(tensor, group: Optional[ProcessGroup] = None) +def differentiable_all_gather(tensor, group: Optional[ProcessGroup] = None): return DifferentiableAllGather.apply(tensor, group) From 30439fdee7cac456be4a2c28798b42c931f7cf72 Mon Sep 17 00:00:00 2001 From: AleHD Date: Wed, 17 Jul 2024 11:16:40 +0000 Subject: [PATCH 08/24] precommit --- .../parallel/tensor_parallel/column_linear.py | 12 ++++++++---- src/nanotron/parallel/tensor_parallel/functional.py | 7 +++---- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/nanotron/parallel/tensor_parallel/column_linear.py b/src/nanotron/parallel/tensor_parallel/column_linear.py index eaab5abe..21daba36 100644 --- a/src/nanotron/parallel/tensor_parallel/column_linear.py +++ b/src/nanotron/parallel/tensor_parallel/column_linear.py @@ -13,9 +13,11 @@ class ColumnLinearContextParallel(torch.autograd.Function): enabled (i.e. tp_mode = TensorParallelLinearMode.REDUCE_SCATTER) and async communication disabled. """ + @staticmethod - def forward(ctx, input: torch.Tensor, weight: torch.Tensor, - bias: Optional[torch.Tensor], group: dist.ProcessGroup): + def forward( + ctx, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], group: dist.ProcessGroup + ): # Prepare context. ctx.save_for_backward(input, weight, bias) @@ -57,6 +59,8 @@ def backward(ctx, grad_output: torch.Tensor): return sub_grad_input, grad_weight, grad_bias, None -def column_linear_context_parallel(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], - group: dist.ProcessGroup): + +def column_linear_context_parallel( + input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], group: dist.ProcessGroup +): return ColumnLinearContextParallel.apply(input, weight, bias, group) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index cedbb219..f4e9de30 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -19,14 +19,13 @@ from torch.nn import functional as F import nanotron.distributed as dist +from nanotron.parallel.tensor_parallel.column_linear import column_linear_context_parallel from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import ( - differentiable_all_gather, differentiable_all_reduce_sum, differentiable_identity, differentiable_reduce_scatter_sum, ) from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode -from nanotron.parallel.tensor_parallel.column_linear import column_linear_context_parallel from nanotron.parallel.utils import assert_cuda_max_connections_set_to_1 @@ -90,10 +89,10 @@ def forward( @staticmethod def backward(ctx, grad_output): - # Retreive tensors from the forward path. + # Retrieve tensors from the forward path. softmax, target_mask, masked_target_1d = ctx.saved_tensors - # All the inputs have softmax as thier gradient. + # All the inputs have softmax as their gradient. grad_input = softmax # For simplicity, work with the 2D gradient. sharded_hidden_size = softmax.size()[-1] From 1e02a9ce9c9b564f4a4274ee62e7208e3d5d9df8 Mon Sep 17 00:00:00 2001 From: AleHD Date: Wed, 17 Jul 2024 12:47:20 +0000 Subject: [PATCH 09/24] Added tp_recompute_allgather option --- src/nanotron/config/parallelism_config.py | 2 + src/nanotron/models/llama.py | 3 ++ .../parallel/tensor_parallel/column_linear.py | 51 ++++++++++++------- .../parallel/tensor_parallel/functional.py | 4 +- src/nanotron/parallel/tensor_parallel/nn.py | 3 ++ 5 files changed, 44 insertions(+), 19 deletions(-) diff --git a/src/nanotron/config/parallelism_config.py b/src/nanotron/config/parallelism_config.py index 5912425b..e9a6f2a4 100644 --- a/src/nanotron/config/parallelism_config.py +++ b/src/nanotron/config/parallelism_config.py @@ -32,6 +32,8 @@ class ParallelismArgs: tp_mode: Optional[TensorParallelLinearMode] = None tp_linear_async_communication: Optional[bool] = None + tp_recompute_allgather: bool = False + expert_parallel_size: int = 1 def __post_init__(self): diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 3319b0ef..a31ebec6 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -154,6 +154,7 @@ def __init__( bias=False, async_communication=tp_linear_async_communication, contiguous_chunks=gate_up_contiguous_chunks, + tp_recompute_allgather=parallel_config.tp_recompute_allgather, ) self.down_proj = TensorParallelRowLinear( config.intermediate_size, @@ -314,6 +315,7 @@ def __init__( bias=False, async_communication=tp_linear_async_communication, contiguous_chunks=qkv_contiguous_chunks, + tp_recompute_allgather=parallel_config.tp_recompute_allgather, ) # TODO(kunhao): We want to have only one version per device and not one version per layer. self.rotary_embedding = RotaryEmbedding( @@ -742,6 +744,7 @@ def __init__( # TODO @thomasw21: refactor so that we store that default in a single place. "mode": self.tp_mode, "async_communication": tp_linear_async_communication, + "tp_recompute_allgather": parallel_config.tp_recompute_allgather, }, module_input_keys={"x"}, module_output_keys={"logits"}, diff --git a/src/nanotron/parallel/tensor_parallel/column_linear.py b/src/nanotron/parallel/tensor_parallel/column_linear.py index 21daba36..2f743199 100644 --- a/src/nanotron/parallel/tensor_parallel/column_linear.py +++ b/src/nanotron/parallel/tensor_parallel/column_linear.py @@ -16,33 +16,47 @@ class ColumnLinearContextParallel(torch.autograd.Function): @staticmethod def forward( - ctx, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], group: dist.ProcessGroup + ctx, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], group: dist.ProcessGroup, + tp_recompute_allgather: bool ): - # Prepare context. - ctx.save_for_backward(input, weight, bias) - ctx.group = group - # Do allgather. sharded_batch_size, *rest_size = input.shape unsharded_batch_size = sharded_batch_size * group.size() - total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) + if tp_recompute_allgather: + total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) + else: + total_input = torch.empty(unsharded_batch_size, *rest_size, dtype=input.dtype, device=input.device) dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) + # Prepare context. + ctx.group = group + ctx.tp_recompute_allgather = tp_recompute_allgather + ctx.input_size = input.shape + if tp_recompute_allgather: + ctx.save_for_backward(input, weight, bias) + else: + ctx.save_for_backward(total_input, weight, bias) + # Get linear output. out = F.linear(total_input, weight, bias) return out @staticmethod def backward(ctx, grad_output: torch.Tensor): - # Allgather the inputs again. - input, weight, bias = ctx.saved_tensors + # Either allgather the inputs again or get them from context. group = ctx.group - sharded_batch_size, *rest_size = input.shape - total_input = sharded_batch_size * group.size() - unsharded_batch_size = sharded_batch_size * group.size() - total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) - dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) + tp_recompute_allgather = ctx.tp_recompute_allgather + input_size = ctx.input_size + if tp_recompute_allgather: + input, weight, bias = ctx.saved_tensors + sharded_batch_size, *rest_size = input.shape + total_input = sharded_batch_size * group.size() + unsharded_batch_size = sharded_batch_size * group.size() + total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) + dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) + else: + total_input, weight, bias = ctx.saved_tensors # Get the grad_output and total_input on the correct views to be able to transpose them below. grad_output = grad_output.contiguous() @@ -51,16 +65,17 @@ def backward(ctx, grad_output: torch.Tensor): total_input = total_input.view(total_input.size(0) * total_input.size(1), total_input.size(2)) # Compute gradients. + grad_weight = grad_output.T @ total_input grad_input = grad_output @ weight - sub_grad_input = torch.empty(input.size(), dtype=input.dtype, device=input.device, requires_grad=False) + sub_grad_input = torch.empty(input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False) dist.reduce_scatter_tensor(sub_grad_input, grad_input, group=group, op=dist.ReduceOp.SUM) - grad_weight = grad_output.T @ total_input grad_bias = torch.sum(grad_output, dim=0) if bias is not None else None - return sub_grad_input, grad_weight, grad_bias, None + return sub_grad_input, grad_weight, grad_bias, None, None def column_linear_context_parallel( - input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], group: dist.ProcessGroup + input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], group: dist.ProcessGroup, + tp_recompute_allgather: bool = False ): - return ColumnLinearContextParallel.apply(input, weight, bias, group) + return ColumnLinearContextParallel.apply(input, weight, bias, group, tp_recompute_allgather) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index f4e9de30..c16ae492 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -21,6 +21,7 @@ import nanotron.distributed as dist from nanotron.parallel.tensor_parallel.column_linear import column_linear_context_parallel from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import ( + differentiable_all_gather, differentiable_all_reduce_sum, differentiable_identity, differentiable_reduce_scatter_sum, @@ -345,6 +346,7 @@ def column_linear( group: dist.ProcessGroup, tp_mode: TensorParallelLinearMode, async_communication: bool, + tp_recompute_allgather: bool = True ): if async_communication: return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode) @@ -352,7 +354,7 @@ def column_linear( if tp_mode is TensorParallelLinearMode.ALL_REDUCE: input = differentiable_identity(input, group=group) elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: - return column_linear_context_parallel(input, weight, bias, group) + return column_linear_context_parallel(input, weight, bias, group, tp_recompute_allgather) else: raise ValueError(f"Got unexpected mode: {tp_mode}.") diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py index 40e89968..42ffc828 100644 --- a/src/nanotron/parallel/tensor_parallel/nn.py +++ b/src/nanotron/parallel/tensor_parallel/nn.py @@ -51,6 +51,7 @@ def __init__( dtype=None, async_communication: bool = False, contiguous_chunks: Optional[Tuple[int, ...]] = None, + tp_recompute_allgather: bool = False, ): self.pg = pg self.world_size = pg.size() @@ -59,6 +60,7 @@ def __init__( self.in_features = in_features self.out_features = out_features // self.world_size + self.tp_recompute_allgather = tp_recompute_allgather super().__init__( in_features=self.in_features, @@ -91,6 +93,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: group=self.pg, tp_mode=self.mode, async_communication=self.async_communication, + tp_recompute_allgather=self.tp_recompute_allgather, ) def extra_repr(self) -> str: From 9cc81bb6fe680b72cf6114f7258af0483886ada1 Mon Sep 17 00:00:00 2001 From: AleHD Date: Wed, 17 Jul 2024 13:39:14 +0000 Subject: [PATCH 10/24] Changed recompute default --- src/nanotron/config/parallelism_config.py | 2 +- .../parallel/tensor_parallel/column_linear.py | 19 ++++++++++++++----- .../parallel/tensor_parallel/functional.py | 3 +-- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/nanotron/config/parallelism_config.py b/src/nanotron/config/parallelism_config.py index e9a6f2a4..cc5d406a 100644 --- a/src/nanotron/config/parallelism_config.py +++ b/src/nanotron/config/parallelism_config.py @@ -32,7 +32,7 @@ class ParallelismArgs: tp_mode: Optional[TensorParallelLinearMode] = None tp_linear_async_communication: Optional[bool] = None - tp_recompute_allgather: bool = False + tp_recompute_allgather: bool = True expert_parallel_size: int = 1 diff --git a/src/nanotron/parallel/tensor_parallel/column_linear.py b/src/nanotron/parallel/tensor_parallel/column_linear.py index 2f743199..880d5ff0 100644 --- a/src/nanotron/parallel/tensor_parallel/column_linear.py +++ b/src/nanotron/parallel/tensor_parallel/column_linear.py @@ -16,8 +16,12 @@ class ColumnLinearContextParallel(torch.autograd.Function): @staticmethod def forward( - ctx, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], group: dist.ProcessGroup, - tp_recompute_allgather: bool + ctx, + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + group: dist.ProcessGroup, + tp_recompute_allgather: bool, ): # Do allgather. @@ -67,7 +71,9 @@ def backward(ctx, grad_output: torch.Tensor): # Compute gradients. grad_weight = grad_output.T @ total_input grad_input = grad_output @ weight - sub_grad_input = torch.empty(input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False) + sub_grad_input = torch.empty( + input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False + ) dist.reduce_scatter_tensor(sub_grad_input, grad_input, group=group, op=dist.ReduceOp.SUM) grad_bias = torch.sum(grad_output, dim=0) if bias is not None else None @@ -75,7 +81,10 @@ def backward(ctx, grad_output: torch.Tensor): def column_linear_context_parallel( - input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], group: dist.ProcessGroup, - tp_recompute_allgather: bool = False + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + group: dist.ProcessGroup, + tp_recompute_allgather: bool = True, ): return ColumnLinearContextParallel.apply(input, weight, bias, group, tp_recompute_allgather) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index c16ae492..454cc447 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -21,7 +21,6 @@ import nanotron.distributed as dist from nanotron.parallel.tensor_parallel.column_linear import column_linear_context_parallel from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import ( - differentiable_all_gather, differentiable_all_reduce_sum, differentiable_identity, differentiable_reduce_scatter_sum, @@ -346,7 +345,7 @@ def column_linear( group: dist.ProcessGroup, tp_mode: TensorParallelLinearMode, async_communication: bool, - tp_recompute_allgather: bool = True + tp_recompute_allgather: bool = True, ): if async_communication: return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode) From 956fbfd09a2f0c2358fcb90be395f97ffa79632e Mon Sep 17 00:00:00 2001 From: AleHD Date: Wed, 17 Jul 2024 13:40:34 +0000 Subject: [PATCH 11/24] Changed recompute default --- src/nanotron/parallel/tensor_parallel/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py index 42ffc828..4c7325cd 100644 --- a/src/nanotron/parallel/tensor_parallel/nn.py +++ b/src/nanotron/parallel/tensor_parallel/nn.py @@ -51,7 +51,7 @@ def __init__( dtype=None, async_communication: bool = False, contiguous_chunks: Optional[Tuple[int, ...]] = None, - tp_recompute_allgather: bool = False, + tp_recompute_allgather: bool = True, ): self.pg = pg self.world_size = pg.size() From 9992f1c5919fd4038e85cd9b3fb1dd4faa81daf1 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Wed, 17 Jul 2024 14:28:57 +0000 Subject: [PATCH 12/24] Little fixes --- src/nanotron/config/config.py | 4 ++-- tools/preprocess_data.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index fe194883..2e1a98cc 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -96,10 +96,10 @@ class NanosetDatasetsArgs: dataset_folder: Union[str, dict, List[str]] def __post_init__(self): - if isinstance(self.dataset_folder, str): # Case 1: 1 Dataset file + if isinstance(self.dataset_folder, str): # Case 1: 1 Dataset folder self.dataset_folder = [self.dataset_folder] self.dataset_weights = [1] - elif isinstance(self.dataset_folder, List): # Case 2: > 1 Dataset file + elif isinstance(self.dataset_folder, List): # Case 2: > 1 Dataset folder self.dataset_weights = None # Set to None so we consume all the samples randomly elif isinstance(self.dataset_folder, dict): # Case 3: dict with > 1 dataset_folder and weights tmp_dataset_folder = self.dataset_folder.copy() diff --git a/tools/preprocess_data.py b/tools/preprocess_data.py index 38db67f1..f3cdab70 100644 --- a/tools/preprocess_data.py +++ b/tools/preprocess_data.py @@ -95,6 +95,7 @@ def main(args): output_folder=args.output_folder, tokenizer_name_or_path=args.tokenizer_name_or_path, eos_token=args.eos_token, + shuffle=False, max_tokens_per_file=1e9, ), ], From b9e92017614e0326acab86b7665e0c8e8718bfc3 Mon Sep 17 00:00:00 2001 From: AleHD Date: Wed, 17 Jul 2024 15:21:04 +0000 Subject: [PATCH 13/24] Moved ColumnLinearNoAsync module for consistency --- .../parallel/tensor_parallel/column_linear.py | 90 ------------------- .../parallel/tensor_parallel/functional.py | 78 +++++++++++++++- 2 files changed, 76 insertions(+), 92 deletions(-) delete mode 100644 src/nanotron/parallel/tensor_parallel/column_linear.py diff --git a/src/nanotron/parallel/tensor_parallel/column_linear.py b/src/nanotron/parallel/tensor_parallel/column_linear.py deleted file mode 100644 index 880d5ff0..00000000 --- a/src/nanotron/parallel/tensor_parallel/column_linear.py +++ /dev/null @@ -1,90 +0,0 @@ -from typing import Optional - -import torch -from torch.nn import functional as F - -import nanotron.distributed as dist -from nanotron.parallel.utils import MemoryBuffer - - -class ColumnLinearContextParallel(torch.autograd.Function): - """ - Column linear with memory_buffer for the allgather, context parallel - enabled (i.e. tp_mode = TensorParallelLinearMode.REDUCE_SCATTER) and - async communication disabled. - """ - - @staticmethod - def forward( - ctx, - input: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor], - group: dist.ProcessGroup, - tp_recompute_allgather: bool, - ): - - # Do allgather. - sharded_batch_size, *rest_size = input.shape - unsharded_batch_size = sharded_batch_size * group.size() - if tp_recompute_allgather: - total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) - else: - total_input = torch.empty(unsharded_batch_size, *rest_size, dtype=input.dtype, device=input.device) - dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) - - # Prepare context. - ctx.group = group - ctx.tp_recompute_allgather = tp_recompute_allgather - ctx.input_size = input.shape - if tp_recompute_allgather: - ctx.save_for_backward(input, weight, bias) - else: - ctx.save_for_backward(total_input, weight, bias) - - # Get linear output. - out = F.linear(total_input, weight, bias) - return out - - @staticmethod - def backward(ctx, grad_output: torch.Tensor): - # Either allgather the inputs again or get them from context. - group = ctx.group - tp_recompute_allgather = ctx.tp_recompute_allgather - input_size = ctx.input_size - if tp_recompute_allgather: - input, weight, bias = ctx.saved_tensors - sharded_batch_size, *rest_size = input.shape - total_input = sharded_batch_size * group.size() - unsharded_batch_size = sharded_batch_size * group.size() - total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) - dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) - else: - total_input, weight, bias = ctx.saved_tensors - - # Get the grad_output and total_input on the correct views to be able to transpose them below. - grad_output = grad_output.contiguous() - assert grad_output.dim() == 3 - grad_output = grad_output.view(grad_output.size(0) * grad_output.size(1), grad_output.size(2)) - total_input = total_input.view(total_input.size(0) * total_input.size(1), total_input.size(2)) - - # Compute gradients. - grad_weight = grad_output.T @ total_input - grad_input = grad_output @ weight - sub_grad_input = torch.empty( - input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False - ) - dist.reduce_scatter_tensor(sub_grad_input, grad_input, group=group, op=dist.ReduceOp.SUM) - grad_bias = torch.sum(grad_output, dim=0) if bias is not None else None - - return sub_grad_input, grad_weight, grad_bias, None, None - - -def column_linear_context_parallel( - input: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor], - group: dist.ProcessGroup, - tp_recompute_allgather: bool = True, -): - return ColumnLinearContextParallel.apply(input, weight, bias, group, tp_recompute_allgather) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 454cc447..2b93fb02 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -19,7 +19,7 @@ from torch.nn import functional as F import nanotron.distributed as dist -from nanotron.parallel.tensor_parallel.column_linear import column_linear_context_parallel +from nanotron.parallel.utils import MemoryBuffer from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import ( differentiable_all_reduce_sum, differentiable_identity, @@ -338,6 +338,80 @@ def backward(ctx, grad_output): raise ValueError(f"Got unexpected mode: {tp_mode}.") +class _ColumnLinearContextParallelNoAsync(torch.autograd.Function): + """ + Column linear with memory_buffer for the allgather, context parallel + enabled (i.e. tp_mode = TensorParallelLinearMode.REDUCE_SCATTER) and + async communication disabled. + """ + + @staticmethod + def forward( + ctx, + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + group: dist.ProcessGroup, + tp_recompute_allgather: bool, + ): + + # Do allgather. + sharded_batch_size, *rest_size = input.shape + unsharded_batch_size = sharded_batch_size * group.size() + if tp_recompute_allgather: + total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) + else: + total_input = torch.empty(unsharded_batch_size, *rest_size, dtype=input.dtype, device=input.device) + dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) + + # Prepare context. + ctx.group = group + ctx.tp_recompute_allgather = tp_recompute_allgather + ctx.input_size = input.shape + if tp_recompute_allgather: + ctx.save_for_backward(input, weight, bias) + else: + ctx.save_for_backward(total_input, weight, bias) + + # Get linear output. + out = F.linear(total_input, weight, bias) + return out + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + # Either allgather the inputs again or get them from context. + group = ctx.group + tp_recompute_allgather = ctx.tp_recompute_allgather + input_size = ctx.input_size + if tp_recompute_allgather: + input, weight, bias = ctx.saved_tensors + sharded_batch_size, *rest_size = input.shape + total_input = sharded_batch_size * group.size() + unsharded_batch_size = sharded_batch_size * group.size() + total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) + dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) + else: + total_input, weight, bias = ctx.saved_tensors + + # Get the grad_output and total_input on the correct views to be able to transpose them below. + grad_output = grad_output.contiguous() + assert grad_output.dim() == 3 + grad_output = grad_output.view(grad_output.size(0) * grad_output.size(1), grad_output.size(2)) + total_input = total_input.view(total_input.size(0) * total_input.size(1), total_input.size(2)) + + # Compute gradients. + grad_weight = grad_output.T @ total_input + grad_input = grad_output @ weight + sub_grad_input = torch.empty( + input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False + ) + dist.reduce_scatter_tensor(sub_grad_input, grad_input, group=group, op=dist.ReduceOp.SUM) + grad_bias = torch.sum(grad_output, dim=0) if bias is not None else None + + return sub_grad_input, grad_weight, grad_bias, None, None + + + def column_linear( input: torch.Tensor, weight: torch.Tensor, @@ -353,7 +427,7 @@ def column_linear( if tp_mode is TensorParallelLinearMode.ALL_REDUCE: input = differentiable_identity(input, group=group) elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: - return column_linear_context_parallel(input, weight, bias, group, tp_recompute_allgather) + return _ColumnLinearContextParallelNoAsync.apply(input, weight, bias, group, tp_recompute_allgather) else: raise ValueError(f"Got unexpected mode: {tp_mode}.") From 7cc6653c69c19cd42da05e6a8712159a146407e7 Mon Sep 17 00:00:00 2001 From: AleHD Date: Thu, 18 Jul 2024 11:16:42 +0000 Subject: [PATCH 14/24] memory efficient async linear --- .../parallel/tensor_parallel/functional.py | 48 ++++++------------- 1 file changed, 15 insertions(+), 33 deletions(-) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 47c0b5a1..1a82254e 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -149,14 +149,7 @@ def forward(ctx, tensor, weight, bias, group, tp_mode): group = dist.distributed_c10d._get_default_group() gathered_batch_size = sharded_batch_size * group.size() - gathered_tensor = torch.empty( - gathered_batch_size, - *intermediate_size, - hidden_size, - device=tensor.device, - dtype=tensor.dtype, - requires_grad=tensor.requires_grad, - ) + gathered_tensor = MemoryBuffer().get("allgather", (gathered_batch_size, *intermediate_size, hidden_size), dtype=tensor.dtype) handle = dist.all_gather_into_tensor(gathered_tensor, tensor, group=group, async_op=True) @@ -261,7 +254,7 @@ def backward(ctx, grad_output): use_bias = ctx.use_bias tp_mode = ctx.tp_mode - handle: Optional[dist.Work] = None + handle1: Optional[dist.Work] = None if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: # TODO @thomasw21: gather along another dimension sharded_batch_size, *rest_size = tensor.shape @@ -273,14 +266,8 @@ def backward(ctx, grad_output): else: unsharded_batch_size = sharded_batch_size * group.size() - unsharded_tensor = torch.empty( - unsharded_batch_size, - *rest_size, - device=tensor.device, - dtype=tensor.dtype, - requires_grad=False, - ) - handle = dist.all_gather_into_tensor(unsharded_tensor, tensor, group=group, async_op=True) + unsharded_tensor = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=tensor.dtype) + handle1 = dist.all_gather_into_tensor(unsharded_tensor, tensor, group=group, async_op=True) # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the # gather is scheduled before the tensor gradient computation total_tensor = unsharded_tensor @@ -289,9 +276,6 @@ def backward(ctx, grad_output): grad_tensor = grad_output.matmul(weight) - if handle is not None: - handle.wait() - # Doing gather + slicing during the NeMo forward pass can make this tensor # not be contiguous. PyTorch only checks if the tensor is contiguous, and only # clones it if it's not contiguous: @@ -303,7 +287,7 @@ def backward(ctx, grad_output): grad_output = grad_output.view(math.prod(grad_output_first_dims), grad_output_last_dim) total_tensor = total_tensor.view(math.prod(total_tensor_first_dims), total_tensor_last_dim) - handle: Optional[dist.Work] = None + handle2: Optional[dist.Work] = None if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: if group.size() == 1: sub_grad_tensor = grad_tensor @@ -312,23 +296,27 @@ def backward(ctx, grad_output): tensor.shape, dtype=grad_tensor.dtype, device=grad_tensor.device, requires_grad=False ) # reduce_scatter - handle = dist.reduce_scatter_tensor(sub_grad_tensor, grad_tensor, group=group, async_op=True) + handle2 = dist.reduce_scatter_tensor(sub_grad_tensor, grad_tensor, group=group, async_op=True) # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the # reduce scatter is scheduled before the weight gradient computation elif tp_mode is TensorParallelLinearMode.ALL_REDUCE: # Asynchronous all-reduce - handle = dist.all_reduce(grad_tensor, group=group, async_op=True) + handle2 = dist.all_reduce(grad_tensor, group=group, async_op=True) # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the # all-reduce is scheduled before the weight gradient computation else: raise ValueError() + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if handle1 is not None: + handle1.wait() + # TODO @thomasw21: This sounds like we don't have the optimal physical layout grad_weight = grad_output.t().matmul(total_tensor) - grad_bias = grad_output.sum(dim=0) if use_bias else None - if handle is not None: - handle.wait() + if handle2 is not None: + handle2.wait() if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: return sub_grad_tensor, grad_weight, grad_bias, None, None @@ -472,13 +460,7 @@ def backward(ctx, grad_output): else: unsharded_batch_size = sharded_batch_size * group.size() - total_grad_output = torch.empty( - unsharded_batch_size, - *rest_size, - device=grad_output.device, - dtype=grad_output.dtype, - requires_grad=False, - ) + total_grad_output = MemoryBuffer().get("allgather2", (unsharded_batch_size, *rest_size), dtype=tensor.dtype) # Doing gather + slicing during the NeMo forward pass can make this tensor # not be contiguous. PyTorch only checks if the tensor is contiguous, and only From cb0f2609e357747a0a3001dd9b12c649f9e6eef7 Mon Sep 17 00:00:00 2001 From: AleHD Date: Thu, 18 Jul 2024 11:17:09 +0000 Subject: [PATCH 15/24] precommit --- .../parallel/tensor_parallel/functional.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 1a82254e..3821d544 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -19,14 +19,13 @@ from torch.nn import functional as F import nanotron.distributed as dist -from nanotron.parallel.utils import MemoryBuffer from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import ( differentiable_all_reduce_sum, differentiable_identity, differentiable_reduce_scatter_sum, ) from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode -from nanotron.parallel.utils import assert_cuda_max_connections_set_to_1 +from nanotron.parallel.utils import MemoryBuffer, assert_cuda_max_connections_set_to_1 class _ShardedCrossEntropy(torch.autograd.Function): @@ -149,7 +148,9 @@ def forward(ctx, tensor, weight, bias, group, tp_mode): group = dist.distributed_c10d._get_default_group() gathered_batch_size = sharded_batch_size * group.size() - gathered_tensor = MemoryBuffer().get("allgather", (gathered_batch_size, *intermediate_size, hidden_size), dtype=tensor.dtype) + gathered_tensor = MemoryBuffer().get( + "allgather", (gathered_batch_size, *intermediate_size, hidden_size), dtype=tensor.dtype + ) handle = dist.all_gather_into_tensor(gathered_tensor, tensor, group=group, async_op=True) @@ -266,7 +267,9 @@ def backward(ctx, grad_output): else: unsharded_batch_size = sharded_batch_size * group.size() - unsharded_tensor = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=tensor.dtype) + unsharded_tensor = MemoryBuffer().get( + "allgather", (unsharded_batch_size, *rest_size), dtype=tensor.dtype + ) handle1 = dist.all_gather_into_tensor(unsharded_tensor, tensor, group=group, async_op=True) # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the # gather is scheduled before the tensor gradient computation @@ -399,7 +402,6 @@ def backward(ctx, grad_output: torch.Tensor): return sub_grad_input, grad_weight, grad_bias, None, None - def column_linear( input: torch.Tensor, weight: torch.Tensor, @@ -460,7 +462,9 @@ def backward(ctx, grad_output): else: unsharded_batch_size = sharded_batch_size * group.size() - total_grad_output = MemoryBuffer().get("allgather2", (unsharded_batch_size, *rest_size), dtype=tensor.dtype) + total_grad_output = MemoryBuffer().get( + "allgather2", (unsharded_batch_size, *rest_size), dtype=tensor.dtype + ) # Doing gather + slicing during the NeMo forward pass can make this tensor # not be contiguous. PyTorch only checks if the tensor is contiguous, and only From 6d85d038d52ffc06ec4f2ae4705deae3b05d25d8 Mon Sep 17 00:00:00 2001 From: AleHD Date: Thu, 18 Jul 2024 11:46:39 +0000 Subject: [PATCH 16/24] Added no_recompute_allgather mode to async --- .../parallel/tensor_parallel/functional.py | 36 +++++++++++++------ 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 3821d544..29480f9a 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -120,10 +120,12 @@ class _ColumnLinearAsyncCommunication(torch.autograd.Function): @staticmethod @assert_cuda_max_connections_set_to_1 - def forward(ctx, tensor, weight, bias, group, tp_mode): + def forward(ctx, tensor, weight, bias, group, tp_mode, tp_recompute_allgather): ctx.use_bias = bias is not None ctx.tp_mode = tp_mode ctx.group = group + ctx.tp_recompute_allgather = tp_recompute_allgather + ctx.tensor_shape = tensor.size() if tp_mode is TensorParallelLinearMode.ALL_REDUCE: gathered_tensor = tensor @@ -140,7 +142,7 @@ def forward(ctx, tensor, weight, bias, group, tp_mode): # `tensor` can sometimes not be contiguous # https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L317 tensor = tensor.contiguous() - ctx.save_for_backward(tensor, weight) + # ctx.save_for_backward(tensor, weight) # TODO @thomasw21: gather along another dimension sharded_batch_size, *intermediate_size, hidden_size = tensor.shape @@ -148,9 +150,19 @@ def forward(ctx, tensor, weight, bias, group, tp_mode): group = dist.distributed_c10d._get_default_group() gathered_batch_size = sharded_batch_size * group.size() - gathered_tensor = MemoryBuffer().get( - "allgather", (gathered_batch_size, *intermediate_size, hidden_size), dtype=tensor.dtype - ) + if tp_recompute_allgather: + gathered_tensor = MemoryBuffer().get( + "allgather", (gathered_batch_size, *intermediate_size, hidden_size), dtype=tensor.dtype + ) + else: + gathered_tensor = torch.empty( + gathered_batch_size, + *intermediate_size, + hidden_size, + device=tensor.device, + dtype=tensor.dtype, + requires_grad=False, + ) handle = dist.all_gather_into_tensor(gathered_tensor, tensor, group=group, async_op=True) @@ -198,6 +210,10 @@ def forward(ctx, tensor, weight, bias, group, tp_mode): # Wait communication handle.wait() + if tp_recompute_allgather: + ctx.save_for_backward(tensor, weight) + else: + ctx.save_for_backward(gathered_tensor, weight) # Compute all the other shards that are obtained from AllGather # weights: w0 w1 w2 w3 @@ -256,7 +272,7 @@ def backward(ctx, grad_output): tp_mode = ctx.tp_mode handle1: Optional[dist.Work] = None - if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: + if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER and ctx.tp_recompute_allgather: # TODO @thomasw21: gather along another dimension sharded_batch_size, *rest_size = tensor.shape if group is None: @@ -296,7 +312,7 @@ def backward(ctx, grad_output): sub_grad_tensor = grad_tensor else: sub_grad_tensor = torch.empty( - tensor.shape, dtype=grad_tensor.dtype, device=grad_tensor.device, requires_grad=False + ctx.tensor_shape, dtype=grad_tensor.dtype, device=grad_tensor.device, requires_grad=False ) # reduce_scatter handle2 = dist.reduce_scatter_tensor(sub_grad_tensor, grad_tensor, group=group, async_op=True) @@ -322,9 +338,9 @@ def backward(ctx, grad_output): handle2.wait() if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: - return sub_grad_tensor, grad_weight, grad_bias, None, None + return sub_grad_tensor, grad_weight, grad_bias, None, None, None elif tp_mode is TensorParallelLinearMode.ALL_REDUCE: - return grad_tensor, grad_weight, grad_bias, None, None + return grad_tensor, grad_weight, grad_bias, None, None, None else: raise ValueError(f"Got unexpected mode: {tp_mode}.") @@ -412,7 +428,7 @@ def column_linear( tp_recompute_allgather: bool = True, ): if async_communication: - return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode) + return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: input = differentiable_identity(input, group=group) From 2afd00769c7c5891341e2d7880492bfc80c524f6 Mon Sep 17 00:00:00 2001 From: AleHD Date: Tue, 23 Jul 2024 09:52:12 +0000 Subject: [PATCH 17/24] Fixed List not found --- src/nanotron/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 7ee44390..49ea86e6 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -14,7 +14,7 @@ # limitations under the License. """PyTorch LLaMa model.""" -from typing import Dict, Optional, Union +from typing import Dict, Optional, Union, List import torch from torch import nn From 7e758db3068948178edd2151232e4abd7b2d5ffd Mon Sep 17 00:00:00 2001 From: AleHD Date: Tue, 23 Jul 2024 13:17:34 +0000 Subject: [PATCH 18/24] Fixed tp=1 case --- .../parallel/tensor_parallel/functional.py | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 2b93fb02..22c8ca3c 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -358,11 +358,14 @@ def forward( # Do allgather. sharded_batch_size, *rest_size = input.shape unsharded_batch_size = sharded_batch_size * group.size() - if tp_recompute_allgather: + if group.size() == 1: + total_input = input.contiguous() + elif tp_recompute_allgather: total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) + dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) else: total_input = torch.empty(unsharded_batch_size, *rest_size, dtype=input.dtype, device=input.device) - dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) + dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) # Prepare context. ctx.group = group @@ -383,21 +386,22 @@ def backward(ctx, grad_output: torch.Tensor): group = ctx.group tp_recompute_allgather = ctx.tp_recompute_allgather input_size = ctx.input_size - if tp_recompute_allgather: + if group.size() == 1 or not tp_recompute_allgather: + total_input, weight, bias = ctx.saved_tensors + else: input, weight, bias = ctx.saved_tensors sharded_batch_size, *rest_size = input.shape total_input = sharded_batch_size * group.size() unsharded_batch_size = sharded_batch_size * group.size() total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) - else: - total_input, weight, bias = ctx.saved_tensors - # Get the grad_output and total_input on the correct views to be able to transpose them below. + # Convert the tensor shapes to 2D for execution compatibility grad_output = grad_output.contiguous() - assert grad_output.dim() == 3 - grad_output = grad_output.view(grad_output.size(0) * grad_output.size(1), grad_output.size(2)) - total_input = total_input.view(total_input.size(0) * total_input.size(1), total_input.size(2)) + grad_output_first_dims, grad_output_last_dim = grad_output.shape[:-1], grad_output.shape[-1] + total_tensor_first_dims, total_tensor_last_dim = total_tensor.shape[:-1], total_tensor.shape[-1] + grad_output = grad_output.view(math.prod(grad_output_first_dims), grad_output_last_dim) + total_tensor = total_tensor.view(math.prod(total_tensor_first_dims), total_tensor_last_dim) # Compute gradients. grad_weight = grad_output.T @ total_input From cd84d4fa7bff4017a9b1653c4b68290c00eae649 Mon Sep 17 00:00:00 2001 From: AleHD Date: Tue, 30 Jul 2024 19:16:14 +0200 Subject: [PATCH 19/24] Fixed column parallel --- .../parallel/tensor_parallel/functional.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 054d41e9..468855a5 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -19,14 +19,13 @@ from torch.nn import functional as F import nanotron.distributed as dist -from nanotron.parallel.utils import MemoryBuffer from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import ( differentiable_all_reduce_sum, differentiable_identity, differentiable_reduce_scatter_sum, ) from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode -from nanotron.parallel.utils import assert_cuda_max_connections_set_to_1 +from nanotron.parallel.utils import MemoryBuffer, assert_cuda_max_connections_set_to_1 class _ShardedCrossEntropy(torch.autograd.Function): @@ -399,23 +398,25 @@ def backward(ctx, grad_output: torch.Tensor): # Convert the tensor shapes to 2D for execution compatibility grad_output = grad_output.contiguous() grad_output_first_dims, grad_output_last_dim = grad_output.shape[:-1], grad_output.shape[-1] - total_tensor_first_dims, total_tensor_last_dim = total_tensor.shape[:-1], total_tensor.shape[-1] + total_input_first_dims, total_input_last_dim = total_input.shape[:-1], total_input.shape[-1] grad_output = grad_output.view(math.prod(grad_output_first_dims), grad_output_last_dim) - total_tensor = total_tensor.view(math.prod(total_tensor_first_dims), total_tensor_last_dim) + total_input = total_input.view(math.prod(total_input_first_dims), total_input_last_dim) # Compute gradients. grad_weight = grad_output.T @ total_input grad_input = grad_output @ weight - sub_grad_input = torch.empty( - input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False - ) - dist.reduce_scatter_tensor(sub_grad_input, grad_input, group=group, op=dist.ReduceOp.SUM) + if group.size() == 1: + sub_grad_input = grad_input + else: + sub_grad_input = torch.empty( + input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False + ) + dist.reduce_scatter_tensor(sub_grad_input, grad_input, group=group, op=dist.ReduceOp.SUM) grad_bias = torch.sum(grad_output, dim=0) if bias is not None else None return sub_grad_input, grad_weight, grad_bias, None, None - def column_linear( input: torch.Tensor, weight: torch.Tensor, From d3db06acb2e3fe235ae512861ff64ee9fbc9ac11 Mon Sep 17 00:00:00 2001 From: AleHD Date: Tue, 30 Jul 2024 19:16:28 +0200 Subject: [PATCH 20/24] Added tp_recompute_allgather test --- tests/test_tensor_parallel.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/tests/test_tensor_parallel.py b/tests/test_tensor_parallel.py index f5dcaeb0..8e73973b 100644 --- a/tests/test_tensor_parallel.py +++ b/tests/test_tensor_parallel.py @@ -18,17 +18,30 @@ @pytest.mark.parametrize("tp,dp,pp", [pytest.param(i, 1, 1) for i in range(1, min(4, available_gpus()) + 1)]) @pytest.mark.parametrize("tp_mode", list(TensorParallelLinearMode)) @pytest.mark.parametrize("async_communication", [False, True]) +@pytest.mark.parametrize("tp_recompute_allgather", [False, True]) @rerun_if_address_is_in_use() -def test_column_linear(tp: int, dp: int, pp: int, tp_mode: TensorParallelLinearMode, async_communication: bool): +def test_column_linear( + tp: int, + dp: int, + pp: int, + tp_mode: TensorParallelLinearMode, + async_communication: bool, + tp_recompute_allgather: bool, +): if tp_mode is TensorParallelLinearMode.ALL_REDUCE and async_communication: pytest.skip("ALL_REDUCE mode does not support async communication") + if tp_mode is TensorParallelLinearMode.ALL_REDUCE and tp_recompute_allgather: + pytest.skip("ALL_REDUCE mode is unaffected by tp_recompute_allgather") init_distributed(tp=tp, dp=dp, pp=pp)(_test_column_linear)( - tp_mode=tp_mode, async_communication=async_communication + tp_mode=tp_mode, async_communication=async_communication, tp_recompute_allgather=tp_recompute_allgather ) def _test_column_linear( - parallel_context: ParallelContext, tp_mode: TensorParallelLinearMode, async_communication: bool + parallel_context: ParallelContext, + tp_mode: TensorParallelLinearMode, + async_communication: bool, + tp_recompute_allgather: bool, ): if async_communication: os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" @@ -44,6 +57,7 @@ def _test_column_linear( mode=tp_mode, device="cuda", async_communication=async_communication, + tp_recompute_allgather=tp_recompute_allgather, ) # Un-sharded @@ -86,7 +100,7 @@ def _test_column_linear( random_input = sharded_random_input else: ValueError(f"Unsupported mode: {tp_mode}") - # It's important that `random_input` and `sharded_random_input` are two seperate tensors with seperate storage + # It's important that `random_input` and `sharded_random_input` are two separate tensors with separate storage sharded_random_input = sharded_random_input.clone() random_input.requires_grad = True sharded_random_input.requires_grad = True From 4c94b99a8cafd5f8dc2b7f208341a17a4d818234 Mon Sep 17 00:00:00 2001 From: AleHD Date: Tue, 30 Jul 2024 19:34:31 +0200 Subject: [PATCH 21/24] Added tp_recompute_allgather test --- tests/test_tensor_parallel.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/tests/test_tensor_parallel.py b/tests/test_tensor_parallel.py index 8e73973b..16008eaa 100644 --- a/tests/test_tensor_parallel.py +++ b/tests/test_tensor_parallel.py @@ -164,15 +164,32 @@ def _test_column_linear( @pytest.mark.parametrize("tp,dp,pp", [pytest.param(i, 1, 1) for i in range(1, min(4, available_gpus()) + 1)]) @pytest.mark.parametrize("tp_mode", list(TensorParallelLinearMode)) @pytest.mark.parametrize("async_communication", [False, True]) +@pytest.mark.parametrize("tp_recompute_allgather", [False, True]) @rerun_if_address_is_in_use() -def test_row_linear(tp: int, dp: int, pp: int, tp_mode: TensorParallelLinearMode, async_communication: bool): +def test_row_linear( + tp: int, + dp: int, + pp: int, + tp_mode: TensorParallelLinearMode, + async_communication: bool, + tp_recompute_allgather: bool, +): if tp_mode is TensorParallelLinearMode.ALL_REDUCE and async_communication: pytest.skip("ALL_REDUCE mode does not support async communication") + if tp_mode is TensorParallelLinearMode.ALL_REDUCE and tp_recompute_allgather: + pytest.skip("ALL_REDUCE mode is not affected by tp_recompute_allgather") - init_distributed(tp=tp, dp=dp, pp=pp)(_test_row_linear)(tp_mode=tp_mode, async_communication=async_communication) + init_distributed(tp=tp, dp=dp, pp=pp)(_test_row_linear)( + tp_mode=tp_mode, async_communication=async_communication, tp_recompute_allgather=tp_recompute_allgather + ) -def _test_row_linear(parallel_context: ParallelContext, tp_mode: TensorParallelLinearMode, async_communication: bool): +def _test_row_linear( + parallel_context: ParallelContext, + tp_mode: TensorParallelLinearMode, + async_communication: bool, + tp_recompute_allgather: bool, +): if async_communication: os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" out_features = 3 From 3c3561158eb053176b6f148a2366ea1aa56fdc7d Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Wed, 31 Jul 2024 17:15:34 +0000 Subject: [PATCH 22/24] change to correct config_nanoset.yaml path --- docs/nanoset.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/nanoset.md b/docs/nanoset.md index 9dce21b7..61393438 100644 --- a/docs/nanoset.md +++ b/docs/nanoset.md @@ -79,7 +79,7 @@ To work with `Nanosets`, we just need to configure 1 argument: Finally, to use the `Nanosets`, launch the training with [`run_train.py`](../run_train.py). ```shell -torchrun --nproc-per-node 8 run_train.py --config configs/config_nanoset.yaml +torchrun --nproc-per-node 1 run_train.py --config examples/config_nanoset.yaml ``` ## Under the hood From 7daa186e84e03fa66fa83129d9b0acffb1a668ba Mon Sep 17 00:00:00 2001 From: AleHD Date: Fri, 2 Aug 2024 15:40:46 +0200 Subject: [PATCH 23/24] Minor restyling --- .../parallel/tensor_parallel/functional.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 468855a5..1fb86cb5 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -115,7 +115,7 @@ def sharded_cross_entropy(sharded_logits, target, group: dist.ProcessGroup, dtyp return _ShardedCrossEntropy.apply(sharded_logits, target, group) -class _ColumnLinearAsyncCommunication(torch.autograd.Function): +class _ColumnLinearNoAsyncCommunicationReduceScatterMode(torch.autograd.Function): """Adapted from https://github.com/NVIDIA/Megatron-LM/blob/e6d7e09845590d0a36bc7f29eb28db974fb8da4e/megatron/core/tensor_parallel/layers.py#L215""" @staticmethod @@ -408,6 +408,9 @@ def backward(ctx, grad_output: torch.Tensor): if group.size() == 1: sub_grad_input = grad_input else: + # Seems that `reduce_scatter` need contiguous tensors: https://github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L305 + # We set grad_input to be contiguous in case it isn't already. + grad_input = grad_input.contiguous() sub_grad_input = torch.empty( input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False ) @@ -427,16 +430,14 @@ def column_linear( tp_recompute_allgather: bool = True, ): if async_communication: - return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode) + return _ColumnLinearNoAsyncCommunicationReduceScatterMode.apply(input, weight, bias, group, tp_mode) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: input = differentiable_identity(input, group=group) - elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: + return F.linear(input, weight, bias) + if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: return _ColumnLinearContextParallelNoAsync.apply(input, weight, bias, group, tp_recompute_allgather) - else: - raise ValueError(f"Got unexpected mode: {tp_mode}.") - - return F.linear(input, weight, bias) + raise ValueError(f"Got unexpected mode: {tp_mode}.") class _RowLinearAsyncCommunication(torch.autograd.Function): From 31c3c5ad0a845ff6318842c663223e9621586a3d Mon Sep 17 00:00:00 2001 From: AleHD Date: Fri, 2 Aug 2024 15:54:44 +0200 Subject: [PATCH 24/24] Fixed names --- src/nanotron/parallel/tensor_parallel/functional.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 1fb86cb5..7a88aec6 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -115,7 +115,7 @@ def sharded_cross_entropy(sharded_logits, target, group: dist.ProcessGroup, dtyp return _ShardedCrossEntropy.apply(sharded_logits, target, group) -class _ColumnLinearNoAsyncCommunicationReduceScatterMode(torch.autograd.Function): +class _ColumnLinearAsyncCommunication(torch.autograd.Function): """Adapted from https://github.com/NVIDIA/Megatron-LM/blob/e6d7e09845590d0a36bc7f29eb28db974fb8da4e/megatron/core/tensor_parallel/layers.py#L215""" @staticmethod @@ -337,7 +337,7 @@ def backward(ctx, grad_output): raise ValueError(f"Got unexpected mode: {tp_mode}.") -class _ColumnLinearContextParallelNoAsync(torch.autograd.Function): +class _ColumnLinearNoAsyncCommunicationReduceScatterMode(torch.autograd.Function): """ Column linear with memory_buffer for the allgather, context parallel enabled (i.e. tp_mode = TensorParallelLinearMode.REDUCE_SCATTER) and @@ -430,13 +430,15 @@ def column_linear( tp_recompute_allgather: bool = True, ): if async_communication: - return _ColumnLinearNoAsyncCommunicationReduceScatterMode.apply(input, weight, bias, group, tp_mode) + return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: input = differentiable_identity(input, group=group) return F.linear(input, weight, bias) if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: - return _ColumnLinearContextParallelNoAsync.apply(input, weight, bias, group, tp_recompute_allgather) + return _ColumnLinearNoAsyncCommunicationReduceScatterMode.apply( + input, weight, bias, group, tp_recompute_allgather + ) raise ValueError(f"Got unexpected mode: {tp_mode}.")