Skip to content

Commit

Permalink
Added option to save the figs
Browse files Browse the repository at this point in the history
  • Loading branch information
tatp22 committed Jun 23, 2020
1 parent 9c83535 commit 1d15c59
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 6 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,10 @@ model = Linformer(
x = torch.randn(1, 512, 16) # What input you want to visualize
y = model(x, visualize=True)
vis = Visualizer(model)
vis.plot_all_heads()
vis.plot_all_heads(title="All P_bar matrices", # Change the title if you'd like
show=True, # Show the picture
save_file="./heads.png", # If not None, save the picture to a file
)
```

## E and F matrices
Expand Down
3 changes: 2 additions & 1 deletion examples/example_vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@
x = torch.randn(1, 512, 16)
y = model(x, visualize=True)
vis = Visualizer(model)
vis.plot_all_heads()
vis.plot_all_heads(title="All P_bar matrices",
show=True)
print(y) # (1, 512, 16)
Binary file modified head_vis.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 4 additions & 4 deletions linformer_pytorch/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ def get_head_visualization(self, depth_no, max_depth, head_no, axs):

def plot_all_heads(self, title="Visualization of Attention Heads", show=True, save_file=None):
"""
Showcases all of the heads on a grid. It shows the W^Q*E*W^K matrices for each head,
which turns out to be an NxK matrix
Showcases all of the heads on a grid. It shows the P_bar matrices for each head,
which turns out to be an NxK matrix for each of them.
"""

self.depth = self.net.depth
self.heads = self.net.nhead

fig, axs = plt.subplots(self.depth, self.heads)
axs = axs.reshape((self.depth, self.heads)) # In case heads or nheads are 1, bug i think
axs = axs.reshape((self.depth, self.heads)) # In case depth or nheads are 1, bug i think

fig.suptitle(title, fontsize=26)

Expand All @@ -59,4 +59,4 @@ def plot_all_heads(self, title="Visualization of Attention Heads", show=True, sa
plt.show()

if save_file is not None:
pass #TODO
fig.savefig(save_file)

0 comments on commit 1d15c59

Please sign in to comment.