Skip to content

Commit

Permalink
refactor(packed): remove duplicate unpack_weights
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed Mar 8, 2024
1 parent e086c56 commit aabef99
Showing 1 changed file with 1 addition and 36 deletions.
37 changes: 1 addition & 36 deletions quanto/tensor/packed.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,41 +53,6 @@ def lshift(t: torch.Tensor, bits: int):
return packed


def unpack_weights(uint8weights: torch.Tensor, bits: int) -> torch.Tensor:
"""
Un-Pack int4 / int2 weights (packed in a uint8) into an expanded torch.uint8 tensor
What un-packing means? Assume we have packed 4 2-bit values in 8-bit
(because torch does not have native support for 2-bit datatypes)
> 1110 0100
Unpacking them means retrieving the original 4 2-bit values:
> 0000 0011 | 0000 0010 | 0000 0001 | 0000 0000
Args:
uint8weights (`torch.Tensor`):
The packed tensor in `torch.uint8` precision
bits (`int`):
The actual `bits` - can be 2, 4
"""
unpacked = []
values_per_item = 8 // bits

def rshift(t: torch.Tensor, bits: int):
if t.device.type == "mps":
# rshift is not supported on MPS device
return t // (2**bits)
return t >> bits

# Unpack each set of values independently
for i in range(values_per_item):
mask = 2 ** (bits * (i + 1)) - 1
unpacked.append(rshift(uint8weights & mask, bits * i))
# Return the concatenated unpacked tensors
return torch.cat(unpacked).to(torch.uint8)


class PackedTensor(torch.Tensor):
@staticmethod
def __new__(cls, data, bits, size, stride, requires_grad=False):
Expand Down Expand Up @@ -117,7 +82,7 @@ def pack(cls, t, bits=4):
return PackedTensor(data, bits, t.size(), t.stride())

def unpack(self):
unpacked_data = unpack_weights(self._data, self._bits)
unpacked_data = torch.ops.quanto.unpack(self._data, self._bits)
# Adjust the first dimension, as unpacked data may have extra rows if the original shape is not a multiple of 8 // bits
return unpacked_data[: self.shape[0]]

Expand Down

0 comments on commit aabef99

Please sign in to comment.