Skip to content

Commit

Permalink
Allow users to specify what kind of AutoModel they want for LanguageM…
Browse files Browse the repository at this point in the history
…odel
  • Loading branch information
JadenFiotto-Kaufman committed Dec 12, 2023
1 parent 26832ab commit e573b8a
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions src/nnsight/models/LanguageModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
BatchEncoding, PretrainedConfig, PreTrainedModel,
PreTrainedTokenizer)

from transformers.models.auto import modeling_auto
from .AbstractModel import AbstractModel


class LanguageModel(AbstractModel):
"""LanguageModels are nnsight wrappers around AutoModelForCausalLM models.
"""LanguageModels are nnsight wrappers around transformer auto models.
Inputs can be in the form of:
Prompt: (str)
Expand All @@ -25,21 +25,23 @@ class LanguageModel(AbstractModel):
If using a custom model, you also need to provide the tokenizer like ``LanguageModel(custom_model, tokenizer=tokenizer)``
Calls to generate pass arguments downstream to :func:`AutoModelForCausalLM.generate`
Calls to generate pass arguments downstream to :func:`GenerationMixin.generate`
Attributes:
config (PretrainedConfig): Huggingface config file loaded from repository or checkpoint.
tokenizer (PreTrainedTokenizer): Tokenizer for LMs.
meta_model (PreTrainedModel): Meta version of underlying AutoModelForCausalLM model.
local_model (PreTrainedModel): Local version of underlying AutoModelForCausalLM model.
automodel (type): AutoModel type from transformer auto models.
meta_model (PreTrainedModel): Meta version of underlying auto model.
local_model (PreTrainedModel): Local version of underlying auto model.
"""

def __init__(self, *args, tokenizer=None, **kwargs) -> None:
def __init__(self, *args, tokenizer=None, automodel=AutoModelForCausalLM, **kwargs) -> None:
self.config: PretrainedConfig = None
self.tokenizer: PreTrainedTokenizer = tokenizer
self.meta_model: PreTrainedModel = None
self.local_model: PreTrainedModel = None
self.automodel = automodel if not isinstance(automodel, str) else getattr(modeling_auto, automodel)

super().__init__(*args, **kwargs)

Expand All @@ -54,10 +56,10 @@ def _load_meta(self, repoid_or_path, *args, **kwargs) -> PreTrainedModel:
)
self.tokenizer.pad_token = self.tokenizer.eos_token

return AutoModelForCausalLM.from_config(self.config, trust_remote_code=True)
return self.automodel.from_config(self.config, trust_remote_code=True)

def _load_local(self, repoid_or_path, *args, **kwargs) -> PreTrainedModel:
return AutoModelForCausalLM.from_pretrained(
return self.automodel.from_pretrained(
repoid_or_path, *args, config=self.config, **kwargs
)

Expand Down Expand Up @@ -117,7 +119,7 @@ def _prepare_inputs(
_inputs["labels"] = labels["input_ids"]

return _inputs

inputs = self._tokenize(inputs)

if labels is not None:
Expand All @@ -144,7 +146,7 @@ def _batch_inputs(
return batched_inputs, len(prepared_inputs["input_ids"])

def _example_input(self) -> Dict[str, torch.Tensor]:
return BatchEncoding({"input_ids": torch.tensor([[0]])})
return BatchEncoding({"input_ids": torch.tensor([[0]]), "labels": torch.tensor([[0]])})

def _scan(self, prepared_inputs, *args, **kwargs) -> None:
# TODO
Expand Down

0 comments on commit e573b8a

Please sign in to comment.