diff --git a/src/nnsight/contexts/Invoker.py b/src/nnsight/contexts/Invoker.py index 4b62943d..ec5f1aee 100644 --- a/src/nnsight/contexts/Invoker.py +++ b/src/nnsight/contexts/Invoker.py @@ -3,6 +3,9 @@ from contextlib import AbstractContextManager from typing import Any, Dict +import torch + +from ..module import Module from ..tracing.Proxy import Proxy from .Tracer import Tracer @@ -59,14 +62,15 @@ def __enter__(self) -> Invoker: self.tracer.model._scan(self.input, *self.tracer.args, **self.tracer.kwargs) else: for name, module in self.tracer.model.meta_model.named_modules(): - module._output = None - module._input = None - module._backward_output = None - module._backward_input = None + if isinstance(module, Module): + module.clear() self.tracer.batch_start += self.tracer.batch_size - - self.tracer.batched_input, self.tracer.batch_size = self.tracer.model._batch_inputs(self.input, self.tracer.batched_input) + + ( + self.tracer.batched_input, + self.tracer.batch_size, + ) = self.tracer.model._batch_inputs(self.input, self.tracer.batched_input) return self @@ -93,10 +97,8 @@ def next(self, increment: int = 1) -> None: ) else: for name, module in self.tracer.model.meta_model.named_modules(): - module._output = None - module._input = None - module._backward_output = None - module._backward_input = None + if isinstance(module, Module): + module.clear() def save_all(self) -> Dict[str, Proxy]: """Saves the output of all modules and returns a dictionary of [module_path -> save proxy] diff --git a/src/nnsight/contexts/Runner.py b/src/nnsight/contexts/Runner.py index fdc79f48..7b1a2c73 100644 --- a/src/nnsight/contexts/Runner.py +++ b/src/nnsight/contexts/Runner.py @@ -91,7 +91,7 @@ def run_server(self): kwargs=self.kwargs, repo_id=self.model.repoid_path_clsname, batched_input=self.batched_input, - intervention_graph=self.graph, + intervention_graph=self.graph.nodes, generation=self.generation, include_output=self.remote_include_output, ) @@ -161,7 +161,9 @@ def blocking_response(data): sio.emit( "blocking_request", - request.model_dump(exclude_defaults=True, exclude_none=True), + request.model_dump( + mode="json", exclude=["session_id", "received", "blocking", "id"] + ), ) sio.wait() diff --git a/src/nnsight/contexts/Tracer.py b/src/nnsight/contexts/Tracer.py index be0a5f60..60276711 100644 --- a/src/nnsight/contexts/Tracer.py +++ b/src/nnsight/contexts/Tracer.py @@ -32,7 +32,7 @@ def __init__( self, model: "AbstractModel", *args, - validate:bool = True, + validate: bool = True, **kwargs, ) -> None: self.model = model @@ -40,7 +40,9 @@ def __init__( self.args = args self.kwargs = kwargs - self.graph = Graph(self.model.meta_model, proxy_class=InterventionProxy, validate=validate) + self.graph = Graph( + self.model.meta_model, proxy_class=InterventionProxy, validate=validate + ) self.batch_size: int = 0 self.batch_start: int = 0 diff --git a/src/nnsight/module.py b/src/nnsight/module.py index 1f59dc7c..7b32dae4 100644 --- a/src/nnsight/module.py +++ b/src/nnsight/module.py @@ -53,6 +53,7 @@ class Module(torch.nn.Module): def __init__(self) -> None: self.module_path: str = None + self.input_shape: torch.Size = None self.input_type: torch.dtype = None self.output_shape: torch.Size = None @@ -62,10 +63,17 @@ def __init__(self) -> None: self._input: InterventionProxy = None self._backward_output: InterventionProxy = None self._backward_input: InterventionProxy = None + self._graph: Graph = None self.tracer: Tracer = None + def clear(self): + self._output: InterventionProxy = None + self._input: InterventionProxy = None + self._backward_output: InterventionProxy = None + self._backward_input: InterventionProxy = None + def __call__( self, *args: List[Any], **kwds: Dict[str, Any] ) -> Union[Any, InterventionProxy]: @@ -279,11 +287,9 @@ def wrap(module: torch.nn.Module) -> Module: Module: The wrapped Module. """ - def hook(module: Module, input: Any, input_kwargs:Dict, output: Any): - module._output = None - module._input = None - module._backward_output = None - module._backward_input = None + def hook(module: Module, input: Any, input_kwargs: Dict, output: Any): + + module.clear() input = (input, input_kwargs) diff --git a/src/nnsight/pydantics/Request.py b/src/nnsight/pydantics/Request.py index 6b77652b..b84997eb 100644 --- a/src/nnsight/pydantics/Request.py +++ b/src/nnsight/pydantics/Request.py @@ -1,16 +1,22 @@ -import pickle +from __future__ import annotations + from datetime import datetime -from typing import Any, Dict, List, Union +from typing import Dict, List, Union -from pydantic import BaseModel, field_serializer +from pydantic import BaseModel, ConfigDict +from ..tracing.Graph import Graph +from .format import types +from .format.types import * class RequestModel(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + args: List kwargs: Dict repo_id: str - batched_input: Union[bytes, Any] - intervention_graph: Union[bytes, Any] + batched_input: types.ValueTypes + intervention_graph: Union[Dict[str, Union[types.NodeType, types.NodeModel]], Graph] generation: bool id: str = None @@ -19,15 +25,17 @@ class RequestModel(BaseModel): blocking: bool = False include_output: bool = False - @field_serializer("intervention_graph") - def intervention_graph_serialize(self, value, _info) -> bytes: - value.compile(None) + def compile(self) -> RequestModel: + graph = Graph(None, validate=False) + + for node in self.intervention_graph.values(): + node.compile(graph, self.intervention_graph) + + self.intervention_graph = graph + + self.batched_input = self.batched_input.compile(None, None) - for node in value.nodes.values(): - node.proxy_value = None + return self - return pickle.dumps(value) - @field_serializer("batched_input") - def serialize(self, value, _info) -> bytes: - return pickle.dumps(value) +RequestModel.model_rebuild() diff --git a/src/nnsight/pydantics/format/__init__.py b/src/nnsight/pydantics/format/__init__.py new file mode 100644 index 00000000..b2a58312 --- /dev/null +++ b/src/nnsight/pydantics/format/__init__.py @@ -0,0 +1 @@ +from .functions import FUNCTIONS_WHITELIST \ No newline at end of file diff --git a/src/nnsight/pydantics/format/functions.py b/src/nnsight/pydantics/format/functions.py new file mode 100644 index 00000000..12bc3cde --- /dev/null +++ b/src/nnsight/pydantics/format/functions.py @@ -0,0 +1,38 @@ +import operator +from inspect import getmembers, isbuiltin, ismethoddescriptor + +import torch + +from ... import util +from ...tracing.Proxy import Proxy + +FUNCTIONS_WHITELIST = {} +FUNCTIONS_WHITELIST.update( + { + f"_VariableFunctionsClass.{key}": value + for key, value in getmembers(torch._C._VariableFunctions, isbuiltin) + } +) +FUNCTIONS_WHITELIST.update( + { + f"Tensor.{key}": value + for key, value in getmembers(torch._C._TensorBase, ismethoddescriptor) + } +) +FUNCTIONS_WHITELIST.update( + { + f"{key}": value + for key, value in getmembers(operator, isbuiltin) + if not key.startswith("_") + } +) +FUNCTIONS_WHITELIST.update( + { + "null": "null", + "module": "module", + "argument": "argument", + "swp": "swp", + "fetch_attr": util.fetch_attr, + "Proxy.proxy_call": Proxy.proxy_call, + } +) diff --git a/src/nnsight/pydantics/format/types.py b/src/nnsight/pydantics/format/types.py new file mode 100644 index 00000000..c602ce5e --- /dev/null +++ b/src/nnsight/pydantics/format/types.py @@ -0,0 +1,216 @@ +from __future__ import annotations + +from types import BuiltinFunctionType, FunctionType, MethodDescriptorType +from typing import Dict, List, Literal, Union + +import torch +from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic.functional_validators import AfterValidator +from typing_extensions import Annotated + +from ...tracing.Graph import Graph +from ...tracing.Node import Node +from . import FUNCTIONS_WHITELIST + +FUNCTION = Union[BuiltinFunctionType, FunctionType, MethodDescriptorType, str] +PRIMITIVE = Union[int, float, str, bool, None] + + +class NodeModel(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + class Reference(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + type_name: Literal["NODE_REFERENCE"] = "NODE_REFERENCE" + + name: str + + def compile(self, graph: Graph, nodes: [str, NodeModel]) -> Node: + return nodes[self.name].compile(graph, nodes) + + name: str + target: Union[FunctionModel, FunctionType] + args: List[ValueTypes] + kwargs: Dict[str, ValueTypes] + + def compile(self, graph: Graph, nodes: [str, NodeModel]) -> Node: + if self.name in graph.nodes: + return graph.nodes[self.name] + + return graph.add( + value=None, + target=self.target.compile(graph, nodes), + args=[value.compile(graph, nodes) for value in self.args], + kwargs={ + key: value.compile(graph, nodes) for key, value in self.kwargs.items() + }, + name=self.name, + ) + + +class PrimitiveModel(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + type_name: Literal["PRIMITIVE"] = "PRIMITIVE" + value: PRIMITIVE + + def compile(self, graph: Graph, nodes: [str, NodeModel]) -> PRIMITIVE: + return self.value + + +class TensorModel(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + type_name: Literal["TENSOR"] = "TENSOR" + + values: List + + def compile(self, graph: Graph, nodes: [str, NodeModel]) -> torch.Tensor: + return torch.tensor(self.values) + + +class SliceModel(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + type_name: Literal["SLICE"] = "SLICE" + + start: ValueTypes + stop: ValueTypes + step: ValueTypes + + def compile(self, graph: Graph, nodes: [str, NodeModel]) -> slice: + return slice( + self.start.compile(graph, nodes), + self.stop.compile(graph, nodes), + self.step.compile(graph, nodes), + ) + + +class ListModel(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + type_name: Literal["LIST"] = "LIST" + + values: List[ValueTypes] + + def compile(self, graph: Graph, nodes: [str, NodeModel]) -> list: + return [value.compile(graph, nodes) for value in self.values] + + +class TupleModel(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + type_name: Literal["TUPLE"] = "TUPLE" + + values: List[ValueTypes] + + def compile(self, graph: Graph, nodes: [str, NodeModel]) -> tuple: + return tuple([value.compile(graph, nodes) for value in self.values]) + + +class DictModel(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + type_name: Literal["DICT"] = "DICT" + + values: Dict[str, ValueTypes] + + def compile(self, graph: Graph, nodes: [str, NodeModel]) -> dict: + return tuple( + {key: value.compile(graph, nodes) for key, value in self.values.items()} + ) + + +class FunctionWhitelistError(Exception): + pass + + +class FunctionModel(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + type_name: Literal["FUNCTION"] = "FUNCTION" + + function_name: str + + @field_validator("function_name") + @classmethod + def check_function_whitelist(cls, qualname: str) -> str: + if qualname not in FUNCTIONS_WHITELIST: + raise FunctionWhitelistError( + f"Function with name `{qualname}` not in function whitelist." + ) + + return qualname + + def compile(self, graph: Graph, nodes: [str, NodeModel]) -> FUNCTION: + return FUNCTIONS_WHITELIST[self.function_name] + + +PrimitiveType = Annotated[ + PRIMITIVE, AfterValidator(lambda value: PrimitiveModel(value=value)) +] + +TensorType = Annotated[ + torch.Tensor, AfterValidator(lambda value: TensorModel(values=value.tolist())) +] + +SliceType = Annotated[ + slice, + AfterValidator( + lambda value: SliceModel(start=value.start, stop=value.stop, step=value.step) + ), +] + +ListType = Annotated[list, AfterValidator(lambda value: ListModel(values=value))] + +TupleType = Annotated[ + tuple, AfterValidator(lambda value: TupleModel(values=list(value))) +] + +DictType = Annotated[dict, AfterValidator(lambda value: DictModel(values=value))] + +FunctionType = Annotated[ + FUNCTION, + AfterValidator( + lambda value: FunctionModel( + function_name=value.__qualname__ if not isinstance(value, str) else value + ) + ), +] + +NodeReferenceType = Annotated[ + Node, AfterValidator(lambda value: NodeModel.Reference(name=value.name)) +] + +NodeType = Annotated[ + Node, + AfterValidator( + lambda value: NodeModel( + name=value.name, target=value.target, args=value.args, kwargs=value.kwargs + ) + ), +] + +ValueTypes = Union[ + Annotated[ + Union[ + NodeModel.Reference, + SliceModel, + TensorModel, + PrimitiveModel, + ListModel, + TupleModel, + DictModel, + ], + Field(discriminator="type_name"), + ], + Union[ + NodeReferenceType, + SliceType, + TensorType, + PrimitiveType, + ListType, + TupleType, + DictType, + ], +] diff --git a/tests/test_lm.py b/tests/test_lm.py index 4c959f80..65d51018 100644 --- a/tests/test_lm.py +++ b/tests/test_lm.py @@ -2,6 +2,9 @@ import torch import nnsight +from nnsight.contexts.Runner import Runner +from nnsight.pydantics import RequestModel +from nnsight.tracing.Graph import Graph @pytest.fixture(scope="module") @@ -14,6 +17,26 @@ def MSG_prompt(): return "Madison Square Garden is located in the city of" +def _test_serialize(runner: Runner): + request = RequestModel( + args=runner.args, + kwargs=runner.kwargs, + repo_id=runner.model.repoid_path_clsname, + generation=runner.generation, + intervention_graph=runner.graph.nodes, + batched_input=runner.batched_input, + ) + + request_json = request.model_dump( + mode="json", exclude=["session_id", "received", "blocking", "id"] + ) + + request2 = RequestModel(**request_json) + request2.compile() + + assert isinstance(request2.intervention_graph, Graph) + + def test_generation(gpt2: nnsight.LanguageModel, MSG_prompt: str): with gpt2.generate(max_new_tokens=3) as generator: with generator.invoke(MSG_prompt) as invoker: @@ -23,6 +46,8 @@ def test_generation(gpt2: nnsight.LanguageModel, MSG_prompt: str): assert output == "Madison Square Garden is located in the city of New York City" + _test_serialize(generator) + def test_save(gpt2: nnsight.LanguageModel): with gpt2.generate(max_new_tokens=1) as generator: @@ -38,6 +63,8 @@ def test_save(gpt2: nnsight.LanguageModel): assert isinstance(hs_input.value, torch.Tensor) assert hs_input.value.ndim == 3 + _test_serialize(generator) + def test_set1(gpt2: nnsight.LanguageModel, MSG_prompt: str): with gpt2.generate(max_new_tokens=1) as generator: @@ -54,6 +81,8 @@ def test_set1(gpt2: nnsight.LanguageModel, MSG_prompt: str): assert (post.value == 0).all().item() assert output != "Madison Square Garden is located in the city of New" + _test_serialize(generator) + def test_set2(gpt2: nnsight.LanguageModel, MSG_prompt: str): with gpt2.generate(max_new_tokens=1) as generator: @@ -70,6 +99,8 @@ def test_set2(gpt2: nnsight.LanguageModel, MSG_prompt: str): assert (post.value == 0).all().item() assert output != "Madison Square Garden is located in the city of New" + _test_serialize(generator) + def test_adhoc_module(gpt2: nnsight.LanguageModel): with gpt2.generate() as generator: @@ -82,6 +113,8 @@ def test_adhoc_module(gpt2: nnsight.LanguageModel): assert output == "\n-el Tower is a the middle centre Paris" + _test_serialize(generator) + def test_embeddings_set1(gpt2: nnsight.LanguageModel, MSG_prompt: str): with gpt2.generate(max_new_tokens=3) as generator: @@ -97,6 +130,8 @@ def test_embeddings_set1(gpt2: nnsight.LanguageModel, MSG_prompt: str): assert output1 == "Madison Square Garden is located in the city of New York City" assert output2 == "_ _ _ _ _ _ _ _ _ New York City" + _test_serialize(generator) + def test_embeddings_set2(gpt2: nnsight.LanguageModel, MSG_prompt: str): with gpt2.generate(max_new_tokens=3) as generator: @@ -114,6 +149,8 @@ def test_embeddings_set2(gpt2: nnsight.LanguageModel, MSG_prompt: str): assert output1 == "Madison Square Garden is located in the city of New York City" assert output2 == "_ _ _ _ _ _ _ _ _ New York City" + _test_serialize(generator) + def test_grad(gpt2: nnsight.LanguageModel): with gpt2.forward(inference=False) as runner: @@ -125,6 +162,8 @@ def test_grad(gpt2: nnsight.LanguageModel): logits.sum().backward() + _test_serialize(runner) + with gpt2.forward(inference=False) as runner: with runner.invoke("Hello World") as invoker: hidden_states_grad = gpt2.transformer.h[-1].backward_output[0].save() @@ -133,4 +172,6 @@ def test_grad(gpt2: nnsight.LanguageModel): logits.sum().backward() + _test_serialize(runner) + assert (hidden_states_grad.value == hidden_states.value.grad).all().item()