Skip to content

Commit

Permalink
Merge branch 'dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
JadenFiotto-Kaufman committed Dec 3, 2023
2 parents b49c98c + 9943216 commit 19b78fc
Show file tree
Hide file tree
Showing 15 changed files with 291 additions and 217 deletions.
17 changes: 10 additions & 7 deletions docs/source/_static/css/custom.css
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
align-content: stretch;
padding-bottom: 20vh;

overflow: hidden;
}

html[data-theme="light"] {
Expand All @@ -28,8 +29,9 @@ html[data-theme="light"] {
}


.page-container {
height: 65vh;
.features {
height: 60vh;
overflow: hidden;
}


Expand All @@ -56,19 +58,20 @@ html[data-theme="light"] {
}

@media only screen and (max-width: 768px) { /* Adjust this value based on your breakpoint for mobile */
.front-container, .container {
.front-container, .hero {
height: auto; /* Change from fixed height to auto */
min-height: 50vh; /* Adjust this as needed */
}
}


@media only screen and (max-width: 768px) {
.features-container {
margin-bottom: 20px; /* Increase bottom margin */
}

.container {
.hero {
margin-bottom: 30px; /* Adjust the bottom margin of the main container */
}

.features {
height: 110vh;
}
}
Binary file modified docs/source/_static/images/icon.ico
Binary file not shown.
Binary file added docs/source/_static/images/intrgraph.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_static/images/remote_execution.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
76 changes: 76 additions & 0 deletions docs/source/about.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@

About nnsight
=============

An API for transparent science on black-box AI
----------------------------------------------

.. card:: How can you study the internals of a deep network that is too large for you to run?

In this era of large-scale deep learning, the most interesting AI models are massive black boxes
that are hard to run. Ordinary commercial inference service APIs let you interact with huge
models, but they do not let you see model internals.

The nnsight library is different: it gives you full access to all the neural network internals.
When used together with a remote service like the `National Deep Inference Facility <https://ndif.us/>`_ (NDIF),
it lets you run experiments on huge open models easily, with full transparent access.
The nnsight library is also terrific for studying smaller local models.

.. figure:: _static/images/remote_execution.png

An overview of the nnsight/NDIF pipeline. Researchers write simple Python code to run along with the neural network locally or remotely. Unlike commercial inference, the experiment code can read or write any of the internal states of the neural networks being studied. This code creates a computation graph that can be sent to the remote service and interleaved with the execution of the neural network.

How you use nnsight
-------------------

Nnsight is built on pytorch.

Running inference on a huge remote model with nnsight is very similar to running a neural network locally on your own workstation. In fact, with nnsight, the same code for running experiments locally on small models can also be used on large models, just by changing a few arguments.

The difference between nnsight and normal inference is that when you use nnsight, you do not treat the model as an opaque black box.
Instead, you set up a python ``with`` context that enables you to get direct access to model internals while the neural network runs.
Here is how it looks:

.. code-block:: python
:linenos:
from nnsight import LanguageModel
model = LanguageModel('meta-llama/Llama-2-70b-hf')
with model.forward(remote=True) as runner:
with runner.invoke('The Eiffel Tower is in the city of ') as invoker:
hidden_state = model.layers[10].input.save() # save one hidden state
model.layers[11].mlp.output = 0 # change one MLP module output
print('The model predicts', runner.output)
print('The internal state was', hidden_state.value)
The library is easy to use. Any HuggingFace model can be loaded into a ``LanguageModel`` object, as you can wee on line 2. Notice we are loading a 70-billion parameter model, which is ordinarily pretty difficult to load on a regular workstation since it would take 140-280 gigabytes of GPU RAM just to store the parameters.

The trick that lets us work with this huge model is on line 3. We set the flag ``remote=True`` to indicate that we want to actually run the network on the remote service. By default the remote service will be NDIF. If we want to just run a smaller model quickly, we could leave it as ``remote=False``.

Then when we invoke the model on line 4, we do not just call it as a function. Instead, we use it as a ``with`` context manager. The reason is that nnsight does not treat neural network models as black boxes; it provides direct access to model internals.

You can see what simple direct access looks like on lines 5-6. On line 5, we grab a hidden state at layer 10, and on layer 6, we change the output of an MLP module inside the transformer at layer 11.

When you run this ``with``-block code on lines 5 and 6 on your local workstation, it actually creates a computation graph storing all the calculations you want to do. When the outermost ``with`` block is completed, all the defined caculations are sent to the remote server and executed there. Then when it's all done, the results can be accessed on your local workstation as shown on line 7 and 8.

What happens behind the scenes?
-------------------------------
When using nnsight, it is helpful to understand that the operations are not executed immediately but instead adds to an intervention graph that is executed alongside the model's computation graph upon exit of the with block.

An example of one such intervention graph can be seen below:

.. figure:: _static/images/intrgraph.png

An example of an intervention graph. Operations in research code create nodes in the graph which depend on module inputs and outputs as well as other nodes. Then, this intervention graph is interleaved with the normal computation graph of the chosen model, and requested inputs and outputs are injected into the intervention graph for execution.

Basic access to model internals can give you a lot of insight about what is going on inside a large model as it runs. For example, you can use the `logit lens <https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens>`_ to read internal hidden states as text. And use can use `causal tracing <https://rome.baulab.info/>`_ or `path patching <https://arxiv.org/abs/2304.05969>`_ or `other circuit discovery methods <https://arxiv.org/abs/2310.10348>`_ to locate the layers and components within the network that play a decisive role in making a decision.

And with nnsight, you can use thes methods on large models like Llama-2-70b.

The nnsight library also provies full access to gradients and optimizations methods, out of order module applications, cross prompt interventions and many more features.

See the :doc:`tutorials/basics` and :doc:`tutorials/features` pages for more information on nnsight functionality.

The project is currently in Alpha pre-release and is looking for early users/and contributors!

If you are interested in contributing or being an early user, join the `NDIF Discord <https://discord.gg/ZRPgsf6P>`_ for updates, feature requests, bug reports and opportunities to help with the effort.
8 changes: 4 additions & 4 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ nnsight
start
documentation
tutorials
About <about>

.. grid:: 1 1 2 2
:class-container: hero
Expand All @@ -33,7 +34,7 @@ nnsight

interpretable neural networks

**nnsight** (/ɛn.saɪt/) is a package for interpreting and manipulating the internals of large models.
**nnsight** (/ɛn.saɪt/) is a package for interpreting and manipulating the internals of large models

.. div:: button-group

Expand All @@ -53,16 +54,15 @@ nnsight
:color: primary
:outline:

Documentation
Docs


.. div:: sd-fs-1 sd-font-weight-bold sd-text-center sd-text-primary sd-mb-5

Key Features

.. grid:: 1 1 2 2
:gutter: 5
:class-container: page-container
:class-container: features

.. grid-item::

Expand Down
3 changes: 2 additions & 1 deletion docs/source/tutorials.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Check out the :bdg-link-primary:`main demo <tutorials/notebooks/main_demo.ipynb>
:gutter: 3

.. grid-item-card:: Basics
:link: basics
:link: tutorials/main_demo.ipynb

Walk through the basic functionality of the package.

Expand All @@ -31,5 +31,6 @@ Check out the :bdg-link-primary:`main demo <tutorials/notebooks/main_demo.ipynb>
:hidden:
:maxdepth: 1

tutorials/main_demo.ipynb
tutorials/basics
tutorials/features
2 changes: 1 addition & 1 deletion docs/source/tutorials/basics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ Basics
.. toctree::
:maxdepth: 1

notebooks/main_demo.ipynb
main_demo.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/nnsight/config.yaml
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
API:
HOST: ndif.baulab.us
HOST: 10.201.22.179:5550
3 changes: 2 additions & 1 deletion src/nnsight/contexts/Invoker.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ def __enter__(self) -> Invoker:
return self

def __exit__(self, exc_type, exc_val, exc_tb) -> None:
pass
if isinstance(exc_val, BaseException):
raise exc_val

def next(self, increment: int = 1) -> None:
"""Designates subsequent interventions should be applied to the next generation for multi-iteration generation runs.
Expand Down
8 changes: 5 additions & 3 deletions src/nnsight/contexts/Runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class Runner(Tracer):
def __init__(
self,
*args,
generation:bool = False,
generation: bool = False,
blocking: bool = True,
remote: bool = False,
**kwargs,
Expand All @@ -59,6 +59,8 @@ def __enter__(self) -> Runner:

def __exit__(self, exc_type, exc_val, exc_tb) -> None:
"""On exit, run and generate using the model whether locally or on the server."""
if isinstance(exc_val, BaseException):
raise exc_val
if self.remote:
self.run_server()
else:
Expand All @@ -83,7 +85,7 @@ def run_server(self):
model_name=self.model.repoid_path_clsname,
batched_input=self.batched_input,
intervention_graph=self.graph,
generation=self.generation
generation=self.generation,
)

if self.blocking:
Expand All @@ -94,7 +96,7 @@ def run_server(self):
def blocking_request(self, request: pydantics.RequestModel):
# Create a socketio connection to the server.
sio = socketio.Client()
sio.connect(f"wss://{CONFIG.API.HOST}")
sio.connect(f"ws://{CONFIG.API.HOST}", transports=['websocket'])

# Called when receiving a response from the server.
@sio.on("blocking_response")
Expand Down
53 changes: 26 additions & 27 deletions src/nnsight/intervention.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class InterventionProxy(Proxy):
Calling ``.shape`` on an InterventionProxy returns the shape or collection of shapes for the tensors traced through this module.
Calling ``.value`` on an InterventionProxy returns the actual populated values, updated during actual execution of the model.
"""

def save(self) -> InterventionProxy:
Expand All @@ -76,18 +76,11 @@ def save(self) -> InterventionProxy:
return self

def retain_grad(self):

self.node.graph.add(
target=torch.Tensor.retain_grad,
args=[self.node]
)
self.node.graph.add(target=torch.Tensor.retain_grad, args=[self.node])

# We need to set the values of self to values of self to add this into the computation graph so grad flows through it
# This is because in intervene(), we call .narrow on activations which removes it from the grad path
self.node.graph.add(
target=Proxy.proxy_update,
args=[self.node, self.node]
)
self.node.graph.add(target=Proxy.proxy_update, args=[self.node, self.node])

@property
def token(self) -> TokenIndexer:
Expand All @@ -96,7 +89,7 @@ def token(self) -> TokenIndexer:
Makes positive indices negative as tokens are padded on the left.
Example:
.. code-block:: python
model.transformer.h[0].mlp.output.token[0]
Expand Down Expand Up @@ -167,7 +160,6 @@ def intervene(activations: Any, module_path: str, graph: Graph, key: str):
module_path = f"{module_path}.{key}.{graph.generation_idx}"

if module_path in graph.argument_node_names:

argument_node_names = graph.argument_node_names[module_path]

# multiple argument nodes can have same module_path if there are multiple invocations.
Expand Down Expand Up @@ -202,21 +194,19 @@ class HookModel(AbstractContextManager):
Should have signature of [outputs(Any), module_path(str)] -> outputs(Any)
handles (List[RemovableHandle]): Handles returned from registering hooks as to be used when removing hooks on __exit__.
"""
#TODO maybe only apply the necassay hooks (e.x if a module has a input hook, all hooks will be added)

def __init__(
self,
model: torch.nn.Module,
modules: List[str],
module_keys: List[str],
input_hook: Callable = None,
output_hook: Callable = None,
backward_input_hook:Callable = None,
backward_output_hook:Callable = None
backward_input_hook: Callable = None,
backward_output_hook: Callable = None,
) -> None:
self.model = model
self.modules: List[Tuple[torch.nn.Module, str]] = [
(util.fetch_attr(self.model, module_path), module_path)
for module_path in modules
]
self.module_keys = module_keys

self.input_hook = input_hook
self.output_hook = output_hook
self.backward_input_hook = backward_input_hook
Expand All @@ -231,34 +221,43 @@ def __enter__(self) -> HookModel:
HookModel: HookModel object.
"""

for module, module_path in self.modules:
if self.input_hook is not None:
for module_key in self.module_keys:
*module_atoms, hook_type = module_key.split(".")[:-1]
module_path = ".".join(module_atoms)

module: torch.nn.Module = util.fetch_attr(self.model, module_path)

if hook_type == "input":

def input_hook(module, input, module_path=module_path):
return self.input_hook(input, module_path)

self.handles.append(module.register_forward_pre_hook(input_hook))

if self.output_hook is not None:
elif hook_type == "output":

def output_hook(module, input, output, module_path=module_path):
return self.output_hook(output, module_path)

self.handles.append(module.register_forward_hook(output_hook))

if self.backward_input_hook is not None:
elif hook_type == "backward_input":

def backward_input_hook(module, input, output, module_path=module_path):
return self.backward_input_hook(input, module_path)

self.handles.append(module.register_full_backward_hook(backward_input_hook))
self.handles.append(
module.register_full_backward_hook(backward_input_hook)
)

if self.backward_output_hook is not None:
elif hook_type == "backward_output":

def backward_output_hook(module, output, module_path=module_path):
return self.backward_output_hook(output, module_path)

self.handles.append(module.register_full_backward_pre_hook(backward_output_hook))
self.handles.append(
module.register_full_backward_pre_hook(backward_output_hook)
)

return self

Expand Down
13 changes: 2 additions & 11 deletions src/nnsight/models/AbstractModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,26 +173,17 @@ def __call__(
lambda module, input, output: graph.increment()
)

# The intervention graph for running a Model will have the modules that are involved
# in the graph's argument_node_names.
modules = set(
[
".".join(name.split(".")[:-2])
for name in graph.argument_node_names.keys()
]
)

logger.info(f"Running `{self.repoid_path_clsname}`...")

# Send local_model to graph to re-compile
graph.compile(self.local_model)

inputs = self._prepare_inputs(inputs)

with torch.inference_mode(mode=inference):
with HookModel(
self.local_model,
list(modules),
list(graph.argument_node_names.keys()),
input_hook=lambda activations, module_path: intervene(
activations, module_path, graph, "input"
),
Expand Down
Loading

0 comments on commit 19b78fc

Please sign in to comment.