Skip to content

Commit

Permalink
Removed casts to fp32
Browse files Browse the repository at this point in the history
Signed-off-by: Chris Granger <[email protected]>
  • Loading branch information
CGranger-sorenson authored Jul 15, 2024
1 parent a1e7739 commit c6eef97
Showing 1 changed file with 0 additions and 9 deletions.
9 changes: 0 additions & 9 deletions nemo/collections/asr/parts/submodules/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,6 @@ def forward(self, query, key, value, mask, pos_emb=None, cache=None):
"""
key, value, query, cache = self.update_cache(key=key, value=value, query=query, cache=cache)

if torch.is_autocast_enabled():
query, key, value = query.to(torch.float32), key.to(torch.float32), value.to(torch.float32)

q, k, v = self.forward_qkv(query, key, value)
scores = torch.matmul(q, k.transpose(-2, -1)) / self.s_d_k
out = self.forward_attention(v, scores, mask)
Expand Down Expand Up @@ -217,9 +214,6 @@ def forward(self, query, key, value, mask, pos_emb, cache=None):
"""
key, value, query, cache = self.update_cache(key=key, value=value, query=query, cache=cache)

if torch.is_autocast_enabled():
query, key, value = query.to(torch.float32), key.to(torch.float32), value.to(torch.float32)

q, k, v = self.forward_qkv(query, key, value)
q = q.transpose(1, 2) # (batch, time1, head, d_k)

Expand Down Expand Up @@ -325,9 +319,6 @@ def forward(self, query, key, value, pad_mask, pos_emb, cache=None):

key, value, query, cache = self.update_cache(key=key, value=value, query=query, cache=cache)

if torch.is_autocast_enabled():
query, key, value = query.to(torch.float32), key.to(torch.float32), value.to(torch.float32)

q, k, v = self.forward_qkv(query, key, value)
n_batch, _, T, _ = q.size()

Expand Down

0 comments on commit c6eef97

Please sign in to comment.