Skip to content

Commit

Permalink
Add the deepspeed injection_policy of mistral (#1309)
Browse files Browse the repository at this point in the history
Signed-off-by: yuanwu <[email protected]>
  • Loading branch information
yuanwu2017 authored and regisss committed Sep 6, 2024
1 parent 5164d51 commit bcd73f8
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions optimum/habana/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,4 +149,9 @@ def get_ds_injection_policy(config):

policy = {LlamaDecoderLayer: ("self_attn.o_proj", "mlp.down_proj")}

if model_type == "mistral":
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer

policy = {MistralDecoderLayer: ("self_attn.o_proj", "mlp.down_proj")}

return policy

0 comments on commit bcd73f8

Please sign in to comment.