Skip to content

Commit

Permalink
Editing documentation for swapping
Browse files Browse the repository at this point in the history
  • Loading branch information
JadenFiotto-Kaufman committed Dec 15, 2023
1 parent 1a5028f commit 7229cf0
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 67 deletions.
8 changes: 4 additions & 4 deletions docs/source/notebooks/features/setting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"source": [
"We often not only want to see whats happening during computation, but intervene and edit the flow of information.\n",
"\n",
"In this example, we create a tensor of noise to add to the hidden states. We then add it, use the assigment `=` operator to update the tensors of `.output[0]` with these new noised values."
"In this example, we create a tensor of noise to add to the hidden states. We then add it, use the assigment `=` operator to update the tensors of `.output[0][:]` with these new noised values."
]
},
{
Expand All @@ -30,13 +30,13 @@
"with model.generate(max_new_tokens=1) as generator:\n",
" with generator.invoke('The Eiffel Tower is in the city of') as invoker:\n",
"\n",
" hidden_states_pre = model.transformer.h[-1].output[0].save()\n",
" hidden_states_pre = model.transformer.h[-1].output[0][:].save()\n",
"\n",
" noise = (0.001**0.5)*torch.randn(hidden_states_pre.shape)\n",
"\n",
" model.transformer.h[-1].output[0] = hidden_states_pre + noise\n",
" model.transformer.h[-1].output[0][:] = hidden_states_pre + noise\n",
"\n",
" hidden_states_post = model.transformer.h[-1].output[0].save()"
" hidden_states_post = model.transformer.h[-1].output[0][:].save()"
]
},
{
Expand Down
19 changes: 11 additions & 8 deletions src/nnsight/intervention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

from . import util
from .tracing.Graph import Graph
from .tracing.Proxy import Proxy
from .tracing.Node import Node
from .tracing.Proxy import Proxy


class InterventionProxy(Proxy):

Expand Down Expand Up @@ -138,9 +139,8 @@ def value(self) -> Any:

def check_swap(graph: Graph, activations: Any, batch_start: int, batch_size: int):
if graph.swap is not None:

def concat(values):

if isinstance(values[0], torch.Tensor):
return torch.concatenate(values)
elif isinstance(values[0], list) or isinstance(values[0], tuple):
Expand All @@ -164,13 +164,14 @@ def concat(values):
torch.Tensor,
)

value = graph.swap
def get_value(node: Node):
value = node.value

node.set_value(True)

if isinstance(value, Node):
return value

_value = value.value
value.set_value(True)
value = _value
value = util.apply(graph.swap, get_value, Node)

activations = concat([pre, value, post])

Expand Down Expand Up @@ -225,6 +226,8 @@ def intervene(activations: Any, module_path: str, graph: Graph, key: str):
)
)

# Check if through the previous value injection, there was a 'swp' intervention.
# This would mean we want to replace activations for this batch with some other ones.
activations = check_swap(graph, activations, batch_start, batch_size)

return activations
Expand Down
16 changes: 4 additions & 12 deletions src/nnsight/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,7 @@ def output(self, value: Union[InterventionProxy, Any]) -> None:
"""

self.output.node.graph.add(
target='swp',
args=[self.output.node, value],
value=True
target="swp", args=[self.output.node, value], value=True
)

self._output = None
Expand Down Expand Up @@ -169,9 +167,7 @@ def input(self, value: Union[InterventionProxy, Any]) -> None:
"""

self.input.node.graph.add(
target='swp',
args=[self.input.node, value],
value=True
target="swp", args=[self.input.node, value], value=True
)

self._input = None
Expand Down Expand Up @@ -212,9 +208,7 @@ def backward_output(self, value: Union[InterventionProxy, Any]) -> None:
"""

self.backward_output.node.graph.add(
target='swp',
args=[self.backward_output.node, value],
value=True
target="swp", args=[self.backward_output.node, value], value=True
)

self._backward_output = None
Expand Down Expand Up @@ -255,9 +249,7 @@ def backward_input(self, value: Union[InterventionProxy, Any]) -> None:
"""

self.backward_input.node.graph.add(
target='swp',
args=[self.backward_input.node, value],
value=True
target="swp", args=[self.backward_input.node, value], value=True
)

