diff --git a/.github/workflows/gpu_test.yaml b/.github/workflows/gpu_test.yaml index 7a664b2e29..829e9384a6 100644 --- a/.github/workflows/gpu_test.yaml +++ b/.github/workflows/gpu_test.yaml @@ -46,7 +46,7 @@ jobs: run: python -m pip install --upgrade pip - name: Install torch nightly if: ${{ matrix.torch-version == 'nightly' }} - run: python -m pip install --pre torch torchvision torchao --index-url https://download.pytorch.org/whl/nightly/cu121 + run: python -m pip install --pre torch torchvision torchao --index-url https://download.pytorch.org/whl/nightly/cu126 - name: Install torch stable if: ${{ matrix.torch-version == 'stable' }} run: python -m pip install torch torchvision torchao diff --git a/README.md b/README.md index 289d433426..0d014e8d2a 100644 --- a/README.md +++ b/README.md @@ -170,7 +170,7 @@ pip install torchtune ```bash # Install PyTorch, torchvision, torchao nightlies -pip install --pre --upgrade torch torchvision torchao --index-url https://download.pytorch.org/whl/nightly/cu121 # full options are cpu/cu118/cu121/cu124 +pip install --pre --upgrade torch torchvision torchao --index-url https://download.pytorch.org/whl/nightly/cu126 # full options are cpu/cu118/cu121/cu124/cu126 pip install --pre --upgrade torchtune --extra-index-url https://download.pytorch.org/whl/nightly/cpu ``` diff --git a/docs/source/install.rst b/docs/source/install.rst index 7b5f908da1..de09573fe2 100644 --- a/docs/source/install.rst +++ b/docs/source/install.rst @@ -19,7 +19,7 @@ nightly versions with the following commands: pip install torch torchvision torchao # Or nightly install for latest features - pip install --pre torch torchvision torchao --index-url https://download.pytorch.org/whl/nightly/cu121 # full options are cpu/cu118/cu121/cu124 + pip install --pre torch torchvision torchao --index-url https://download.pytorch.org/whl/nightly/cu126 # full options are cpu/cu118/cu121/cu124/cu126 Install via PyPI diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 745ef64eb4..9ef5e6533f 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -766,7 +766,9 @@ def train(self) -> None: if self._optimizer_in_bwd: torch.distributed.all_reduce(num_tokens) torch.distributed.all_reduce(running_loss) - current_loss = current_loss / num_tokens + + # We multiply by world_size to undo FSDP2 gradient normalization. + current_loss = current_loss * (world_size / num_tokens) current_loss.backward() @@ -778,7 +780,8 @@ def train(self) -> None: # This will ensure that the logged loss matches what we're optimizing torch.distributed.all_reduce(running_loss) # Manually scale the gradients from unnormalized loss by total # of tokens - training.scale_grads(self._model, 1 / num_tokens) + # We multiply by world_size to undo FSDP2 gradient normalization. + training.scale_grads(self._model, world_size / num_tokens) if self._clip_grad_norm is not None: grad_norm = torch.nn.utils.clip_grad_norm_( self._model.parameters(), diff --git a/recipes/knowledge_distillation_distributed.py b/recipes/knowledge_distillation_distributed.py index a11adabb97..77fc50927c 100644 --- a/recipes/knowledge_distillation_distributed.py +++ b/recipes/knowledge_distillation_distributed.py @@ -769,7 +769,6 @@ def save_checkpoint(self, epoch: int) -> None: def _loss_step( self, batch: Dict[str, torch.Tensor] ) -> (torch.Tensor, torch.Tensor): - # Both are shape [b, s] tokens, labels = batch["tokens"], batch["labels"] @@ -875,7 +874,8 @@ def train(self) -> None: torch.distributed.all_reduce(running_class_loss) torch.distributed.all_reduce(running_kd_loss) # Manually scale the gradients from unnormalized loss by total # of tokens - training.scale_grads(self._model, 1 / num_tokens) + # We multiply by world_size to undo FSDP2 gradient normalization. + training.scale_grads(self._model, world_size / num_tokens) class_loss_to_log = running_class_loss.item() / num_tokens kd_loss_to_log = running_kd_loss.item() / num_tokens self._optimizer.step() diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index ac05e2060a..2cdfcd8010 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -822,7 +822,8 @@ def train(self) -> None: # This will ensure that the logged loss matches what we're optimizing torch.distributed.all_reduce(running_loss) # Manually scale the gradients from unnormalized loss by total # of tokens - training.scale_grads(self._model, 1 / num_tokens) + # We multiply by world_size to undo FSDP2 gradient normalization. + training.scale_grads(self._model, world_size / num_tokens) if self._clip_grad_norm is not None: grad_norm = torch.nn.utils.clip_grad_norm_( self._model.parameters(), diff --git a/recipes/lora_finetune_distributed_multi_dataset.py b/recipes/lora_finetune_distributed_multi_dataset.py index 30ece70347..ce482bfa27 100644 --- a/recipes/lora_finetune_distributed_multi_dataset.py +++ b/recipes/lora_finetune_distributed_multi_dataset.py @@ -851,7 +851,8 @@ def train(self) -> None: # This will ensure that the logged loss matches what we're optimizing torch.distributed.all_reduce(running_loss) # Manually scale the gradients from unnormalized loss by total # of tokens - training.scale_grads(self._model, 1 / num_tokens) + # We multiply by world_size to undo FSDP2 gradient normalization. + training.scale_grads(self._model, world_size / num_tokens) if self._clip_grad_norm is not None: grad_norm = torch.nn.utils.clip_grad_norm_( self._model.parameters(), diff --git a/recipes/qat_distributed.py b/recipes/qat_distributed.py index eaa2974579..f1b1302b7d 100644 --- a/recipes/qat_distributed.py +++ b/recipes/qat_distributed.py @@ -837,7 +837,9 @@ def train(self) -> None: if self._optimizer_in_bwd: torch.distributed.all_reduce(num_tokens) torch.distributed.all_reduce(running_loss) - current_loss = current_loss / num_tokens + + # We multiply by world_size to undo FSDP2 gradient normalization. + current_loss = current_loss * (world_size / num_tokens) current_loss.backward() @@ -849,7 +851,8 @@ def train(self) -> None: # This will ensure that the logged loss matches what we're optimizing torch.distributed.all_reduce(running_loss) # Manually scale the gradients from unnormalized loss by total # of tokens - training.scale_grads(self._model, 1 / num_tokens) + # We multiply by world_size to undo FSDP2 gradient normalization. + training.scale_grads(self._model, world_size / num_tokens) if self._clip_grad_norm is not None: grad_norm = torch.nn.utils.clip_grad_norm_( self._model.parameters(), diff --git a/recipes/qat_lora_finetune_distributed.py b/recipes/qat_lora_finetune_distributed.py index b9080de77d..133c39c94b 100644 --- a/recipes/qat_lora_finetune_distributed.py +++ b/recipes/qat_lora_finetune_distributed.py @@ -866,7 +866,8 @@ def train(self) -> None: # This will ensure that the logged loss matches what we're optimizing torch.distributed.all_reduce(running_loss) # Manually scale the gradients from unnormalized loss by total # of tokens - training.scale_grads(self._model, 1 / num_tokens) + # We multiply by world_size to undo FSDP2 gradient normalization. + training.scale_grads(self._model, world_size / num_tokens) if self._clip_grad_norm is not None: grad_norm = torch.nn.utils.clip_grad_norm_( self._model.parameters(),