Skip to content

Commit

Permalink
Added intermediate ff dimension
Browse files Browse the repository at this point in the history
Now, the model dimension can be different in the intermediate layers.
This change applies to the ff module, and only in the encoder. Now, if
the flag `ff_intermediate` is not None, the layers will look like this:

```
channels -> ff_dim -> ff_intermediate (For layer 1)
ff_intermediate -> ff_dim -> ff_intermediate (For layers 2 to depth-1)
ff_intermediate -> ff_dim -> channels (For layer depth)
```

As opposed to

```
channels -> ff_dim -> channels (For all layers)
```
  • Loading branch information
tatp22 committed Aug 4, 2020
1 parent 5d44af9 commit e21153a
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 19 deletions.
25 changes: 25 additions & 0 deletions examples/example_intermediate_ff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import sys
import torch

sys.path.insert(0, "../")
from linformer_pytorch import Linformer

model = Linformer(
input_size=510,
channels=21,
dim_d=26,
dim_k=61,
dim_ff=32,
nhead=4,
depth=3,
activation="relu",
checkpoint_level="C1",
parameter_sharing="none",
k_reduce_by_layer=1,
include_ff=True,
convolution=True,
ff_intermediate=64,
)
x = torch.randn(1, 510, 21)
y = model(x)
print(y) # (1, 512, 16)
46 changes: 27 additions & 19 deletions linformer_pytorch/linformer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,14 @@ class Residual(nn.Module):
Implemenation taken from
https://github.com/lucidrains/sinkhorn-transformer/blob/master/sinkhorn_transformer/sinkhorn_transformer.py
"""
def __init__(self, fn):
def __init__(self, fn, input_channels=0, output_channels=0):
super(Residual, self).__init__()
self.fn = fn
self.resample = nn.Linear(input_channels, output_channels) if input_channels != output_channels else None

def forward(self, tensor, **kwargs):
if self.resample is not None:
return self.resample(tensor) + self.fn(tensor, **kwargs)
return tensor + self.fn(tensor, **kwargs)

class PreNorm(nn.Module):
Expand Down Expand Up @@ -99,10 +102,10 @@ class FeedForward(nn.Module):
"""
Standard Feed Forward Layer
"""
def __init__(self, channels, ff_dim, dropout, activation="gelu"):
def __init__(self, input_channels, output_channels, ff_dim, dropout, activation="gelu"):
super(FeedForward, self).__init__()
self.w_1 = nn.Linear(channels, ff_dim)
self.w_2 = nn.Linear(ff_dim, channels)
self.w_1 = nn.Linear(input_channels, ff_dim)
self.w_2 = nn.Linear(ff_dim, output_channels)
self.activation = get_act(activation)
self.dropout = nn.Dropout(dropout)

Expand Down Expand Up @@ -242,12 +245,14 @@ class Linformer(nn.Module):
My attempt at reproducing the Linformer Paper
https://arxiv.org/pdf/2006.04768.pdf
"""
def __init__(self, input_size, channels, dim_k, dim_ff=256, dim_d=None, dropout_ff=0.15, nhead=4, depth=1, dropout=0.1, activation="gelu", checkpoint_level="C0", parameter_sharing="layerwise", k_reduce_by_layer=0, full_attention=False, include_ff=True, w_o_intermediate_dim=None, decoder_mode=False, causal=False, convolution=False):
def __init__(self, input_size, channels, dim_k, dim_ff=256, dim_d=None, dropout_ff=0.15, nhead=4, depth=1, dropout=0.1, activation="gelu", checkpoint_level="C0", parameter_sharing="layerwise", k_reduce_by_layer=0, full_attention=False, include_ff=True, w_o_intermediate_dim=None, decoder_mode=False, causal=False, convolution=False, ff_intermediate=None):
super(Linformer, self).__init__()
assert activation == "gelu" or activation == "relu", "Only gelu and relu activations supported for now"
assert checkpoint_level == "C0" or checkpoint_level == "C1" or checkpoint_level == "C2", "Checkpoint level has to be either C0, C1, or C2."
assert parameter_sharing == "none" or parameter_sharing == "headwise" or parameter_sharing == "kv" or parameter_sharing == "layerwise", "The `parameter_sharing` flag has to be either 'none', 'headwise', 'kv', or 'layerwise'."
assert channels % nhead == 0 if dim_d is None else True, "If `dim_d` is not set to a custom value, `channels` must be divisible by `nhead`!"
assert not (ff_intermediate and parameter_sharing=="layerwise"), "Parameter sharing must not be layerwise if ff_intermediate is enabled!"
assert not (ff_intermediate and decoder_mode), "Raising the dimension in the middle cannot be done in the decoder!"

