Skip to content

Commit

Permalink
Fix emb model export and load with trfrs (#756)
Browse files Browse the repository at this point in the history
# What does this PR do?

Fixes #744 

With the PR, we should be once again able to export embedding model via
transformers library or sentence transformer library depending on the
class called:

* With Transformers

```python
import torch
from optimum.neuron import NeuronModelForFeatureExtraction
from transformers import AutoConfig, AutoTokenizer

compiler_args = {"auto_cast": "matmul", "auto_cast_type": "fp16"}
input_shapes = {"batch_size": 4, "sequence_length": 512}
model = NeuronModelForFeatureExtraction.from_pretrained(
    model_id="TaylorAI/bge-micro-v2", # BERT SMALL
    export=True,
    disable_neuron_cache=True,
    **compiler_args,
    **input_shapes,
)
```

* With Sentence Transformers

```python
import torch
from optimum.neuron import NeuronModelForSentenceTransformers
from transformers import AutoConfig, AutoTokenizer

compiler_args = {"auto_cast": "matmul", "auto_cast_type": "fp16"}
input_shapes = {"batch_size": 4, "sequence_length": 512}
model = NeuronModelForSentenceTransformers.from_pretrained(
    model_id="TaylorAI/bge-micro-v2", # BERT SMALL
    export=True,
    disable_neuron_cache=True,
    **compiler_args,
    **input_shapes,
)
```
  • Loading branch information
JingyaHuang authored Jan 9, 2025
1 parent 6639b1d commit fe71b3c
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 1 deletion.
1 change: 1 addition & 0 deletions optimum/neuron/modeling_traced.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ def _export(
local_files_only=local_files_only,
token=token,
do_validation=False,
library_name=cls.library_name,
**kwargs_shapes,
)
config = AutoConfig.from_pretrained(save_dir_path)
Expand Down
1 change: 0 additions & 1 deletion tests/inference/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def test_load_model_from_hub_subfolder(self):
self.TINY_SUBFOLDER_MODEL_ID,
subfolder="my_subfolder",
export=True,
library_name="transformers",
**self.STATIC_INPUTS_SHAPES,
)
self.assertIsInstance(model.model, torch.jit._script.ScriptModule)
Expand Down

0 comments on commit fe71b3c

Please sign in to comment.