Skip to content

Commit

Permalink
Merge pull request #54 from JadenFiotto-Kaufman/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
JadenFiotto-Kaufman authored Jan 18, 2024
2 parents 95bfff2 + 2631325 commit 97ef041
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 10 deletions.
4 changes: 4 additions & 0 deletions src/nnsight/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
from .models.LanguageModel import LanguageModel
from .module import Module
from .patching import Patch, Patcher
from .logger import logger

logger.disabled = not CONFIG.APP.LOGGING


# Below do default patching:
DEFAULT_PATCHER = Patcher()
Expand Down
4 changes: 3 additions & 1 deletion src/nnsight/config.yaml
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
API:
HOST: ndif.dev
HOST: ndif.dev
APP:
LOGGING: False
19 changes: 10 additions & 9 deletions src/nnsight/models/LanguageModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def _tokenize(
torch.Tensor,
Dict[str, Any],
],
**kwargs,
):
if isinstance(inputs, BatchEncoding):
return inputs
Expand All @@ -93,7 +94,7 @@ def _tokenize(

return self.tokenizer.pad(inputs, return_tensors="pt")

return self.tokenizer(inputs, return_tensors="pt", padding=True)
return self.tokenizer(inputs, return_tensors="pt", padding=True, **kwargs)

def _prepare_inputs(
self,
Expand All @@ -111,23 +112,23 @@ def _prepare_inputs(
**kwargs,
) -> BatchEncoding:
if isinstance(inputs, dict):
_inputs = self._tokenize(inputs["input_ids"])

for ai, attn_mask in enumerate(inputs['attention_mask']):
_inputs = self._tokenize(inputs["input_ids"], **kwargs)

_inputs['attention_mask'][ai, -len(attn_mask):] = attn_mask
if "attention_mask" in inputs:
for ai, attn_mask in enumerate(inputs["attention_mask"]):
_inputs["attention_mask"][ai, -len(attn_mask) :] = attn_mask

if "labels" in inputs:
labels = self._tokenize(inputs["labels"])
labels = self._tokenize(labels)
labels = self._tokenize(inputs["labels"], **kwargs)

_inputs["labels"] = labels["input_ids"]

return _inputs

inputs = self._tokenize(inputs)
inputs = self._tokenize(inputs, **kwargs)

if labels is not None:
labels = self._tokenize(labels)
labels = self._tokenize(labels, **kwargs)

inputs["labels"] = labels["input_ids"]

Expand Down
5 changes: 5 additions & 0 deletions src/nnsight/pydantics/Config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,10 @@ class ApiConfigModel(BaseModel):
HOST: str


class AppConfigModel(BaseModel):
LOGGING: bool


class ConfigModel(BaseModel):
API: ApiConfigModel
APP: AppConfigModel

0 comments on commit 97ef041

Please sign in to comment.