Skip to content

Commit

Permalink
Upgrade transformers deps to support mistral
Browse files Browse the repository at this point in the history
- Raft is not compatible with `transformers >= 4.35.0`, add version
  check in `src/lmflow/pipeline/auto_pipeline.py` to avoid conflicts
  • Loading branch information
research4pan committed Mar 4, 2024
1 parent 467f0c9 commit 3711298
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
numpy==1.24.2
datasets==2.14.6
tokenizers==0.13.3
tokenizers>=0.13.3
peft==0.4.0
torch>=2.0.1
wandb==0.14.0
deepspeed==0.10.0
trl==0.5.0
sentencepiece
transformers>=4.31.0,<4.35.0
transformers>=4.38.0
flask
flask_cors
icetk
Expand Down
18 changes: 15 additions & 3 deletions src/lmflow/pipeline/auto_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,32 @@
# coding=utf-8
"""Return a pipeline automatically based on its name.
"""
import pkg_resources

def is_package_version_at_least(package_name, min_version):
try:
package_version = pkg_resources.get_distribution(package_name).version
if (pkg_resources.parse_version(package_version)
< pkg_resources.parse_version(min_version)):
return False
except pkg_resources.DistributionNotFound:
return False
return True

from lmflow.pipeline.evaluator import Evaluator
from lmflow.pipeline.finetuner import Finetuner
from lmflow.pipeline.inferencer import Inferencer
from lmflow.pipeline.raft_aligner import RaftAligner


PIPELINE_MAPPING = {
"evaluator": Evaluator,
"finetuner": Finetuner,
"inferencer": Inferencer,
"raft_aligner": RaftAligner,
}

if not is_package_version_at_least('transformers', '4.35.0'):
from lmflow.pipeline.raft_aligner import RaftAligner
PIPELINE_MAPPING['raft_aligner'] = RaftAligner


class AutoPipeline:
"""
Expand Down

0 comments on commit 3711298

Please sign in to comment.