Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SubTB simple running example ? #185

Closed
TomAvrech opened this issue Aug 29, 2024 · 4 comments
Closed

SubTB simple running example ? #185

TomAvrech opened this issue Aug 29, 2024 · 4 comments
Assignees
Labels
documentation Improvements or additions to documentation invalid This doesn't seem right

Comments

@TomAvrech
Copy link

The given main TorchGFN example runs fine, including the line:

# 4 - We define the GFlowNet.
gfn = TBGFlowNet(init_logZ=0., pf=pf_estimator, pb=pb_estimator)  # We initialize logZ to 0

But when I replace only it with:

# 4 - We define the GFlowNet.
logF = DiscretePolicyEstimator(module=module_PF, n_actions=env.n_actions, preprocessor=env.preprocessor)
gfn = SubTBGFlowNet(pf=pf_estimator, pb=pb_estimator, logF=logF, lamda=0.9)

The error is:

error

Is there a working example for SubTB, or how can I fix the above ?

Thanks :)

@josephdviviano
Copy link
Collaborator

Thank you for bringing this to my attention. It looks like a bug. I'll try to prioritize it asap.

@josephdviviano josephdviviano self-assigned this Aug 29, 2024
@josephdviviano josephdviviano added bug Something isn't working high priority Let's do these first! labels Aug 29, 2024
@TomAvrech
Copy link
Author

TomAvrech commented Sep 1, 2024

Thank you for bringing this to my attention. It looks like a bug. I'll try to prioritize it asap.

I have also sent you mail to [email protected] :)

@josephdviviano
Copy link
Collaborator

Thanks - I've looked at it and in fact there is no bug. However this prompted me to add some QOL improvements to the code, so I appreciate you bringing it to my attention.

logF = DiscretePolicyEstimator(... is the source of the issue. It should be logF = ScalarEstimator(..., which outputs a single value (scalar). This fixes the dimension issues.

Here's a full working example:

import torch
from tqdm import tqdm

from gfn.gflownet import TBGFlowNet, SubTBGFlowNet
from gfn.gym import HyperGrid  # We use the hyper grid environment
from gfn.modules import DiscretePolicyEstimator, ScalarEstimator
from gfn.samplers import Sampler
from gfn.utils import NeuralNet  # NeuralNet is a simple multi-layer perceptron (MLP)


USE_TB = False

if __name__ == "__main__":

    # 1 - We define the environment.
    env = HyperGrid(ndim=4, height=8, R0=0.01)  # Grid of size 8x8x8x8

    # 2 - We define the needed modules (neural networks).
    # The environment has a preprocessor attribute, which is used to preprocess the state before feeding it to the policy estimator
    module_PF = NeuralNet(
        input_dim=env.preprocessor.output_dim,
        output_dim=env.n_actions
    )  # Neural network for the forward policy, with as many outputs as there are actions

    module_PB = NeuralNet(
        input_dim=env.preprocessor.output_dim,
        output_dim=env.n_actions - 1,
        torso=module_PF.torso  # We share all the parameters of P_F and P_B, except for the last layer
    )

    # 3 - We define the estimators.
    pf_estimator = DiscretePolicyEstimator(module_PF, env.n_actions, is_backward=False, preprocessor=env.preprocessor)
    pb_estimator = DiscretePolicyEstimator(module_PB, env.n_actions, is_backward=True, preprocessor=env.preprocessor)

    # 4 - We define the GFlowNet.
    if USE_TB:
        gfn = TBGFlowNet(init_logZ=0., pf=pf_estimator, pb=pb_estimator)  # We initialize logZ to 0
    else:
        # Define the logF estimator.
        module_logF = NeuralNet(
            input_dim=env.preprocessor.output_dim,
            output_dim=1,
        )
        logF = ScalarEstimator(module=module_logF, preprocessor=env.preprocessor)
        gfn = SubTBGFlowNet(pf=pf_estimator, pb=pb_estimator, logF=logF, lamda=0.9)

    # 5 - We define the sampler and the optimizer.
    sampler = Sampler(estimator=pf_estimator)  # We use an on-policy sampler, based on the forward policy

    # Different policy parameters can have their own LR.
    if USE_TB:
        # Log Z gets dedicated learning rate (typically higher).
        optimizer = torch.optim.Adam(gfn.pf_pb_parameters(), lr=1e-3)
        optimizer.add_param_group({"params": gfn.logz_parameters(), "lr": 1e-1})
    else:
        # Log F gets dedicated learning rate (typically higher).
        optimizer = torch.optim.Adam(gfn.pf_pb_parameters(), lr=1e-3)
        optimizer.add_param_group({"params": gfn.logF_parameters(), "lr": 1e-2})

    # 6 - We train the GFlowNet for 1000 iterations, with 16 trajectories per iteration
    for i in (pbar := tqdm(range(1000))):
        trajectories = sampler.sample_trajectories(env=env, n_trajectories=16)
        optimizer.zero_grad()
        loss = gfn.loss(env, trajectories)
        loss.backward()
        optimizer.step()
        if i % 25 == 0:
            pbar.set_postfix({"loss": loss.item()})

@josephdviviano josephdviviano added documentation Improvements or additions to documentation invalid This doesn't seem right and removed bug Something isn't working high priority Let's do these first! labels Sep 20, 2024
@josephdviviano
Copy link
Collaborator

I pushed a simple fix here #186. Please reopen the issue if you continue to have problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation invalid This doesn't seem right
Projects
None yet
Development

No branches or pull requests

2 participants