layers = nn.ModuleList()
self.decoder_mode = decoder_mode
Expand All @@ -264,15 +269,18 @@ def __init__(self, input_size, channels, dim_k, dim_ff=256, dim_d=None, dropout_
# If we want causal but only with the encoder
causal_enc = gen_causal_mask(input_size, dim_k, full_attention) if (causal and not decoder_mode) else None

get_attn = lambda curr_dim_k: MHAttention(input_size, head_dim, channels, curr_dim_k, nhead, dropout, activation, checkpoint_level, parameter_sharing, E_proj, E_proj, full_attention, causal_enc, w_o_intermediate_dim, decoder_mode=False, convolution=convolution)
get_attn_context = lambda curr_dim_k: MHAttention(input_size, head_dim, channels, curr_dim_k, nhead, dropout, activation, checkpoint_level, parameter_sharing, E_proj, E_proj, full_attention, causal_mask, w_o_intermediate_dim, decoder_mode=True, convolution=convolution)
get_ff = lambda: FeedForward(channels, dim_ff, dropout_ff)
get_attn = lambda attn_channels, curr_dim_k: MHAttention(input_size, head_dim, attn_channels, curr_dim_k, nhead, dropout, activation, checkpoint_level, parameter_sharing, E_proj, E_proj, full_attention, causal_enc, w_o_intermediate_dim, decoder_mode=False, convolution=convolution)
get_attn_context = lambda attn_channels, curr_dim_k: MHAttention(input_size, head_dim, attn_channels, curr_dim_k, nhead, dropout, activation, checkpoint_level, parameter_sharing, E_proj, E_proj, full_attention, causal_mask, w_o_intermediate_dim, decoder_mode=True, convolution=convolution)
get_ff = lambda input_channels, output_channels: FeedForward(input_channels, output_channels, dim_ff, dropout_ff)

for index in range(depth):
attn_layer = get_attn(max(1, dim_k - index*k_reduce_by_layer))
ff_layer = get_ff()
input_channels = ff_intermediate if (index != 0 and ff_intermediate is not None) and not decoder_mode else channels
output_channels = ff_intermediate if (index != depth-1 and ff_intermediate is not None) and not decoder_mode else channels
# TODO: Change the input and output channels here
attn_layer = get_attn(input_channels, max(1, dim_k - index*k_reduce_by_layer))
ff_layer = get_ff(input_channels, output_channels)

attn_layer, ff_layer = map(lambda fn: Residual(PreNorm(channels, fn)), (attn_layer, ff_layer))
attn_layer, ff_layer = map(lambda res_ch_in, res_ch_out, fn: Residual(PreNorm(input_channels, fn), res_ch_in, res_ch_out), (input_channels, input_channels), (input_channels, output_channels), (attn_layer, ff_layer))

if include_ff:
layers.extend([attn_layer, ff_layer])
Expand All @@ -282,8 +290,8 @@ def __init__(self, input_size, channels, dim_k, dim_ff=256, dim_d=None, dropout_
if not self.decoder_mode:
continue

attn_context = get_attn_context(max(1, dim_k - index*k_reduce_by_layer))
ff_context = get_ff()
attn_context = get_attn_context(channels, max(1, dim_k - index*k_reduce_by_layer))
ff_context = get_ff(channels, channels)

attn_context, ff_context = map(lambda fn: Residual(PreNorm(channels, fn)), (attn_context, ff_context))

Expand Down Expand Up @@ -317,7 +325,7 @@ class LinformerLM(nn.Module):
"""
def __init__(self, num_tokens, input_size, channels,
dim_k=64, dim_ff=1024, dim_d=None,
dropout_ff=0.1, nhead=4, depth=2,
dropout_ff=0.1, nhead=4, depth=2, ff_intermediate=None,
dropout=0.05, activation="gelu", checkpoint_level="C0",
parameter_sharing="layerwise", k_reduce_by_layer=0, full_attention=False,
include_ff=True, w_o_intermediate_dim=None, emb_dim=None,
Expand All @@ -331,10 +339,10 @@ def __init__(self, num_tokens, input_size, channels,
self.pos_emb = PositionalEmbedding(emb_dim)
self.linformer = Linformer(input_size, channels, dim_k=dim_k,
dim_ff=dim_ff, dim_d=dim_d, dropout_ff=dropout_ff,
nhead=nhead, depth=depth, dropout=dropout,
nhead=nhead, depth=depth, dropout=dropout, ff_intermediate=ff_intermediate,
activation=activation, checkpoint_level=checkpoint_level, parameter_sharing=parameter_sharing,
k_reduce_by_layer=k_reduce_by_layer, full_attention=full_attention, include_ff=include_ff,
w_o_intermediate_dim=w_o_intermediate_dim, decoder_mode=decoder_mode, causal=causal, convolution=False)
w_o_intermediate_dim=w_o_intermediate_dim, decoder_mode=decoder_mode, causal=causal, convolution=convolution)

if emb_dim != channels:
self.linformer = ProjectInOut(self.linformer, emb_dim, channels)
Expand All @@ -356,7 +364,7 @@ class LinformerEncDec(nn.Module):
A complete seq -> seq translation task. Complete with an encoder and a decoder module.
"""
def __init__(self, enc_num_tokens, enc_input_size, enc_channels, dec_num_tokens, dec_input_size, dec_channels,
enc_dim_k=64, enc_dim_ff=1024, enc_dim_d=None,
enc_dim_k=64, enc_dim_ff=1024, enc_dim_d=None, enc_ff_intermediate=None, dec_ff_intermediate=None,
enc_dropout_ff=0.1, enc_nhead=4, enc_depth=2, enc_dropout=0.05, enc_parameter_sharing="layerwise", enc_k_reduce_by_layer=0,
enc_full_attention=False, enc_include_ff=True, enc_w_o_intermediate_dim=None, enc_emb_dim=None, enc_convolution=False,
dec_dim_k=64, dec_dim_ff=1024, dec_dim_d=None, dec_dropout_ff=0.1, dec_nhead=4, dec_depth=2, dec_dropout=0.05,
Expand All @@ -366,11 +374,11 @@ def __init__(self, enc_num_tokens, enc_input_size, enc_channels, dec_num_tokens,
super(LinformerEncDec, self).__init__()
self.encoder = LinformerLM(num_tokens=enc_num_tokens, input_size=enc_input_size, channels=enc_channels, dim_d=enc_dim_d, dim_ff=enc_dim_ff,
dim_k=enc_dim_k, dropout_ff=enc_dropout_ff, nhead=enc_nhead, depth=enc_depth, dropout=enc_dropout,
parameter_sharing=enc_parameter_sharing, k_reduce_by_layer=enc_k_reduce_by_layer,
parameter_sharing=enc_parameter_sharing, k_reduce_by_layer=enc_k_reduce_by_layer, ff_intermediate=enc_ff_intermediate,
full_attention=enc_full_attention, include_ff=enc_include_ff, w_o_intermediate_dim=enc_w_o_intermediate_dim,
emb_dim=enc_emb_dim, return_emb=True, activation=activation, checkpoint_level=checkpoint_level, convolution=enc_convolution)
self.decoder = LinformerLM(num_tokens=dec_num_tokens, input_size=dec_input_size, channels=dec_channels, dim_d=dec_dim_d, dim_ff=dec_dim_ff,
dim_k=dec_dim_k, dropout_ff=dec_dropout_ff, nhead=dec_nhead, depth=dec_depth, dropout=dec_dropout,
dim_k=dec_dim_k, dropout_ff=dec_dropout_ff, nhead=dec_nhead, depth=dec_depth, dropout=dec_dropout, ff_intermediate=dec_ff_intermediate,
parameter_sharing=dec_parameter_sharing, k_reduce_by_layer=dec_k_reduce_by_layer, convolution=dec_convolution,
full_attention=dec_full_attention, include_ff=dec_include_ff, w_o_intermediate_dim=dec_w_o_intermediate_dim,
emb_dim=dec_emb_dim, decoder_mode=True, causal=True, activation=activation, checkpoint_level=checkpoint_level)
Expand Down

0 comments on commit e21153a

Please sign in to comment.