Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No more class factories #149

Merged
merged 30 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
0535e9f
removed one order of magnitude precision required
josephdviviano Nov 27, 2023
7753041
merge conflicts
josephdviviano Nov 27, 2023
f7a562e
replaced State method call with Env method call
josephdviviano Nov 29, 2023
742bfb2
replaced State method call with Env method call, and removed is_tenso…
josephdviviano Nov 29, 2023
f545167
replaced State method call with Env method call
josephdviviano Nov 29, 2023
f945b39
switch name of backward/forward step, and replaced State method call …
josephdviviano Nov 29, 2023
2a51704
removed States/Actions class definition, added the appropriate args t…
josephdviviano Nov 29, 2023
cdd425c
moved environment to Gym
josephdviviano Nov 29, 2023
6ae846b
renamed maskless_?_step functions, and made the generic step/backward…
josephdviviano Nov 29, 2023
a09c9a5
removed comment
josephdviviano Nov 29, 2023
93f6a65
States methods moved to Env methods, also, name change for step
josephdviviano Nov 29, 2023
2ab5885
changes to the handling of forward / backward masks. in addition, mak…
josephdviviano Nov 29, 2023
bfd6bbf
method renaming
josephdviviano Nov 29, 2023
d6d30fe
docs update (TOOD: this might need a full rework)
josephdviviano Nov 29, 2023
25b7527
changes to support new API
josephdviviano Nov 29, 2023
f12cbec
tweaks (TODO: fix in follow up PR)
josephdviviano Nov 29, 2023
cdffab1
black / isort
josephdviviano Nov 29, 2023
3b756a5
cleanup
josephdviviano Nov 29, 2023
c98f423
gradient clipping added back in
Nov 30, 2023
6af395f
renaming make_States_class to follow pep
josephdviviano Dec 8, 2023
6c0d8aa
updated documentation
josephdviviano Dec 8, 2023
364e52d
rename methods
josephdviviano Dec 8, 2023
8d8a4c1
rename method
josephdviviano Dec 8, 2023
ebf0db2
deps
josephdviviano Feb 13, 2024
71e6603
requirements
josephdviviano Feb 13, 2024
c393014
deps
josephdviviano Feb 13, 2024
3cb9914
merge
josephdviviano Feb 16, 2024
b85f1eb
merged
josephdviviano Feb 16, 2024
ad80e7e
update
josephdviviano Feb 16, 2024
ae3fa2e
update
josephdviviano Feb 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions docs/requirements_docs.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
pre-commit
black
pytest
sphinx==5.3.0
myst-parser==0.18.1
sphinx_rtd_theme==1.1.1
sphinx-math-dollar==1.2.1
sphinx-autoapi==2.0.0
sphinx>=6.2.1
myst-parser
sphinx_rtd_theme
sphinx-math-dollar
sphinx-autoapi>=3.0.0
renku-sphinx-theme
6 changes: 2 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ myst-parser = { version = "*", optional = true }
pre-commit = { version = "*", optional = true }
pytest = { version = "*", optional = true }
renku-sphinx-theme = { version = "*", optional = true }
sphinx = { version = "*", optional = true }
sphinx = { version = ">=6.2.1", optional = true }
sphinx_rtd_theme = { version = "*", optional = true }
sphinx-autoapi = { version = "*", optional = true }
sphinx-autoapi = { version = ">=3.0.0", optional = true }
sphinx-math-dollar = { version = "*", optional = true }
tox = { version = "*", optional = true }

Expand Down Expand Up @@ -85,8 +85,6 @@ all = [
"Homepage" = "https://gfn.readthedocs.io/en/latest/"
"Bug Tracker" = "https://github.com/saleml/gfn/issues"



[tool.black]
py36 = true
include = '\.pyi?$'
Expand Down
4 changes: 2 additions & 2 deletions src/gfn/containers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def __init__(
self.training_objects = Transitions(env)
self.objects_type = "transitions"
elif objects_type == "states":
self.training_objects = env.States.from_batch_shape((0,))
self.terminating_states = env.States.from_batch_shape((0,))
self.training_objects = env.states_from_batch_shape((0,))
self.terminating_states = env.states_from_batch_shape((0,))
self.objects_type = "states"
else:
raise ValueError(f"Unknown objects_type: {objects_type}")
Expand Down
14 changes: 8 additions & 6 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,11 @@ def __init__(
self.states = (
states
if states is not None
else env.States.from_batch_shape(batch_shape=(0, 0))
else env.states_from_batch_shape((0, 0))
)
assert len(self.states.batch_shape) == 2
self.actions = (
actions
if actions is not None
else env.Actions.make_dummy_actions(batch_shape=(0, 0))
actions if actions is not None else env.actions_from_batch_shape((0, 0))
)
assert len(self.actions.batch_shape) == 2
self.when_is_done = (
Expand Down Expand Up @@ -253,9 +251,13 @@ def extend(self, other: Trajectories) -> None:

# Either set, or append, estimator outputs if they exist in the submitted
# trajectory.
if self.estimator_outputs is None and is_tensor(other.estimator_outputs):
if self.estimator_outputs is None and isinstance(
other.estimator_outputs, Tensor
):
self.estimator_outputs = other.estimator_outputs
elif is_tensor(self.estimator_outputs) and is_tensor(other.estimator_outputs):
elif isinstance(self.estimator_outputs, Tensor) and isinstance(
other.estimator_outputs, Tensor
):
batch_shape = self.actions.batch_shape
n_bs = len(batch_shape)
output_dtype = self.estimator_outputs.dtype
Expand Down
8 changes: 3 additions & 5 deletions src/gfn/containers/transitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,12 @@ def __init__(
self.states = (
states
if states is not None
else env.States.from_batch_shape(batch_shape=(0,))
else env.states_from_batch_shape(batch_shape=(0,))
)
assert len(self.states.batch_shape) == 1

self.actions = (
actions
if actions is not None
else env.Actions.make_dummy_actions(batch_shape=(0,))
actions if actions is not None else env.actions_from_batch_shape((0,))
)
self.is_done = (
is_done
Expand All @@ -85,7 +83,7 @@ def __init__(
self.next_states = (
next_states
if next_states is not None
else env.States.from_batch_shape(batch_shape=(0,))
else env.states_from_batch_shape(batch_shape=(0,))
)
assert (
len(self.next_states.batch_shape) == 1
Expand Down
Loading
Loading