Skip to content

Commit

Permalink
Custom serialization format. Dropping pickle on the request side
Browse files Browse the repository at this point in the history
  • Loading branch information
JadenFiotto-Kaufman committed Jan 4, 2024
1 parent 47208ef commit 8d4de2f
Show file tree
Hide file tree
Showing 6 changed files with 322 additions and 16 deletions.
6 changes: 4 additions & 2 deletions src/nnsight/contexts/Runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def run_server(self):
kwargs=self.kwargs,
repo_id=self.model.repoid_path_clsname,
batched_input=self.batched_input,
intervention_graph=self.graph,
intervention_graph=self.graph.nodes,
generation=self.generation,
include_output=self.remote_include_output,
)
Expand Down Expand Up @@ -161,7 +161,9 @@ def blocking_response(data):

sio.emit(
"blocking_request",
request.model_dump(exclude_defaults=True, exclude_none=True),
request.model_dump(
mode="json", exclude=["session_id", "received", "blocking", "id"]
),
)

sio.wait()
Expand Down
36 changes: 22 additions & 14 deletions src/nnsight/pydantics/Request.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
import pickle
from __future__ import annotations

from datetime import datetime
from typing import Any, Dict, List, Union
from typing import Dict, List, Union

from pydantic import BaseModel, field_serializer
from pydantic import BaseModel, ConfigDict

from ..tracing.Graph import Graph
from .format import types
from .format.types import *

