Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to reproduce numbers for RhoMath 1.1B K=3 on MATH dataset #11

Open
RahulSChand opened this issue Jan 12, 2025 · 1 comment
Open

Comments

@RahulSChand
Copy link

I tried running the training script for rhomath1.1B (configs/polIter_rho1bSft2_vineppo_MATH.jsonnet) and changed two things, first num_mc_rollouts = 3 and (import 'episode_generators/9rolls.jsonnet') to (import 'episode_generators/3rolls.jsonnet') to make it K=3 rather than K=9. I am running on a 4xL40S gpus and can't reproduce the 21.3 pass@1 MATH numbers in the paper. I get around 19.9 pass@1 after 500 iterations & then the pass@1 keeps fluctuating between 19.5-20.0 never going close to 21.3.

Can u help me tell what am I missing here? Are the config changes I made not correct to get K=3 numbers? or are the default hyper-parameters specific to a 8x GPU setup? Thanks

@MiladInk
Copy link

Hi Rahul,

What you described for running the K=3 ablation looks correct to me. Could you try running it again? It’s possible that what we’re observing is due to randomness and you were unlucky in that seed. If rerunning that is not possible, I have written my thoughts out loud:

As you mentioned, your test accuracy isn’t improving. Could you ensure that your training accuracy is increasing? I am worried something went wrong in the training somehow. It’s logged under episodes_metric/scores/mean. In the run from the paper, we observed this curve:
image

Additionally, remember to first select the best checkpoint based on validation loss and then report the corresponding test accuracy. That said, I understand that you’re not achieving a test accuracy higher than 19.9, so regardless of how the test accuracy is chosen, it won’t reach 21.3.

Lastly, how many more iterations did you run after the initial 500? Could you share the training curve here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants