From 8abbcbd868656f66b9a0ecb8fad22ebbc6c38438 Mon Sep 17 00:00:00 2001 From: Philipp Schmid <32632186+philschmid@users.noreply.github.com> Date: Mon, 6 Nov 2023 17:21:30 +0100 Subject: [PATCH] Add revision to push_to_hub (#292) * Add revision to push_to_hub * make style * remove fairscale * style happy * Update modeling_base.py * Update modeling_decoder.py --- optimum/neuron/modeling_base.py | 2 ++ optimum/neuron/modeling_decoder.py | 2 ++ optimum/neuron/trainers.py | 5 ----- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/optimum/neuron/modeling_base.py b/optimum/neuron/modeling_base.py index 7e6e3da3f..be6a4950f 100644 --- a/optimum/neuron/modeling_base.py +++ b/optimum/neuron/modeling_base.py @@ -317,6 +317,7 @@ def push_to_hub( save_directory: str, repository_id: str, private: Optional[bool] = None, + revision: Optional[str] = None, use_auth_token: Union[bool, str] = True, endpoint: Optional[str] = None, ) -> str: @@ -348,6 +349,7 @@ def push_to_hub( repo_id=repository_id, path_or_fileobj=os.path.join(os.getcwd(), local_file_path), path_in_repo=hub_file_path, + revision=revision, ) def forward(self, *args, **kwargs): diff --git a/optimum/neuron/modeling_decoder.py b/optimum/neuron/modeling_decoder.py index 6f8f3dbbb..a05022aef 100644 --- a/optimum/neuron/modeling_decoder.py +++ b/optimum/neuron/modeling_decoder.py @@ -259,6 +259,7 @@ def push_to_hub( save_directory: str, repository_id: str, private: Optional[bool] = None, + revision: Optional[str] = None, use_auth_token: Union[bool, str] = True, endpoint: Optional[str] = None, ) -> str: @@ -290,4 +291,5 @@ def push_to_hub( repo_id=repository_id, path_or_fileobj=os.path.join(os.getcwd(), local_file_path), path_in_repo=hub_file_path, + revision=revision, ) diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index 2de816818..36af2efd1 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -27,8 +27,6 @@ import torch from packaging import version from transformers import PreTrainedModel, Seq2SeqTrainer, Trainer, TrainingArguments -from transformers.dependency_versions_check import dep_version_check -from transformers.integrations import is_fairscale_available from transformers.modeling_utils import unwrap_model from transformers.trainer import ( OPTIMIZER_NAME, @@ -80,9 +78,6 @@ else: IS_SAGEMAKER_MP_POST_1_10 = False -if is_fairscale_available(): - dep_version_check("fairscale") - logger = logging.get_logger("transformers.trainer")