From afdfbf16514402d7cee556bfb07595a46a986180 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Fri, 29 Nov 2024 17:43:46 +0000 Subject: [PATCH] fix hanging due to NanotronParameter.__repr__ (param.data == NanotronParameter) --- src/nanotron/fp8/utils.py | 8 +++++++- src/nanotron/parallel/parameters.py | 13 +++++++------ tests/fp8/test_fp8_model.py | 15 +++++++++------ 3 files changed, 23 insertions(+), 13 deletions(-) diff --git a/src/nanotron/fp8/utils.py b/src/nanotron/fp8/utils.py index 6fe6eac1..ac425f1e 100644 --- a/src/nanotron/fp8/utils.py +++ b/src/nanotron/fp8/utils.py @@ -375,8 +375,14 @@ def convert_model_to_fp8(model: NanotronModel, config: FP8Args) -> NanotronModel # NOTE: convert it to the residual stream's dtype # for p in module.parameters(): # p.data = p.data.to(self.config.model.dtype) - module.to(dtype=config.resid_dtype) + # for p in module.parameters(): + # p.data = p.data.to(dtype=config.resid_dtype) if p.data + # pass + # assert module.weight.data.__class__ == torch.Tensor + # module.to(dtype=config.resid_dtype) # pass # assert module.weight.data.__class__ == torch.Tensor + # NOTE: this causes param.data == NanotronParameter + assert config.resid_dtype == torch.float32, "not support datatype conversion, because of error 8" return model diff --git a/src/nanotron/parallel/parameters.py b/src/nanotron/parallel/parameters.py index f19eba9a..67a9ac7b 100644 --- a/src/nanotron/parallel/parameters.py +++ b/src/nanotron/parallel/parameters.py @@ -264,8 +264,9 @@ def is_sharded(self) -> bool: self, self.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME ) - # def __repr__(self): - # return f"NanotronParameter({super().__repr__()})" + def __repr__(self): + # return f"NanotronParameter({super().__repr__()})" + return "NanotronParameter()" @property def data(self): @@ -293,13 +294,13 @@ def data(self, data): def __torch_dispatch__(cls, func, types, args=(), kwargs=None): from nanotron.fp8.tensor import FP8Tensor - print(f"__torch_dispatch__ called with func: {func}, args: {args}, kwargs: {kwargs}") + # print(f"__torch_dispatch__ called with func: {func}, args: {args}, kwargs: {kwargs}") - if func in {torch._tensor_str._str, repr}: - return super().__torch_dispatch__(func, types, args, kwargs) + # if func in {torch._tensor_str._str, repr}: + # return super().__torch_dispatch__(func, types, args, kwargs) def unwrap(e): - print(f"Unwrapping: {e} (type: {type(e)})") + # print(f"Unwrapping: {e} (type: {type(e)})") return e._data if e.__class__ == NanotronParameter else e def wrap(e): diff --git a/tests/fp8/test_fp8_model.py b/tests/fp8/test_fp8_model.py index 61d9dfbc..41cff327 100644 --- a/tests/fp8/test_fp8_model.py +++ b/tests/fp8/test_fp8_model.py @@ -44,12 +44,15 @@ def _test_initialize_fp8_model(parallel_context: ParallelContext, fp8_config: FP assert all( p.dtype == fp8_config.resid_dtype for p in module.parameters() ), f"name: {name}, __class__: {module.weight.data.__class__}" - try: - assert all( - p.data.__class__ == nn.Parameter for p in module.parameters() - ), f"name: {name}, __class__: {module.weight.data.__class__}" - except: - assert 1 == 1 + # try: + # assert all( + # p.data.__class__ == nn.Parameter for p in module.parameters() + # ), f"name: {name}, __class__: {module.weight.data.__class__}" + # except: + # assert 1 == 1 + assert all( + p.data.__class__ == nn.Parameter for p in module.parameters() + ), f"name: {name}, __class__: {module.weight.data.__class__}" else: assert all( isinstance(p.data.__class__, FP8Tensor) for p in module.parameters()