diff --git a/nemo_aligner/algorithms/self_rewarding.py b/nemo_aligner/algorithms/self_rewarding.py index 5f50b4c4f..71ff253c7 100644 --- a/nemo_aligner/algorithms/self_rewarding.py +++ b/nemo_aligner/algorithms/self_rewarding.py @@ -86,7 +86,7 @@ def self_rewarding_custom_collate(batch, eos_id): "answers_only": answer_ids, "prompt_lengths": context_lengths, "combined_lengths": combined_lengths, - "dataset_mask": batch[0]['metadata']['mask'] if 'metadata' in batch[0] else None, + "dataset_mask": batch[0]['metadata']['mask'] if 'metadata' in batch[0] else "", } return output @@ -588,6 +588,8 @@ def get_rewards_meta(self, list_of_batches): reward_scores = [[] for _ in range(sum([len(b["prompt_lengths"]) for b in list_of_batches]))] reward_scores = [] reward_responses, prompt_lengths, resp_lengths, is_end = self.get_generations(list_of_batches) + if torch.distributed.get_rank() == 0 and torch.distributed.get_rank() == parallel_state.get_data_parallel_src_rank(): + print(f"*** META_PROMPT_AND_RESP [ {self.tokenizer.ids_to_text(reward_responses[0].tolist())} ]") batch_responses_str = [] for t, s, e in zip(reward_responses, prompt_lengths.tolist(), resp_lengths.tolist()): response = self.tokenizer.ids_to_text(t[s:e].tolist()) @@ -1056,6 +1058,7 @@ def augment_dataloader(self, dataloader): orig_response_str = self.tokenizer.ids_to_text( cand_for_meta[1][cand_for_meta[2] : cand_for_meta[3]].tolist() ) + norm_prompt_str, norm_response_str = self.normalise_prompt(orig_prompt_str, orig_response_str, buffer[0]["dataset_mask"]) meta_batch = [] for a, b in itertools.combinations( [self.tokenizer.ids_to_text(s[0][s[1] : s[2]].tolist()) for s in reward_tokens_raw], 2 @@ -1069,10 +1072,10 @@ def augment_dataloader(self, dataloader): a = re.sub("(?i)(?:Score|Points): ([0-9\.]+)", "", a) b = re.sub("(?i)(?:Score|Points): ([0-9\.]+)", "", b) meta_str_ab = self.meta_judge_template_fn( - prompt=orig_prompt_str, response=orig_response_str, judgement_a=a, judgement_b=b + prompt=norm_prompt_str, response=norm_response_str, judgement_a=a, judgement_b=b ) meta_str_ba = self.meta_judge_template_fn( - prompt=orig_prompt_str, response=orig_response_str, judgement_a=b, judgement_b=a + prompt=norm_prompt_str, response=norm_response_str, judgement_a=b, judgement_b=a ) meta_tokens_ab = self.model.tokenizer.text_to_ids(meta_str_ab) meta_tokens_ba = self.model.tokenizer.text_to_ids(meta_str_ba)