Skip to content

Commit

Permalink
On swap operation, move new values to device of old values
Browse files Browse the repository at this point in the history
  • Loading branch information
JadenFiotto-Kaufman committed Dec 31, 2023
1 parent 69689f0 commit 5509288
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions src/nnsight/intervention.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def intervene(activations: Any, module_path: str, graph: Graph, key: str):

# Key to module activation argument nodes has format: <module path>.<output/input>.<generation index>
module_path = f"{module_path}.{key}.{graph.generation_idx}"

if module_path in graph.argument_node_names:
argument_node_names = graph.argument_node_names[module_path]

Expand All @@ -219,8 +219,24 @@ def intervene(activations: Any, module_path: str, graph: Graph, key: str):
# Check if through the previous value injection, there was a 'swp' intervention.
# This would mean we want to replace activations for this batch with some other ones.
if graph.swap is not None:
device = None

def _device(value: torch.Tensor):
nonlocal device

device = value.device

util.apply(value, _device, torch.Tensor)

value = util.apply(graph.swap.args[1], lambda x: x.value, Node)

if device is not None:

def _to(value: torch.Tensor):
return value.to(device)

value = util.apply(value, _to, torch.Tensor)

# Set value of 'swp' node so it destroys itself and listeners.
graph.swap.set_value(True)

Expand Down Expand Up @@ -284,7 +300,9 @@ def __enter__(self) -> HookModel:
def input_hook(module, input, kwargs, module_path=module_path):
return self.input_hook((input, kwargs), module_path)

self.handles.append(module.register_forward_pre_hook(input_hook, with_kwargs=True))
self.handles.append(
module.register_forward_pre_hook(input_hook, with_kwargs=True)
)

elif hook_type == "output":

Expand Down

0 comments on commit 5509288

Please sign in to comment.