diff --git a/quanto/quantize.py b/quanto/quantize.py index 806dd01a..f939b7d4 100644 --- a/quanto/quantize.py +++ b/quanto/quantize.py @@ -23,14 +23,9 @@ def set_module_by_name(parent_module, name, child_module): if len(module_names) == 1: setattr(parent_module, name, child_module) else: - next_module = parent_module - for idx in range(len(module_names) - 1): - next_module_name = module_names[idx] - if next_module_name.isnumeric(): - next_module = next_module[int(next_module_name)] - else: - next_module = getattr(next_module, next_module_name) - setattr(next_module, module_names[-1], child_module) + parent_module_name = name[: name.rindex(".")] + parent_module = parent_module.get_submodule(parent_module_name) + setattr(parent_module, module_names[-1], child_module) def quantize(model, modules=None, **kwargs):