diff --git a/quanto/tensor/packed.py b/quanto/tensor/packed.py index 442a1611..2c0beb63 100644 --- a/quanto/tensor/packed.py +++ b/quanto/tensor/packed.py @@ -53,41 +53,6 @@ def lshift(t: torch.Tensor, bits: int): return packed -def unpack_weights(uint8weights: torch.Tensor, bits: int) -> torch.Tensor: - """ - Un-Pack int4 / int2 weights (packed in a uint8) into an expanded torch.uint8 tensor - What un-packing means? Assume we have packed 4 2-bit values in 8-bit - (because torch does not have native support for 2-bit datatypes) - - > 1110 0100 - - Unpacking them means retrieving the original 4 2-bit values: - - > 0000 0011 | 0000 0010 | 0000 0001 | 0000 0000 - - Args: - uint8weights (`torch.Tensor`): - The packed tensor in `torch.uint8` precision - bits (`int`): - The actual `bits` - can be 2, 4 - """ - unpacked = [] - values_per_item = 8 // bits - - def rshift(t: torch.Tensor, bits: int): - if t.device.type == "mps": - # rshift is not supported on MPS device - return t // (2**bits) - return t >> bits - - # Unpack each set of values independently - for i in range(values_per_item): - mask = 2 ** (bits * (i + 1)) - 1 - unpacked.append(rshift(uint8weights & mask, bits * i)) - # Return the concatenated unpacked tensors - return torch.cat(unpacked).to(torch.uint8) - - class PackedTensor(torch.Tensor): @staticmethod def __new__(cls, data, bits, size, stride, requires_grad=False): @@ -117,7 +82,7 @@ def pack(cls, t, bits=4): return PackedTensor(data, bits, t.size(), t.stride()) def unpack(self): - unpacked_data = unpack_weights(self._data, self._bits) + unpacked_data = torch.ops.quanto.unpack(self._data, self._bits) # Adjust the first dimension, as unpacked data may have extra rows if the original shape is not a multiple of 8 // bits return unpacked_data[: self.shape[0]]