diff --git a/onmt/modules/multi_headed_attn.py b/onmt/modules/multi_headed_attn.py index c8ed357f96..b92d3c380d 100644 --- a/onmt/modules/multi_headed_attn.py +++ b/onmt/modules/multi_headed_attn.py @@ -18,14 +18,18 @@ # are both < 2048 tokens. -def rotaryembeddings(dim: int, maxseqlen=2048, base=10000): +def rotaryembeddings(dim: int, maxseqlen=2048, base=10000, device=None): inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) tmax = torch.arange(maxseqlen, device=inv_freq.device) rope = torch.outer(tmax, inv_freq).float() # rope is now matrix [maxseqlen, dim/2] rope = torch.polar(torch.ones_like(rope), rope) rope = torch.cat((rope, rope), dim=1) - return rope + if device is not None: + rope = rope.to(device) + cos = rope[:, : rope.size(1) // 2].real.contiguous().half() + sin = rope[:, : rope.size(1) // 2].imag.contiguous().half() + return rope, cos, sin def rotate_half(x): @@ -369,12 +373,8 @@ def __init__( self.rotary_dim = self.dim_per_head else: self.rotary_dim = rotary_dim - self.rope = rotaryembeddings(self.rotary_dim, base=rotary_theta) - self.cos = ( - self.rope[:, : self.rope.size(1) // 2].real.contiguous().half() - ) - self.sin = ( - self.rope[:, : self.rope.size(1) // 2].imag.contiguous().half() + self.rope, self.cos, self.sin = rotaryembeddings( + self.rotary_dim, base=rotary_theta ) self.rotary_interleave = rotary_interleave self.rotary_theta = rotary_theta @@ -465,11 +465,13 @@ def forward( ): if self.max_relative_positions == -1: # Rotary Embeddings if seqlen > self.rope.size(0): - self.rope = rotaryembeddings( + + self.rope, _, _ = rotaryembeddings( self.rotary_dim, maxseqlen=(seqlen + 2048), base=self.rotary_theta, - ).to(self.rope.device) + device=self.rope.device, + ) rope = self.rope[start_pos : start_pos + seqlen] query, key = apply_rotary_emb( query, key, rope, interleave=self.rotary_interleave @@ -486,23 +488,6 @@ def forward( self.layer_cache[1]["values"] = value else: - if self.max_relative_positions == -1: # Rotary Embeddings - if seqlen > self.rope.size(0): - self.rope = rotaryembeddings( - self.rotary_dim, - maxseqlen=(seqlen + 2048), - base=self.rotary_theta, - ).to(self.rope.device) - self.cos = ( - self.rope[:, : self.rope.size(1) // 2] - .real.contiguous() - .half() - ) - self.sin = ( - self.rope[:, : self.rope.size(1) // 2] - .imag.contiguous() - .half() - ) if start_pos >= self.layer_cache[1]["keys"].size(2): self.layer_cache[1]["keys"] = torch.cat( [ @@ -528,6 +513,20 @@ def forward( ], dim=-2, ) + if ( + self.max_relative_positions == -1 + and start_pos + 32 >= self.rope.size(0) + ): + # Resize rotary embeddings. + # We take a margin of 32 tokens as the kv_cache + # is incremented by 32 tokens every 32 tokens. + self.rope, self.cos, self.sin = rotaryembeddings( + self.rotary_dim, + maxseqlen=(start_pos + 2048), + base=self.rotary_theta, + device=self.rope.device, + ) + if sliding_window > 0 and key.size(2) > sliding_window: self.layer_cache[1]["keys"] = self.layer_cache[1]["keys"][ :, :, 1:, : @@ -593,12 +592,14 @@ def forward( start_pos = 0 seqlen = query.size(2) if seqlen > self.rope.size(0): - self.rope = rotaryembeddings( + # Resize rotary embeddings. + self.rope, self.cos, self.sin = rotaryembeddings( self.rotary_dim, maxseqlen=(seqlen + 2048), base=self.rotary_theta, - ).to(self.rope.device) - rope = self.rope[start_pos : start_pos + seqlen].to(query.device) + device=query.device, + ) + rope = self.rope[start_pos : start_pos + seqlen] query, key = apply_rotary_emb( query, key, rope, interleave=self.rotary_interleave )