Skip to content

Commit

Permalink
Merge pull request #7 from tatp22/linear_change
Browse files Browse the repository at this point in the history
Changed things as mentioned in issue 6
  • Loading branch information
tatp22 authored Jun 28, 2020
2 parents bfb9f27 + 6b103b9 commit 7c5c3a0
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 38 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,13 @@ y = model(x, x, x)
print(y) # (1, 512, 64)
```

An easy way to get the `E` and `F` matrices can be done by calling the `get_EF` function. As an example, for an `n` of `1000` and a `k` of `100`:
An easy way to get the `E` and `F` matrices can be done by calling the `get_linear` function. As an example, for an `n` of `1000` and a `k` of `100`:

```python
from linfromer_pytorch import get_EF
from linfromer_pytorch import get_linear
import torch

E = get_EF(1000, 100)
E = get_linear(1000, 100)
```

## Checkpoint levels
Expand Down
68 changes: 33 additions & 35 deletions linformer_pytorch/linformer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ def get_act(activation):
return F.relu
return None

def get_EF(input_size, dim):
def get_linear(input_size, dim, bias=True):
"""
Retuns the E or F matrix, initialized via xavier initialization.
This is the recommended way to do it according to the authors of the paper.
"""
EF = nn.Linear(input_size, dim)
torch.nn.init.xavier_normal_(EF.weight)
return EF
lin = nn.Linear(input_size, dim, bias)
torch.nn.init.xavier_normal_(lin.weight)
return lin

class PositionalEmbedding(nn.Module):
"""
Expand All @@ -43,8 +43,8 @@ class FeedForward(nn.Module):
"""
def __init__(self, channels, ff_dim, dropout=0.0, activation="gelu"):
super(FeedForward, self).__init__()
self.w_1 = nn.Linear(channels, ff_dim)
self.w_2 = nn.Linear(ff_dim, channels)
self.w_1 = get_linear(channels, ff_dim)
self.w_2 = get_linear(ff_dim, channels)
self.activation = get_act(activation)
self.dropout = nn.Dropout(dropout)

Expand All @@ -62,9 +62,6 @@ class LinearAttentionHead(nn.Module):
"""
def __init__(self, dim, dropout, E_proj, F_proj, full_attention=False):
super(LinearAttentionHead, self).__init__()
self.w_k = nn.Linear(dim, dim)
self.w_q = nn.Linear(dim, dim)
self.w_v = nn.Linear(dim, dim)
self.E = E_proj
self.F = F_proj
self.dim = dim
Expand All @@ -77,14 +74,12 @@ def forward(self, Q, K, V, **kwargs):
Assume Q, K, V have same dtype
E, F are `nn.Linear` modules
"""
KW = self.w_k(K)
KW = torch.transpose(KW, 1, 2)
K = torch.transpose(K, 1, 2)
if not self.full_attention:
KW = self.E(KW)
QW = self.w_q(Q)
QW = torch.matmul(QW, KW)
K = self.E(K)
Q = torch.matmul(Q, K)

P_bar = QW/torch.sqrt(torch.tensor(self.dim).type(Q.type()))
P_bar = Q/torch.sqrt(torch.tensor(self.dim).type(Q.type()))
P_bar = P_bar.softmax(dim=-1)

# Only save this when visualizing
Expand All @@ -93,13 +88,11 @@ def forward(self, Q, K, V, **kwargs):

P_bar = self.dropout(P_bar)

VW = self.w_v(V)

if not self.full_attention:
VW = torch.transpose(VW, 1, 2)
VW = self.F(VW)
VW = torch.transpose(VW, 1, 2)
out_tensor = torch.matmul(P_bar, VW)
V = torch.transpose(V, 1, 2)
V = self.F(V)
V = torch.transpose(V, 1, 2)
out_tensor = torch.matmul(P_bar, V)

return out_tensor

Expand All @@ -115,29 +108,34 @@ def __init__(self, input_size, dim, channels, dim_k, nhead, dropout, activation,
self.dim_k = dim_k
self.checkpoint_level = checkpoint_level
if parameter_sharing != "layerwise":
E_proj = get_EF(input_size, dim_k)
F_proj = get_EF(input_size, dim_k) if parameter_sharing == "none" or parameter_sharing == "headwise" else E_proj
E_proj = get_linear(input_size, dim_k)
F_proj = get_linear(input_size, dim_k) if parameter_sharing == "none" or parameter_sharing == "headwise" else E_proj

self.get_linear = lambda: get_linear(channels, dim, bias=False)
self.to_q = nn.ModuleList()
self.to_k = nn.ModuleList()
self.to_v = nn.ModuleList()

for head in range(nhead):
if parameter_sharing == "none":
E_proj = get_EF(input_size, dim_k)
F_proj = get_EF(input_size, dim_k)
E_proj = get_linear(input_size, dim_k)
F_proj = get_linear(input_size, dim_k)
attn = LinearAttentionHead(dim, dropout, E_proj, F_proj, full_attention)
self.heads.append(attn)
self.w_o = nn.Linear(dim*nhead, channels)
self.to_q = nn.Linear(channels, dim, bias=False)
self.to_k = nn.Linear(channels, dim, bias=False)
self.to_v = nn.Linear(channels, dim, bias=False)
self.to_q.append(self.get_linear())
self.to_k.append(self.get_linear())
self.to_v.append(self.get_linear())
self.w_o = get_linear(dim*nhead, channels)
self.activation = get_act(activation)

def forward(self, tensor, **kwargs):
head_outputs = []
for head in self.heads:
Q = self.to_q(tensor)
K = self.to_k(tensor)
V = self.to_v(tensor)
for index, head in enumerate(self.heads):
Q = self.to_q[index](tensor)
K = self.to_k[index](tensor)
V = self.to_v[index](tensor)
if self.checkpoint_level == "C2":
head_outputs.append(checkpoint(head,Q,K,V,**kwargs))
head_outputs.append(checkpoint(head,Q,K,V))
else:
head_outputs.append(head(Q,K,V,**kwargs))
out = torch.cat(head_outputs, dim=2)
Expand Down Expand Up @@ -168,7 +166,7 @@ def __init__(self, input_size=8192, channels=128, dim_k=64, dim_ff=256, dim_d=No

head_dim = channels // nhead if dim_d is None else dim_d

self.E = get_EF(input_size, dim_k)
self.E = get_linear(input_size, dim_k)
self.F = self.E

get_attn = lambda curr_dim_k: MHAttention(input_size, head_dim, channels, curr_dim_k, nhead, dropout, activation, checkpoint_level, parameter_sharing, self.E, self.F, full_attention)
Expand Down

0 comments on commit 7c5c3a0

Please sign in to comment.