Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

'States' object has no attribute 'shape' #228

Open
philmar1 opened this issue Jan 13, 2025 · 2 comments
Open

'States' object has no attribute 'shape' #228

philmar1 opened this issue Jan 13, 2025 · 2 comments

Comments

@philmar1
Copy link

Hi!

I'm trying to build a simple autoregressive model to generate sequences of str using torchgfn

I get the following error: 'States' object has no attribute 'shape' (line 390 from env.py in the init() of class DiscreteEnv. : assert s0.shape == state_shape)

I'm using torchgfn 1.1.1

@saleml
Copy link
Collaborator

saleml commented Jan 14, 2025

Hi @philmar1, looking at the main branch, it looks like s0 is a Tensor object, and not a States object at initialization of DiscreteEnv. In your code, are you initializing with a tensor or with a State ?

@philmar1
Copy link
Author

Thank you, indeed, I didn't have a proper look and s0 confused be to initiate it as a State. Correcting that, I have a following error based on the next assert (line 391 in env.py): assert dummy_action.shape == action_shape

Here is my code

SEQ_LEB = 8 
VOCAB_SIZE = 5

class MyEnv(DiscreteEnv):
    def __init__(self,**kwargs):
        super().__init__(**kwargs)

    def backward_step(self, states: States, actions: Actions) -> torch.Tensor:
        # Issue: need to find the iteration we're at to define backward step
        return states
    
    def update_masks(self, states: States) -> None:
        return torch.ones_like(states)
    
    def step(self, states: States, actions: Actions) -> torch.Tensor:
        return states
    
env = MyEnv(n_actions=VOCAB_SIZE, 
            state_shape=(BATCH_SIZE, SEQ_LEN),
            action_shape=(BATCH_SIZE),
            dummy_action=torch.ones(BATCH_SIZE),
            s0=torch.zeros([BATCH_SIZE,SEQ_LEN]))

If I provide action_shape as a tuple, it will be different for a torch.size object

Moreover, I've tried different configurations, including BATCH_SIZE or not. It's unclear to me how the batch_size should be handled when creating an environment

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants