diff --git a/benchmark/text-generation-inference/performance/phi4-14b/docker-compose.yaml b/benchmark/text-generation-inference/performance/phi4-14b/docker-compose.yaml new file mode 100644 index 000000000..960e93ea8 --- /dev/null +++ b/benchmark/text-generation-inference/performance/phi4-14b/docker-compose.yaml @@ -0,0 +1,76 @@ +version: '3.7' + +services: + tgi-1: + image: neuronx-tgi:latest + ports: + - "8081:8081" + environment: + - PORT=8081 + - MODEL_ID=${MODEL_ID} + - HF_AUTO_CAST_TYPE=${HF_AUTO_CAST_TYPE} + - HF_NUM_CORES=8 + - MAX_BATCH_SIZE=${MAX_BATCH_SIZE} + - MAX_INPUT_TOKENS=${MAX_INPUT_TOKENS} + - MAX_TOTAL_TOKENS=${MAX_TOTAL_TOKENS} + - MAX_CONCURRENT_REQUESTS=512 + - HF_TOKEN=${HF_TOKEN} + devices: + - "/dev/neuron0" + - "/dev/neuron1" + - "/dev/neuron2" + - "/dev/neuron3" + + tgi-2: + image: neuronx-tgi:latest + ports: + - "8082:8082" + environment: + - PORT=8082 + - MODEL_ID=${MODEL_ID} + - HF_AUTO_CAST_TYPE=${HF_AUTO_CAST_TYPE} + - HF_NUM_CORES=8 + - MAX_BATCH_SIZE=${MAX_BATCH_SIZE} + - MAX_INPUT_TOKENS=${MAX_INPUT_TOKENS} + - MAX_TOTAL_TOKENS=${MAX_TOTAL_TOKENS} + - MAX_CONCURRENT_REQUESTS=512 + - HF_TOKEN=${HF_TOKEN} + devices: + - "/dev/neuron4" + - "/dev/neuron5" + - "/dev/neuron6" + - "/dev/neuron7" + + tgi-3: + image: neuronx-tgi:latest + ports: + - "8083:8083" + environment: + - PORT=8083 + - MODEL_ID=${MODEL_ID} + - HF_AUTO_CAST_TYPE=${HF_AUTO_CAST_TYPE} + - HF_NUM_CORES=8 + - MAX_BATCH_SIZE=${MAX_BATCH_SIZE} + - MAX_INPUT_TOKENS=${MAX_INPUT_TOKENS} + - MAX_TOTAL_TOKENS=${MAX_TOTAL_TOKENS} + - MAX_CONCURRENT_REQUESTS=512 + - HF_TOKEN=${HF_TOKEN} + devices: + - "/dev/neuron8" + - "/dev/neuron9" + - "/dev/neuron10" + - "/dev/neuron11" + + loadbalancer: + image: nginx:alpine + ports: + - "8080:80" + volumes: + - ./nginx.conf:/etc/nginx/nginx.conf:ro + depends_on: + - tgi-1 + - tgi-2 + - tgi-3 + deploy: + placement: + constraints: [node.role == manager] diff --git a/benchmark/text-generation-inference/performance/phi4-14b/nginx.conf b/benchmark/text-generation-inference/performance/phi4-14b/nginx.conf new file mode 100644 index 000000000..37a3b8721 --- /dev/null +++ b/benchmark/text-generation-inference/performance/phi4-14b/nginx.conf @@ -0,0 +1,15 @@ +### Nginx TGI Load Balancer +events {} +http { + upstream tgicluster { + server tgi-1:8081; + server tgi-2:8082; + server tgi-3:8083; + } + server { + listen 80; + location / { + proxy_pass http://tgicluster; + } + } +} diff --git a/optimum/exporters/neuron/model_configs/decoder_configs.py b/optimum/exporters/neuron/model_configs/decoder_configs.py index 30ddc808e..10708024d 100644 --- a/optimum/exporters/neuron/model_configs/decoder_configs.py +++ b/optimum/exporters/neuron/model_configs/decoder_configs.py @@ -14,11 +14,11 @@ # limitations under the License. """Neuron export configurations for models using transformers_neuronx.""" - from optimum.exporters.tasks import TasksManager from ....neuron.models.granite.model import GraniteForSampling from ....neuron.models.qwen2.model import Qwen2ForSampling +from ....neuron.models.phi4.model import Phi4ForSampling from ..config import TextNeuronDecoderConfig @@ -70,3 +70,9 @@ class Qwen2NeuronConfig(TextNeuronDecoderConfig): class GraniteNeuronConfig(TextNeuronDecoderConfig): NEURONX_CLASS = GraniteForSampling CONTINUOUS_BATCHING = True + + +@register_in_tasks_manager("phi4", "text-generation") +class Phi4NeuronConfig(TextNeuronDecoderConfig): + NEURONX_CLASS = Phi4ForSampling + CONTINUOUS_BATCHING = True diff --git a/optimum/neuron/models/phi4/__init__.py b/optimum/neuron/models/phi4/__init__.py new file mode 100644 index 000000000..fdc025786 --- /dev/null +++ b/optimum/neuron/models/phi4/__init__.py @@ -0,0 +1,14 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/optimum/neuron/models/phi4/config.py b/optimum/neuron/models/phi4/config.py new file mode 100644 index 000000000..4fbbc936e --- /dev/null +++ b/optimum/neuron/models/phi4/config.py @@ -0,0 +1,27 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers import PretrainedConfig +from transformers_neuronx.llama.config import LlamaConfig + + +class Phi4Config(LlamaConfig): + """The Phi4 model uses the same configuration as the TnX LLama model""" + + def __init__( + self, config: PretrainedConfig, n_positions: int, batch_size: int, amp: str, tp_degree: int, **kwargs + ): + super().__init__(config, n_positions, batch_size, amp, tp_degree, **kwargs) + self.model_type = "phi4" diff --git a/optimum/neuron/models/phi4/model.py b/optimum/neuron/models/phi4/model.py new file mode 100644 index 000000000..fd3d51322 --- /dev/null +++ b/optimum/neuron/models/phi4/model.py @@ -0,0 +1,294 @@ +# Copyright Amazon Web Services and its Affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import warnings + +import torch +from transformers import PretrainedConfig +from transformers_neuronx import base, bucket, decoder, ops, utils +from transformers_neuronx.config import NeuronConfig +from transformers_neuronx.constants import KV_SHARD_PAD, LAYOUT_HSB +from transformers_neuronx.llama.hlo import LlamaForSamplingNoEmbeddingHlo + +from .config import Phi4Config +from .modules import Phi4ForCausalLM + + +class Phi4ForSampling(base.NeuronModelBase): + """The Phi4 model is essentially a LLama model with fused qkv and gate_up projections. + + The implementation in this class is very similar to the one used for Llama in Tnx. + The only differences are: + - the config (Phi4Config) and base model (Phi4ForCausalLM) used in __init__, + - the addition of biases parameters when loading weights from the checkpoint model. + """ + + def __init__( + self, + config: PretrainedConfig, + *, + n_positions: int = 2048, + batch_size: int = 1, + amp: str = "f32", + tp_degree: int = 2, + context_length_estimate: int = None, + context_unroll: int = None, + unroll: int = None, + neuron_config: NeuronConfig = None, + prefixed_length: int = 0, + **kwargs, + ): + config = Phi4Config(config, n_positions, batch_size, amp, tp_degree) + super().__init__(Phi4ForCausalLM, config) + self.context_pre_hook = None + self.context_hook = None + self.config = config + self.neuron_config = neuron_config if neuron_config else NeuronConfig() + if self.neuron_config.shard_over_sequence: + n_kv_head = self.config.num_key_value_heads + kv_shard_degree = self.config.tp_degree // n_kv_head + assert kv_shard_degree <= KV_SHARD_PAD, "increase kv_shard degree is higher than default 128" + warnings.warn(f"shard over sequence enabled, increasing n_positions {n_positions} by 128") + if isinstance(n_positions, list): + npos = sorted(n_positions) + npos[-1] += KV_SHARD_PAD + else: + npos = n_positions + KV_SHARD_PAD + self.config.n_positions = npos + config.n_positions = npos + n_positions = npos + if self.neuron_config.on_device_generation: + self.neuron_config.on_device_generation.vocab_size = self.config.vocab_size + + self.layers_after_partition = self.neuron_config.auto_layer_partition(config.num_hidden_layers) + self.prefixed_length = prefixed_length + + if context_unroll is None: + context_unroll = len(self.layers_after_partition) + self.context_unroll = context_unroll + + if unroll is None: + unroll = len(self.layers_after_partition) + self.unroll = unroll + + self.token_buckets = bucket.token_sizes(n_positions) + self.context_buckets = bucket.context_sizes(context_length_estimate, self.token_buckets) + # input length should be divisable by tp_degree to activate seq paralle + if neuron_config and neuron_config.sequence_parallel_norm: + for bucket_size in self.context_buckets: + if ( + bucket_size > neuron_config.sequence_parallel_norm_threshold + and bucket_size % self.config.tp_degree != 0 + ): + raise ValueError( + f"Sequence parallel normalization requires the bucket size ({bucket_size}) to be divisible by the tensor parallel degree ({self.config.tp_degree})" + ) + self.window_context_buckets = [] + if prefixed_length: + if prefixed_length not in self.context_buckets: + self.context_buckets.append(prefixed_length) + self.context_buckets = sorted(self.context_buckets) + + self.batch_sizes = bucket.batch_sizes(batch_size) + self.context_batch_sizes = ( + [1] if self.neuron_config and self.neuron_config.continuous_batching else self.batch_sizes + ) + hlo_builder = LlamaForSamplingNoEmbeddingHlo(config, neuron_config=self.neuron_config) + self.decoder_param_set = decoder.DecoderLmHeadForSamplingNoEmbedding( + tp_degree=tp_degree, + n_positions_list=self.token_buckets, + n_active_tokens=1, + batch_size=self.batch_sizes, + attention_head_size=config.attention_head_size, + amp=amp, + num_layers=len(self.layers_after_partition), + n_head=config.num_attention_heads, + n_kv_head=config.num_key_value_heads, + unroll=unroll, + neuron_config=self.neuron_config, + allow_pad=True, + builder=hlo_builder, + ) + self.decoder_lm_head = self.decoder_param_set.init_token_decoder( + unroll=self.unroll, buckets=self.token_buckets, model_obj=self + ) + self.decoder_lm_head_for_context = self.decoder_param_set.init_context_decoder( + unroll=self.context_unroll, buckets=self.context_buckets, model_obj=self + ) + self.decoder_lm_head_for_speculation = {} + self.decoder_lm_head_for_window_context = {} + + def load_weights(self): + self.materialize_embeddings() + ops.init() + + for layer_id, layer in enumerate(self.chkpt_model.model.layers): + if layer_id not in self.layers_after_partition: + continue + layer.materialize() + attn = layer.self_attn + mlp = layer.mlp + if self.neuron_config and self.neuron_config.quant: + is_unit_scale = self.neuron_config.quant.is_unit_scale(layer_id) + else: + is_unit_scale = False + + # Split fused qkv_proj and mlp into separate weights + fused_attn = attn.qkv_proj.weight.clone().detach() + fused_gate_up = mlp.gate_up_proj.weight.clone().detach() + q_weight, k_weight, v_weight = torch.chunk(fused_attn, 3, dim=0) + gate, up = torch.chunk(fused_gate_up, 2, dim=0) + + new_layer = self.decoder_lm_head.new_layer(is_unit_scale=is_unit_scale) + new_layer.add_pre_attention_layer_norm(layer.input_layernorm.weight.detach(), None) + new_layer.add_attention_query(q_weight) + new_layer.add_attention_key(k_weight) + new_layer.add_attention_value(v_weight) + if self.neuron_config and self.neuron_config.attn_output_transposed: + new_layer.add_attention_output(attn.o_proj.weight.T.detach(), None, sharding=0, transposed=True) + else: + new_layer.add_attention_output(attn.o_proj.weight.detach(), None, sharding=1, transposed=False) + new_layer.add_pre_mlp_layer_norm(layer.post_attention_layernorm.weight.detach(), None) + + # Note: Automatic MLP padding is safe since zeros are *only* introduced to intermediary state + if self.neuron_config.fuse_mlp: + assert ( + fused_gate_up.shape[0] % self.config.tp_degree == 0 + ), f"mlp weights are not divisible by tp_degree {self.config.tp_degree}" + new_layer.add_mlp_input(fused_gate_up) + if self.neuron_config.mlp_out_weight_transpose: + new_layer.add_mlp_output( + mlp.down_proj.weight.T.detach(), + None, + sharding=0, + transposed=True, + ) + else: + new_layer.add_mlp_output( + mlp.down_proj.weight.detach(), + None, + sharding=1, + transposed=False, + ) + else: + new_layer.add_parameter(gate, sharding=1, allow_pad=True, allow_quantize=True, allow_transform=True) + new_layer.add_parameter(up, sharding=1, allow_pad=True, allow_quantize=True, allow_transform=True) + if self.neuron_config.weight_tiling: + new_layer.add_parameter( + mlp.down_proj.weight.T, sharding=0, allow_pad=True, allow_quantize=True, allow_transform=True + ) + else: + if self.neuron_config.mlp_out_weight_transpose: + new_layer.add_parameter( + mlp.down_proj.weight.T, sharding=0, allow_pad=True, allow_quantize=True + ) + else: + new_layer.add_parameter( + mlp.down_proj.weight, sharding=1, allow_pad=True, allow_quantize=True, out_feature_dim=0 + ) + new_layer.to_neuron() + layer.nullify() + if self.neuron_config.shard_over_sequence: + self.decoder_lm_head.add_pre_layer_parameter(torch.arange(self.config.tp_degree), sharding=0) + # For pipeline parallel, we need to load ln and lm_head for now even if the pipeline stage doesn't compute the, because + # 1) we need the ln_lm_head hlo for pp0 to get the logits shape and dtype + # 2) we don't needs these for intermediate pp stages, but to keep things simple, just include ln_lm_head for all pp stages for now + # 3) to get ln_lm_head hlo, we need to do weight loading and sharding + # 4) this will introduce extra memory allocation, but ln_lm_head i/o tensor is much smaller and we can get rid of it when we can construct hlo in init + ln_f = self.chkpt_model.model.norm + ln_f.materialize() + self.decoder_lm_head.add_final_layer_norm(ln_f.weight.detach(), None) + + lm_head = self.chkpt_model.lm_head + lm_head.materialize() + self.decoder_lm_head.add_lm_head(lm_head.weight.detach().T) + if self.neuron_config.on_device_embedding: + if self.neuron_config.sequence_parallel_norm: + self.decoder_lm_head.add_pre_layer_parameter( + self.chkpt_model.model.embed_tokens.weight, sharding=None, allow_pad=True + ) + else: + self.decoder_lm_head.add_pre_layer_parameter( + self.chkpt_model.model.embed_tokens.weight, sharding=1, allow_pad=True + ) + lm_head.nullify() + + self.decoder_lm_head.to_neuron() + self.init_rest_of_model() + + def materialize_embeddings(self): + # Materialize the embedding to CPU + self.chkpt_model.model.embed_tokens.materialize() + + def init_rest_of_model(self): + # Pipeline sparallel deosn't support executor right now + if not self.neuron_config.is_pp(): + self.decoder_lm_head.use_executor = True + + if self.context_buckets: + for context_length_estimate in self.context_buckets: + for batch_size in self.context_batch_sizes: + model = self.decoder_lm_head.build_weight_shared( + share_caches=True, new=self.decoder_lm_head_for_context[context_length_estimate, batch_size] + ) + # PERF: No latency improvement seen in multi-layer models from executor + # Pipeline parallel deosn't support executor right now + if self.context_unroll == self.config.num_hidden_layers and not self.neuron_config.is_pp(): + model.use_executor = True + self.decoder_lm_head_for_context[context_length_estimate, batch_size] = model + + if self.decoder_lm_head_for_speculation: + for i, k in enumerate(self.decoder_lm_head_for_speculation): + model = self.decoder_lm_head.build_weight_shared( + share_caches=True, + new=self.decoder_lm_head_for_speculation[k], + embed_weight=self.chkpt_model.model.embed_tokens.weight, + ) + self.decoder_lm_head_for_speculation[k] = model + + if self.decoder_lm_head_for_window_context: + for i, k in enumerate(self.decoder_lm_head_for_window_context): + model = self.decoder_lm_head.build_weight_shared( + share_caches=True, new=self.decoder_lm_head_for_window_context[k] + ) + self.decoder_lm_head_for_window_context[k] = model + + def set_prefixed(self, input_ids): + self.prefixed_input_ids = input_ids[:, : self.prefixed_length] + prefixed_length = self.prefixed_length + self.prefixed_length = 0 + self.forward(self.prefixed_input_ids) + self.prefixed_length = prefixed_length + + def preprocess_and_embed(self, input_ids, cache_ids=None, start_ids=None, **kwargs): + padded_inputs, *rst = self._preprocess(input_ids, start_ids=start_ids, cache_ids=cache_ids, **kwargs) + if not self.neuron_config.on_device_embedding: + input_embeddings = self.chkpt_model.model.embed_tokens(padded_inputs) + if self.neuron_config.attention_layout == LAYOUT_HSB: + input_embeddings = input_embeddings.transpose(0, -1).contiguous() + else: + # embedding layer is on device and will be computed as part of self._forward(), so don't compute here + input_embeddings = None + return padded_inputs, input_embeddings, *rst + + def forward(self, input_ids, cache_ids=None, start_ids=None, last_token_id=None, input_embeddings=None, **kwargs): + if last_token_id is not None: # preprocess_and_embed() has already been invoked + rst = cache_ids, start_ids, last_token_id + else: # invoke preprocess_and_embed() + input_ids, input_embeddings, *rst = self.preprocess_and_embed(input_ids, cache_ids, start_ids, **kwargs) + # either input_embeddings are generated (off device embedding), or input_ids will be padded from preprocess_and_embed (on device embedding) + inputs = input_embeddings if input_embeddings is not None else input_ids + logits = self._forward(inputs, *rst) + logits = self._postprocess(logits, start_ids=start_ids, **kwargs) + return logits diff --git a/optimum/neuron/models/phi4/modules.py b/optimum/neuron/models/phi4/modules.py new file mode 100644 index 000000000..297edbf20 --- /dev/null +++ b/optimum/neuron/models/phi4/modules.py @@ -0,0 +1,79 @@ +# Copyright Amazon Web Services and its Affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from transformers_neuronx import dtypes, module, utils + +from .config import Phi4Config + + +class Phi4ForCausalLM(module.PretrainedModel): + def __init__(self, config: Phi4Config): + super().__init__() + dtype, _, _ = utils.parse_amp(config.amp) + dtype = dtypes.to_torch_dtype(dtype) + self.model = Phi4Model(config) + self.lm_head = module.LowMemoryLazyLinear(config.vocab_size, dtype=dtype, bias=False) + + def get_tied_parameters(self): + return [(self.model.embed_tokens.weight, self.lm_head.weight)] + + def get_base_model(self): + return self.model + + +class Phi4Model(module.LowMemoryModule): + def __init__(self, config: Phi4Config): + super().__init__() + self.embed_tokens = module.LowMemoryEmbedding(config.vocab_size, config.hidden_size) + self.layers = module.LowMemoryModuleList([Phi4DecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.norm = Phi4RMSNorm(config) + + +class Phi4RMSNorm(module.LowMemoryModule): + def __init__(self, config: Phi4Config) -> None: + super().__init__() + self.weight = module.UninitializedParameter() + + +class Phi4DecoderLayer(module.LowMemoryModule): + def __init__(self, config: Phi4Config): + super().__init__() + self.self_attn = Phi4Attention(config) + self.mlp = Phi4MLP(config) + self.input_layernorm = Phi4RMSNorm(config) + self.post_attention_layernorm = Phi4RMSNorm(config) + + +class Phi4Attention(module.LowMemoryModule): + def __init__(self, config: Phi4Config): + super().__init__() + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + dtype, _, _ = utils.parse_amp(config.amp) + dtype = dtypes.to_torch_dtype(dtype) + self.q_proj = module.LowMemoryLazyLinear(self.num_heads * self.head_dim, bias=False, dtype=dtype) + self.k_proj = module.LowMemoryLazyLinear(self.num_heads * self.head_dim, bias=False, dtype=dtype) + self.v_proj = module.LowMemoryLazyLinear(self.num_heads * self.head_dim, bias=False, dtype=dtype) + self.o_proj = module.LowMemoryLazyLinear(self.hidden_size, bias=False, dtype=dtype) + + +class Phi4MLP(module.LowMemoryModule): + def __init__(self, config: Phi4Config): + super().__init__() + dtype, _, _ = utils.parse_amp(config.amp) + dtype = dtypes.to_torch_dtype(dtype) + self.gate_proj = module.LowMemoryLazyLinear(config.intermediate_size, bias=False, dtype=dtype) + self.up_proj = module.LowMemoryLazyLinear(config.intermediate_size, bias=False, dtype=dtype) + self.down_proj = module.LowMemoryLazyLinear(config.hidden_size, bias=False, dtype=dtype) diff --git a/tests/decoder/conftest.py b/tests/decoder/conftest.py index 677b8ffbf..ba69511e3 100644 --- a/tests/decoder/conftest.py +++ b/tests/decoder/conftest.py @@ -49,6 +49,10 @@ "model_id": "dacorvo/Mixtral-tiny", "export_kwargs": {"batch_size": 4, "sequence_length": 1024, "num_cores": 2, "auto_cast_type": "fp16"}, }, + "phi4": { + "model_id": "microsoft/phi-4", + "export_kwargs": {"batch_size": 4, "sequence_length": 4096, "num_cores": 2, "auto_cast_type": "bf16"}, + }, } diff --git a/tests/decoder/test_decoder_export.py b/tests/decoder/test_decoder_export.py index 61aa57481..8e005e230 100644 --- a/tests/decoder/test_decoder_export.py +++ b/tests/decoder/test_decoder_export.py @@ -32,6 +32,7 @@ "opt": "hf-internal-testing/tiny-random-OPTForCausalLM", "qwen2": "yujiepan/qwen2.5-128k-tiny-random", "granite": "hf-internal-testing/tiny-random-GraniteForCausalLM", + "phi4": "yujiepan/phi-4-tiny-random", } diff --git a/text-generation-inference/tests/fixtures/model.py b/text-generation-inference/tests/fixtures/model.py index 6fa63ce86..8d8edcf22 100644 --- a/text-generation-inference/tests/fixtures/model.py +++ b/text-generation-inference/tests/fixtures/model.py @@ -45,6 +45,10 @@ "model_id": "ibm-granite/granite-3.1-2b-instruct", "export_kwargs": {"batch_size": 4, "sequence_length": 4096, "num_cores": 2, "auto_cast_type": "bf16"}, }, + "phi4": { + "model_id": "microsoft/phi-4", + "export_kwargs": {"batch_size": 4, "sequence_length": 4096, "num_cores": 2, "auto_cast_type": "bf16"}, + }, }