From 6c410279f6a12acdadac8651f87458b60b2c916c Mon Sep 17 00:00:00 2001 From: JadenFiottoKaufman Date: Sat, 16 Dec 2023 13:20:06 -0500 Subject: [PATCH] Bug fix for swap --- src/nnsight/intervention.py | 21 +++++++++++++-------- src/nnsight/tracing/Graph.py | 4 ++-- src/nnsight/tracing/Node.py | 4 ++-- 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/src/nnsight/intervention.py b/src/nnsight/intervention.py index 0eeae020..e4a2f89f 100644 --- a/src/nnsight/intervention.py +++ b/src/nnsight/intervention.py @@ -138,6 +138,7 @@ def value(self) -> Any: def check_swap(graph: Graph, activations: Any, batch_start: int, batch_size: int): + # If swap is populated due to a 'swp' intervention. if graph.swap is not None: def concat(values): @@ -154,27 +155,31 @@ def concat(values): for key in values[0].keys() } + # As interventions are scoped only to their relevant batch, if we want to swap in values for this batch + # we need to concatenate the batches before and after the relevant batch with the new values. + # Getting batch data before. pre = util.apply( activations, lambda x: x.narrow(0, 0, batch_start), torch.Tensor ) post_batch_start = batch_start + batch_size + # Getting batch data after. post = util.apply( activations, lambda x: x.narrow(0, post_batch_start, x.shape[0] - post_batch_start), torch.Tensor, ) - def get_value(node: Node): - value = node.value - - node.set_value(True) - - return value - - value = util.apply(graph.swap, get_value, Node) + # Second argument of 'swp' interventions is the new value. + # Convert all Nodes in the value to their value. + value = util.apply(graph.swap.args[1], lambda x: x.value, Node) + # Concatenate activations = concat([pre, value, post]) + # Set value of 'swp' node so it destroys itself and listeners. + graph.swap.set_value(True) + + # Un-set swap. graph.swap = None return activations diff --git a/src/nnsight/tracing/Graph.py b/src/nnsight/tracing/Graph.py index 0e8e71bf..eed9073e 100644 --- a/src/nnsight/tracing/Graph.py +++ b/src/nnsight/tracing/Graph.py @@ -30,7 +30,7 @@ class Graph: module_proxy (Proxy): Proxy for given root meta module. argument_node_names (Dict[str, List[str]]): Map of name of argument to name of nodes that depend on it. generation_idx (int): Current generation index. - swap (Any): Attribute to store swap values from 'swp' nodes. + swap (Node): Attribute to store swap values from 'swp' nodes. """ @staticmethod @@ -129,7 +129,7 @@ def __init__( self.generation_idx = 0 - self.swap: Any = None + self.swap: Node = None def increment(self) -> None: """Increments the generation_idx by one. Should be called by a forward hook on the model being used for generation.""" diff --git a/src/nnsight/tracing/Node.py b/src/nnsight/tracing/Node.py index b13677b5..f3f723dd 100644 --- a/src/nnsight/tracing/Node.py +++ b/src/nnsight/tracing/Node.py @@ -194,7 +194,7 @@ def execute(self) -> None: if self.target == "null": return elif self.target == "swp": - self.graph.swap = self.args[1] + self.graph.swap = self return @@ -230,7 +230,7 @@ def set_value(self, value: Any): if dependency.redundant(): dependency.destroy() - if self.value is not None and self.redundant(): + if self.value is not inspect._empty and self.redundant(): self.destroy() def destroy(self) -> None: