You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to build a simple autoregressive model to generate sequences of str using torchgfn
I get the following error: 'States' object has no attribute 'shape' (line 390 from env.py in the init() of class DiscreteEnv. : assert s0.shape == state_shape)
I'm using torchgfn 1.1.1
The text was updated successfully, but these errors were encountered:
Hi @philmar1, looking at the main branch, it looks like s0 is a Tensor object, and not a States object at initialization of DiscreteEnv. In your code, are you initializing with a tensor or with a State ?
Thank you, indeed, I didn't have a proper look and s0 confused be to initiate it as a State. Correcting that, I have a following error based on the next assert (line 391 in env.py): assert dummy_action.shape == action_shape
Here is my code
SEQ_LEB = 8
VOCAB_SIZE = 5
class MyEnv(DiscreteEnv):
def __init__(self,**kwargs):
super().__init__(**kwargs)
def backward_step(self, states: States, actions: Actions) -> torch.Tensor:
# Issue: need to find the iteration we're at to define backward step
return states
def update_masks(self, states: States) -> None:
return torch.ones_like(states)
def step(self, states: States, actions: Actions) -> torch.Tensor:
return states
env = MyEnv(n_actions=VOCAB_SIZE,
state_shape=(BATCH_SIZE, SEQ_LEN),
action_shape=(BATCH_SIZE),
dummy_action=torch.ones(BATCH_SIZE),
s0=torch.zeros([BATCH_SIZE,SEQ_LEN]))
If I provide action_shape as a tuple, it will be different for a torch.size object
Moreover, I've tried different configurations, including BATCH_SIZE or not. It's unclear to me how the batch_size should be handled when creating an environment
Hi!
I'm trying to build a simple autoregressive model to generate sequences of str using torchgfn
I get the following error: 'States' object has no attribute 'shape' (line 390 from env.py in the init() of class DiscreteEnv. : assert s0.shape == state_shape)
I'm using torchgfn 1.1.1
The text was updated successfully, but these errors were encountered: