Skip to content

Commit

Permalink
Merge pull request #243 from ndif-team/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
JadenFiotto-Kaufman authored Sep 22, 2024
2 parents c70cf15 + 47758d6 commit ad06cca
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 43 deletions.
2 changes: 1 addition & 1 deletion src/nnsight/contexts/Invoker.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class Invoker(AbstractContextManager):
scan (bool): If to execute the model using `FakeTensor` in order to update the potential sizes/dtypes of all modules' Envoys' inputs/outputs as well as validate things work correctly.
Scanning is not free computation wise so you may want to turn this to false when running in a loop.
When making interventions, you made get shape errors if scan is false as it validates operations based on shapes so
for looped calls where shapes are consistent, you may want to have scan=True for the first loop. Defaults to True.
for looped calls where shapes are consistent, you may want to have scan=True for the first loop. Defaults to False.
kwargs (Dict[str,Any]): Keyword arguments passed to the model's _prepare_inputs method.
scanning (bool): If currently scanning.
"""
Expand Down
31 changes: 25 additions & 6 deletions src/nnsight/schema/Request.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,50 @@
from __future__ import annotations

import json
from datetime import datetime
from typing import TYPE_CHECKING, Dict, List, Union

from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, TypeAdapter, field_serializer

from .. import NNsight
from .format.types import *

if TYPE_CHECKING:
from ..contexts.backends.RemoteBackend import RemoteMixin

OBJECT_TYPES = Union[SessionType, TracerType, SessionModel, TracerModel]


class RequestModel(BaseModel):

model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())

object: Union[SessionType, TracerType, SessionModel, TracerModel]
model_config = ConfigDict(
arbitrary_types_allowed=True, protected_namespaces=()
)

object: str | OBJECT_TYPES
model_key: str

id: str = None
received: datetime = None

session_id: Optional[str] = None

@field_serializer("object")
def serialize_object(
self, object: Union[SessionType, TracerType, SessionModel, TracerModel]
) -> str:

if isinstance(object, str):
return object

return object.model_dump_json()

def deserialize(self, model: NNsight) -> "RemoteMixin":

handler = DeserializeHandler(model=model)

return self.object.deserialize(handler)
object = TypeAdapter(
OBJECT_TYPES, config=RequestModel.model_config
).validate_python(json.loads(self.object))

return object.deserialize(handler)
85 changes: 49 additions & 36 deletions src/nnsight/schema/format/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from types import BuiltinFunctionType
from types import FunctionType as FuncType
from types import MethodDescriptorType
from typing import Dict, List, Literal, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union

import torch
from pydantic import BaseModel, ConfigDict, Field, Strict, field_validator
from pydantic import (BaseModel, ConfigDict, Field, Strict, field_validator,
model_serializer)
from pydantic.functional_validators import AfterValidator
from typing_extensions import Annotated

Expand Down Expand Up @@ -50,6 +51,14 @@ class BaseNNsightModel(BaseModel):
def deserialize(self, handler: DeserializeHandler):
raise NotImplementedError()

def try_deserialize(value: BaseNNsightModel | Any, handler: DeserializeHandler):

if isinstance(value, BaseNNsightModel):

return value.deserialize(handler)

return value


### Custom Pydantic types for all supported base types
class NodeModel(BaseNNsightModel):
Expand All @@ -66,11 +75,30 @@ def deserialize(self, handler: DeserializeHandler) -> Node:

name: str
target: Union[FunctionModel, FunctionType]
args: List[ValueTypes]
kwargs: Dict[str, ValueTypes]
condition: Union[
NodeReferenceType, NodeModel.Reference, PrimitiveModel, PrimitiveType
]
args: List[ValueTypes] = []
kwargs: Dict[str, ValueTypes] = {}
condition: None | Union[
NodeReferenceType, NodeModel.Reference
] = None

@model_serializer(mode='wrap')
def serialize_model(self, handler):

dump = handler(self)

if self.condition is None:

dump.pop('condition')

if not self.kwargs:

dump.pop('kwargs')

if not self.args:

dump.pop('args')

return dump

def deserialize(self, handler: DeserializeHandler) -> Node:

Expand All @@ -80,14 +108,15 @@ def deserialize(self, handler: DeserializeHandler) -> Node:
node = handler.graph.create(
proxy_value=None,
target=self.target.deserialize(handler),
args=[value.deserialize(handler) for value in self.args],
args=[try_deserialize(value, handler) for value in self.args],
kwargs={
key: value.deserialize(handler) for key, value in self.kwargs.items()
key: try_deserialize(value, handler) for key, value in self.kwargs.items()
},
name=self.name,
).node

node.cond_dependency = self.condition.deserialize(handler)
node.cond_dependency = try_deserialize(self.condition, handler)

if isinstance(node.cond_dependency, Node):
node.cond_dependency.listeners.append(weakref.proxy(node))

Expand All @@ -99,17 +128,6 @@ def deserialize(self, handler: DeserializeHandler) -> Node:

return node


class PrimitiveModel(BaseNNsightModel):

type_name: Literal["PRIMITIVE"] = "PRIMITIVE"

value: PRIMITIVE

def deserialize(self, handler: DeserializeHandler) -> PRIMITIVE:
return self.value


class TensorModel(BaseNNsightModel):

type_name: Literal["TENSOR"] = "TENSOR"
Expand All @@ -133,9 +151,9 @@ class SliceModel(BaseNNsightModel):
def deserialize(self, handler: DeserializeHandler) -> slice:

return slice(
self.start.deserialize(handler),
self.stop.deserialize(handler),
self.step.deserialize(handler),
try_deserialize(self.start, handler),
try_deserialize(self.stop, handler),
try_deserialize(self.step, handler)
)


Expand All @@ -158,7 +176,7 @@ class ListModel(BaseNNsightModel):
values: List[ValueTypes]

def deserialize(self, handler: DeserializeHandler) -> list:
return [value.deserialize(handler) for value in self.values]
return [try_deserialize(value, handler) for value in self.values]


class TupleModel(BaseNNsightModel):
Expand All @@ -168,7 +186,7 @@ class TupleModel(BaseNNsightModel):
values: List[ValueTypes]

def deserialize(self, handler: DeserializeHandler) -> tuple:
return tuple([value.deserialize(handler) for value in self.values])
return tuple([try_deserialize(value, handler) for value in self.values])


class DictModel(BaseNNsightModel):
Expand All @@ -178,7 +196,7 @@ class DictModel(BaseNNsightModel):
values: Dict[str, ValueTypes]

def deserialize(self, handler: DeserializeHandler) -> dict:
return {key: value.deserialize(handler) for key, value in self.values.items()}
return {key: try_deserialize(value, handler) for key, value in self.values.items()}


class FunctionWhitelistError(Exception):
Expand Down Expand Up @@ -253,10 +271,10 @@ def deserialize(self, handler: DeserializeHandler) -> Tracer:

handler.graph = graph

kwargs = {key: value.deserialize(handler) for key, value in self.kwargs.items()}
kwargs = {key: try_deserialize(value, handler) for key, value in self.kwargs.items()}

invoker_inputs = [
invoker_input.deserialize(handler) for invoker_input in self.invoker_inputs
try_deserialize(invoker_input, handler) for invoker_input in self.invoker_inputs
]

tracer = Tracer(
Expand Down Expand Up @@ -287,7 +305,7 @@ def deserialize(self, handler: DeserializeHandler) -> Iterator:

handler.graph = graph

data = self.data.deserialize(handler)
data = try_deserialize(self.data, handler)

iterator = Iterator(data, None, bridge=handler.bridge, graph=graph)

Expand Down Expand Up @@ -329,10 +347,6 @@ def deserialize(self, handler: DeserializeHandler) -> Session:
),
]

PrimitiveType = Annotated[
PRIMITIVE, AfterValidator(lambda value: PrimitiveModel(value=value))
]

TensorType = Annotated[
torch.Tensor,
AfterValidator(
Expand Down Expand Up @@ -414,7 +428,6 @@ def deserialize(self, handler: DeserializeHandler) -> Session:
NodeModel.Reference,
SliceModel,
TensorModel,
PrimitiveModel,
TupleModel,
ListModel,
DictModel,
Expand All @@ -428,7 +441,6 @@ def deserialize(self, handler: DeserializeHandler) -> Session:
NodeReferenceType,
SliceType,
TensorType,
PrimitiveType,
TupleType,
ListType,
DictType,
Expand All @@ -437,6 +449,7 @@ def deserialize(self, handler: DeserializeHandler) -> Session:

### Final registration
ValueTypes = Union[
PRIMITIVE,
Annotated[
TOTYPES,
Field(discriminator="type_name"),
Expand Down

0 comments on commit ad06cca

Please sign in to comment.