Skip to content

Commit

Permalink
Rename AbstractModel to NNsightModel. NNsightModel now exists will fi…
Browse files Browse the repository at this point in the history
…lled in basic implementations of the formally abstract methods. Can maybe be used for very basic models. Remove alteration stuff (maybe restore that later)
  • Loading branch information
JadenFiotto-Kaufman committed Jan 11, 2024
1 parent 33ec382 commit fa0f4cc
Show file tree
Hide file tree
Showing 19 changed files with 105 additions and 303 deletions.
2 changes: 1 addition & 1 deletion docs/source/documentation/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ nnsight.models
:members:


.. automodule:: nnsight.models.AbstractModel
.. automodule:: nnsight.models.NNsightModel
:members:


Expand Down
2 changes: 1 addition & 1 deletion docs/source/notebooks/walkthrough.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
"\n",
"Use `dispatch=True` to load the model onto the specified device locally on initialization, not on first use.\n",
"\n",
"`LanguageModel` is one wrapper for `nnsight` functionality. Another one `DiffusionModel` for text-to-image disffusion models. We encourage you to create your own by inheriting from `AbstractModel`\n",
"`LanguageModel` is one wrapper for `nnsight` functionality. Another one `DiffusionModel` for text-to-image disffusion models. We encourage you to create your own by inheriting from `NNsightModel`\n",
"\n",
"</details>\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/sourcelatex/documentation/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ nnsight.models
:members:


.. automodule:: nnsight.models.AbstractModel
.. automodule:: nnsight.models.NNsightModel
:members:


Expand Down
2 changes: 1 addition & 1 deletion examples/lora.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any

import torch
from nnsight import AbstractModel, LanguageModel, util
from nnsight import NNsightModel, LanguageModel, util
from nnsight.module import Module
from nnsight.toolbox.optim.lora import LORA
from torch.utils.data import DataLoader, Dataset
Expand Down
2 changes: 1 addition & 1 deletion examples/optim.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any

import torch
from nnsight import AbstractModel, LanguageModel, util
from nnsight import NNsightModel, LanguageModel, util
from nnsight.module import Module
from torch.utils.data import DataLoader, Dataset

Expand Down
2 changes: 1 addition & 1 deletion examples/optim2.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any

import torch
from nnsight import AbstractModel, LanguageModel, util
from nnsight import NNsightModel, LanguageModel, util
from nnsight.module import Module
from torch.utils.data import DataLoader, Dataset

Expand Down
2 changes: 1 addition & 1 deletion src/nnsight/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
with open(os.path.join(PATH, "config.yaml"), "r") as file:
CONFIG = ConfigModel(**yaml.safe_load(file))

from .models.AbstractModel import AbstractModel
from .models.NNsightModel import NNsightModel
from .models.DiffuserModel import DiffuserModel
from .models.LanguageModel import LanguageModel
from .module import Module
Expand Down
3 changes: 0 additions & 3 deletions src/nnsight/alteration/__init__.py

This file was deleted.

90 changes: 0 additions & 90 deletions src/nnsight/alteration/gpt.py

This file was deleted.

4 changes: 2 additions & 2 deletions src/nnsight/contexts/DirectInvoker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from .Runner import Runner

if TYPE_CHECKING:
from ..models.AbstractModel import AbstractModel
from ..models.NNsightModel import NNsightModel


class DirectInvoker(Runner, Invoker):
def __init__(
self, model: "AbstractModel", *args, fwd_args: Dict[str, Any] = None, **kwargs
self, model: "NNsightModel", *args, fwd_args: Dict[str, Any] = None, **kwargs
):
if fwd_args is None:
fwd_args = dict()
Expand Down
8 changes: 4 additions & 4 deletions src/nnsight/contexts/Tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
from ..tracing.Graph import Graph

if TYPE_CHECKING:
from ..models.AbstractModel import AbstractModel
from ..models.NNsightModel import NNsightModel


class Tracer(AbstractContextManager):
"""The Tracer class creates a :class:`nnsight.tracing.Graph.Graph` around the meta_model of a :class:`nnsight.models.AbstractModel.AbstractModel` which tracks and manages the operations performed on the inputs and outputs of said model.
"""The Tracer class creates a :class:`nnsight.tracing.Graph.Graph` around the meta_model of a :class:`nnsight.models.NNsightModel.NNsightModel` which tracks and manages the operations performed on the inputs and outputs of said model.
Attributes:
model (nnsight.models.AbstractModel.AbstractModel): nnsight Model object that ths context manager traces and executes.
model (nnsight.models.NNsightModel.NNsightModel): nnsight Model object that ths context manager traces and executes.
graph (nnsight.tracing.Graph.Graph): Graph which operations performed on the input and output of Modules are added and later executed.
args (List[Any]): Positional arguments to be passed to function that executes the model.
kwargs (Dict[str,Any]): Keyword arguments to be passed to function that executes the model.
Expand All @@ -30,7 +30,7 @@ class Tracer(AbstractContextManager):

