-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dont recompute masks #163
Dont recompute masks #163
Conversation
…proposed new method for stacking a list of states into a trajectory, but as the assert statements show, the tensor is correct, but the forward_masks are not
Some debugging: First, I changed the batch_size to 3 in the script. Then, with a breakpoint at the assertion error, I see that the forward masks of all the steps within the batch of 3 trajectories are the same. So here is the interesting part. If I add a breakpoint at lines 195 and 199 of The first mask is ok. Initially, we have 3 copies of s0, so the masks should all be True. Once we call So the problem happens here in And looking at the it seems to me that the masks are changed in place and that the problem is due to #149. I don't remember if tests were passing in that PR (I haven't reviewed that PR). Let me know what you think |
src/gfn/states.py
Outdated
stacked_states.forward_masks = torch.stack([s.forward_masks for s in states], dim=0) | ||
stacked_states.backward_masks = torch.stack([s.backward_masks for s in states], dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should only be implemented for DiscreteStates, not all States
Good catch! These tests pass fine - I think the inplace update of masks is desirable behaviour except in this case where we want to accumulate a trajectory of states. To reduce user error, the base When I was messing around, I tried |
… in mask updating behaviour to prevent accumulation of errors.
OK @saleml figured it out - check line 413 here 77e7e1b Before setting the This, plus using |
checks whether user-defined env.step method returns the expected type
This isn't working @saleml -- please see the issue in
samplers.py
.You can reproduce the error with
tutorials/examples/train_hypergrid_simple.py