Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Enable non-strict loading of state dicts #295

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions optimum/quanto/nn/qmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def _load_from_state_dict(
if self.weight_qtype is not None and weight_name not in state_dict:
# The weight Tensor is not present because it is a flattened QTensor
weight_prefix = weight_name + "."
# note: deserialized_weight can be None if a key is missing in the state_dict
if self.weight_qtype.bits == 8:
deserialized_weight = WeightQBytesTensor.load_from_state_dict(
state_dict,
Expand All @@ -165,6 +166,7 @@ def _load_from_state_dict(
axis=0,
size=self.weight.size(),
stride=self.weight.stride(),
missing_keys=missing_keys,
)
else:
deserialized_weight = QBitsTensor.load_from_state_dict(
Expand All @@ -175,13 +177,15 @@ def _load_from_state_dict(
group_size=self.weight_group_size,
size=self.weight.size(),
stride=self.weight.stride(),
missing_keys=missing_keys,
)
deserialized_weight = deserialized_weight.optimize()
if deserialized_weight is not None:
deserialized_weight = deserialized_weight.optimize()

assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
if assign_to_params_buffers:
if assign_to_params_buffers and (deserialized_weight is not None):
self.weight = torch.nn.Parameter(deserialized_weight)
else:
elif deserialized_weight is not None:
if type(self.weight.data) is not type(deserialized_weight):
# Reloading frozen weights into unfrozen module: move to the correct device and force assignment
self.weight = torch.nn.Parameter(deserialized_weight.to(self.weight.device))
Expand Down
6 changes: 5 additions & 1 deletion optimum/quanto/tensor/qbits/packed.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,11 @@ def dtype(self):
return torch.uint8

@staticmethod
def load_from_state_dict(state_dict, prefix, bits, size, stride):
def load_from_state_dict(state_dict, prefix, bits, size, stride, missing_keys):
if prefix + "_data" not in state_dict:
missing_keys.append(prefix + "_data")
return

inner_tensors_dict = {"_data": state_dict.pop(prefix + "_data")}
meta = [name.replace(prefix, "") for name in state_dict.keys() if name.startswith(prefix)]
meta = {"bits": str(bits), "size": str(list(size)), "stride": str(stride)}
Expand Down
15 changes: 12 additions & 3 deletions optimum/quanto/tensor/qbits/qbits.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def dequantize(self):
return QBitsDequantizer.apply(self)

@staticmethod
def load_from_state_dict(state_dict, prefix, qtype, axis, group_size, size, stride):
def load_from_state_dict(state_dict, prefix, qtype, axis, group_size, size, stride, missing_keys):
if group_size is None:
data_size = size
data_stride = stride
Expand All @@ -176,11 +176,20 @@ def load_from_state_dict(state_dict, prefix, qtype, axis, group_size, size, stri
data_stride = (data_size[1], 1)
inner_tensors_dict = {
"_data": PackedTensor.load_from_state_dict(
state_dict, prefix + "_data.", qtype.bits, data_size, data_stride
state_dict, prefix + "_data.", qtype.bits, data_size, data_stride, missing_keys=missing_keys
)
}
missing = inner_tensors_dict["_data"] is None
for name in ["_scale", "_shift"]:
inner_tensors_dict[name] = state_dict.pop(prefix + name)
if prefix + name not in state_dict:
missing_keys.append(prefix + name)
missing = True
else:
inner_tensors_dict[name] = state_dict.pop(prefix + name)

if missing: # could not deserialize because of missing keys
return None

meta = {
"qtype": qtype.name,
"axis": str(axis),
Expand Down
13 changes: 11 additions & 2 deletions optimum/quanto/tensor/weights/qbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,19 @@ def quantize(cls, base: torch.Tensor, qtype: qtype, axis: int, scale: torch.Tens
return WeightQBytesQuantizer.apply(base, qtype, axis, scale)

@staticmethod
def load_from_state_dict(state_dict, prefix, qtype, axis, size, stride):
def load_from_state_dict(state_dict, prefix, qtype, axis, size, stride, missing_keys):
inner_tensors_dict = {}
missing = False
for name in ["_data", "_scale"]:
inner_tensors_dict[name] = state_dict.pop(prefix + name)
if prefix + name not in state_dict:
missing_keys.append(prefix + name)
missing = True
else:
inner_tensors_dict[name] = state_dict.pop(prefix + name)

if missing: # could not deserialize because of missing keys
return None

meta = {
"qtype": qtype.name,
"axis": str(axis),
Expand Down
21 changes: 21 additions & 0 deletions test/models/test_quantized_model_for_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,24 @@ def test_causal_lm_base_push_to_hub(staging, in_org):
compare_models(quantized, requantized)

delete_repo(hub_repo_id, token=staging["token"])


@pytest.mark.skipif(not is_transformers_available(), reason="requires transformers")
@pytest.mark.parametrize("model_id", ["facebook/opt-125m"])
@pytest.mark.parametrize("qtype", [qint4, qint8], ids=["qint4", "qint8"])
def test_quantized_model_load_state_dict_non_strict(model_id, qtype):
# see issue #278
quantized = quantized_model_for_causal_lm(model_id, qtype, exclude=None)
sd = quantized.state_dict()

# delete a key used by both qint4 and qint8 from the state dict
key = "model.decoder.layers.0.self_attn.k_proj.weight._scale"
del sd[key]

# strict loading should raise a RuntimeError, which is what PyTorch does in this case
with pytest.raises(RuntimeError, match=key):
quantized.load_state_dict(sd)

# non-strict loading should not raise an errror
result = quantized.load_state_dict(sd, strict=False)
assert result.missing_keys == [key]
Loading