self._backward_input = None
Expand Down
42 changes: 9 additions & 33 deletions src/nnsight/tracing/Graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class Graph:
* 'module' : There should only be the single root module as a node in the graph for tracing. Added on __init__ and when compiling, the node's value is set to to be whatever module that is being interleaved with this computation graph.
* 'argument' : There can be multiple argument nodes. Their first argument needs to be the argument name which acts as a key in graph.argument_node_names which maps to a list of names for nodes that depend on it's value. These nodes values need to be set outside of the computation graph as entry points to kick of the execution of the graph.
* 'rtn' : Should only be one 'rtn' target named node as this is what is used.
* 'swp' : swp nodes indicate populating the graph's swap attribute. When executed, its value is not set. Logic involving the swap value should set its value after using it.
* 'null' : Null nodes never get executed and therefore their listeners never get destroyed.
Attributes:
Expand All @@ -30,6 +30,7 @@ class Graph:
module_proxy (Proxy): Proxy for given root meta module.
argument_node_names (Dict[str, List[str]]): Map of name of argument to name of nodes that depend on it.
generation_idx (int): Current generation index.
swap (Any): Attribute to store swap values from 'swp' nodes.
"""

@staticmethod
Expand Down Expand Up @@ -104,36 +105,13 @@ def get_argument_value(param: inspect.Parameter, idx: int):
# Run forward with root module proxy and arguments
output: Proxy = forward(graph.module_proxy, *arguments)

# Get proxy_value for return
value = util.apply(output, lambda x: x.node.proxy_value, Proxy)

# Create the 'rtn_0' return proxy
# Create the 'swap' return proxy
return_proxy = graph.add(
graph=graph, value=value, target=Graph.rtn, args=output
)

# This is how we tell the graph not to destroy a proxy after it's listeners are completed.
# Create a 'null' proxy. The return proxy listens to the 'null' proxy with args=[return_proxy.node] but 'null' will never be completed.
graph.add(
graph=graph,
value=None,
target="null",
args=[return_proxy.node],
graph=graph, value=True, target="swp", args=[output.node, output.node]
)

return graph

@staticmethod
def rtn(*args, **kwargs):
"""
Function to just pass through data for returning data in a graph forward method.
Returns:
_type_: _description_
"""

return args

def __init__(
self,
module: torch.nn.Module,
Expand All @@ -151,7 +129,7 @@ def __init__(

self.generation_idx = 0

self.swap:Proxy = None
self.swap: Any = None

def increment(self) -> None:
"""Increments the generation_idx by one. Should be called by a forward hook on the model being used for generation."""
Expand Down Expand Up @@ -201,7 +179,7 @@ def add(
Proxy: Proxy for the added node.
Raises:
ValueError: If more than one reserved "rtn" or "module" nodes are added to the graph.
ValueError: If more than one reserved "module" nodes are added to the graph.
"""

# If we're validating and the user did not provide a value, execute the given target with meta proxy values to compute new proxy_value.
Expand All @@ -219,8 +197,6 @@ def add(
if target_name not in self.name_idx:
self.name_idx[target_name] = 0
else:
if target_name == "rtn":
raise ValueError("Can only have one return ('rtn') node.")
if target_name == "module":
raise ValueError("Can only have one module node.")

Expand Down Expand Up @@ -295,9 +271,9 @@ def forward(*args, **kwargs):
if key in self.argument_node_names:
self.nodes[self.argument_node_names[key][0]].set_value(arg)

# 'rtn_0' should have the value we need to return.
return_value = self.nodes["rtn_0"].value
self.nodes["rtn_0"].destroy()
# should have the value we need to return.
return_value = self.swap
self.swap.set_value(True)
return return_value

# Replace forward method with custom graph execution method.
Expand Down
1 change: 0 additions & 1 deletion src/nnsight/tracing/Node.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,6 @@ def execute(self) -> None:
if self.target == "null":
return
elif self.target == "swp":

self.graph.swap = self.args[1]

return
Expand Down
1 change: 0 additions & 1 deletion src/nnsight/tracing/Proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def __getitem__(self, key: Union[Proxy, Any]) -> Proxy:
)

def __setitem__(self, key: Union[Proxy, Any], value: Union[Proxy, Any]) -> None:

self.node.graph.add(
target=operator.setitem,
args=[self.node, key, value],
Expand Down
16 changes: 8 additions & 8 deletions tests/test_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ def test_save(gpt2: nnsight.LanguageModel):
assert hs_input.value.ndim == 3


def test_set1(gpt2: nnsight.LanguageModel):
def test_set1(gpt2: nnsight.LanguageModel, MSG_prompt: str):
with gpt2.generate(max_new_tokens=1) as generator:
with generator.invoke("Hello world") as invoker:
with generator.invoke(MSG_prompt) as invoker:
pre = gpt2.transformer.h[-1].output[0].clone().save()

gpt2.transformer.h[-1].output[0][:] = 0
Expand All @@ -54,14 +54,15 @@ def test_set1(gpt2: nnsight.LanguageModel):
assert (post.value == 0).all().item()
assert output != "Madison Square Garden is located in the city of New"

def test_set2(gpt2: nnsight.LanguageModel):

def test_set2(gpt2: nnsight.LanguageModel, MSG_prompt: str):
with gpt2.generate(max_new_tokens=1) as generator:
with generator.invoke("Hello world") as invoker:
pre = gpt2.transformer.h[-1].mlp.output.clone().save()
with generator.invoke(MSG_prompt) as invoker:
pre = gpt2.transformer.wte.output.clone().save()

gpt2.transformer.h[-1].mlp.output = gpt2.transformer.h[-1].mlp.output * 0
gpt2.transformer.wte.output = gpt2.transformer.wte.output * 0

post = gpt2.transformer.h[-1].mlp.output.save()
post = gpt2.transformer.wte.output.save()

output = gpt2.tokenizer.decode(generator.output[0])

Expand All @@ -70,7 +71,6 @@ def test_set2(gpt2: nnsight.LanguageModel):
assert output != "Madison Square Garden is located in the city of New"



def test_adhoc_module(gpt2: nnsight.LanguageModel):
with gpt2.generate() as generator:
with generator.invoke("The Eiffel Tower is in the city of") as invoker:
Expand Down

0 comments on commit 7229cf0

Please sign in to comment.