Skip to content

Commit

Permalink
fix style
Browse files Browse the repository at this point in the history
  • Loading branch information
mht-sharma committed Jun 24, 2024
1 parent 8a0bb53 commit 557e18e
Show file tree
Hide file tree
Showing 12 changed files with 47 additions and 14 deletions.
4 changes: 2 additions & 2 deletions Dockerfile_amd
Original file line number Diff line number Diff line change
Expand Up @@ -213,5 +213,5 @@ FROM base-copy
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh

ENTRYPOINT ["/tgi-entrypoint.sh"]
CMD ["--json-output"]
# ENTRYPOINT ["/tgi-entrypoint.sh"]
# CMD ["--json-output"]
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,6 @@ run-falcon-7b-instruct-quantize:

clean:
rm -rf target aml

interact:
docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 64g --net host -v /home/mohit/.cache/huggingface/hub/:/data -v $(PWD):/tgi tgi-mht
2 changes: 1 addition & 1 deletion docs/source/basic_tutorials/fp_kv_cache.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,4 @@ Use [AutoFP8](https://github.com/neuralmagic/AutoFP8) with calibration data to g

TGI provides a utility to extract the FP8 KV cache scales from an `AutoFP8` quantized model and save them to the FP16 model for use with TGI. For more information: <path to script>

Alternatively, you can use other quantizer tools, such as Nvidia AMMO, to obtain these scaling factors.
Alternatively, you can use other quantizer tools, such as Nvidia AMMO, to obtain these scaling factors.
4 changes: 2 additions & 2 deletions docs/source/basic_tutorials/launcher.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ Options:
--hostname <HOSTNAME>
The IP address to listen on

[env: HOSTNAME=hf-amd-mi250-dev]
[env: HOSTNAME=]
[default: 0.0.0.0]

```
Expand Down Expand Up @@ -279,7 +279,7 @@ Options:
--huggingface-hub-cache <HUGGINGFACE_HUB_CACHE>
The location of the huggingface hub cache. Used to override the location if you want to provide a mounted disk for instance

[env: HUGGINGFACE_HUB_CACHE=/data]
[env: HUGGINGFACE_HUB_CACHE=]

```
## WEIGHTS_CACHE_OVERRIDE
Expand Down
2 changes: 1 addition & 1 deletion examples/fp8_kvcache/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@ To extract KV cache scaling factors from a quantized FP8 model and save them to

```
python extract_fp8_kv_scales.py --quantized-model neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV --model meta-llama/Meta-Llama-3-8B-Instruct --save-path Meta-Llama-3-8B-Instruct
```
```
6 changes: 4 additions & 2 deletions server/text_generation_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,12 @@ def serve(
raise RuntimeError(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
)

if kv_cache_dtype in {"fp8", "fp8_e5m2"}:
if SYSTEM not in {"cuda", "rocm"}:
raise RuntimeError(f"`{kv_cache_dtype}` KV cache is only supported on Nvidia and AMD GPUs.")
raise RuntimeError(
f"`{kv_cache_dtype}` KV cache is only supported on Nvidia and AMD GPUs."
)
if kv_cache_dtype == "fp8_e5m2" and SYSTEM != "cuda":
raise RuntimeError(f"`fp8_e5m2` KV cache is only supported on Nvidia GPUs.")

Expand Down
4 changes: 3 additions & 1 deletion server/text_generation_server/layers/attention/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ def reshape_and_cache(
kv_cache_dtype: str = "auto",
kv_scale: int = 1.0,
):
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, kv_cache_dtype, kv_scale)
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, kv_cache_dtype, kv_scale
)


def paged_attention(
Expand Down
4 changes: 3 additions & 1 deletion server/text_generation_server/layers/attention/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ def reshape_and_cache(
kv_cache_dtype: str = "auto",
kv_scale: int = 1.0,
):
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, kv_cache_dtype, kv_scale)
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, kv_cache_dtype, kv_scale
)


def paged_attention(
Expand Down
2 changes: 2 additions & 0 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
if MAMBA_AVAILABLE:
__all__.append(Mamba)


class ModelType(enum.Enum):
IDEFICS2 = {
"type": "idefics2",
Expand Down Expand Up @@ -244,6 +245,7 @@ class ModelType(enum.Enum):
"multimodal": True,
}


FP8_KVCACHE_SUPPORTED_MODELS = {
"llama",
"baichun",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
)

from loguru import logger

if SYSTEM == "rocm":
try:
from vllm import _custom_C
Expand Down Expand Up @@ -138,7 +139,9 @@ def __init__(
self.kv_cache_dtype = config.kv_cache_dtype

if self.kv_cache_dtype == "fp8":
self.kv_scale = weights.get_kv_cache_scaling_factor(prefix, self.kv_cache_dtype)
self.kv_scale = weights.get_kv_cache_scaling_factor(
prefix, self.kv_cache_dtype
)
else:
self.kv_scale = 1.0
logger.info(f"kv_cache_dtype: {self.kv_cache_dtype}, kv_scale: {self.kv_scale}")
Expand Down Expand Up @@ -168,7 +171,15 @@ def forward(

self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)

reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots, self.kv_cache_dtype, self.kv_scale)
reshape_and_cache(
kv[:, 0],
kv[:, 1],
kv_cache[0],
kv_cache[1],
slots,
self.kv_cache_dtype,
self.kv_scale,
)

# output tensor
attn_output = torch.empty_like(query)
Expand Down
9 changes: 8 additions & 1 deletion server/text_generation_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,13 @@ async def serve_inner(
set_model_id(model_id)
asyncio.run(
serve_inner(
model_id, revision, sharded, quantize, speculate, dtype, kv_cache_dtype, trust_remote_code
model_id,
revision,
sharded,
quantize,
speculate,
dtype,
kv_cache_dtype,
trust_remote_code,
)
)
6 changes: 5 additions & 1 deletion server/text_generation_server/utils/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ def get_tensor(self, tensor_name: str, to_device=True):
# Special case for gptq which shouldn't convert
# u4 which are disguised as int32. Exl2 uses int16
# as well.
if tensor.dtype not in [torch.int16, torch.int32,torch.int64] and not tensor_name.endswith("kv_scale"):
if tensor.dtype not in [
torch.int16,
torch.int32,
torch.int64,
] and not tensor_name.endswith("kv_scale"):
tensor = tensor.to(dtype=self.dtype)
if to_device:
tensor = tensor.to(device=self.device)
Expand Down

0 comments on commit 557e18e

Please sign in to comment.