From 47de7a2fd5ee837b5593d1cbe0c8e1f2e858d315 Mon Sep 17 00:00:00 2001 From: neverix Date: Wed, 3 Nov 2021 19:16:36 +0300 Subject: [PATCH] optimize generation caching (#12) Over 10x speedup, adds MLP caching and optimizes attention caching. Uses changes from https://t.co/BTwo6NKq9H. --- rudalle/dalle/transformer.py | 57 ++++++++++++++++++++++++++++++------ 1 file changed, 48 insertions(+), 9 deletions(-) diff --git a/rudalle/dalle/transformer.py b/rudalle/dalle/transformer.py index 5a62aa0..e2d969b 100755 --- a/rudalle/dalle/transformer.py +++ b/rudalle/dalle/transformer.py @@ -146,7 +146,7 @@ def forward(self, hidden_states, ltor_mask, has_cache, use_cache): layernorm_output = self.input_layernorm(hidden_states) # Self attention. - attention_output, has_cache = self.attention( + attention_output, att_has_cache = self.attention( layernorm_output, ltor_mask, has_cache=has_cache, use_cache=use_cache) if self.cogview_sandwich_layernorm: @@ -159,7 +159,8 @@ def forward(self, hidden_states, ltor_mask, has_cache, use_cache): layernorm_output = self.post_attention_layernorm(layernorm_input) # MLP. - mlp_output = self.mlp(layernorm_output) + mlp_output, mlp_has_cache = self.mlp( + layernorm_output, has_cache=has_cache, use_cache=use_cache) if self.cogview_sandwich_layernorm: mlp_output = self.before_second_addition_layernorm(mlp_output) @@ -167,7 +168,7 @@ def forward(self, hidden_states, ltor_mask, has_cache, use_cache): # Second residual connection. output = layernorm_input + mlp_output - return output, has_cache + return output, att_has_cache and mlp_has_cache class DalleSelfAttention(torch.nn.Module): @@ -212,6 +213,11 @@ def __init__(self, hidden_size, num_attention_heads, self.dense = torch.nn.Linear(hidden_size, hidden_size) self.output_dropout = torch.nn.Dropout(output_dropout_prob) + # Cache + self.past_key = None + self.past_value = None + self.past_output = None + def _transpose_for_scores(self, tensor): """ Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with size [b, np, s, hn]. """ new_tensor_shape = tensor.size()[:-1] + (self.num_attention_heads, self.hidden_size_per_attention_head) @@ -227,6 +233,7 @@ def _calculate_attention_scores(self, query_layer, key_layer, ltor_mask): ) else: attention_scores = torch.matmul(query_layer, key_t) / math.sqrt(self.hidden_size_per_attention_head) + ltor_mask = ltor_mask[:, :, -attention_scores.shape[-2]:] attention_scores = torch.mul(attention_scores, ltor_mask) - 10000.0 * (1.0 - ltor_mask) if self.cogview_pb_relax: # normalize attention scores. Should not affect resulting softmax value @@ -258,10 +265,10 @@ def forward(self, hidden_states, ltor_mask, has_cache=False, use_cache=False,): key_layer = self._transpose_for_scores(mixed_key_layer) value_layer = self._transpose_for_scores(mixed_value_layer) + # Can be simplified, but I didn't for readability's sake if use_cache and has_cache: - value_layer = torch.cat((self.past_value, value_layer), dim=-2) - query_layer = torch.cat((self.past_query, query_layer), dim=-2) key_layer = torch.cat((self.past_key, key_layer), dim=-2) + value_layer = torch.cat((self.past_value, value_layer), dim=-2) attention_scores = self._calculate_attention_scores( query_layer=query_layer, key_layer=key_layer, ltor_mask=ltor_mask ) @@ -271,13 +278,17 @@ def forward(self, hidden_states, ltor_mask, has_cache=False, use_cache=False,): ) if use_cache: - self.past_query = query_layer self.past_key = key_layer self.past_value = value_layer - has_cache = True else: + self.past_key = None + self.past_value = None + self.past_output = None has_cache = False + if use_cache and has_cache: + attention_scores = attention_scores[..., -1:, :] + # Attention probabilities. [b, np, s, s] attention_probs = torch.nn.Softmax(dim=-1)(attention_scores) @@ -298,6 +309,16 @@ def forward(self, hidden_states, ltor_mask, has_cache=False, use_cache=False,): # Output. [b, s, h] output = self.dense(context_layer) + + if use_cache: + # Can be simplified, but I didn't for readability's sake + if has_cache: + output = torch.cat((self.past_output, output), dim=-2) + self.past_output = output + else: + self.past_output = output + has_cache = True + output = self.output_dropout(output) return output, has_cache @@ -321,12 +342,30 @@ def __init__(self, hidden_size, output_dropout_prob): # Project back to h. self.dense_4h_to_h = torch.nn.Linear(4*hidden_size, hidden_size) self.dropout = torch.nn.Dropout(output_dropout_prob) + # MLP cache + self.past_x = None + + def forward(self, hidden_states, has_cache=False, use_cache=False): + if has_cache and use_cache: + hidden_states = hidden_states[:, -1:] - def forward(self, hidden_states): # [b, s, 4hp] x = self.dense_h_to_4h(hidden_states) x = gelu(x) # [b, s, h] x = self.dense_4h_to_h(x) + if use_cache: + # Can be simplified, but I didn't for readability's sake + if has_cache: + x = torch.cat((self.past_x, x), dim=-2) + self.past_x = x + else: + self.past_x = x + + has_cache = True + else: + self.past_x = None + has_cache = False output = self.dropout(x) - return output + + return output, has_cache