Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fix for swap #17

Merged
merged 1 commit into from
Dec 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions src/nnsight/intervention.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def value(self) -> Any:


def check_swap(graph: Graph, activations: Any, batch_start: int, batch_size: int):
# If swap is populated due to a 'swp' intervention.
if graph.swap is not None:

def concat(values):
Expand All @@ -154,27 +155,31 @@ def concat(values):
for key in values[0].keys()
}

# As interventions are scoped only to their relevant batch, if we want to swap in values for this batch
# we need to concatenate the batches before and after the relevant batch with the new values.
# Getting batch data before.
pre = util.apply(
activations, lambda x: x.narrow(0, 0, batch_start), torch.Tensor
)
post_batch_start = batch_start + batch_size
# Getting batch data after.
post = util.apply(
activations,
lambda x: x.narrow(0, post_batch_start, x.shape[0] - post_batch_start),
torch.Tensor,
)

def get_value(node: Node):
value = node.value

node.set_value(True)

return value

value = util.apply(graph.swap, get_value, Node)
# Second argument of 'swp' interventions is the new value.
# Convert all Nodes in the value to their value.
value = util.apply(graph.swap.args[1], lambda x: x.value, Node)

# Concatenate
activations = concat([pre, value, post])

# Set value of 'swp' node so it destroys itself and listeners.
graph.swap.set_value(True)

# Un-set swap.
graph.swap = None

return activations
Expand Down
4 changes: 2 additions & 2 deletions src/nnsight/tracing/Graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class Graph:
module_proxy (Proxy): Proxy for given root meta module.
argument_node_names (Dict[str, List[str]]): Map of name of argument to name of nodes that depend on it.
generation_idx (int): Current generation index.
swap (Any): Attribute to store swap values from 'swp' nodes.
swap (Node): Attribute to store swap values from 'swp' nodes.
"""

@staticmethod
Expand Down Expand Up @@ -129,7 +129,7 @@ def __init__(

self.generation_idx = 0

self.swap: Any = None
self.swap: Node = None

def increment(self) -> None:
"""Increments the generation_idx by one. Should be called by a forward hook on the model being used for generation."""
Expand Down
4 changes: 2 additions & 2 deletions src/nnsight/tracing/Node.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def execute(self) -> None:
if self.target == "null":
return
elif self.target == "swp":
self.graph.swap = self.args[1]
self.graph.swap = self

return

Expand Down Expand Up @@ -230,7 +230,7 @@ def set_value(self, value: Any):
if dependency.redundant():
dependency.destroy()

if self.value is not None and self.redundant():
if self.value is not inspect._empty and self.redundant():
self.destroy()

def destroy(self) -> None:
Expand Down