Skip to content

Commit

Permalink
Updated README, changed default dim_d behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
tatp22 committed Jun 21, 2020
1 parent 66dadd4 commit 8e2f0f9
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 7 deletions.
51 changes: 49 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ cd linformer-pytorch
```

## Code example
Linformer self attention, stacks of `MHAttention` and `FeedForward`s

```python
from linformer_pytorch import Linformer
Expand All @@ -50,6 +51,54 @@ y = model(x)
print(y) # (1, 262144, 64)
```

Linformer Multihead attention

```python
from linformer_pytorch import MHAttention
import torch

model = MHAttention(
input_size=512, # Dimension 1 of the input
channels=64, # Dimension 2 of the input
dim=8, # Dim of each attn head
dim_k=128, # What to sample the input length down to
nhead=8, # Number of heads
dropout=0, # Dropout for each of the heads
activation="gelu", # Activation after attention has been concat'd
checkpoint_level="C2", # If C2, checkpoint each of the heads
parameter_sharing="layerwise", # What level of parameter sharing to do
E_proj, F_proj, # The E and F projection matrices
)
x = torch.randn(1, 512, 64)
y = model(x)
print(y) # (1, 512, 64)
```

The Linear attention head, the novelty of the paper

```python
from linformer_pytorch import LinearAttentionHead
import torch

model = LinearAttentionHead(
dim=64, # Dim 2 of the input
dropout=0.1, # Dropout of the P matrix
E_proj, F_proj # The E and F layers
)
x = torch.randn(1, 512, 64)
y = model(x, x, x)
print(y) # (1, 512, 64)
```

An easy way to get the `E` and `F` matrices can be done by calling the `get_EF` function. As an example, for an `n` of `1000` and a `k` of `100`:

```python
from linfromer_pytorch import get_EF
import torch

E = get_EF(1000, 100)
```

## Checkpoint levels
As an attempt to further introduce memory savings, the concept of checkpoint levels have been introduced. The current three checkpoint levels are `C0`, `C1`, and `C2`. When going up checkpoint levels, one sacrifices speed for memory savings. That is, checkpoint level `C0` is the fastest, but takes up the most space on the GPU, while `C2` is the slowest, but takes up the least space on the GPU. The details of each checkpoint level are as follows:
* `C0`: No checkpointing. The models runs while keeping all of the attention heads and ff layers in the GPU memory.
Expand Down Expand Up @@ -105,11 +154,9 @@ print(y) # (1, 500, 16)
* In practice, I found that the memory and time requirements are more on the order of O(nkd), with n=`input_size`, k=`dim_k`, and d=`dim_d`.

## Future work
* ~~Add option to change the `E` and `F` downsampling matrices~~
* Run some benchmark tests to see what the performance is
* Instead of matrix multiplication to bring the dimensions down to k (With EKW and FVW), try to do convolution, as mentioned in the paper, with a stride length and kernel size of n/k.
* Right now, all that is implemented is the encoder. Add the decoder at a future point in time.
* ~~In the paper, empirical studies showed that one can reduce the value of k when increasing depth, because the eigenvalues went up. Add some option to decrease k more per layers, saving even more memory.~~

## Disclaimer
This is the first time that I am reproducing a result from a paper, so some things may be wrong. If you see a problem, please open up an issue, and I will attempt to work on it.
Expand Down
3 changes: 1 addition & 2 deletions examples/example_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
model = Linformer(
input_size=512,
channels=16,
dim_d=32,
dim_k=16,
dim_ff=32,
nhead=6,
nhead=4,
depth=3,
activation="relu",
checkpoint_level="C2",
Expand Down
10 changes: 7 additions & 3 deletions linformer_pytorch/linformer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,21 +141,25 @@ class Linformer(nn.Module):
My attempt at reproducing the Linformer Paper
https://arxiv.org/pdf/2006.04768.pdf
"""
def __init__(self, input_size=8192, channels=128, dim_k=64, dim_ff=256, dim_d=512, dropout_ff=0.15, nhead=4, depth=1, dropout=0.1, activation="gelu", use_pos_emb=True, checkpoint_level="C0", parameter_sharing="layerwise", k_reduce_by_layer=0):
def __init__(self, input_size=8192, channels=128, dim_k=64, dim_ff=256, dim_d=None, dropout_ff=0.15, nhead=4, depth=1, dropout=0.1, activation="gelu", use_pos_emb=True, checkpoint_level="C0", parameter_sharing="layerwise", k_reduce_by_layer=0):
super(Linformer, self).__init__()
assert activation == "gelu" or activation == "relu", "Only gelu and relu activations supported for now"
assert checkpoint_level == "C0" or checkpoint_level == "C1" or checkpoint_level == "C2", "Checkpoint level has to be either C0, C1, or C2."
assert parameter_sharing == "none" or parameter_sharing == "headwise" or parameter_sharing == "kv" or parameter_sharing == "layerwise", "The `parameter_sharing` flag has to be either 'none', 'headwise', 'kv', or 'layerwise'."
assert channels % nhead == 0 if dim_d is None else True, "If `dim_d` is not set to a custom value, `channels` must be divisible by `nhead`!"

self.layers = nn.ModuleList()
self.input_size = input_size
self.channels = channels
self.checkpoint_level = checkpoint_level
self.pos_emb = PositionalEmbedding(channels) if use_pos_emb else None
self.E = get_EF(input_size, dim_d)

head_dim = channels // nhead if dim_d is None else dim_d

self.E = get_EF(input_size, head_dim)
self.F = self.E

get_attn = lambda curr_dim_k: MHAttention(input_size, dim_d, channels, curr_dim_k, nhead, dropout, activation, checkpoint_level, parameter_sharing, self.E, self.F)
get_attn = lambda curr_dim_k: MHAttention(input_size, head_dim, channels, curr_dim_k, nhead, dropout, activation, checkpoint_level, parameter_sharing, self.E, self.F)
get_ff = lambda: FeedForward(channels, dim_ff, dropout_ff)
norm_attn = lambda: nn.LayerNorm(channels)
norm_ff = lambda: nn.LayerNorm(channels)
Expand Down

0 comments on commit 8e2f0f9

Please sign in to comment.