Skip to content

Commit

Permalink
Revert "Simplify multi headed attention (#1153)" (#1155)
Browse files Browse the repository at this point in the history
This reverts commit 3d9a418.
  • Loading branch information
francoishernandez authored and vince62s committed Dec 31, 2018
1 parent 3d9a418 commit 27c6fd5
Showing 1 changed file with 20 additions and 8 deletions.
28 changes: 20 additions & 8 deletions onmt/modules/multi_headed_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,14 @@ def unshape(x):
return x.transpose(1, 2).contiguous() \
.view(batch_size, -1, head_count * dim_per_head)

query, key, value = self.linear_query(query),\
self.linear_keys(key),\
self.linear_values(value)
key = shape(key)
value = shape(value)
query = shape(query)

# 1) Project key, value, and query.
if layer_cache is not None:
if type == "self":
query, key, value = self.linear_query(query),\
self.linear_keys(query),\
self.linear_values(query)
key = shape(key)
value = shape(value)
device = key.device
if layer_cache["self_keys"] is not None:
key = torch.cat(
Expand All @@ -142,11 +140,25 @@ def unshape(x):
layer_cache["self_keys"] = key
layer_cache["self_values"] = value
elif type == "context":
if layer_cache["memory_keys"] is not None:
query = self.linear_query(query)
if layer_cache["memory_keys"] is None:
key, value = self.linear_keys(key),\
self.linear_values(value)
key = shape(key)
value = shape(value)
else:
key, value = layer_cache["memory_keys"],\
layer_cache["memory_values"]
layer_cache["memory_keys"] = key
layer_cache["memory_values"] = value
else:
key = self.linear_keys(key)
value = self.linear_values(value)
query = self.linear_query(query)
key = shape(key)
value = shape(value)

query = shape(query)

key_len = key.size(2)
query_len = query.size(2)
Expand Down

0 comments on commit 27c6fd5

Please sign in to comment.