Skip to content

Commit

Permalink
Avoid implicit int64 conversion in padding.
Browse files Browse the repository at this point in the history
The concatenation resulted in an int64 array due to the
prior int() cast.

PiperOrigin-RevId: 409910620
Change-Id: I43d27b837683cc575113e1a10cf5ab3da99da0ff
  • Loading branch information
JraphDev authored and jg8610 committed Nov 19, 2021
1 parent e4241c0 commit 38bb907
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions jraph/_src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,10 +615,10 @@ def pad_with_graphs(graph: gn_graph.GraphsTuple,

padding_graph = gn_graph.GraphsTuple(
n_node=np.concatenate(
[np.array([pad_n_node]),
[np.array([pad_n_node], dtype=np.int32),
np.zeros(pad_n_empty_graph, dtype=np.int32)]),
n_edge=np.concatenate(
[np.array([pad_n_edge]),
[np.array([pad_n_edge], dtype=np.int32),
np.zeros(pad_n_empty_graph, dtype=np.int32)]),
nodes=tree.tree_map(tree_nodes_pad, graph.nodes),
edges=tree.tree_map(tree_edges_pad, graph.edges),
Expand Down

0 comments on commit 38bb907

Please sign in to comment.