Skip to content

Commit

Permalink
Add einops and torch.nn.functional to both function whitelist, and pr…
Browse files Browse the repository at this point in the history
…oxy wrapped einops (and subsequent changes to make work)
  • Loading branch information
JadenFiotto-Kaufman committed Jan 22, 2024
1 parent 2631325 commit 2961a04
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 28 deletions.
30 changes: 20 additions & 10 deletions src/nnsight/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
with open(os.path.join(PATH, "config.yaml"), "r") as file:
CONFIG = ConfigModel(**yaml.safe_load(file))

from .models.NNsightModel import NNsightModel
from .logger import logger
from .models.DiffuserModel import DiffuserModel
from .models.LanguageModel import LanguageModel
from .models.NNsightModel import NNsightModel
from .module import Module
from .patching import Patch, Patcher
from .logger import logger
from .tracing.Proxy import proxy_wrapper

logger.disabled = not CONFIG.APP.LOGGING

Expand All @@ -23,9 +24,17 @@
DEFAULT_PATCHER = Patcher()

from functools import wraps
from inspect import getmembers, isfunction

import einops
import torch

for key, value in getmembers(einops.einops, isfunction):
DEFAULT_PATCHER.add(Patch(einops.einops, proxy_wrapper(value), key))


DEFAULT_PATCHER.add(Patch(torch, proxy_wrapper(torch.gather), "gather"))


# Need to patch repeat_interleave to work with meta tensors
# Computes appropriate shape if meta. Otherwise just call repeat_interleave
Expand Down Expand Up @@ -130,8 +139,12 @@ def meta_where_wrapper(fn):
def where(input: torch.Tensor, *args, **kwargs):
if input.device.type == "meta":
if len(args) > 0:
dtype = args[0].dtype if isinstance(args[0], torch.Tensor) else type(args[0])
return torch.zeros_like(input, dtype=input.dtype, device='meta')
dtype = (
args[0].dtype
if isinstance(args[0], torch.Tensor)
else type(args[0])
)
return torch.zeros_like(input, dtype=input.dtype, device="meta")
return meta_nonzero(input, as_tuple=True)

else:
Expand All @@ -145,12 +158,9 @@ def where(input: torch.Tensor, *args, **kwargs):

DEFAULT_PATCHER.__enter__()

from torch._meta_registrations import (
_meta_lib_dont_use_me_use_register_meta,
aten,
global_decomposition_table,
register_meta,
)
from torch._meta_registrations import (_meta_lib_dont_use_me_use_register_meta,
aten, global_decomposition_table,
register_meta)


# Function which "activates" the most recent meta registered function.
Expand Down
2 changes: 1 addition & 1 deletion src/nnsight/pydantics/format/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .functions import FUNCTIONS_WHITELIST
from .functions import FUNCTIONS_WHITELIST, get_function_name
33 changes: 27 additions & 6 deletions src/nnsight/pydantics/format/functions.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,60 @@
import operator
from inspect import getmembers, isbuiltin, ismethoddescriptor
from inspect import getmembers, isbuiltin, isfunction, ismethoddescriptor

import einops
import torch

from ... import util
from ...tracing.Proxy import Proxy


def get_function_name(fn):
if isinstance(fn, str):
return fn

return f"{getattr(fn, '__module__', '')}.{fn.__qualname__}"


FUNCTIONS_WHITELIST = {}
FUNCTIONS_WHITELIST.update(
{
f"_VariableFunctionsClass.{key}": value
get_function_name(value): value
for key, value in getmembers(torch._C._VariableFunctions, isbuiltin)
}
)
FUNCTIONS_WHITELIST.update(
{
f"Tensor.{key}": value
get_function_name(value): value
for key, value in getmembers(torch.nn.functional, isfunction)
}
)
FUNCTIONS_WHITELIST.update(
{
get_function_name(value): value
for key, value in getmembers(torch._C._TensorBase, ismethoddescriptor)
}
)
FUNCTIONS_WHITELIST.update(
{
f"{key}": value
get_function_name(value): value
for key, value in getmembers(operator, isbuiltin)
if not key.startswith("_")
}
)
FUNCTIONS_WHITELIST.update(
{
get_function_name(value): value
for key, value in getmembers(einops.einops, isfunction)
}
)
FUNCTIONS_WHITELIST.update(
{
"null": "null",
"module": "module",
"argument": "argument",
"swp": "swp",
"grad": "grad",
"fetch_attr": util.fetch_attr,
"Proxy.proxy_call": Proxy.proxy_call,
get_function_name(util.fetch_attr): util.fetch_attr,
get_function_name(Proxy.proxy_call): Proxy.proxy_call,
}
)
8 changes: 2 additions & 6 deletions src/nnsight/pydantics/format/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from ...tracing.Graph import Graph
from ...tracing.Node import Node
from . import FUNCTIONS_WHITELIST
from . import FUNCTIONS_WHITELIST, get_function_name

FUNCTION = Union[BuiltinFunctionType, FuncType, MethodDescriptorType, str]
PRIMITIVE = Union[int, float, str, bool, None]
Expand Down Expand Up @@ -168,11 +168,7 @@ def compile(self, graph: Graph, nodes: Dict[str, NodeModel]) -> FUNCTION:

FunctionType = Annotated[
FUNCTION,
AfterValidator(
lambda value: FunctionModel(
function_name=value.__qualname__ if not isinstance(value, str) else value
)
),
AfterValidator(lambda value: FunctionModel(function_name=get_function_name(value))),
]

NodeReferenceType = Annotated[
Expand Down
9 changes: 5 additions & 4 deletions src/nnsight/tracing/Proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,12 @@ def patched(*args, **kwargs):

node = None

for arg in arguments:
if isinstance(arg, Proxy):
node = arg.node
def get_node(proxy: Proxy):
nonlocal node

break
node = proxy.node

util.apply(list(args) + list(kwargs.values()), get_node, Proxy)

if node is not None:
return node.graph.add(target=fn, args=args, kwargs=kwargs)
Expand Down
1 change: 0 additions & 1 deletion src/nnsight/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ class WrapperModule(torch.nn.Module):
"""

def forward(self, *args, **kwargs):
# TODO document
if len(args) == 1:
args = args[0]

Expand Down

0 comments on commit 2961a04

Please sign in to comment.