Skip to content

Commit

Permalink
fix scaling
Browse files Browse the repository at this point in the history
  • Loading branch information
mht-sharma committed Dec 18, 2024
1 parent 988c1dc commit bffccdd
Showing 1 changed file with 43 additions and 29 deletions.
72 changes: 43 additions & 29 deletions server/text_generation_server/layers/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,43 +63,50 @@ def normalize_e4m3fn_to_e4m3fnuz(
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
assert weight.dtype == torch.float8_e4m3fn
# The bits pattern 10000000(-128) represents zero in e4m3fn
# but NaN in e4m3fnuz. So here we set it to 0.
# https://onnx.ai/onnx/technical/float8.html
weight_as_int8 = weight.view(torch.int8)
ROCM_FP8_NAN_AS_INT = -128
weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
weight = weight_as_int8.view(torch.float8_e4m3fnuz)

# For the same bits representation, e4m3fnuz value is half of
# the e4m3fn value, so we should double the scaling factor to
# get the same dequantized value.
# https://onnx.ai/onnx/technical/float8.html
weight_scale = weight_scale * 2.0
if input_scale is not None:
input_scale = input_scale * 2.0
if weight.dtype == torch.float8_e4m3fn:
# The bits pattern 10000000(-128) represents zero in e4m3fn
# but NaN in e4m3fnuz. So here we set it to 0.
# https://onnx.ai/onnx/technical/float8.html
weight_as_int8 = weight.view(torch.int8)
ROCM_FP8_NAN_AS_INT = -128
weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
weight = weight_as_int8.view(torch.float8_e4m3fnuz)

# For the same bits representation, e4m3fnuz value is half of
# the e4m3fn value, so we should double the scaling factor to
# get the same dequantized value.
# https://onnx.ai/onnx/technical/float8.html
weight_scale = weight_scale * 2.0
if input_scale is not None:
input_scale = input_scale * 2.0
return weight, weight_scale, input_scale


def per_tensor_dequantize(
tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor]
tensor: torch.Tensor,
inv_scale: Union[float, torch.Tensor],
dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
fake_qweight = tensor.to(torch.float16)
fake_qweight = tensor.to(dtype)
dq_weight = fake_qweight * inv_scale
return dq_weight


def requantize_with_max_scale(
weight: torch.Tensor, weight_scale: torch.Tensor, logical_widths: int
weight: torch.Tensor,
weight_scale: torch.Tensor,
logical_widths: int,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Max scale to be used for requanitzation.
max_w_scale = weight_scale.max().float()

start = 0
for idx, logical_width in enumerate(logical_widths):
end = start + logical_width
weight_dq = per_tensor_dequantize(weight[start:end, :], weight_scale[idx])
weight_dq = per_tensor_dequantize(
weight[start:end, :], weight_scale[idx], dtype
)
weight[start:end, :], max_w_scale_normalized = fp8_quantize(
weight_dq, max_w_scale
)
Expand All @@ -112,7 +119,7 @@ def fp8_quantize(
weight: torch.Tensor,
scale: Optional[torch.Tensor] = None,
scale_upper_bound: Optional[torch.Tensor] = None,
qdtype: torch.dtype = quant_dtype,
qdtype: torch.dtype = torch.float8_e4m3fn,
scalar: bool = False,
):
"""
Expand All @@ -125,7 +132,7 @@ def fp8_quantize(
shape = weight.shape
qweight, scale = marlin_kernels.scaled_fp8_quant(
weight.reshape(-1, shape[-1]),
dtype=qdtype,
dtype=quant_dtype,
scale=scale,
scale_ub=scale_upper_bound,
# TODO: don't do this when we have to use the Torch kernel.
Expand All @@ -145,6 +152,8 @@ def fp8_quantize(
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
scale = scale.float().reciprocal()
else:
if SYSTEM == "rocm":
scale = scale / 2.0
# Use reciprocal to avoid more expensive division.
qweight = (weight * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max)

Expand Down Expand Up @@ -263,12 +272,6 @@ def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: in
]
scale = torch.cat(scale, dim=0).reshape(-1)

if scale.numel() == len(prefixes):
logical_widths = [x[0] for x in shapes]
w, scale = requantize_with_max_scale(
w, scale.to(weights.device), logical_widths
)

input_scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
for p, shape in zip(prefixes, shapes)
Expand All @@ -281,6 +284,17 @@ def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: in
else None
)

if SYSTEM == "rocm":
w, scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
w, scale, input_scale
)

if scale.numel() == len(prefixes):
logical_widths = [x[0] for x in shapes]
w, scale = requantize_with_max_scale(
w, scale.to(weights.device), logical_widths, weights.dtype
)

return Fp8Weight(
weight=w,
weight_scale=scale,
Expand Down Expand Up @@ -366,7 +380,7 @@ def __init__(
if CUTLASS_FP8_AVAILABLE:
log_once(logger.info, "Using cutlass w8a8 kernels")
if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(
qweight, scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=qweight, weight_scale=scale
)

Expand Down

0 comments on commit bffccdd

Please sign in to comment.