diff --git a/README.md b/README.md index f85515a..33201e4 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ cd linformer-pytorch ``` ## Code example +Linformer self attention, stacks of `MHAttention` and `FeedForward`s ```python from linformer_pytorch import Linformer @@ -50,6 +51,54 @@ y = model(x) print(y) # (1, 262144, 64) ``` +Linformer Multihead attention + +```python +from linformer_pytorch import MHAttention +import torch + +model = MHAttention( + input_size=512, # Dimension 1 of the input + channels=64, # Dimension 2 of the input + dim=8, # Dim of each attn head + dim_k=128, # What to sample the input length down to + nhead=8, # Number of heads + dropout=0, # Dropout for each of the heads + activation="gelu", # Activation after attention has been concat'd + checkpoint_level="C2", # If C2, checkpoint each of the heads + parameter_sharing="layerwise", # What level of parameter sharing to do + E_proj, F_proj, # The E and F projection matrices + ) +x = torch.randn(1, 512, 64) +y = model(x) +print(y) # (1, 512, 64) +``` + +The Linear attention head, the novelty of the paper + +```python +from linformer_pytorch import LinearAttentionHead +import torch + +model = LinearAttentionHead( + dim=64, # Dim 2 of the input + dropout=0.1, # Dropout of the P matrix + E_proj, F_proj # The E and F layers + ) +x = torch.randn(1, 512, 64) +y = model(x, x, x) +print(y) # (1, 512, 64) +``` + +An easy way to get the `E` and `F` matrices can be done by calling the `get_EF` function. As an example, for an `n` of `1000` and a `k` of `100`: + +```python +from linfromer_pytorch import get_EF +import torch + +E = get_EF(1000, 100) +``` + ## Checkpoint levels As an attempt to further introduce memory savings, the concept of checkpoint levels have been introduced. The current three checkpoint levels are `C0`, `C1`, and `C2`. When going up checkpoint levels, one sacrifices speed for memory savings. That is, checkpoint level `C0` is the fastest, but takes up the most space on the GPU, while `C2` is the slowest, but takes up the least space on the GPU. The details of each checkpoint level are as follows: * `C0`: No checkpointing. The models runs while keeping all of the attention heads and ff layers in the GPU memory. @@ -105,11 +154,9 @@ print(y) # (1, 500, 16) * In practice, I found that the memory and time requirements are more on the order of O(nkd), with n=`input_size`, k=`dim_k`, and d=`dim_d`. ## Future work -* ~~Add option to change the `E` and `F` downsampling matrices~~ * Run some benchmark tests to see what the performance is * Instead of matrix multiplication to bring the dimensions down to k (With EKW and FVW), try to do convolution, as mentioned in the paper, with a stride length and kernel size of n/k. * Right now, all that is implemented is the encoder. Add the decoder at a future point in time. -* ~~In the paper, empirical studies showed that one can reduce the value of k when increasing depth, because the eigenvalues went up. Add some option to decrease k more per layers, saving even more memory.~~ ## Disclaimer This is the first time that I am reproducing a result from a paper, so some things may be wrong. If you see a problem, please open up an issue, and I will attempt to work on it. diff --git a/examples/example_small.py b/examples/example_small.py index a2c590d..fa3c220 100644 --- a/examples/example_small.py +++ b/examples/example_small.py @@ -7,10 +7,9 @@ model = Linformer( input_size=512, channels=16, - dim_d=32, dim_k=16, dim_ff=32, - nhead=6, + nhead=4, depth=3, activation="relu", checkpoint_level="C2", diff --git a/linformer_pytorch/linformer_pytorch.py b/linformer_pytorch/linformer_pytorch.py index 1daafca..b103c04 100644 --- a/linformer_pytorch/linformer_pytorch.py +++ b/linformer_pytorch/linformer_pytorch.py @@ -141,21 +141,25 @@ class Linformer(nn.Module): My attempt at reproducing the Linformer Paper https://arxiv.org/pdf/2006.04768.pdf """ - def __init__(self, input_size=8192, channels=128, dim_k=64, dim_ff=256, dim_d=512, dropout_ff=0.15, nhead=4, depth=1, dropout=0.1, activation="gelu", use_pos_emb=True, checkpoint_level="C0", parameter_sharing="layerwise", k_reduce_by_layer=0): + def __init__(self, input_size=8192, channels=128, dim_k=64, dim_ff=256, dim_d=None, dropout_ff=0.15, nhead=4, depth=1, dropout=0.1, activation="gelu", use_pos_emb=True, checkpoint_level="C0", parameter_sharing="layerwise", k_reduce_by_layer=0): super(Linformer, self).__init__() assert activation == "gelu" or activation == "relu", "Only gelu and relu activations supported for now" assert checkpoint_level == "C0" or checkpoint_level == "C1" or checkpoint_level == "C2", "Checkpoint level has to be either C0, C1, or C2." assert parameter_sharing == "none" or parameter_sharing == "headwise" or parameter_sharing == "kv" or parameter_sharing == "layerwise", "The `parameter_sharing` flag has to be either 'none', 'headwise', 'kv', or 'layerwise'." + assert channels % nhead == 0 if dim_d is None else True, "If `dim_d` is not set to a custom value, `channels` must be divisible by `nhead`!" self.layers = nn.ModuleList() self.input_size = input_size self.channels = channels self.checkpoint_level = checkpoint_level self.pos_emb = PositionalEmbedding(channels) if use_pos_emb else None - self.E = get_EF(input_size, dim_d) + + head_dim = channels // nhead if dim_d is None else dim_d + + self.E = get_EF(input_size, head_dim) self.F = self.E - get_attn = lambda curr_dim_k: MHAttention(input_size, dim_d, channels, curr_dim_k, nhead, dropout, activation, checkpoint_level, parameter_sharing, self.E, self.F) + get_attn = lambda curr_dim_k: MHAttention(input_size, head_dim, channels, curr_dim_k, nhead, dropout, activation, checkpoint_level, parameter_sharing, self.E, self.F) get_ff = lambda: FeedForward(channels, dim_ff, dropout_ff) norm_attn = lambda: nn.LayerNorm(channels) norm_ff = lambda: nn.LayerNorm(channels)