diff --git a/megatron/model/language_model.py b/megatron/model/language_model.py index ec2ae1877a..c93f478e8e 100644 --- a/megatron/model/language_model.py +++ b/megatron/model/language_model.py @@ -543,6 +543,10 @@ def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, else: if args.curriculum_learning_legacy or args.data_efficiency_curriculum_learning: rotary_pos_emb = self.rotary_pos_emb(args.curriculum_seqlen) + elif args.ds_sequence_parallel_size > 1: + parallel_seq_len = self.seq_length / args.ds_sequence_parallel_size + ds_sp_offset = mpu.get_sequence_parallel_rank() * parallel_seq_len + rotary_pos_emb = self.rotary_pos_emb(parallel_seq_len, ds_sp_offset) else: rotary_pos_emb = self.rotary_pos_emb(self.seq_length)