From 1a5028f7cb65b2febb77e9dd00460665a4538b87 Mon Sep 17 00:00:00 2001 From: JadenFiottoKaufman Date: Thu, 14 Dec 2023 20:01:18 -0500 Subject: [PATCH 1/2] Simplified the set operator (=) !! No need for proxy_update. If directly setting an output or input, utilizes the 'swap' attribute on graph. Still need to update and double check all docs because setting will no longer work on tuples. --- README.md | 10 +++---- src/nnsight/intervention.py | 52 +++++++++++++++++++++++++++++++++--- src/nnsight/module.py | 20 +++++++++++--- src/nnsight/tracing/Graph.py | 2 ++ src/nnsight/tracing/Node.py | 5 ++++ src/nnsight/tracing/Proxy.py | 26 +++--------------- tests/test_lm.py | 20 ++++++++++++-- 7 files changed, 98 insertions(+), 37 deletions(-) diff --git a/README.md b/README.md index 45aad53e..ac82aa01 100644 --- a/README.md +++ b/README.md @@ -198,18 +198,18 @@ model = LanguageModel('gpt2', device_map='cuda') with model.generate(max_new_tokens=1) as generator: with generator.invoke('The Eiffel Tower is in the city of') as invoker: - hidden_states_pre = model.transformer.h[-1].output[0].save() + hidden_states_pre = model.transformer.h[-1].mlp.output.save() noise = (0.001**0.5)*torch.randn(hidden_states_pre.shape) - model.transformer.h[-1].output[0] = hidden_states_pre + noise + model.transformer.h[-1].mlp.output = hidden_states_pre + noise - hidden_states_post = model.transformer.h[-1].output[0].save() + hidden_states_post = model.transformer.h[-1].mlp.output.save() print(hidden_states_pre.value) print(hidden_states_post.value) ``` -In this example, we create a tensor of noise to add to the hidden states. We then add it, use the assigment `=` operator to update the tensors of `.output[0]` with these new noised values. +In this example, we create a tensor of noise to add to the hidden states. We then add it, use the assigment `=` operator to update the value of `.output` with these new noised activations. We can see the change in the results: @@ -232,8 +232,6 @@ tensor([[[ 0.0674, -0.1741, -0.1771, ..., -0.9811, 0.1972, -1.0645], device='cuda:0') ``` -Note: Only assigment updates of tensors works with this functionality. - --- ###### Multiple Token Generation diff --git a/src/nnsight/intervention.py b/src/nnsight/intervention.py index 78c081f4..309a68c8 100644 --- a/src/nnsight/intervention.py +++ b/src/nnsight/intervention.py @@ -8,8 +8,8 @@ """ from __future__ import annotations -from contextlib import AbstractContextManager import inspect +from contextlib import AbstractContextManager from typing import Any, Callable, Collection, List, Tuple, Union import torch @@ -18,7 +18,7 @@ from . import util from .tracing.Graph import Graph from .tracing.Proxy import Proxy - +from .tracing.Node import Node class InterventionProxy(Proxy): @@ -81,7 +81,7 @@ def retain_grad(self): # We need to set the values of self to values of self to add this into the computation graph so grad flows through it # This is because in intervene(), we call .narrow on activations which removes it from the grad path - self.node.graph.add(target=Proxy.proxy_update, args=[self.node, self.node]) + self[:] = self @property def token(self) -> TokenIndexer: @@ -136,6 +136,49 @@ def value(self) -> Any: return self.node.value +def check_swap(graph: Graph, activations: Any, batch_start: int, batch_size: int): + if graph.swap is not None: + + def concat(values): + + if isinstance(values[0], torch.Tensor): + return torch.concatenate(values) + elif isinstance(values[0], list) or isinstance(values[0], tuple): + return [ + concat([value[value_idx] for value in values]) + for value_idx in range(len(values[0])) + ] + elif isinstance(values[0], dict): + return { + key: concat([value[key] for value in values]) + for key in values[0].keys() + } + + pre = util.apply( + activations, lambda x: x.narrow(0, 0, batch_start), torch.Tensor + ) + post_batch_start = batch_start + batch_size + post = util.apply( + activations, + lambda x: x.narrow(0, post_batch_start, x.shape[0] - post_batch_start), + torch.Tensor, + ) + + value = graph.swap + + if isinstance(value, Node): + + _value = value.value + value.set_value(True) + value = _value + + activations = concat([pre, value, post]) + + graph.swap = None + + return activations + + def intervene(activations: Any, module_path: str, graph: Graph, key: str): """Entry to intervention graph. This should be hooked to all modules involved in the intervention graph. @@ -181,6 +224,9 @@ def intervene(activations: Any, module_path: str, graph: Graph, key: str): torch.Tensor, ) ) + + activations = check_swap(graph, activations, batch_start, batch_size) + return activations diff --git a/src/nnsight/module.py b/src/nnsight/module.py index 8c649be7..2d102371 100644 --- a/src/nnsight/module.py +++ b/src/nnsight/module.py @@ -126,10 +126,13 @@ def output(self, value: Union[InterventionProxy, Any]) -> None: """ self.output.node.graph.add( - target=Proxy.proxy_update, + target='swp', args=[self.output.node, value], + value=True ) + self._output = None + @property def input(self) -> InterventionProxy: """ @@ -166,10 +169,13 @@ def input(self, value: Union[InterventionProxy, Any]) -> None: """ self.input.node.graph.add( - target=Proxy.proxy_update, + target='swp', args=[self.input.node, value], + value=True ) + self._input = None + @property def backward_output(self) -> InterventionProxy: """ @@ -206,10 +212,13 @@ def backward_output(self, value: Union[InterventionProxy, Any]) -> None: """ self.backward_output.node.graph.add( - target=Proxy.proxy_update, + target='swp', args=[self.backward_output.node, value], + value=True ) + self._backward_output = None + @property def backward_input(self) -> InterventionProxy: """ @@ -246,10 +255,13 @@ def backward_input(self, value: Union[InterventionProxy, Any]) -> None: """ self.backward_input.node.graph.add( - target=Proxy.proxy_update, + target='swp', args=[self.backward_input.node, value], + value=True ) + self._backward_input = None + @property def graph(self) -> Graph: if self._graph is None: diff --git a/src/nnsight/tracing/Graph.py b/src/nnsight/tracing/Graph.py index 2ab780e0..0b8a1aa3 100644 --- a/src/nnsight/tracing/Graph.py +++ b/src/nnsight/tracing/Graph.py @@ -151,6 +151,8 @@ def __init__( self.generation_idx = 0 + self.swap:Proxy = None + def increment(self) -> None: """Increments the generation_idx by one. Should be called by a forward hook on the model being used for generation.""" self.generation_idx += 1 diff --git a/src/nnsight/tracing/Node.py b/src/nnsight/tracing/Node.py index 846a7f9d..eb40eacf 100644 --- a/src/nnsight/tracing/Node.py +++ b/src/nnsight/tracing/Node.py @@ -193,6 +193,11 @@ def execute(self) -> None: # We se a nodes target to 'null' if we don't want it to be executed and therefore never done if self.target == "null": return + elif self.target == "swp": + + self.graph.swap = self.args[1] + + return # Prepare arguments. args, kwargs = self.prepare_inputs() diff --git a/src/nnsight/tracing/Proxy.py b/src/nnsight/tracing/Proxy.py index 7ffe5167..45bde9d8 100644 --- a/src/nnsight/tracing/Proxy.py +++ b/src/nnsight/tracing/Proxy.py @@ -20,23 +20,6 @@ class Proxy: node (Node): This proxy's node. """ - @staticmethod - def proxy_update(value1: Any, value2: Any) -> None: - """Updates Tensor values with other Tensor values. - - Args: - value1 (Any): Collection with Tensors to update. - value2 (Any): Collection with Tensors to pull values from. - """ - if isinstance(value1, torch.Tensor): - value1[:] = value2 - elif isinstance(value1, list) or isinstance(value1, tuple): - for value_idx in range(len(value1)): - Proxy.proxy_update(value1[value_idx], value2[value_idx]) - elif isinstance(value1, dict): - for key in value1: - Proxy.proxy_update(value1[key], value2[key]) - @staticmethod def proxy_call(callable: Callable, *args, **kwargs) -> None: return callable(*args, **kwargs) @@ -76,11 +59,10 @@ def __getitem__(self, key: Union[Proxy, Any]) -> Proxy: ) def __setitem__(self, key: Union[Proxy, Any], value: Union[Proxy, Any]) -> None: - item_proxy = self[key] - - item_proxy.node.graph.add( - target=Proxy.proxy_update, - args=[item_proxy.node, value], + + self.node.graph.add( + target=operator.setitem, + args=[self.node, key, value], ) def __getattr__(self, key: Union[Proxy, Any]) -> Proxy: diff --git a/tests/test_lm.py b/tests/test_lm.py index a43cd452..17f5cc5c 100644 --- a/tests/test_lm.py +++ b/tests/test_lm.py @@ -39,12 +39,12 @@ def test_save(gpt2: nnsight.LanguageModel): assert hs_input.value.ndim == 3 -def test_set(gpt2: nnsight.LanguageModel): +def test_set1(gpt2: nnsight.LanguageModel): with gpt2.generate(max_new_tokens=1) as generator: with generator.invoke("Hello world") as invoker: pre = gpt2.transformer.h[-1].output[0].clone().save() - gpt2.transformer.h[-1].output[0] = 0 + gpt2.transformer.h[-1].output[0][:] = 0 post = gpt2.transformer.h[-1].output[0].save() @@ -54,6 +54,22 @@ def test_set(gpt2: nnsight.LanguageModel): assert (post.value == 0).all().item() assert output != "Madison Square Garden is located in the city of New" +def test_set2(gpt2: nnsight.LanguageModel): + with gpt2.generate(max_new_tokens=1) as generator: + with generator.invoke("Hello world") as invoker: + pre = gpt2.transformer.h[-1].mlp.output.clone().save() + + gpt2.transformer.h[-1].mlp.output = gpt2.transformer.h[-1].mlp.output * 0 + + post = gpt2.transformer.h[-1].mlp.output.save() + + output = gpt2.tokenizer.decode(generator.output[0]) + + assert not (pre.value == 0).all().item() + assert (post.value == 0).all().item() + assert output != "Madison Square Garden is located in the city of New" + + def test_adhoc_module(gpt2: nnsight.LanguageModel): with gpt2.generate() as generator: From 7229cf0ec36b8bb404b6e4d1d706db72e64a6981 Mon Sep 17 00:00:00 2001 From: JadenFiottoKaufman Date: Fri, 15 Dec 2023 12:24:24 -0500 Subject: [PATCH 2/2] Editing documentation for swapping --- docs/source/notebooks/features/setting.ipynb | 8 ++-- src/nnsight/intervention.py | 19 +++++---- src/nnsight/module.py | 16 ++------ src/nnsight/tracing/Graph.py | 42 +++++--------------- src/nnsight/tracing/Node.py | 1 - src/nnsight/tracing/Proxy.py | 1 - tests/test_lm.py | 16 ++++---- 7 files changed, 36 insertions(+), 67 deletions(-) diff --git a/docs/source/notebooks/features/setting.ipynb b/docs/source/notebooks/features/setting.ipynb index eb0e3ae0..d489854d 100644 --- a/docs/source/notebooks/features/setting.ipynb +++ b/docs/source/notebooks/features/setting.ipynb @@ -13,7 +13,7 @@ "source": [ "We often not only want to see whats happening during computation, but intervene and edit the flow of information.\n", "\n", - "In this example, we create a tensor of noise to add to the hidden states. We then add it, use the assigment `=` operator to update the tensors of `.output[0]` with these new noised values." + "In this example, we create a tensor of noise to add to the hidden states. We then add it, use the assigment `=` operator to update the tensors of `.output[0][:]` with these new noised values." ] }, { @@ -30,13 +30,13 @@ "with model.generate(max_new_tokens=1) as generator:\n", " with generator.invoke('The Eiffel Tower is in the city of') as invoker:\n", "\n", - " hidden_states_pre = model.transformer.h[-1].output[0].save()\n", + " hidden_states_pre = model.transformer.h[-1].output[0][:].save()\n", "\n", " noise = (0.001**0.5)*torch.randn(hidden_states_pre.shape)\n", "\n", - " model.transformer.h[-1].output[0] = hidden_states_pre + noise\n", + " model.transformer.h[-1].output[0][:] = hidden_states_pre + noise\n", "\n", - " hidden_states_post = model.transformer.h[-1].output[0].save()" + " hidden_states_post = model.transformer.h[-1].output[0][:].save()" ] }, { diff --git a/src/nnsight/intervention.py b/src/nnsight/intervention.py index 309a68c8..0eeae020 100644 --- a/src/nnsight/intervention.py +++ b/src/nnsight/intervention.py @@ -17,8 +17,9 @@ from . import util from .tracing.Graph import Graph -from .tracing.Proxy import Proxy from .tracing.Node import Node +from .tracing.Proxy import Proxy + class InterventionProxy(Proxy): @@ -138,9 +139,8 @@ def value(self) -> Any: def check_swap(graph: Graph, activations: Any, batch_start: int, batch_size: int): if graph.swap is not None: - + def concat(values): - if isinstance(values[0], torch.Tensor): return torch.concatenate(values) elif isinstance(values[0], list) or isinstance(values[0], tuple): @@ -164,13 +164,14 @@ def concat(values): torch.Tensor, ) - value = graph.swap + def get_value(node: Node): + value = node.value + + node.set_value(True) - if isinstance(value, Node): + return value - _value = value.value - value.set_value(True) - value = _value + value = util.apply(graph.swap, get_value, Node) activations = concat([pre, value, post]) @@ -225,6 +226,8 @@ def intervene(activations: Any, module_path: str, graph: Graph, key: str): ) ) + # Check if through the previous value injection, there was a 'swp' intervention. + # This would mean we want to replace activations for this batch with some other ones. activations = check_swap(graph, activations, batch_start, batch_size) return activations diff --git a/src/nnsight/module.py b/src/nnsight/module.py index 2d102371..0c2d31eb 100644 --- a/src/nnsight/module.py +++ b/src/nnsight/module.py @@ -126,9 +126,7 @@ def output(self, value: Union[InterventionProxy, Any]) -> None: """ self.output.node.graph.add( - target='swp', - args=[self.output.node, value], - value=True + target="swp", args=[self.output.node, value], value=True ) self._output = None @@ -169,9 +167,7 @@ def input(self, value: Union[InterventionProxy, Any]) -> None: """ self.input.node.graph.add( - target='swp', - args=[self.input.node, value], - value=True + target="swp", args=[self.input.node, value], value=True ) self._input = None @@ -212,9 +208,7 @@ def backward_output(self, value: Union[InterventionProxy, Any]) -> None: """ self.backward_output.node.graph.add( - target='swp', - args=[self.backward_output.node, value], - value=True + target="swp", args=[self.backward_output.node, value], value=True ) self._backward_output = None @@ -255,9 +249,7 @@ def backward_input(self, value: Union[InterventionProxy, Any]) -> None: """ self.backward_input.node.graph.add( - target='swp', - args=[self.backward_input.node, value], - value=True + target="swp", args=[self.backward_input.node, value], value=True ) self._backward_input = None diff --git a/src/nnsight/tracing/Graph.py b/src/nnsight/tracing/Graph.py index 0b8a1aa3..0e8e71bf 100644 --- a/src/nnsight/tracing/Graph.py +++ b/src/nnsight/tracing/Graph.py @@ -18,7 +18,7 @@ class Graph: * 'module' : There should only be the single root module as a node in the graph for tracing. Added on __init__ and when compiling, the node's value is set to to be whatever module that is being interleaved with this computation graph. * 'argument' : There can be multiple argument nodes. Their first argument needs to be the argument name which acts as a key in graph.argument_node_names which maps to a list of names for nodes that depend on it's value. These nodes values need to be set outside of the computation graph as entry points to kick of the execution of the graph. - * 'rtn' : Should only be one 'rtn' target named node as this is what is used. + * 'swp' : swp nodes indicate populating the graph's swap attribute. When executed, its value is not set. Logic involving the swap value should set its value after using it. * 'null' : Null nodes never get executed and therefore their listeners never get destroyed. Attributes: @@ -30,6 +30,7 @@ class Graph: module_proxy (Proxy): Proxy for given root meta module. argument_node_names (Dict[str, List[str]]): Map of name of argument to name of nodes that depend on it. generation_idx (int): Current generation index. + swap (Any): Attribute to store swap values from 'swp' nodes. """ @staticmethod @@ -104,36 +105,13 @@ def get_argument_value(param: inspect.Parameter, idx: int): # Run forward with root module proxy and arguments output: Proxy = forward(graph.module_proxy, *arguments) - # Get proxy_value for return - value = util.apply(output, lambda x: x.node.proxy_value, Proxy) - - # Create the 'rtn_0' return proxy + # Create the 'swap' return proxy return_proxy = graph.add( - graph=graph, value=value, target=Graph.rtn, args=output - ) - - # This is how we tell the graph not to destroy a proxy after it's listeners are completed. - # Create a 'null' proxy. The return proxy listens to the 'null' proxy with args=[return_proxy.node] but 'null' will never be completed. - graph.add( - graph=graph, - value=None, - target="null", - args=[return_proxy.node], + graph=graph, value=True, target="swp", args=[output.node, output.node] ) return graph - @staticmethod - def rtn(*args, **kwargs): - """ - Function to just pass through data for returning data in a graph forward method. - - Returns: - _type_: _description_ - """ - - return args - def __init__( self, module: torch.nn.Module, @@ -151,7 +129,7 @@ def __init__( self.generation_idx = 0 - self.swap:Proxy = None + self.swap: Any = None def increment(self) -> None: """Increments the generation_idx by one. Should be called by a forward hook on the model being used for generation.""" @@ -201,7 +179,7 @@ def add( Proxy: Proxy for the added node. Raises: - ValueError: If more than one reserved "rtn" or "module" nodes are added to the graph. + ValueError: If more than one reserved "module" nodes are added to the graph. """ # If we're validating and the user did not provide a value, execute the given target with meta proxy values to compute new proxy_value. @@ -219,8 +197,6 @@ def add( if target_name not in self.name_idx: self.name_idx[target_name] = 0 else: - if target_name == "rtn": - raise ValueError("Can only have one return ('rtn') node.") if target_name == "module": raise ValueError("Can only have one module node.") @@ -295,9 +271,9 @@ def forward(*args, **kwargs): if key in self.argument_node_names: self.nodes[self.argument_node_names[key][0]].set_value(arg) - # 'rtn_0' should have the value we need to return. - return_value = self.nodes["rtn_0"].value - self.nodes["rtn_0"].destroy() + # should have the value we need to return. + return_value = self.swap + self.swap.set_value(True) return return_value # Replace forward method with custom graph execution method. diff --git a/src/nnsight/tracing/Node.py b/src/nnsight/tracing/Node.py index eb40eacf..b13677b5 100644 --- a/src/nnsight/tracing/Node.py +++ b/src/nnsight/tracing/Node.py @@ -194,7 +194,6 @@ def execute(self) -> None: if self.target == "null": return elif self.target == "swp": - self.graph.swap = self.args[1] return diff --git a/src/nnsight/tracing/Proxy.py b/src/nnsight/tracing/Proxy.py index 45bde9d8..d86b4939 100644 --- a/src/nnsight/tracing/Proxy.py +++ b/src/nnsight/tracing/Proxy.py @@ -59,7 +59,6 @@ def __getitem__(self, key: Union[Proxy, Any]) -> Proxy: ) def __setitem__(self, key: Union[Proxy, Any], value: Union[Proxy, Any]) -> None: - self.node.graph.add( target=operator.setitem, args=[self.node, key, value], diff --git a/tests/test_lm.py b/tests/test_lm.py index 17f5cc5c..c66c1ba4 100644 --- a/tests/test_lm.py +++ b/tests/test_lm.py @@ -39,9 +39,9 @@ def test_save(gpt2: nnsight.LanguageModel): assert hs_input.value.ndim == 3 -def test_set1(gpt2: nnsight.LanguageModel): +def test_set1(gpt2: nnsight.LanguageModel, MSG_prompt: str): with gpt2.generate(max_new_tokens=1) as generator: - with generator.invoke("Hello world") as invoker: + with generator.invoke(MSG_prompt) as invoker: pre = gpt2.transformer.h[-1].output[0].clone().save() gpt2.transformer.h[-1].output[0][:] = 0 @@ -54,14 +54,15 @@ def test_set1(gpt2: nnsight.LanguageModel): assert (post.value == 0).all().item() assert output != "Madison Square Garden is located in the city of New" -def test_set2(gpt2: nnsight.LanguageModel): + +def test_set2(gpt2: nnsight.LanguageModel, MSG_prompt: str): with gpt2.generate(max_new_tokens=1) as generator: - with generator.invoke("Hello world") as invoker: - pre = gpt2.transformer.h[-1].mlp.output.clone().save() + with generator.invoke(MSG_prompt) as invoker: + pre = gpt2.transformer.wte.output.clone().save() - gpt2.transformer.h[-1].mlp.output = gpt2.transformer.h[-1].mlp.output * 0 + gpt2.transformer.wte.output = gpt2.transformer.wte.output * 0 - post = gpt2.transformer.h[-1].mlp.output.save() + post = gpt2.transformer.wte.output.save() output = gpt2.tokenizer.decode(generator.output[0]) @@ -70,7 +71,6 @@ def test_set2(gpt2: nnsight.LanguageModel): assert output != "Madison Square Garden is located in the city of New" - def test_adhoc_module(gpt2: nnsight.LanguageModel): with gpt2.generate() as generator: with generator.invoke("The Eiffel Tower is in the city of") as invoker: