Skip to content

Commit

Permalink
Merge pull request #141 from marpaia/marpaia/fix-fix
Browse files Browse the repository at this point in the history
Minimally fix typing issues
  • Loading branch information
josephdviviano authored Oct 17, 2023
2 parents 2784c54 + 13737f6 commit d9ca558
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions src/gfn/gflownet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def loss(self, env: Env, training_objects):
"""Computes the loss given the training objects."""


class PFBasedGFlowNet(GFlowNet, Generic[TrainingSampleType]):
class PFBasedGFlowNet(GFlowNet[TrainingSampleType]):
r"""Base class for gflownets that explicitly uses $P_F$.
Attributes:
Expand All @@ -75,9 +75,7 @@ def sample_trajectories(self, env: Env, n_samples: int) -> Trajectories:
return trajectories


class TrajectoryBasedGFlowNet(
PFBasedGFlowNet[Trajectories], Generic[TrainingSampleType]
):
class TrajectoryBasedGFlowNet(PFBasedGFlowNet[Trajectories]):
def get_pfs_and_pbs(
self,
trajectories: Trajectories,
Expand Down

0 comments on commit d9ca558

Please sign in to comment.