From 2cca808c31461df7c1aa097071c79e327fa7c57a Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Tue, 10 Dec 2024 10:25:38 +0000 Subject: [PATCH 1/6] (vllm) updated vllm rocm kernels --- Dockerfile_amd | 13 +++++ server/Makefile-vllm | 2 +- .../layers/attention/kv_cache.py | 4 +- .../layers/attention/rocm.py | 58 ++++++++++++------- .../layers/layernorm.py | 37 +++++++----- .../text_generation_server/layers/linear.py | 8 +-- .../layers/moe/__init__.py | 5 +- .../layers/moe/unquantized.py | 4 +- .../text_generation_server/layers/rotary.py | 2 +- .../custom_modeling/flash_cohere_modeling.py | 2 +- .../custom_modeling/flash_dbrx_modeling.py | 4 +- .../flash_deepseek_v2_modeling.py | 6 +- .../custom_modeling/flash_gptj_modeling.py | 2 +- .../custom_modeling/flash_llama_modeling.py | 33 +++++++++-- .../custom_modeling/flash_mistral_modeling.py | 6 +- .../custom_modeling/idefics_modeling.py | 2 +- 16 files changed, 120 insertions(+), 68 deletions(-) diff --git a/Dockerfile_amd b/Dockerfile_amd index 7638947a5c7..beb9101d72f 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -234,6 +234,7 @@ FROM kernel-builder AS vllm-builder WORKDIR /usr/src COPY server/Makefile-vllm Makefile +RUN pip install setuptools_scm # Build specific version of vllm RUN make build-vllm-rocm @@ -267,6 +268,15 @@ COPY server/exllamav2_kernels/ . RUN python setup.py build +FROM kernel-builder AS moe-kernels +WORKDIR /usr/src +ENV MOE_KERNELS_BRANCH=b74b163a2e28042068ac8355e07e6dde926a967a +ENV VLLM_TARGET_DEVICE=rocm +RUN git clone https://github.com/mht-sharma/moe-kernels.git && \ + cd moe-kernels && \ + git checkout ${MOE_KERNELS_BRANCH} && \ + python setup.py install + FROM install_deps AS base-copy # Text Generation Inference base env @@ -289,6 +299,9 @@ COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 # Copy build artifacts from exllamav2 kernels builder COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages +# Copy build artifacts from moe kernels +COPY --from=moe-kernels /usr/src/moe-kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages + # Install server COPY proto proto COPY server server diff --git a/server/Makefile-vllm b/server/Makefile-vllm index 45a7980d4bd..90da96d26fc 100644 --- a/server/Makefile-vllm +++ b/server/Makefile-vllm @@ -1,4 +1,4 @@ -commit_rocm := 4e0929e6e4fa0a3d09d358715c288020ea9dc247 +commit_rocm := de990cd12537f78f74e40b5c8ee1a62d63d734dd build-vllm-rocm: if [ ! -d 'vllm' ]; then \ diff --git a/server/text_generation_server/layers/attention/kv_cache.py b/server/text_generation_server/layers/attention/kv_cache.py index cad1d98a0b8..93d74732408 100644 --- a/server/text_generation_server/layers/attention/kv_cache.py +++ b/server/text_generation_server/layers/attention/kv_cache.py @@ -215,7 +215,9 @@ def paged_reshape_and_cache( raise ImportError( f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" ) - ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0) + ops.reshape_and_cache( + key, value, key_cache, value_cache, slots, "auto", 1.0, 1.0 + ) elif SYSTEM == "ipex": import intel_extension_for_pytorch as ipex diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index ea11c2c2615..d65054a1ff2 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -6,26 +6,42 @@ from text_generation_server.layers.attention import Seqlen from text_generation_server.utils.log import log_master from loguru import logger +import vllm._custom_ops as ops major, minor = torch.cuda.get_device_capability() is_sm75 = major == 7 and minor == 5 -_PARTITION_SIZE_V1V2 = 512 +_PARTITION_SIZE_V1V2 = 1024 _PARTITION_SIZE_CUSTOM = 256 +_GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName +_ON_MI250_MI300 = any( + arch in _GPU_ARCH for arch in ["gfx90a", "gfx940", "gfx941", "gfx942"] +) + use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"} ENGINE = "triton" if use_triton else "ck" use_rocm_custom_paged_attn = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0" -try: - if use_rocm_custom_paged_attn: - from vllm._custom_C import paged_attention_custom -except ImportError as e: - log_master( - logger.info, - f"Custom Paged Attention not available. Complete error: {e}", + + +def _use_rocm_custom_paged_attention( + qtype: torch.dtype, + head_size: int, + block_size: int, + gqa_ratio: int, + max_seq_len: int, +) -> bool: + # rocm custom page attention not support on navi (gfx1*) + return ( + use_rocm_custom_paged_attn + and _ON_MI250_MI300 + and (qtype == torch.half or qtype == torch.bfloat16) + and (head_size == 64 or head_size == 128) + and (block_size == 16 or block_size == 32) + and (gqa_ratio >= 1 and gqa_ratio <= 16) + and max_seq_len <= 131072 ) - use_rocm_custom_paged_attn = False def paged_attention( @@ -66,13 +82,8 @@ def paged_attention( num_kv_heads = kv_cache.key.shape[1] gqa_ratio = num_heads // num_kv_heads - use_custom = ( - use_rocm_custom_paged_attn - and (query.dtype == torch.half or query.dtype == torch.bfloat16) - and (head_size == 128 or head_size == 64) - and (block_size == 16 or block_size == 32) - and (gqa_ratio >= 1 and gqa_ratio <= 16) - and max_s <= 32768 + use_custom = _use_rocm_custom_paged_attention( + query.dtype, head_size, block_size, gqa_ratio, max_s ) if not use_custom: @@ -90,8 +101,6 @@ def paged_attention( # V1 to avoid the overhead of reduction. Also, if the number of # sequences or heads is large, we use V1 since there is enough work # to parallelize. - import vllm._custom_ops as ops - use_v1 = ( max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) @@ -103,7 +112,7 @@ def paged_attention( query, kv_cache.key, kv_cache.value, - kv_head_mapping, + num_kv_heads, softmax_scale, block_tables, input_lengths, @@ -112,6 +121,7 @@ def paged_attention( None, "auto", 1.0, + 1.0, ) else: # Run PagedAttention V2. @@ -129,6 +139,7 @@ def paged_attention( max_logits = torch.empty_like(exp_sums) if not use_custom: + logger.info("Using PagedAttention V2") ops.paged_attention_v2( out, exp_sums, @@ -137,7 +148,7 @@ def paged_attention( query, kv_cache.key, kv_cache.value, - kv_head_mapping, + num_kv_heads, softmax_scale, block_tables, input_lengths, @@ -146,9 +157,10 @@ def paged_attention( None, "auto", 1.0, + 1.0, ) else: - paged_attention_custom( + ops.paged_attention_rocm( out, exp_sums, max_logits, @@ -164,6 +176,10 @@ def paged_attention( max_s, None, "auto", + 1.0, + 1.0, + None, + 512, ) return out diff --git a/server/text_generation_server/layers/layernorm.py b/server/text_generation_server/layers/layernorm.py index ce5289f9337..8c7a2eb048f 100644 --- a/server/text_generation_server/layers/layernorm.py +++ b/server/text_generation_server/layers/layernorm.py @@ -72,7 +72,7 @@ def forward(self, hidden_states, residual=None): return normed_hidden_states, residual elif SYSTEM == "rocm": - from vllm._C import ops + import vllm._custom_ops as ops class FastLayerNorm(nn.LayerNorm): def forward(self, hidden_states, residual=None): @@ -121,6 +121,27 @@ def forward(self, hidden_states, residual=None): residual is not None, ) return out, residual if residual is not None else hidden_states + elif SYSTEM == "rocm": + # We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not. + if residual is not None: + ops.fused_add_rms_norm( + hidden_states, + residual, + self.weight.data, + self.variance_epsilon, + ) + return hidden_states, residual + + residual = hidden_states + + out = torch.empty_like(hidden_states) + ops.rms_norm( + out, + hidden_states, + self.weight.data, + self.variance_epsilon, + ) + return out, residual elif hidden_states.shape[-1] > 8192: if residual is not None: hidden_states += residual @@ -164,20 +185,6 @@ def forward(self, hidden_states, residual=None): res = hidden_states return normed_hidden_states, res - elif SYSTEM == "rocm": - # We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not. - if residual is not None: - hidden_states += residual - residual = hidden_states - - out = torch.empty_like(hidden_states) - ops.rms_norm( - out, - hidden_states, - self.weight.data, - self.variance_epsilon, - ) - return out, residual else: raise ValueError( "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." diff --git a/server/text_generation_server/layers/linear.py b/server/text_generation_server/layers/linear.py index 08306d57969..166c2874460 100644 --- a/server/text_generation_server/layers/linear.py +++ b/server/text_generation_server/layers/linear.py @@ -11,10 +11,10 @@ if ROCM_USE_SKINNY_GEMM: try: - from vllm import _custom_C + import vllm._custom_ops as ops except Exception as e: raise ImportError( - f"Could not load `vllm._custom_C` for ROCm skinny gemm. Full error: {e}" + f"Could not load `vllm._custom_ops` for ROCm skinny gemm. Full error: {e}" ) @@ -95,12 +95,12 @@ def forward(self, inp: torch.Tensor) -> torch.Tensor: out = torch.empty( inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device ) - _custom_C.wvSpltK(weight, inp, out, n, self.cu_count) + ops.wvSpltK(weight, inp, out, n, self.cu_count) elif m % 4 == 0 and n == 1 and k <= 8192: out = torch.empty( inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device ) - _custom_C.LLMM1(weight, inp, out, 4) + ops.LLMM1(weight, inp, out, 4) else: out = F.linear(inp, weight) diff --git a/server/text_generation_server/layers/moe/__init__.py b/server/text_generation_server/layers/moe/__init__.py index a5ae7ff4fde..be40d78a8a1 100644 --- a/server/text_generation_server/layers/moe/__init__.py +++ b/server/text_generation_server/layers/moe/__init__.py @@ -24,10 +24,7 @@ UnquantizedWeight, ) -if SYSTEM == "rocm": - from .fused_moe_rocm import grouped_topk - from vllm.model_executor.layers.fused_moe import fused_topk -elif SYSTEM == "ipex": +if SYSTEM == "ipex": from intel_extension_for_pytorch.llm.modules import GatedMLPMOE else: from moe_kernels.fused_moe import fused_topk, grouped_topk diff --git a/server/text_generation_server/layers/moe/unquantized.py b/server/text_generation_server/layers/moe/unquantized.py index 75af040906c..3c9bcabace2 100644 --- a/server/text_generation_server/layers/moe/unquantized.py +++ b/server/text_generation_server/layers/moe/unquantized.py @@ -6,9 +6,7 @@ from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.weights import UnquantizedWeight, Weights -if SYSTEM == "rocm": - from vllm.model_executor.layers.fused_moe import fused_moe -elif SYSTEM == "ipex": +if SYSTEM == "ipex": from intel_extension_for_pytorch.llm.modules import GatedMLPMOE else: from moe_kernels.fused_moe import fused_moe diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index 123bbadbb9e..e346d0f8946 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -7,7 +7,7 @@ if SYSTEM == "cuda": import rotary_emb elif SYSTEM == "rocm": - from vllm._C import ops + import vllm._custom_ops as ops elif SYSTEM == "ipex": import intel_extension_for_pytorch as ipex diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 68719106fca..ece15e942b2 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -75,7 +75,7 @@ def forward( rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) elif SYSTEM == "rocm": - from vllm._C import ops + import vllm._custom_ops as ops # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems. # Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773 diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 2d1aa96c285..aa0327825d7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -23,9 +23,7 @@ from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.utils.import_utils import SYSTEM -if SYSTEM == "rocm": - from vllm.model_executor.layers.fused_moe import fused_moe -elif SYSTEM == "ipex": +if SYSTEM == "ipex": from intel_extension_for_pytorch.llm.modules import GatedMLPMOE else: from moe_kernels.fused_moe import fused_moe diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 906a83a4151..2fb733cd309 100644 --- a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -43,9 +43,9 @@ if SYSTEM == "rocm": try: - from vllm import _custom_C + import vllm._custom_ops as ops except Exception as e: - raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") + raise ImportError(f"Could not load `vllm._custom_ops`. Full error: {e}") class DeepseekV2Config(PretrainedConfig): @@ -408,7 +408,7 @@ def forward(self, hidden_states: torch.Tensor, reduce: bool = True): dtype=hidden_states.dtype, device="cuda", ) - _custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8) + ops.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8) return self.down_proj(out, reduce=reduce) else: gate_up_states = self.gate_up_proj(hidden_states) diff --git a/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py index 692f8ca31be..45b90679da3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py @@ -91,7 +91,7 @@ def forward( rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) elif SYSTEM == "rocm": - from vllm._C import ops + import vllm._custom_ops as ops # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems. # Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773 diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 2c007d15648..b093bcd73a5 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -64,9 +64,9 @@ if SYSTEM == "rocm": try: - from vllm import _custom_C + import vllm._custom_ops as ops except Exception as e: - raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") + raise ImportError(f"Could not load `vllm._custom_ops`. Full error: {e}") def load_attention(config, prefix: str, weights, layer_id): @@ -392,16 +392,37 @@ def forward(self, hidden_states, adapter_data): dtype=hidden_states.dtype, device="cuda", ) - _custom_C.LLMM_Silu( + ops.LLMM_Silu( self.gate_up_proj.base_layer.linear.weight, hidden_states, out, 8 ) return self.down_proj(out, adapter_data) else: gate_up_states = self.gate_up_proj(hidden_states, adapter_data) - gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) - return self.down_proj( - self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data + # x = gate_up_states.view(-1, 1,self.intermediate_size) + # from loguru import logger + # logger.info(f"gate_up_states: {gate_up_states.shape}") + # x = self.act(gate_up_states[:, 0]) * gate_up_states[:, 1] + # logger.info(f"x: {x.shape}") + + # return self.down_proj( + # x, adapter_data + # ) + + # gate_up_states: torch.Size([4096, 2, 14336]) + # x: torch.Size([4096, 14336]) + + # gate_up_states = self.gate_up_proj(hidden_states, adapter_data) + # x = gate_up_states.view(-1, 2, self.intermediate_size) + # # x = gate_up_states[:, 0] * self.act(gate_up_states[:, 1]) + + output_shape = gate_up_states.shape[:-1] + (self.intermediate_size,) + + out = torch.empty( + output_shape, dtype=gate_up_states.dtype, device=gate_up_states.device ) + ops.silu_and_mul(out, gate_up_states) + + return self.down_proj(out, adapter_data) class FlashLlamaLayer(nn.Module): diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index c66c732f21d..0fa172d039e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -49,9 +49,9 @@ if SYSTEM == "rocm": try: - from vllm import _custom_C + import vllm._custom_ops as ops except Exception as e: - raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") + raise ImportError(f"Could not load `vllm._custom_ops`. Full error: {e}") class MistralConfig(PretrainedConfig): @@ -318,7 +318,7 @@ def forward(self, hidden_states, adapter_data): dtype=hidden_states.dtype, device="cuda", ) - _custom_C.LLMM_Silu( + ops.LLMM_Silu( self.gate_up_proj.base_layer.linear.weight, hidden_states, out, 8 ) return self.down_proj(out, adapter_data) diff --git a/server/text_generation_server/models/custom_modeling/idefics_modeling.py b/server/text_generation_server/models/custom_modeling/idefics_modeling.py index fc6becc4b09..9fc9bca63b8 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_modeling.py +++ b/server/text_generation_server/models/custom_modeling/idefics_modeling.py @@ -52,7 +52,7 @@ if SYSTEM == "cuda": import dropout_layer_norm elif SYSTEM == "rocm": - from vllm._C import ops + import vllm._custom_ops as ops else: dropout_layer_norm = None From ca071bdd1d34cd7fab8e9430b2173cecb7e4664e Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Tue, 10 Dec 2024 10:41:40 +0000 Subject: [PATCH 2/6] revert silu --- .../custom_modeling/flash_llama_modeling.py | 27 +++---------------- 1 file changed, 3 insertions(+), 24 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index b093bcd73a5..10309006af9 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -398,31 +398,10 @@ def forward(self, hidden_states, adapter_data): return self.down_proj(out, adapter_data) else: gate_up_states = self.gate_up_proj(hidden_states, adapter_data) - # x = gate_up_states.view(-1, 1,self.intermediate_size) - # from loguru import logger - # logger.info(f"gate_up_states: {gate_up_states.shape}") - # x = self.act(gate_up_states[:, 0]) * gate_up_states[:, 1] - # logger.info(f"x: {x.shape}") - - # return self.down_proj( - # x, adapter_data - # ) - - # gate_up_states: torch.Size([4096, 2, 14336]) - # x: torch.Size([4096, 14336]) - - # gate_up_states = self.gate_up_proj(hidden_states, adapter_data) - # x = gate_up_states.view(-1, 2, self.intermediate_size) - # # x = gate_up_states[:, 0] * self.act(gate_up_states[:, 1]) - - output_shape = gate_up_states.shape[:-1] + (self.intermediate_size,) - - out = torch.empty( - output_shape, dtype=gate_up_states.dtype, device=gate_up_states.device + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data ) - ops.silu_and_mul(out, gate_up_states) - - return self.down_proj(out, adapter_data) class FlashLlamaLayer(nn.Module): From 07e9ec2b66056d7e73854ae78bbe3783fdd66604 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Tue, 10 Dec 2024 10:54:52 +0000 Subject: [PATCH 3/6] update partition size --- server/text_generation_server/layers/attention/rocm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index d65054a1ff2..a401b589eb6 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -179,7 +179,7 @@ def paged_attention( 1.0, 1.0, None, - 512, + _PARTITION_SIZE, ) return out From 1194cdb1ba090507cfcbc079a9c8c0bd84fef1f7 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Tue, 10 Dec 2024 10:58:01 +0000 Subject: [PATCH 4/6] remove grouped_topk --- .../layers/moe/fused_moe_rocm.py | 52 ------------------- 1 file changed, 52 deletions(-) delete mode 100644 server/text_generation_server/layers/moe/fused_moe_rocm.py diff --git a/server/text_generation_server/layers/moe/fused_moe_rocm.py b/server/text_generation_server/layers/moe/fused_moe_rocm.py deleted file mode 100644 index 68accb99022..00000000000 --- a/server/text_generation_server/layers/moe/fused_moe_rocm.py +++ /dev/null @@ -1,52 +0,0 @@ -# coding=utf-8 -# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Tuple - -import torch -import torch.distributed - - -# TODO: Remove the functions once moe_kernel are built for ROCM -def grouped_topk( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - num_expert_group: int = 0, - topk_group: int = 0, -) -> Tuple[torch.Tensor, torch.Tensor]: - scores = torch.softmax(gating_output, dim=-1) - num_token = scores.shape[0] - group_scores = ( - scores.view(num_token, num_expert_group, -1).max(dim=-1).values - ) # [n, n_group] - group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ - 1 - ] # [n, top_k_group] - group_mask = torch.zeros_like(group_scores) # [n, n_group] - group_mask.scatter_(1, group_idx, 1) # [n, n_group] - score_mask = ( - group_mask.unsqueeze(-1) - .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) - .reshape(num_token, -1) - ) # [n, e] - tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] - topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) - - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - - return topk_weights, topk_ids From 999eba8096b38045ca8208be82fbbccefe0590d1 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Wed, 11 Dec 2024 12:10:35 +0000 Subject: [PATCH 5/6] (nit) remove log --- server/text_generation_server/layers/attention/rocm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index a401b589eb6..0cfac25bd93 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -139,7 +139,6 @@ def paged_attention( max_logits = torch.empty_like(exp_sums) if not use_custom: - logger.info("Using PagedAttention V2") ops.paged_attention_v2( out, exp_sums, From 12cb6aa6b7ff0f2e9c9ed19ad33b94856a77bb34 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Fri, 13 Dec 2024 14:51:11 +0000 Subject: [PATCH 6/6] add flash decoding --- server/Makefile-flash-att-v2 | 2 +- .../layers/attention/rocm.py | 48 +++++++++++++++++-- .../models/flash_causal_lm.py | 14 ++++-- 3 files changed, 53 insertions(+), 11 deletions(-) diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 index a9cdf782270..9a946d97f8b 100644 --- a/server/Makefile-flash-att-v2 +++ b/server/Makefile-flash-att-v2 @@ -1,5 +1,5 @@ flash_att_v2_commit_cuda := v2.6.1 -flash_att_v2_commit_rocm := 2092111b9f975b3347c652ff7fabd431130256c4 +flash_att_v2_commit_rocm := 47bd46e0204a95762ae48712fd1a3978827c77fd build-flash-attention-v2-cuda: pip install -U packaging wheel diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 0cfac25bd93..69a245ad5aa 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -5,6 +5,10 @@ from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import Seqlen from text_generation_server.utils.log import log_master +from text_generation_server.models.globals import ( + ATTENTION, + BLOCK_SIZE, +) from loguru import logger import vllm._custom_ops as ops @@ -73,11 +77,44 @@ def paged_attention( # limitations under the License. # + if ATTENTION == "flashdecoding": + max_q = 1 + max_k = max_s + import flash_attn_2_cuda + + if softcap is None: + softcap = 0.0 + out = flash_attn_2_cuda.varlen_fwd( + query, + kv_cache.key, + kv_cache.value, + None, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_k, + None, # pad_k + None, + block_tables, + None, + max_q, + max_k, + 0.0, # dropout + softmax_scale, + False, # zero_tensors + True, # causal + -1, # Window_left + -1, # Window right + softcap, + False, # return softmax + None, # generator + ) + return out[0] + if softcap is not None: raise RuntimeError("Paged attention doesn't support softcapping") # value_cache => [num_blocks, num_heads, head_size, block_size] - block_size = kv_cache.value.shape[3] + # block_size = kv_cache.value.shape[3] + block_size = BLOCK_SIZE num_seqs, num_heads, head_size = query.shape num_kv_heads = kv_cache.key.shape[1] @@ -247,14 +284,15 @@ def attention( # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. return flash_attn_2_cuda.varlen_fwd( query, - key, - value, + # flashdecoding: pass the KV caches, paged: pass the KV. + kv_cache.key if ATTENTION == "flashdecoding" else key, + kv_cache.value if ATTENTION == "flashdecoding" else value, out, seqlen.cu_seqlen_q, - seqlen.cu_seqlen_q, - None, + seqlen.cu_seqlen_k, None, None, + block_tables if ATTENTION == "flashdecoding" else None, None, seqlen.max_q, seqlen.max_k, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 8989110a7ad..fe15a30efc4 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1659,7 +1659,7 @@ def warmup( for seqlen in tuning_sequences: log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}") - self.tunableop_warmup(seqlen) + self.tunableop_warmup(seqlen, max_total_tokens) torch.cuda.tunable.write_file(tunableop_filepath) if os.environ.get("PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP") != "1": torch.cuda.tunable.tuning_enable(False) @@ -1689,7 +1689,7 @@ def warmup( assert max_total_tokens is not None return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens - def tunableop_warmup(self, seqlen: int): + def tunableop_warmup(self, seqlen: int, max_bt: int): input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device) position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device) slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) @@ -1703,11 +1703,15 @@ def tunableop_warmup(self, seqlen: int): [0, seqlen], device=self.device, dtype=torch.int32 ) max_s = seqlen + + block_tables = torch.arange( + max_bt, dtype=torch.int32, device=self.device + ).repeat(seqlen) + block_tables = block_tables.reshape((seqlen, max_bt)) + seqlen = Seqlen( input_lengths=input_lengths, cache_lengths=cache_lengths_tensor, - cu_seqlen_q=cu_seqlen_prefill, - max_q=1, max_k=seqlen, ) @@ -1717,7 +1721,7 @@ def tunableop_warmup(self, seqlen: int): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=self.kv_cache, - block_tables=None, + block_tables=block_tables, seqlen=seqlen, slots=slots, max_s=max_s,