Skip to content

Commit

Permalink
fix(modeling_base): partial loading of a sharded checkpoint (#578)
Browse files Browse the repository at this point in the history
  • Loading branch information
maxreciprocate authored Oct 26, 2023
1 parent 0575ce5 commit 91a0f43
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions trlx/models/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,11 +240,15 @@ def from_pretrained( # noqa: max-complexity
)
logger.info("Trained peft adapter loaded")

# No peft
if base_model is None:
# No peft
# Disable warnings about missing weights when loading the base model
verbosity = transformers.logging.get_verbosity()
transformers.logging.set_verbosity_error()
base_model = cls._auto_model_parent_class.from_pretrained(
pretrained_model_name_or_path, *model_args, **from_pretrained_kwargs
)
transformers.logging.set_verbosity(verbosity)

elif isinstance(pretrained_model_name_or_path, transformers.PreTrainedModel):
base_model = pretrained_model_name_or_path
Expand Down Expand Up @@ -288,11 +292,9 @@ def from_pretrained( # noqa: max-complexity
)
with open(index_file_name, "r") as f:
index = json.load(f)
# Collect files containing weights from supported modules
files_to_download = set()
for k, v in index["weight_map"].items():
if any([module in k for module in cls._supported_modules]):
files_to_download.add(v)

# Load all weights from the shards
files_to_download = set(index["weight_map"].values())
is_sharded = True

if is_sharded:
Expand Down

0 comments on commit 91a0f43

Please sign in to comment.