def __init__(
self,
model: "AbstractModel",
model: "NNsightModel",
*args,
validate: bool = True,
**kwargs,
Expand Down
2 changes: 1 addition & 1 deletion src/nnsight/contexts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
The primary two classes involved here are :class:`Tracer <nnsight.contexts.Tracer.Tracer>` and :class:`Invoker <nnsight.contexts.Invoker.Invoker>`.
The :class:`Tracer <nnsight.contexts.Tracer.Tracer>` class creates a :class:`Graph <nnsight.tracing.Graph.Graph>` around the meta_model of an :class:`AbstractModel <nnsight.models.AbstractModel.AbstractModel>`. The graph tracks and manages the operations performed on the inputs and outputs of said model.
The :class:`Tracer <nnsight.contexts.Tracer.Tracer>` class creates a :class:`Graph <nnsight.tracing.Graph.Graph>` around the meta_model of an :class:`NNsightModel <nnsight.models.NNsightModel.NNsightModel>`. The graph tracks and manages the operations performed on the inputs and outputs of said model.
Modules in the meta_model expose their ``.output`` and ``.input`` attributes which when accessed, add to the computation graph of the tracer.
To do this, they need to know about the current Tracer object, so each Module's ``.tracer`` object is set to be the current Tracer.
Expand Down
4 changes: 2 additions & 2 deletions src/nnsight/models/DiffuserModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch.utils.hooks import RemovableHandle
from transformers import BatchEncoding, CLIPTextModel, CLIPTokenizer

from .AbstractModel import AbstractModel
from .NNsightModel import NNsightModel


class Diffuser(torch.nn.Module):
Expand Down Expand Up @@ -132,7 +132,7 @@ def scan(
has_nsfw_concept = None


class DiffuserModel(AbstractModel):
class DiffuserModel(NNsightModel):
def __init__(self, *args, tokenizer=None, **kwargs) -> None:
self.local_model: Diffuser = None
self.meta_model: Diffuser = None
Expand Down
29 changes: 5 additions & 24 deletions src/nnsight/models/LanguageModel.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
from __future__ import annotations

import collections
from typing import Any, Callable, Dict, List, Tuple, Union
from typing import Any, Dict, List, Union

import torch
from torch.utils.hooks import RemovableHandle
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
BatchEncoding, PretrainedConfig, PreTrainedModel,
PreTrainedTokenizer)
from transformers.models.auto import modeling_auto

from .AbstractModel import AbstractModel
from .NNsightModel import NNsightModel


class LanguageModel(AbstractModel):
class LanguageModel(NNsightModel):
"""LanguageModels are nnsight wrappers around transformer auto models.
Inputs can be in the form of:
Expand Down Expand Up @@ -52,9 +50,6 @@ def __init__(

super().__init__(*args, **kwargs)

def _register_increment_hook(self, hook: Callable) -> RemovableHandle:
return self.local_model.register_forward_hook(hook)

def _load_meta(self, repoid_or_path, *args, **kwargs) -> PreTrainedModel:
self.config = AutoConfig.from_pretrained(repoid_or_path, *args, **kwargs)

Expand Down Expand Up @@ -160,23 +155,9 @@ def _example_input(self) -> Dict[str, torch.Tensor]:
{"input_ids": torch.tensor([[0]]), "labels": torch.tensor([[0]])}
)

def _scan(self, prepared_inputs, *args, **kwargs) -> None:
# TODO
# Actually use args and kwargs. Dont do this now because the args may be specific to _generation which throws unused args errors
# Maybe inspect signature and filter out unused args.
self.meta_model(**prepared_inputs.copy().to("meta"))

def _forward(self, prepared_inputs, *args, **kwargs) -> Any:
return self.local_model(
*args, **prepared_inputs.to(self.local_model.device), **kwargs
)

def _generation(
self, prepared_inputs, *args, max_new_tokens: int = 1, **kwargs
) -> Any:
return self.local_model.generate(
*args,
**prepared_inputs.to(self.local_model.device),
max_new_tokens=max_new_tokens,
**kwargs,
return super()._generation(
prepared_inputs, *args, max_new_tokens=max_new_tokens, **kwargs
)
Loading

0 comments on commit fa0f4cc

Please sign in to comment.