Skip to content

Commit

Permalink
add docs
Browse files Browse the repository at this point in the history
  • Loading branch information
mht-sharma committed Jun 24, 2024
1 parent f0d95b0 commit 8a0bb53
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
title: Monitoring TGI with Prometheus and Grafana
- local: basic_tutorials/train_medusa
title: Train Medusa
- local: basic_tutorials/fp8_kv_cache
title: Accelerating Inference with FP8 KV Cache
title: Tutorials
- sections:
- local: conceptual/streaming
Expand Down
61 changes: 61 additions & 0 deletions docs/source/basic_tutorials/fp_kv_cache.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Accelerating Inference with FP8 KV Cache

Text Generation Inference (TGI) now supports FP8 KV Cache, enhancing inference speed on both Nvidia and AMD GPUs. This feature significantly boosts performance and memory efficiency, enabling faster and more scalable text generation.

By quantizing the KV cache to 8-bit floating point (FP8) formats, we can greatly reduce the memory footprint. This reduction allows for:
* Increased token storage capacity in the cache
* Improved throughput in text generation tasks
* More efficient GPU memory utilization

## FP8 Formats: E4M3 and E5M2
The Open Compute Project (OCP) defines two common 8-bit floating point data formats:

E4M3:

* 1 sign bit
* 4 biased exponent bits
* 3 mantissa bits

E5M2:

* 1 sign bit
* 5 biased exponent bits
* 2 mantissa bits

E4M3 offers higher precision for representing floating point numbers. However, due to its limited range, E4M3 typically requires a higher-precision (usually FP32) scaling factor alongside each quantized tensor. Currently, TGI supports only per-tensor (scalar) scaling factors.

## Current Hardware Support

* Nvidia GPUs: Supports both FP8E4M3 and FP8E5M2.
* AMD GPUs: Supports FP8E4M3.

## FP8 E5M2 KV Cache
Example usage:
```
text-generation-launcher --model-id <> --kv-cache-dtype fp8_e5m2
```

## FP8 E4M3 KV Cache
While E4M3 offers higher precision, it requires careful handling of scaling factors to maintain accuracy. Therefore, it is recommended to provide KV cache scaling factors as part of the FP16 checkpoint. If scaling factors are not provided, a default factor of 1.0 is used, which may lead to accuracy loss.

Example usage:
```
text-generation-launcher --model-id <> --kv-cache-dtype fp8
```

### Checkpoint structure for KV scales
The FP8 kv cache scaling factors required in the FP16 checkpoints are specified through the .kv_scale parameter present on the `Attention` module, such as:

```
model.layers.0.self_attn.kv_scale < F32
model.layers.1.self_attn.kv_scale < F32
...
```

### Generating model with KV Cache scales

Use [AutoFP8](https://github.com/neuralmagic/AutoFP8) with calibration data to generate per-tensor scales for FP8 quantized KV Cache. For more details, see the following example: https://github.com/neuralmagic/AutoFP8/blob/main/example_dataset.py

TGI provides a utility to extract the FP8 KV cache scales from an `AutoFP8` quantized model and save them to the FP16 model for use with TGI. For more information: <path to script>

Alternatively, you can use other quantizer tools, such as Nvidia AMMO, to obtain these scaling factors.
40 changes: 40 additions & 0 deletions examples/fp8_kvcache/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# FP8 (fp8_e4m3) KV Cache Scaling Factor Extraction Utility

This utility is designed to extract KV cache scaling factors from a quantized `FP8(fp8_e4m3)` Hugging Face (HF) model. The extracted scaling factors are then saved to the corresponding unquantized HF model, which can be used with Text Generation Inference (TGI).

Note: This tool specifically works with models quantized using the [AutoFP8](https://github.com/neuralmagic/AutoFP8/tree/main) repository.

The KV scales are integrated into the unquantized HF model in the following format. The FP8 KV cache scaling factors are added to the FP16 checkpoints and specified through the .kv_scale parameter within the Attention module, as shown below:

```
model.layers.0.self_attn.kv_scale < F32
model.layers.1.self_attn.kv_scale < F32
...
```

## Prerequisites

- text-generation-server
- AutoFP8

## CLI options
```
usage: extract_fp8_kv_scales.py [-h] [--quantized-model QUANTIZED_MODEL] [--model MODEL] [--save-path SAVE_PATH]
Extract FP8 KV cache scales and add them to a FP16 model.
options:
-h, --help show this help message and exit
--quantized-model QUANTIZED_MODEL
Path to the FP8 model checkpoint to extract KV cache scales
--model MODEL Model ID of the FP16 model to save the KV cache scales
--save-path SAVE_PATH
Path to save the FP16 model with the kv scales
```

## Example usage
To extract KV cache scaling factors from a quantized FP8 model and save them to an unquantized FP16 model, use the following command:

```
python extract_fp8_kv_scales.py --quantized-model neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV --model meta-llama/Meta-Llama-3-8B-Instruct --save-path Meta-Llama-3-8B-Instruct
```
97 changes: 97 additions & 0 deletions examples/fp8_kvcache/extract_fp8_kv_scales.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from pathlib import Path
from text_generation_server.utils.hub import (
weight_files,
download_weights,
weight_hub_files,
)
from safetensors import safe_open
import argparse


def load_model(ckpt_path):
model_args = {"torch_dtype": "auto"}

model = AutoModelForCausalLM.from_pretrained(
ckpt_path, device_map="auto", **model_args, trust_remote_code=True
)
model.eval()

tokenizer = AutoTokenizer.from_pretrained(ckpt_path)

return model, tokenizer


def set_nested_attribute(obj, attribute_path, value):
keys = attribute_path.split(".")
current_obj = obj
for key in keys[:-1]:
current_obj = getattr(current_obj, key)
setattr(current_obj, keys[-1], value)


def apply_kv_scales_to_model(model, layer_scales_map):
for layer_name, scale_value in layer_scales_map.items():
scale_param = torch.nn.Parameter(torch.tensor(scale_value), requires_grad=False)
set_nested_attribute(model, layer_name, scale_param)


def extract_kv_scales(quantized_model):
def fetch_parameters(filename):
with safe_open(filename, framework="pt") as f:
for name in f.keys():
param_tensor = f.get_tensor(name)
yield name, param_tensor

checkpoint_dir = Path(quantized_model)
if not checkpoint_dir.is_dir():
hub_filenames = weight_hub_files(quantized_model)
downloaded_files = download_weights(hub_filenames, quantized_model)
downloaded_files = weight_files(quantized_model, extension=".safetensors")

layer_scales_map = {}
for tensor_file in downloaded_files:
for name, param in fetch_parameters(tensor_file):
if ".kv_scale" in name:
layer_scales_map[name] = param.item()

return layer_scales_map


def main(quantized_model, model_id, save_path):
layer_scales_map = extract_kv_scales(quantized_model)

model, tokenizer = load_model(model_id)

apply_kv_scales_to_model(model, layer_scales_map)

model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)

print(f"Model saved to {save_path}")


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Extract FP8 KV cache scales and add them to a FP16 model."
)
parser.add_argument(
"--quantized-model",
type=str,
help="Path to the FP8 model checkpoint to extract KV cache scales",
)
parser.add_argument(
"--model",
type=str,
help="Model ID of the FP16 model to save the KV cache scales",
)
parser.add_argument(
"--save-path",
type=str,
help="Path to save the FP16 model with the kv scales",
)

args = parser.parse_args()

main(args.quantized_model, args.model, args.save_path)

0 comments on commit 8a0bb53

Please sign in to comment.