From 8b592524d5d42747e9bc6581e80367906c0cf49b Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 26 Aug 2024 18:28:37 +0200 Subject: [PATCH] fix: Enable non-strict loading of state dicts Resolves #278 PyTorch allows to load state dicts with they strict=False argument to ignore missing keys. This is now also supported in optimum-quanto. Before this fix, a KeyError would be raised. One context where this is important is for parameter-efficient fine-tuning adapters such as LoRA. There, we want to load only a small subset of parameters and leave the other model weights untouched. This requires non-strict loading. --- optimum/quanto/nn/qmodule.py | 10 ++++++--- optimum/quanto/tensor/qbits/packed.py | 6 +++++- optimum/quanto/tensor/qbits/qbits.py | 15 ++++++++++--- optimum/quanto/tensor/weights/qbytes.py | 13 ++++++++++-- .../test_quantized_model_for_causal_lm.py | 21 +++++++++++++++++++ 5 files changed, 56 insertions(+), 9 deletions(-) diff --git a/optimum/quanto/nn/qmodule.py b/optimum/quanto/nn/qmodule.py index 11ba373d..cceccacc 100644 --- a/optimum/quanto/nn/qmodule.py +++ b/optimum/quanto/nn/qmodule.py @@ -157,6 +157,7 @@ def _load_from_state_dict( if self.weight_qtype is not None and weight_name not in state_dict: # The weight Tensor is not present because it is a flattened QTensor weight_prefix = weight_name + "." + # note: deserialized_weight can be None if a key is missing in the state_dict if self.weight_qtype.bits == 8: deserialized_weight = WeightQBytesTensor.load_from_state_dict( state_dict, @@ -165,6 +166,7 @@ def _load_from_state_dict( axis=0, size=self.weight.size(), stride=self.weight.stride(), + missing_keys=missing_keys, ) else: deserialized_weight = QBitsTensor.load_from_state_dict( @@ -175,13 +177,15 @@ def _load_from_state_dict( group_size=self.weight_group_size, size=self.weight.size(), stride=self.weight.stride(), + missing_keys=missing_keys, ) - deserialized_weight = deserialized_weight.optimize() + if deserialized_weight is not None: + deserialized_weight = deserialized_weight.optimize() assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False) - if assign_to_params_buffers: + if assign_to_params_buffers and (deserialized_weight is not None): self.weight = torch.nn.Parameter(deserialized_weight) - else: + elif deserialized_weight is not None: if type(self.weight.data) is not type(deserialized_weight): # Reloading frozen weights into unfrozen module: move to the correct device and force assignment self.weight = torch.nn.Parameter(deserialized_weight.to(self.weight.device)) diff --git a/optimum/quanto/tensor/qbits/packed.py b/optimum/quanto/tensor/qbits/packed.py index 368d84af..56269679 100644 --- a/optimum/quanto/tensor/qbits/packed.py +++ b/optimum/quanto/tensor/qbits/packed.py @@ -111,7 +111,11 @@ def dtype(self): return torch.uint8 @staticmethod - def load_from_state_dict(state_dict, prefix, bits, size, stride): + def load_from_state_dict(state_dict, prefix, bits, size, stride, missing_keys): + if prefix + "_data" not in state_dict: + missing_keys.append(prefix + "_data") + return + inner_tensors_dict = {"_data": state_dict.pop(prefix + "_data")} meta = [name.replace(prefix, "") for name in state_dict.keys() if name.startswith(prefix)] meta = {"bits": str(bits), "size": str(list(size)), "stride": str(stride)} diff --git a/optimum/quanto/tensor/qbits/qbits.py b/optimum/quanto/tensor/qbits/qbits.py index f3c7326f..9217c940 100644 --- a/optimum/quanto/tensor/qbits/qbits.py +++ b/optimum/quanto/tensor/qbits/qbits.py @@ -165,7 +165,7 @@ def dequantize(self): return QBitsDequantizer.apply(self) @staticmethod - def load_from_state_dict(state_dict, prefix, qtype, axis, group_size, size, stride): + def load_from_state_dict(state_dict, prefix, qtype, axis, group_size, size, stride, missing_keys): if group_size is None: data_size = size data_stride = stride @@ -176,11 +176,20 @@ def load_from_state_dict(state_dict, prefix, qtype, axis, group_size, size, stri data_stride = (data_size[1], 1) inner_tensors_dict = { "_data": PackedTensor.load_from_state_dict( - state_dict, prefix + "_data.", qtype.bits, data_size, data_stride + state_dict, prefix + "_data.", qtype.bits, data_size, data_stride, missing_keys=missing_keys ) } + missing = inner_tensors_dict["_data"] is None for name in ["_scale", "_shift"]: - inner_tensors_dict[name] = state_dict.pop(prefix + name) + if prefix + name not in state_dict: + missing_keys.append(prefix + name) + missing = True + else: + inner_tensors_dict[name] = state_dict.pop(prefix + name) + + if missing: # could not deserialize because of missing keys + return None + meta = { "qtype": qtype.name, "axis": str(axis), diff --git a/optimum/quanto/tensor/weights/qbytes.py b/optimum/quanto/tensor/weights/qbytes.py index 9e216f18..e0592350 100644 --- a/optimum/quanto/tensor/weights/qbytes.py +++ b/optimum/quanto/tensor/weights/qbytes.py @@ -71,10 +71,19 @@ def quantize(cls, base: torch.Tensor, qtype: qtype, axis: int, scale: torch.Tens return WeightQBytesQuantizer.apply(base, qtype, axis, scale) @staticmethod - def load_from_state_dict(state_dict, prefix, qtype, axis, size, stride): + def load_from_state_dict(state_dict, prefix, qtype, axis, size, stride, missing_keys): inner_tensors_dict = {} + missing = False for name in ["_data", "_scale"]: - inner_tensors_dict[name] = state_dict.pop(prefix + name) + if prefix + name not in state_dict: + missing_keys.append(prefix + name) + missing = True + else: + inner_tensors_dict[name] = state_dict.pop(prefix + name) + + if missing: # could not deserialize because of missing keys + return None + meta = { "qtype": qtype.name, "axis": str(axis), diff --git a/test/models/test_quantized_model_for_causal_lm.py b/test/models/test_quantized_model_for_causal_lm.py index 302eb3cb..7366a90c 100644 --- a/test/models/test_quantized_model_for_causal_lm.py +++ b/test/models/test_quantized_model_for_causal_lm.py @@ -126,3 +126,24 @@ def test_causal_lm_base_push_to_hub(staging, in_org): compare_models(quantized, requantized) delete_repo(hub_repo_id, token=staging["token"]) + + +@pytest.mark.skipif(not is_transformers_available(), reason="requires transformers") +@pytest.mark.parametrize("model_id", ["facebook/opt-125m"]) +@pytest.mark.parametrize("qtype", [qint4, qint8], ids=["qint4", "qint8"]) +def test_quantized_model_load_state_dict_non_strict(model_id, qtype): + # see issue #278 + quantized = quantized_model_for_causal_lm(model_id, qtype, exclude=None) + sd = quantized.state_dict() + + # delete a key used by both qint4 and qint8 from the state dict + key = "model.decoder.layers.0.self_attn.k_proj.weight._scale" + del sd[key] + + # strict loading should raise a RuntimeError, which is what PyTorch does in this case + with pytest.raises(RuntimeError, match=key): + quantized.load_state_dict(sd) + + # non-strict loading should not raise an errror + result = quantized.load_state_dict(sd, strict=False) + assert result.missing_keys == [key]