diff --git a/README.md b/README.md index 33cefec..c785c17 100644 --- a/README.md +++ b/README.md @@ -172,7 +172,10 @@ model = Linformer( x = torch.randn(1, 512, 16) # What input you want to visualize y = model(x, visualize=True) vis = Visualizer(model) -vis.plot_all_heads() +vis.plot_all_heads(title="All P_bar matrices", # Change the title if you'd like + show=True, # Show the picture + save_file="./heads.png", # If not None, save the picture to a file + ) ``` ## E and F matrices diff --git a/examples/example_vis.py b/examples/example_vis.py index 98e152e..0af7d4f 100644 --- a/examples/example_vis.py +++ b/examples/example_vis.py @@ -19,5 +19,6 @@ x = torch.randn(1, 512, 16) y = model(x, visualize=True) vis = Visualizer(model) -vis.plot_all_heads() +vis.plot_all_heads(title="All P_bar matrices", + show=True) print(y) # (1, 512, 16) diff --git a/head_vis.png b/head_vis.png index 1e147f5..c650f53 100644 Binary files a/head_vis.png and b/head_vis.png differ diff --git a/linformer_pytorch/visualizer.py b/linformer_pytorch/visualizer.py index 5c4fa20..d2f6e31 100644 --- a/linformer_pytorch/visualizer.py +++ b/linformer_pytorch/visualizer.py @@ -39,15 +39,15 @@ def get_head_visualization(self, depth_no, max_depth, head_no, axs): def plot_all_heads(self, title="Visualization of Attention Heads", show=True, save_file=None): """ - Showcases all of the heads on a grid. It shows the W^Q*E*W^K matrices for each head, - which turns out to be an NxK matrix + Showcases all of the heads on a grid. It shows the P_bar matrices for each head, + which turns out to be an NxK matrix for each of them. """ self.depth = self.net.depth self.heads = self.net.nhead fig, axs = plt.subplots(self.depth, self.heads) - axs = axs.reshape((self.depth, self.heads)) # In case heads or nheads are 1, bug i think + axs = axs.reshape((self.depth, self.heads)) # In case depth or nheads are 1, bug i think fig.suptitle(title, fontsize=26) @@ -59,4 +59,4 @@ def plot_all_heads(self, title="Visualization of Attention Heads", show=True, sa plt.show() if save_file is not None: - pass #TODO + fig.savefig(save_file)