class RequestModel(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

args: List
kwargs: Dict
repo_id: str
batched_input: Union[bytes, Any]
intervention_graph: Union[bytes, Any]
batched_input: types.ValueTypes
intervention_graph: Union[Dict[str, Union[types.NodeType, types.NodeModel]], Graph]
generation: bool

id: str = None
Expand All @@ -19,15 +25,17 @@ class RequestModel(BaseModel):
blocking: bool = False
include_output: bool = False

@field_serializer("intervention_graph")
def intervention_graph_serialize(self, value, _info) -> bytes:
value.compile(None)
def compile(self) -> RequestModel:
graph = Graph(None, validate=False)

for node in self.intervention_graph.values():
node.compile(graph, self.intervention_graph)

self.intervention_graph = graph

self.batched_input = self.batched_input.compile(None, None)

for node in value.nodes.values():
node.proxy_value = None
return self

return pickle.dumps(value)

@field_serializer("batched_input")
def serialize(self, value, _info) -> bytes:
return pickle.dumps(value)
RequestModel.model_rebuild()
1 change: 1 addition & 0 deletions src/nnsight/pydantics/format/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .functions import FUNCTIONS_WHITELIST
38 changes: 38 additions & 0 deletions src/nnsight/pydantics/format/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import operator
from inspect import getmembers, isbuiltin, ismethoddescriptor

import torch

from ... import util
from ...tracing.Proxy import Proxy

FUNCTIONS_WHITELIST = {}
FUNCTIONS_WHITELIST.update(
{
f"_VariableFunctionsClass.{key}": value
for key, value in getmembers(torch._C._VariableFunctions, isbuiltin)
}
)
FUNCTIONS_WHITELIST.update(
{
f"Tensor.{key}": value
for key, value in getmembers(torch._C._TensorBase, ismethoddescriptor)
}
)
FUNCTIONS_WHITELIST.update(
{
f"{key}": value
for key, value in getmembers(operator, isbuiltin)
if not key.startswith("_")
}
)
FUNCTIONS_WHITELIST.update(
{
"null": "null",
"module": "module",
"argument": "argument",
"swp": "swp",
"fetch_attr": util.fetch_attr,
"Proxy.proxy_call": Proxy.proxy_call,
}
)
216 changes: 216 additions & 0 deletions src/nnsight/pydantics/format/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
from __future__ import annotations

from types import BuiltinFunctionType, FunctionType, MethodDescriptorType
from typing import Dict, List, Literal, Union

import torch
from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic.functional_validators import AfterValidator
from typing_extensions import Annotated

from ...tracing.Graph import Graph
from ...tracing.Node import Node
from . import FUNCTIONS_WHITELIST

FUNCTION = Union[BuiltinFunctionType, FunctionType, MethodDescriptorType, str]
PRIMITIVE = Union[int, float, str, bool, None]


class NodeModel(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

class Reference(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

type_name: Literal["NODE_REFERENCE"] = "NODE_REFERENCE"

name: str

def compile(self, graph: Graph, nodes: [str, NodeModel]) -> Node:
return nodes[self.name].compile(graph, nodes)

name: str
target: Union[FunctionModel, FunctionType]
args: List[ValueTypes]
kwargs: Dict[str, ValueTypes]

def compile(self, graph: Graph, nodes: [str, NodeModel]) -> Node:
if self.name in graph.nodes:
return graph.nodes[self.name]

return graph.add(
value=None,
target=self.target.compile(graph, nodes),
args=[value.compile(graph, nodes) for value in self.args],
kwargs={
key: value.compile(graph, nodes) for key, value in self.kwargs.items()
},
name=self.name,
)


class PrimitiveModel(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
type_name: Literal["PRIMITIVE"] = "PRIMITIVE"
value: PRIMITIVE

def compile(self, graph: Graph, nodes: [str, NodeModel]) -> PRIMITIVE:
return self.value


class TensorModel(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

type_name: Literal["TENSOR"] = "TENSOR"

values: List

def compile(self, graph: Graph, nodes: [str, NodeModel]) -> torch.Tensor:
return torch.tensor(self.values)


class SliceModel(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

type_name: Literal["SLICE"] = "SLICE"

start: ValueTypes
stop: ValueTypes
step: ValueTypes

def compile(self, graph: Graph, nodes: [str, NodeModel]) -> slice:
return slice(
self.start.compile(graph, nodes),
self.stop.compile(graph, nodes),
self.step.compile(graph, nodes),
)


class ListModel(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

type_name: Literal["LIST"] = "LIST"

values: List[ValueTypes]

def compile(self, graph: Graph, nodes: [str, NodeModel]) -> list:
return [value.compile(graph, nodes) for value in self.values]


class TupleModel(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

type_name: Literal["TUPLE"] = "TUPLE"

values: List[ValueTypes]

def compile(self, graph: Graph, nodes: [str, NodeModel]) -> tuple:
return tuple([value.compile(graph, nodes) for value in self.values])


class DictModel(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

type_name: Literal["DICT"] = "DICT"

values: Dict[str, ValueTypes]

def compile(self, graph: Graph, nodes: [str, NodeModel]) -> dict:
return tuple(
{key: value.compile(graph, nodes) for key, value in self.values.items()}
)


class FunctionWhitelistError(Exception):
pass


class FunctionModel(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

type_name: Literal["FUNCTION"] = "FUNCTION"

function_name: str

@field_validator("function_name")
@classmethod
def check_function_whitelist(cls, qualname: str) -> str:
if qualname not in FUNCTIONS_WHITELIST:
raise FunctionWhitelistError(
f"Function with name `{qualname}` not in function whitelist."
)

return qualname

def compile(self, graph: Graph, nodes: [str, NodeModel]) -> FUNCTION:
return FUNCTIONS_WHITELIST[self.function_name]


PrimitiveType = Annotated[
PRIMITIVE, AfterValidator(lambda value: PrimitiveModel(value=value))
]

TensorType = Annotated[
torch.Tensor, AfterValidator(lambda value: TensorModel(values=value.tolist()))
]

SliceType = Annotated[
slice,
AfterValidator(
lambda value: SliceModel(start=value.start, stop=value.stop, step=value.step)
),
]

ListType = Annotated[list, AfterValidator(lambda value: ListModel(values=value))]

TupleType = Annotated[
tuple, AfterValidator(lambda value: TupleModel(values=list(value)))
]

DictType = Annotated[dict, AfterValidator(lambda value: DictModel(values=value))]

FunctionType = Annotated[
FUNCTION,
AfterValidator(
lambda value: FunctionModel(
function_name=value.__qualname__ if not isinstance(value, str) else value
)
),
]

NodeReferenceType = Annotated[
Node, AfterValidator(lambda value: NodeModel.Reference(name=value.name))
]

NodeType = Annotated[
Node,
AfterValidator(
lambda value: NodeModel(
name=value.name, target=value.target, args=value.args, kwargs=value.kwargs
)
),
]

ValueTypes = Union[
Annotated[
Union[
NodeModel.Reference,
SliceModel,
TensorModel,
PrimitiveModel,
ListModel,
TupleModel,
DictModel,
],
Field(discriminator="type_name"),
],
Union[
NodeReferenceType,
SliceType,
TensorType,
PrimitiveType,
ListType,
TupleType,
DictType,
],
]
Loading

0 comments on commit 8d4de2f

Please sign in to comment.