diff --git a/src/nnsight/contexts/Runner.py b/src/nnsight/contexts/Runner.py index fdc79f48..7b1a2c73 100644 --- a/src/nnsight/contexts/Runner.py +++ b/src/nnsight/contexts/Runner.py @@ -91,7 +91,7 @@ def run_server(self): kwargs=self.kwargs, repo_id=self.model.repoid_path_clsname, batched_input=self.batched_input, - intervention_graph=self.graph, + intervention_graph=self.graph.nodes, generation=self.generation, include_output=self.remote_include_output, ) @@ -161,7 +161,9 @@ def blocking_response(data): sio.emit( "blocking_request", - request.model_dump(exclude_defaults=True, exclude_none=True), + request.model_dump( + mode="json", exclude=["session_id", "received", "blocking", "id"] + ), ) sio.wait() diff --git a/src/nnsight/pydantics/Request.py b/src/nnsight/pydantics/Request.py index 6b77652b..b84997eb 100644 --- a/src/nnsight/pydantics/Request.py +++ b/src/nnsight/pydantics/Request.py @@ -1,16 +1,22 @@ -import pickle +from __future__ import annotations + from datetime import datetime -from typing import Any, Dict, List, Union +from typing import Dict, List, Union -from pydantic import BaseModel, field_serializer +from pydantic import BaseModel, ConfigDict +from ..tracing.Graph import Graph +from .format import types +from .format.types import * class RequestModel(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + args: List kwargs: Dict repo_id: str - batched_input: Union[bytes, Any] - intervention_graph: Union[bytes, Any] + batched_input: types.ValueTypes + intervention_graph: Union[Dict[str, Union[types.NodeType, types.NodeModel]], Graph] generation: bool id: str = None @@ -19,15 +25,17 @@ class RequestModel(BaseModel): blocking: bool = False include_output: bool = False - @field_serializer("intervention_graph") - def intervention_graph_serialize(self, value, _info) -> bytes: - value.compile(None) + def compile(self) -> RequestModel: + graph = Graph(None, validate=False) + + for node in self.intervention_graph.values(): + node.compile(graph, self.intervention_graph) + + self.intervention_graph = graph + + self.batched_input = self.batched_input.compile(None, None) - for node in value.nodes.values(): - node.proxy_value = None + return self - return pickle.dumps(value) - @field_serializer("batched_input") - def serialize(self, value, _info) -> bytes: - return pickle.dumps(value) +RequestModel.model_rebuild() diff --git a/src/nnsight/pydantics/format/__init__.py b/src/nnsight/pydantics/format/__init__.py new file mode 100644 index 00000000..b2a58312 --- /dev/null +++ b/src/nnsight/pydantics/format/__init__.py @@ -0,0 +1 @@ +from .functions import FUNCTIONS_WHITELIST \ No newline at end of file diff --git a/src/nnsight/pydantics/format/functions.py b/src/nnsight/pydantics/format/functions.py new file mode 100644 index 00000000..12bc3cde --- /dev/null +++ b/src/nnsight/pydantics/format/functions.py @@ -0,0 +1,38 @@ +import operator +from inspect import getmembers, isbuiltin, ismethoddescriptor + +import torch + +from ... import util +from ...tracing.Proxy import Proxy + +FUNCTIONS_WHITELIST = {} +FUNCTIONS_WHITELIST.update( + { + f"_VariableFunctionsClass.{key}": value + for key, value in getmembers(torch._C._VariableFunctions, isbuiltin) + } +) +FUNCTIONS_WHITELIST.update( + { + f"Tensor.{key}": value + for key, value in getmembers(torch._C._TensorBase, ismethoddescriptor) + } +) +FUNCTIONS_WHITELIST.update( + { + f"{key}": value + for key, value in getmembers(operator, isbuiltin) + if not key.startswith("_") + } +) +FUNCTIONS_WHITELIST.update( + { + "null": "null", + "module": "module", + "argument": "argument", + "swp": "swp", + "fetch_attr": util.fetch_attr, + "Proxy.proxy_call": Proxy.proxy_call, + } +) diff --git a/src/nnsight/pydantics/format/types.py b/src/nnsight/pydantics/format/types.py new file mode 100644 index 00000000..c602ce5e --- /dev/null +++ b/src/nnsight/pydantics/format/types.py @@ -0,0 +1,216 @@ +from __future__ import annotations + +from types import BuiltinFunctionType, FunctionType, MethodDescriptorType +from typing import Dict, List, Literal, Union + +import torch +from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic.functional_validators import AfterValidator +from typing_extensions import Annotated + +from ...tracing.Graph import Graph +from ...tracing.Node import Node +from . import FUNCTIONS_WHITELIST + +FUNCTION = Union[BuiltinFunctionType, FunctionType, MethodDescriptorType, str] +PRIMITIVE = Union[int, float, str, bool, None] + + +class NodeModel(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + class Reference(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + type_name: Literal["NODE_REFERENCE"] = "NODE_REFERENCE" + + name: str + + def compile(self, graph: Graph, nodes: [str, NodeModel]) -> Node: + return nodes[self.name].compile(graph, nodes) + + name: str + target: Union[FunctionModel, FunctionType] + args: List[ValueTypes] + kwargs: Dict[str, ValueTypes] + + def compile(self, graph: Graph, nodes: [str, NodeModel]) -> Node: + if self.name in graph.nodes: + return graph.nodes[self.name] + + return graph.add( + value=None, + target=self.target.compile(graph, nodes), + args=[value.compile(graph, nodes) for value in self.args], + kwargs={ + key: value.compile(graph, nodes) for key, value in self.kwargs.items() + }, + name=self.name, + ) + + +class PrimitiveModel(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + type_name: Literal["PRIMITIVE"] = "PRIMITIVE" + value: PRIMITIVE + + def compile(self, graph: Graph, nodes: [str, NodeModel]) -> PRIMITIVE: + return self.value + + +class TensorModel(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + type_name: Literal["TENSOR"] = "TENSOR" + + values: List + + def compile(self, graph: Graph, nodes: [str, NodeModel]) -> torch.Tensor: + return torch.tensor(self.values) + + +class SliceModel(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + type_name: Literal["SLICE"] = "SLICE" + + start: ValueTypes + stop: ValueTypes + step: ValueTypes + + def compile(self, graph: Graph, nodes: [str, NodeModel]) -> slice: + return slice( + self.start.compile(graph, nodes), + self.stop.compile(graph, nodes), + self.step.compile(graph, nodes), + ) + + +class ListModel(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + type_name: Literal["LIST"] = "LIST" + + values: List[ValueTypes] + + def compile(self, graph: Graph, nodes: [str, NodeModel]) -> list: + return [value.compile(graph, nodes) for value in self.values] + + +class TupleModel(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + type_name: Literal["TUPLE"] = "TUPLE" + + values: List[ValueTypes] + + def compile(self, graph: Graph, nodes: [str, NodeModel]) -> tuple: + return tuple([value.compile(graph, nodes) for value in self.values]) + + +class DictModel(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + type_name: Literal["DICT"] = "DICT" + + values: Dict[str, ValueTypes] + + def compile(self, graph: Graph, nodes: [str, NodeModel]) -> dict: + return tuple( + {key: value.compile(graph, nodes) for key, value in self.values.items()} + ) + + +class FunctionWhitelistError(Exception): + pass + + +class FunctionModel(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + type_name: Literal["FUNCTION"] = "FUNCTION" + + function_name: str + + @field_validator("function_name") + @classmethod + def check_function_whitelist(cls, qualname: str) -> str: + if qualname not in FUNCTIONS_WHITELIST: + raise FunctionWhitelistError( + f"Function with name `{qualname}` not in function whitelist." + ) + + return qualname + + def compile(self, graph: Graph, nodes: [str, NodeModel]) -> FUNCTION: + return FUNCTIONS_WHITELIST[self.function_name] + + +PrimitiveType = Annotated[ + PRIMITIVE, AfterValidator(lambda value: PrimitiveModel(value=value)) +] + +TensorType = Annotated[ + torch.Tensor, AfterValidator(lambda value: TensorModel(values=value.tolist())) +] + +SliceType = Annotated[ + slice, + AfterValidator( + lambda value: SliceModel(start=value.start, stop=value.stop, step=value.step) + ), +] + +ListType = Annotated[list, AfterValidator(lambda value: ListModel(values=value))] + +TupleType = Annotated[ + tuple, AfterValidator(lambda value: TupleModel(values=list(value))) +] + +DictType = Annotated[dict, AfterValidator(lambda value: DictModel(values=value))] + +FunctionType = Annotated[ + FUNCTION, + AfterValidator( + lambda value: FunctionModel( + function_name=value.__qualname__ if not isinstance(value, str) else value + ) + ), +] + +NodeReferenceType = Annotated[ + Node, AfterValidator(lambda value: NodeModel.Reference(name=value.name)) +] + +NodeType = Annotated[ + Node, + AfterValidator( + lambda value: NodeModel( + name=value.name, target=value.target, args=value.args, kwargs=value.kwargs + ) + ), +] + +ValueTypes = Union[ + Annotated[ + Union[ + NodeModel.Reference, + SliceModel, + TensorModel, + PrimitiveModel, + ListModel, + TupleModel, + DictModel, + ], + Field(discriminator="type_name"), + ], + Union[ + NodeReferenceType, + SliceType, + TensorType, + PrimitiveType, + ListType, + TupleType, + DictType, + ], +] diff --git a/tests/test_lm.py b/tests/test_lm.py index 4c959f80..65d51018 100644 --- a/tests/test_lm.py +++ b/tests/test_lm.py @@ -2,6 +2,9 @@ import torch import nnsight +from nnsight.contexts.Runner import Runner +from nnsight.pydantics import RequestModel +from nnsight.tracing.Graph import Graph @pytest.fixture(scope="module") @@ -14,6 +17,26 @@ def MSG_prompt(): return "Madison Square Garden is located in the city of" +def _test_serialize(runner: Runner): + request = RequestModel( + args=runner.args, + kwargs=runner.kwargs, + repo_id=runner.model.repoid_path_clsname, + generation=runner.generation, + intervention_graph=runner.graph.nodes, + batched_input=runner.batched_input, + ) + + request_json = request.model_dump( + mode="json", exclude=["session_id", "received", "blocking", "id"] + ) + + request2 = RequestModel(**request_json) + request2.compile() + + assert isinstance(request2.intervention_graph, Graph) + + def test_generation(gpt2: nnsight.LanguageModel, MSG_prompt: str): with gpt2.generate(max_new_tokens=3) as generator: with generator.invoke(MSG_prompt) as invoker: @@ -23,6 +46,8 @@ def test_generation(gpt2: nnsight.LanguageModel, MSG_prompt: str): assert output == "Madison Square Garden is located in the city of New York City" + _test_serialize(generator) + def test_save(gpt2: nnsight.LanguageModel): with gpt2.generate(max_new_tokens=1) as generator: @@ -38,6 +63,8 @@ def test_save(gpt2: nnsight.LanguageModel): assert isinstance(hs_input.value, torch.Tensor) assert hs_input.value.ndim == 3 + _test_serialize(generator) + def test_set1(gpt2: nnsight.LanguageModel, MSG_prompt: str): with gpt2.generate(max_new_tokens=1) as generator: @@ -54,6 +81,8 @@ def test_set1(gpt2: nnsight.LanguageModel, MSG_prompt: str): assert (post.value == 0).all().item() assert output != "Madison Square Garden is located in the city of New" + _test_serialize(generator) + def test_set2(gpt2: nnsight.LanguageModel, MSG_prompt: str): with gpt2.generate(max_new_tokens=1) as generator: @@ -70,6 +99,8 @@ def test_set2(gpt2: nnsight.LanguageModel, MSG_prompt: str): assert (post.value == 0).all().item() assert output != "Madison Square Garden is located in the city of New" + _test_serialize(generator) + def test_adhoc_module(gpt2: nnsight.LanguageModel): with gpt2.generate() as generator: @@ -82,6 +113,8 @@ def test_adhoc_module(gpt2: nnsight.LanguageModel): assert output == "\n-el Tower is a the middle centre Paris" + _test_serialize(generator) + def test_embeddings_set1(gpt2: nnsight.LanguageModel, MSG_prompt: str): with gpt2.generate(max_new_tokens=3) as generator: @@ -97,6 +130,8 @@ def test_embeddings_set1(gpt2: nnsight.LanguageModel, MSG_prompt: str): assert output1 == "Madison Square Garden is located in the city of New York City" assert output2 == "_ _ _ _ _ _ _ _ _ New York City" + _test_serialize(generator) + def test_embeddings_set2(gpt2: nnsight.LanguageModel, MSG_prompt: str): with gpt2.generate(max_new_tokens=3) as generator: @@ -114,6 +149,8 @@ def test_embeddings_set2(gpt2: nnsight.LanguageModel, MSG_prompt: str): assert output1 == "Madison Square Garden is located in the city of New York City" assert output2 == "_ _ _ _ _ _ _ _ _ New York City" + _test_serialize(generator) + def test_grad(gpt2: nnsight.LanguageModel): with gpt2.forward(inference=False) as runner: @@ -125,6 +162,8 @@ def test_grad(gpt2: nnsight.LanguageModel): logits.sum().backward() + _test_serialize(runner) + with gpt2.forward(inference=False) as runner: with runner.invoke("Hello World") as invoker: hidden_states_grad = gpt2.transformer.h[-1].backward_output[0].save() @@ -133,4 +172,6 @@ def test_grad(gpt2: nnsight.LanguageModel): logits.sum().backward() + _test_serialize(runner) + assert (hidden_states_grad.value == hidden_states.value.grad).all().item()