Skip to content

Commit

Permalink
Merge pull request #45 from JadenFiotto-Kaufman/nopickle
Browse files Browse the repository at this point in the history
Nopickle
  • Loading branch information
JadenFiotto-Kaufman authored Jan 4, 2024
2 parents 9fa48aa + 8d4de2f commit efda5b1
Show file tree
Hide file tree
Showing 9 changed files with 349 additions and 33 deletions.
22 changes: 12 additions & 10 deletions src/nnsight/contexts/Invoker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from contextlib import AbstractContextManager
from typing import Any, Dict

import torch

from ..module import Module
from ..tracing.Proxy import Proxy
from .Tracer import Tracer

Expand Down Expand Up @@ -59,14 +62,15 @@ def __enter__(self) -> Invoker:
self.tracer.model._scan(self.input, *self.tracer.args, **self.tracer.kwargs)
else:
for name, module in self.tracer.model.meta_model.named_modules():
module._output = None
module._input = None
module._backward_output = None
module._backward_input = None
if isinstance(module, Module):
module.clear()

self.tracer.batch_start += self.tracer.batch_size

self.tracer.batched_input, self.tracer.batch_size = self.tracer.model._batch_inputs(self.input, self.tracer.batched_input)

(
self.tracer.batched_input,
self.tracer.batch_size,
) = self.tracer.model._batch_inputs(self.input, self.tracer.batched_input)

return self

Expand All @@ -93,10 +97,8 @@ def next(self, increment: int = 1) -> None:
)
else:
for name, module in self.tracer.model.meta_model.named_modules():
module._output = None
module._input = None
module._backward_output = None
module._backward_input = None
if isinstance(module, Module):
module.clear()

def save_all(self) -> Dict[str, Proxy]:
"""Saves the output of all modules and returns a dictionary of [module_path -> save proxy]
Expand Down
6 changes: 4 additions & 2 deletions src/nnsight/contexts/Runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def run_server(self):
kwargs=self.kwargs,
repo_id=self.model.repoid_path_clsname,
batched_input=self.batched_input,
intervention_graph=self.graph,
intervention_graph=self.graph.nodes,
generation=self.generation,
include_output=self.remote_include_output,
)
Expand Down Expand Up @@ -161,7 +161,9 @@ def blocking_response(data):

sio.emit(
"blocking_request",
request.model_dump(exclude_defaults=True, exclude_none=True),
request.model_dump(
mode="json", exclude=["session_id", "received", "blocking", "id"]
),
)

sio.wait()
Expand Down
6 changes: 4 additions & 2 deletions src/nnsight/contexts/Tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,17 @@ def __init__(
self,
model: "AbstractModel",
*args,
validate:bool = True,
validate: bool = True,
**kwargs,
) -> None:
self.model = model

self.args = args
self.kwargs = kwargs

self.graph = Graph(self.model.meta_model, proxy_class=InterventionProxy, validate=validate)
self.graph = Graph(
self.model.meta_model, proxy_class=InterventionProxy, validate=validate
)

self.batch_size: int = 0
self.batch_start: int = 0
Expand Down
16 changes: 11 additions & 5 deletions src/nnsight/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class Module(torch.nn.Module):

def __init__(self) -> None:
self.module_path: str = None

self.input_shape: torch.Size = None
self.input_type: torch.dtype = None
self.output_shape: torch.Size = None
Expand All @@ -62,10 +63,17 @@ def __init__(self) -> None:
self._input: InterventionProxy = None
self._backward_output: InterventionProxy = None
self._backward_input: InterventionProxy = None

self._graph: Graph = None

self.tracer: Tracer = None

def clear(self):
self._output: InterventionProxy = None
self._input: InterventionProxy = None
self._backward_output: InterventionProxy = None
self._backward_input: InterventionProxy = None

def __call__(
self, *args: List[Any], **kwds: Dict[str, Any]
) -> Union[Any, InterventionProxy]:
Expand Down Expand Up @@ -279,11 +287,9 @@ def wrap(module: torch.nn.Module) -> Module:
Module: The wrapped Module.
"""

def hook(module: Module, input: Any, input_kwargs:Dict, output: Any):
module._output = None
module._input = None
module._backward_output = None
module._backward_input = None
def hook(module: Module, input: Any, input_kwargs: Dict, output: Any):

module.clear()

input = (input, input_kwargs)

Expand Down
36 changes: 22 additions & 14 deletions src/nnsight/pydantics/Request.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
import pickle
from __future__ import annotations

from datetime import datetime
from typing import Any, Dict, List, Union
from typing import Dict, List, Union

from pydantic import BaseModel, field_serializer
from pydantic import BaseModel, ConfigDict

from ..tracing.Graph import Graph
from .format import types
from .format.types import *

class RequestModel(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

args: List
kwargs: Dict
repo_id: str
batched_input: Union[bytes, Any]
intervention_graph: Union[bytes, Any]
batched_input: types.ValueTypes
intervention_graph: Union[Dict[str, Union[types.NodeType, types.NodeModel]], Graph]
generation: bool

id: str = None
Expand All @@ -19,15 +25,17 @@ class RequestModel(BaseModel):
blocking: bool = False
include_output: bool = False

@field_serializer("intervention_graph")
def intervention_graph_serialize(self, value, _info) -> bytes:
value.compile(None)
def compile(self) -> RequestModel:
graph = Graph(None, validate=False)

for node in self.intervention_graph.values():
node.compile(graph, self.intervention_graph)

self.intervention_graph = graph

self.batched_input = self.batched_input.compile(None, None)

for node in value.nodes.values():
node.proxy_value = None
return self

return pickle.dumps(value)

@field_serializer("batched_input")
def serialize(self, value, _info) -> bytes:
return pickle.dumps(value)
RequestModel.model_rebuild()
1 change: 1 addition & 0 deletions src/nnsight/pydantics/format/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .functions import FUNCTIONS_WHITELIST
38 changes: 38 additions & 0 deletions src/nnsight/pydantics/format/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import operator
from inspect import getmembers, isbuiltin, ismethoddescriptor

import torch

from ... import util
from ...tracing.Proxy import Proxy

FUNCTIONS_WHITELIST = {}
FUNCTIONS_WHITELIST.update(
{
f"_VariableFunctionsClass.{key}": value
for key, value in getmembers(torch._C._VariableFunctions, isbuiltin)
}
)
FUNCTIONS_WHITELIST.update(
{
f"Tensor.{key}": value
for key, value in getmembers(torch._C._TensorBase, ismethoddescriptor)
}
)
FUNCTIONS_WHITELIST.update(
{
f"{key}": value
for key, value in getmembers(operator, isbuiltin)
if not key.startswith("_")
}
)
FUNCTIONS_WHITELIST.update(
{
"null": "null",
"module": "module",
"argument": "argument",
"swp": "swp",
"fetch_attr": util.fetch_attr,
"Proxy.proxy_call": Proxy.proxy_call,
}
)
Loading

0 comments on commit efda5b1

Please sign in to comment.