diff --git a/src/nanotron/constants.py b/src/nanotron/constants.py index d279fe8d..0ccaa091 100644 --- a/src/nanotron/constants.py +++ b/src/nanotron/constants.py @@ -11,14 +11,4 @@ CHECKPOINT_FILE_NAME = "checkpoint_metadata.json" MODEL_CONFIG_FILE_NAME = "model_config.json" - -# TODO(xrsrke): remove this shit -ITERATION_STEP = 1 -# TODO(xrsrke): refactor to training stage, -# keep it in the same class as iteration_step - -is_ready_to_log = False - -# TODO(xrsrke): refactor CPU_WEIGHTS = {} -ACCUM_GRADS = {} diff --git a/src/nanotron/fp8/__init__.py b/src/nanotron/fp8/__init__.py index 963d424b..abdbd15c 100644 --- a/src/nanotron/fp8/__init__.py +++ b/src/nanotron/fp8/__init__.py @@ -2,7 +2,6 @@ from nanotron.fp8.dtypes import DTypes # noqa from nanotron.fp8.linear import FP8Linear # noqa -from nanotron.fp8.parameter import FP8Parameter # noqa from nanotron.fp8.tensor import FP8Tensor # noqa try: diff --git a/src/nanotron/fp8/parameter.py b/src/nanotron/fp8/parameter.py deleted file mode 100644 index 3e736724..00000000 --- a/src/nanotron/fp8/parameter.py +++ /dev/null @@ -1,115 +0,0 @@ -from typing import Optional, Union - -import torch -from torch import nn - -from nanotron import constants -from nanotron.fp8.constants import FP8_DTYPES, FP8LM_RECIPE, INITIAL_AMAX, INITIAL_SCALING_FACTOR -from nanotron.fp8.dtypes import DTypes -from nanotron.fp8.meta import FP8Meta -from nanotron.fp8.tensor import FP8Tensor, update_scaling_factor - -# from nanotron.config.fp8_config import FP8Args - - -class FP8Parameter(nn.Parameter): - """ - A custom FP8 parameter class that allows - fp8 gradients (which are integer tensors) - to flow into FP8 tensors. - """ - - def __new__(cls, data: torch.Tensor, dtype: DTypes, requires_grad: bool = True, interval: int = 1) -> nn.Parameter: - assert isinstance(data, torch.Tensor), "data must be a tensor" - assert data.dtype not in FP8_DTYPES, "Currently only support turn a non-fp8 tensor to an fp8 parameter" - assert data.device != torch.device("cpu"), "FP8Parameter only supports CUDA tensors" - - with torch.no_grad(): - from typing import cast - - if constants.CONFIG is None: - sync_amax_in_weight = False - else: - fp8_config = cast(FP8Args, constants.CONFIG.fp8) - sync_amax_in_weight = fp8_config.sync_amax_in_weight - - # TODO(xrsrke): support take an FP8 Tensor as data - # currently we can't only quantize a tensor to FP8 after the parameter is created - # because it raise "Only Tensors of floating point and complex dtype can require gradients" - # TODO(xrsrke): delete this fp32 tensor from memory after quantization - self = torch.Tensor._make_subclass(cls, data, requires_grad) - self._data = FP8Tensor(data, dtype=dtype, interval=interval, sync=sync_amax_in_weight) - # TODO(xrsrke): don't store fp32 raw data in memory after quantization - - # if constants.ITERATION_STEP == 1: - # self.orig_data = data.data - - # TODO(xrsrke): don't fixed these, take it from the FP8 recipe - fp8e4m3_scale = update_scaling_factor( - amax=torch.tensor(INITIAL_AMAX, dtype=torch.float32), - scaling_factor=torch.tensor(INITIAL_SCALING_FACTOR), - dtype=DTypes.FP8E4M3, - ) - fp8e5m2_scale = update_scaling_factor( - amax=torch.tensor(INITIAL_AMAX, dtype=torch.float32), - scaling_factor=torch.tensor(INITIAL_SCALING_FACTOR, dtype=torch.float32), - dtype=DTypes.FP8E5M2, - ) - - # TODO(xrsrke): add type hints of fp8_grad_meta to FP8Parameter - self.fp8_grad_meta = FP8GradMeta( - input_grad=FP8Meta( - amax=INITIAL_AMAX, - dtype=DTypes.FP8E4M3, - scale=fp8e4m3_scale, - interval=FP8LM_RECIPE.linear.input_grad.interval, - ), - # TODO(xrsrke): change weight_grad to data_grad - # because this is the gradient of the parameter itself - weight_grad=FP8Meta( - amax=INITIAL_AMAX, - dtype=DTypes.FP8E4M3, - scale=fp8e4m3_scale, - interval=FP8LM_RECIPE.linear.weight_grad.interval, - ), - # kfloat8_e5m2 - output_grad=FP8Meta( - amax=INITIAL_AMAX, - dtype=DTypes.FP8E5M2, - scale=fp8e5m2_scale, - interval=FP8LM_RECIPE.linear.output_grad.interval, - ), - ) - self._grad = None - - return self - - @property - def data(self) -> FP8Tensor: - return self._data - - @data.setter - def data(self, data: FP8Tensor): - self._data = data - - # # NOTE: because pytorch don't allow to assign an int grad to a tensor - # # so we bypass it by using a property - @property - def grad(self) -> Optional[Union[torch.Tensor, FP8Tensor]]: - return self.data._grad - # return self.data.grad - - @grad.setter - def grad(self, value: Optional[Union[torch.Tensor, FP8Tensor]]): - self.data._grad = value - - @property - def dtype(self) -> torch.dtype: - return self._data.dtype - - @property - def fp8_meta(self) -> FP8Meta: - return self.data.fp8_meta - - def __repr__(self) -> str: - return f"FP8Parameter({self.data}, fp8_meta={self.fp8_meta}, requires_grad={self.requires_grad}, fp8_grad_meta={self.fp8_grad_meta})" diff --git a/src/nanotron/parallel/parameters.py b/src/nanotron/parallel/parameters.py index 28995e98..5a8bb684 100644 --- a/src/nanotron/parallel/parameters.py +++ b/src/nanotron/parallel/parameters.py @@ -243,7 +243,6 @@ def __repr__(self): @property def data(self): - # from nanotron.fp8.parameter import FP8Parameter return self._data @data.setter diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 5cf2e990..b6756be4 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -216,12 +216,6 @@ def __init__( if hasattr(p, "_is_future_fp8") and p._is_future_fp8 is True: constants.CPU_WEIGHTS[n.replace("module.", "")] = p.data.cpu().clone() - # NOTE: sanity check all hash are different - # param_hash = [] - # for p in self.model.parameters(): - # assert hash(p) not in param_hash - # param_hash.append(hash(p)) - # NOTE: if we cast model to FP8 before wrapping it with NanotronParameter, # then we can create a NanotronParameter that has dtype=[torch.int8, torch.uint8] # which then it allows us to assign [torch.int8, torch.uint8] gradients to the parameter @@ -231,7 +225,6 @@ def __init__( # Please ensure that the gradient and the tensor have the same dtype" # NOTE: the reason that we cast after initializing the optimizer is that # we want to create some master weights for fp8 parameters, before quantizing them - if self.config.model.dtype == torch.int8: self.model = convert_model_to_fp8(self.model, config=self.config) diff --git a/tests/fp8/_test_fp8_parameter.py b/tests/fp8/_test_fp8_parameter.py deleted file mode 100644 index 1d464615..00000000 --- a/tests/fp8/_test_fp8_parameter.py +++ /dev/null @@ -1,151 +0,0 @@ -import pytest -import torch -from nanotron.constants import CHECKPOINT_VERSION -from nanotron.fp8.constants import FP8_DTYPES -from nanotron.fp8.dtypes import DTypes -from nanotron.fp8.meta import FP8Meta -from nanotron.fp8.parameter import FP8Parameter -from nanotron.fp8.tensor import FP8Tensor -from nanotron.parallel import ParallelContext -from nanotron.parallel.parameters import NanotronParameter -from nanotron.parallel.sharded_parameters import SplitConfig, create_sharded_parameter_from_config -from nanotron.serialize.metadata import TensorMetadata -from nanotron.testing.parallel import init_distributed, rerun_if_address_is_in_use -from torch import nn - - -def create_sharded_fp8_parameter(param: nn.Parameter, parallel_context: ParallelContext): - split_config = SplitConfig( - split_dim=0, - contiguous_chunks=(8, 8), - ) - param = create_sharded_parameter_from_config(parameter=param, pg=parallel_context.tp_pg, split_config=split_config) - return param - - -@pytest.mark.parametrize("dtype", [DTypes.FP8E4M3, DTypes.FP8E5M2]) -def test_create_fp8_parameter(dtype): - tensor = torch.randn(16, 16, device="cuda", dtype=torch.float32) - - fp8_parameter = FP8Parameter(tensor, dtype) - - assert isinstance(fp8_parameter.data, FP8Tensor) - assert fp8_parameter.requires_grad is True - assert fp8_parameter.grad is None - assert fp8_parameter.dtype in FP8_DTYPES - - assert isinstance(fp8_parameter.fp8_meta, FP8Meta) - assert isinstance(fp8_parameter.data.fp8_meta, FP8Meta) - assert fp8_parameter.data.fp8_meta is fp8_parameter.fp8_meta - - -def test_fp8_parameter_grad_metadata(): - GRAD_META = ["input_grad", "weight_grad", "output_grad"] - tensor = torch.randn(16, 16, device="cuda", dtype=torch.float32) - fp8_parameter = FP8Parameter(tensor, DTypes.FP8E4M3) - - assert all(hasattr(fp8_parameter.fp8_grad_meta, attr) for attr in GRAD_META) - assert all(isinstance(getattr(fp8_parameter.fp8_grad_meta, attr), FP8Meta) for attr in GRAD_META) - - -@pytest.mark.parametrize("dtype", [DTypes.FP8E4M3, DTypes.FP8E5M2]) -@pytest.mark.parametrize("grad_dtype", [DTypes.FP8E4M3, DTypes.FP8E5M2]) -def test_setting_fp8_gradient_to_fp8_parameter(dtype, grad_dtype): - fp8_parameter = FP8Parameter(torch.randn(16, 16, device="cuda"), dtype) - fp8_grad = FP8Tensor(torch.randn(16, 16, device="cuda"), dtype=grad_dtype) - - fp8_parameter.grad = fp8_grad - - assert torch.equal(fp8_parameter.grad, fp8_parameter.data.grad) - assert id(fp8_parameter.grad) == id(fp8_parameter.data.grad) - assert fp8_parameter.grad.data_ptr() == fp8_parameter.data.grad.data_ptr() - - -@pytest.mark.parametrize("dtype", [DTypes.FP8E4M3, DTypes.FP8E5M2]) -def test_fp8_parameter_storage_memory(dtype): - data = torch.randn(16, 16, device="cuda", dtype=torch.float32) - fp8_parameter = FP8Parameter(data, dtype) - - assert id(fp8_parameter.data) != id(data) - assert fp8_parameter.data_ptr() == data.data_ptr() - assert fp8_parameter.data.data_ptr() != data.data_ptr() - - -@pytest.mark.parametrize("dtype", [DTypes.FP8E4M3, DTypes.FP8E5M2]) -def test_set_data_in_fp8_parameter(dtype): - data = torch.randn(16, 16, device="cuda", dtype=torch.float32) - fp8_parameter = FP8Parameter(data, dtype) - - new_data = torch.randn(16, 16, device="cuda", dtype=torch.float32) - new_fp8_data = FP8Tensor(new_data, dtype=dtype) - - fp8_parameter.data = new_fp8_data - - assert fp8_parameter.data is new_fp8_data - assert torch.equal(fp8_parameter.data, new_fp8_data) - assert fp8_parameter.data.data_ptr() == new_fp8_data.data_ptr() - - assert fp8_parameter.fp8_meta is new_fp8_data.fp8_meta - - -@pytest.mark.parametrize("dtype", [DTypes.FP8E4M3, DTypes.FP8E5M2]) -def test_set_gradient_in_fp8_parameter(dtype): - data = torch.randn(16, 16, device="cuda", dtype=torch.float32) - fp8_parameter = FP8Parameter(data, dtype) - - grad = torch.randn(16, 16, device="cuda", dtype=torch.float32) - fp8_grad = FP8Tensor(grad, dtype=dtype) - - fp8_parameter.grad = fp8_grad - - assert fp8_parameter.grad is fp8_grad - assert torch.equal(fp8_parameter.grad, fp8_grad) - assert fp8_parameter.grad.data_ptr() == fp8_grad.data_ptr() - - assert fp8_parameter.data.grad is fp8_parameter.grad - - -@pytest.mark.parametrize("dtype", [DTypes.FP8E4M3, DTypes.FP8E5M2]) -@rerun_if_address_is_in_use() -def test_create_sharded_fp8_parameter(dtype): - init_distributed(tp=2, dp=1, pp=1)(_test_create_sharded_fp8_parameter)(dtype=dtype) - - -def _test_create_sharded_fp8_parameter(parallel_context: ParallelContext, dtype: DTypes): - data = torch.randn(16, 64, device="cuda") - param = FP8Parameter(data, dtype) - - param = create_sharded_fp8_parameter(param, parallel_context) - sharded_info = param.get_sharded_info() - - assert isinstance(param, NanotronParameter) - assert isinstance(param.data, FP8Tensor) - assert isinstance(param.data.fp8_meta, FP8Meta) - - metadata = TensorMetadata( - version=CHECKPOINT_VERSION, - local_global_slices_pairs=sharded_info.local_global_slices_pairs, - unsharded_shape=sharded_info.unsharded_shape, - ) - metadata_str_dict = metadata.to_str_dict() - # Assert metadata_str_dict is Dict[str, str] - assert isinstance(metadata_str_dict, dict) - assert all(isinstance(key, str) for key in metadata_str_dict.keys()) - assert all(isinstance(value, str) for value in metadata_str_dict.values()) - - metadata_from_str_dict = TensorMetadata.from_str_dict(metadata_str_dict) - assert metadata == metadata_from_str_dict - - parallel_context.destroy() - - -# TODO(xrsrke): add test for preventing torch autograd do the backward pass -# on a FP8Parameter - -# TODO(xrsrke): test CPU parameter - - -# TODO(xrsrke): test convert model to FP8 -# include the FP8's NanotronParameter's dtype and requires_grad - -# TODO(xrsrke): test set FP8 gradients to FP8 NanotronParameter diff --git a/tests/fp8/_test_linear.py b/tests/fp8/_test_linear.py index 4c53e00d..05beef4b 100644 --- a/tests/fp8/_test_linear.py +++ b/tests/fp8/_test_linear.py @@ -6,7 +6,6 @@ from nanotron.fp8.constants import FP8_DTYPES, QTYPE_TO_DTYPE from nanotron.fp8.dtypes import DTypes from nanotron.fp8.linear import FP8Linear, FP8LinearMeta -from nanotron.fp8.parameter import FP8Parameter from nanotron.fp8.recipe import FP8LinearRecipe from nanotron.fp8.tensor import FP8Tensor, convert_tensor_from_fp8 from nanotron.fp8.utils import convert_linear_to_fp8, convert_to_fp8_module, is_overflow_underflow_nan