-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Graphs as States #210
base: master
Are you sure you want to change the base?
Add Graphs as States #210
Conversation
…o graph-states
…o graph-states
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the big PR. Here are few questions, comments, suggestions.
I am unable to run the main script.
I don't know why CI isn't triggered in this PR.
I'd be happy to merge once the main issues are resolved, and the smaller ones defined as github issues
@@ -255,21 +257,22 @@ def _step( | |||
) | |||
|
|||
new_sink_states_idx = actions.is_exit | |||
new_states.tensor[new_sink_states_idx] = self.sf | |||
sf_tensor = self.States.make_sink_states_tensor((new_sink_states_idx.sum(),)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious about the reason for this change? Is it specific to GraphStates?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason is because of how graphs are represented in the tensor, i.e:
tensor = TensorDict({
'node_features': shape (N, F1)
'edge_features': shape (M, F2)
'edge_index': shape (2, M)
})
Notice that tensor[some_index]
doesn't make sense, and doesn't work. There is a more complex behavior defined in GraphStates.__setitem__
, to do it correctly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a comment would be worth adding here.
src/gfn/modules.py
Outdated
dists["action_type"] = CategoricalActionType(probs=action_type_probs) | ||
|
||
edge_index_logits = module_output["edge_index"] | ||
if states.tensor["node_feature"].shape[0] > 1 and torch.any(edge_index_logits != -float("inf")): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check is per batch, not per graph. Shouldn't you use batch_ptr to check each graph separately
BTW this might also need to handle masks for valid edges
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At the moment, all the nets return the same action type, so this works, however, you are right, this is not general,.
I restrict to this case because when the number of nodes varies across the batch, also the outputs of the nets vary (e.g. for one batch you will have probs of (N1, N1) for edge_indexes and the other (N2, N2) breaking batching.
I am not sure how to overcome this problem while being reasonably efficient.
BTW this might also need to handle masks for valid edges
Perfectly right, this is the reason for CI failing; see the other comment
) * edge_index_probs + epsilon * uniform_dist_probs | ||
dists["edge_index"] = CategoricalIndexes(probs=edge_index_probs) | ||
|
||
dists["features"] = Normal(module_output["features"], temperature) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is no need to mask invalid features here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean by invalid features?
In case the action is an EXIT action, these features are ignored... but otherwise, I don't see any invalid feature(?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@saleml we may need to mask the invalid edge_index
indeed.
At the moment:
- GraphBuilding env requires the edge not to exist (no multiple edges from & to the same node). This is because it is then easier for backward_step (removing the edge).
- There is no real masking for it, so the Estimator can sample an invalid action
This is the reason tests in CI are failing. Two solutions I can see:
1) Add mask in the estimator (the code you commented here)
This is easy to do, but is it general enough that multiple edges (from & to the same node) are not allowed?
Also, I already wanted to raise that we have forward_mask
and backward_mask
in States, while also having isActionValid
method. This is a repetition of code (and now we would add another repetition here).
2) Improving the forward & backward masks in State
The current implementation only masks for the type of action (e.g. if there are no nodes, you cannot add an edge).
This is a bit more complicated to code but avoids the repetition of the code of the above solution.
3) Allow GraphBuilding to have multiple same edges
We need to use the edge features to check which edge to remove in the backward step.
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest adding basic edge masking in the estimator for now to fix the CI, but we should create a follow-up issue to properly consolidate all action validation logic into the States class. This will eliminate code duplication and provide a single source of truth for action validity. The States class should be responsible for providing comprehensive masks that cover both action types and edge validity.
This approach:
- Gets CI passing quickly
- Acknowledges the technical debt
- Sets a clear path forward
- Keeps the current PR scope manageable
SO
Short-term fix (for this PR):
- Go with solution 1 (add masking in the estimator) as a temporary fix to get CI passing
- Add a TODO/issue comment indicating this is temporary
- Document the current limitation in the docstring
Long-term solution (next PR): - Consolidate all action validation logic into the States class (solution 2)
- The States class should provide comprehensive masks that include:
- Action type validity (current implementation)
- Edge validity (currently missing)
- Any other environment-specific constraints
LMK what you think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good!
Actually, if I make the States masks comprehensive, then I should able to just use them in the estimator for masking (so solution 2).
I will do it soon!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I disagree that we should fix it here (i.e., in the Estimator
), or if we do so, we should implement masking properly in the States
class before the release of V2 because this is a clear design pattern violation.
There is no problem adding multiple edges to-from the same node in principle (in say, a multi-attribute graph) but that does make things way more complex, and I think we can safely avoid that complexity here, as single edges to-from the same nodes will cover a lot of AI for Science applications in the near term.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I am trying to implement the complete forward/backward mask in states.
I have the following problem:
In general, the number of nodes for each graph can vary. The mask for edge_index is (B, N, N), where N is the number of nodes. However, the number of nodes varies across batch...
Two possible solutions:
- Use (B, N, N), with N the total number of nodes (not in one graph, but across the graphes in the batch). This is more memory consuming, but general.
- Enforce the number of nodes to be the same across the batch, which means enforcing the action to be the same across the batch (you cannot add a node to one graph while adding an edge to another).
|
||
return torch.cat([action_type, edge_actions], dim=-1) | ||
|
||
class RingGraphBuilding(GraphBuilding): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I cannot run this file. I get this error
ValueError: batch size was not specified when creating the TensorDict instance and it could not be retrieved from source.
Can you run it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I can run it...
Which version of TensorDict are you using? I have 0.6.2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ERROR: Could not find a version that satisfies the requirement tensordict==0.6.2 (from versions: 0.0.1a0, 0.0.1b0, 0.0.1rc0, 0.0.2a0, 0.0.2b0, 0.0.3, 0.1.0, 0.1.1, 0.1.2, 0.2.0, 0.2.1, 0.3.0, 0.3.1, 0.3.2)
ERROR: No matching distribution found for tensordict==0.6.2
Althout in their github repo there is a 0.6.2 release. I will try to debug this.
It looks like GitHub is down, I will rerun CI tomorrow, but they are green locally |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First pass at a review. Thanks so much for your amazing work. I am going to mess around with the test cases and environment next to get a better understanding of how things function together. In the meantime, I have some questions.
@@ -255,21 +257,22 @@ def _step( | |||
) | |||
|
|||
new_sink_states_idx = actions.is_exit | |||
new_states.tensor[new_sink_states_idx] = self.sf | |||
sf_tensor = self.States.make_sink_states_tensor((new_sink_states_idx.sum(),)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a comment would be worth adding here.
""" | ||
self.s0 = s0.to(device_str) | ||
self.features_dim = s0["node_feature"].shape[-1] | ||
self.sf = sf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps we could have a special NoneTensorDict
GraphState
which acts like None
but passes the relevant checks?
self.check_output_dim(out) | ||
self._output_dim_is_checked = True | ||
|
||
assert out.shape[-1] == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This also seems like a much harder constraint.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why? The expected_output_dim
is 1 in this class, so it seems the same to me (actually softer as I don't check the dtype).
) * edge_index_probs + epsilon * uniform_dist_probs | ||
dists["edge_index"] = CategoricalIndexes(probs=edge_index_probs) | ||
|
||
dists["features"] = Normal(module_output["features"], temperature) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I disagree that we should fix it here (i.e., in the Estimator
), or if we do so, we should implement masking properly in the States
class before the release of V2 because this is a clear design pattern violation.
There is no problem adding multiple edges to-from the same node in principle (in say, a multi-attribute graph) but that does make things way more complex, and I think we can safely avoid that complexity here, as single edges to-from the same nodes will cover a lot of AI for Science applications in the near term.
The code runs, thanks for the last change. A few suggestions/questions:
|
Description:
Unlike the current States object that necessitates appending dummy states to batch trajectories of varying lengths, our approach aims to support Trajectories through a nested Batch object representation. The Data class in Torch Geometric represents the graph structure, while the Batch class, which encapsulates batching of Data objects and their efficient indexing, represents the GraphStates object.
The current implementation of Trajectory supports the indexing dimensions: (Num time steps, Num trajectories, State Size). By using a nested Batch of Batch object to represent state Trajectories, the indexing would inherently take the form (Num trajectories, Num timesteps, State size). This approach requires implementing logic within
_getitem_()
and_setitem_()
to internally.To Do:
Compatibility check with Trajectories, Transition class