diff --git a/src/nanotron/fp8/meta.py b/src/nanotron/fp8/meta.py index eed45ea2..cba94bc2 100644 --- a/src/nanotron/fp8/meta.py +++ b/src/nanotron/fp8/meta.py @@ -38,7 +38,7 @@ class FP8Meta: sync_amax: bool = False @property - def te_dtype(self) -> tex.DType: + def te_dtype(self) -> "tex.DType": from nanotron.fp8.tensor import convert_torch_dtype_to_te_dtype return convert_torch_dtype_to_te_dtype(self.dtype) diff --git a/src/nanotron/fp8/tensor.py b/src/nanotron/fp8/tensor.py index f3e8bcdb..989f33e9 100644 --- a/src/nanotron/fp8/tensor.py +++ b/src/nanotron/fp8/tensor.py @@ -197,7 +197,7 @@ def _quantize(tensor: torch.Tensor, fp8_meta: "FP8Meta") -> torch.Tensor: return (tensor * fp8_meta.scale).to(torch.float16) -def convert_torch_dtype_to_te_dtype(dtype: torch.dtype) -> tex.DType: +def convert_torch_dtype_to_te_dtype(dtype: torch.dtype) -> "tex.DType": # NOTE: transformer engine maintains it own dtype mapping # so we need to manually map torch dtypes to TE dtypes TORCH_DTYPE_TE_DTYPE_NAME_MAPPING = {