-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
No more class factories #149
Conversation
…r() function, for linter compatibility
…with Env method call
…o be passed to subclasses.
…_step functions private. this maybe isn't the best solution as they are accessed externally by other elements of the library. mask updating is now handled by the DiscreteEnv. A generic make_States_class and make_Actions_class method is added to both Env and DiscreteEnv.
…e_random_state_tensor is now a function passed to the States class as inheritance can no longer be relied on to overwrite the default method.
FYI @josephdviviano I changed the "base" branch to rethinking_sampling instead of master. This allows us to view this PRs changes in isolation. When you merge #147, this PR will automatically update to be based off of master again! Alternatively, you can merge this PR into #147 and then merge #147 into master and it will have the same effect. I would suggest merging #147 first though and then iterating on / merging this PR in isolation 🙌 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a fantastic PR and that the new API for defining environments represents a huge improvement to this codebase 🙌 I left a few comments about the naming of the functions make_xxx_class
which I think would improve the API as well but in general, this is awesome. Thank you for making this change!
src/gfn/containers/trajectories.py
Outdated
@@ -16,11 +16,6 @@ | |||
from gfn.containers.transitions import Transitions | |||
|
|||
|
|||
def is_tensor(t) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice change 👌
src/gfn/env.py
Outdated
raise NotImplementedError | ||
|
||
# Optionally implemented by the user when advanced functionality is required. | ||
def make_States_class(self) -> type[States]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know that the class is States
but I still would advocate that this method should be make_states_class
to be more inline with PEP 8.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes makes sense to me. That naming bugged me as well. thanks.
src/gfn/env.py
Outdated
make_random_states_tensor = env.make_random_states_tensor | ||
|
||
return DefaultEnvState | ||
|
||
def make_Actions_class(self) -> type[Actions]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment here - I know that the class is Actions
but I still would advocate that this method should be make_actions_class
to be more inline with PEP 8.
src/gfn/env.py
Outdated
n_actions = env.n_actions | ||
device = env.device | ||
|
||
return DiscreteEnvStates | ||
|
||
def make_Actions_class(self) -> type[Actions]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment here - I know that the class is Actions
but I still would advocate that this method should be make_actions_class
to be more inline with PEP 8.
tutorials/examples/train_line.py
Outdated
# if p.ndim > 0 and p.grad is not None: # We do not clip logZ grad. | ||
# p.grad.data.clamp_( | ||
# -gradient_clip_value, gradient_clip_value | ||
# ).nan_to_num_(0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would suggest deleting this code or adding a comment explaining why it's not included in the example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oups -- this is a mistake - it shouldn't be commented out - good catch :)
I implemented the renaming and also realized I needed to update the documentation which is now fixed. |
To be merged after #147
make_States_class
andmake_Actions_class
no longer need to be defined by the user - all relevant logic is submitted directly to theEnv
orDiscreteEnv
subclass. (Of course, the user could overwrite the defaultDefaultEnvState
andDefaultEnvAction
classes returned bymake_States_class
andmake_Actions_class
IIF they require boutique functionality, but this is not expected to be a normal workflow).maskless_
naming fromstep
andbackward_step
.As a result of this, multiple methods are offloaded from the
States
class into theEnv
, andmake_random_states_tensor
must be passed from theEnv
to theStates
class, which accounts for a large number of these diffs.As an example, see the below Env definition for the line environment, which is complete: