From 38bb9078fab2cb35fca8007ba726bd7367beca6b Mon Sep 17 00:00:00 2001 From: JraphDev Date: Mon, 15 Nov 2021 09:10:43 +0000 Subject: [PATCH] Avoid implicit int64 conversion in padding. The concatenation resulted in an int64 array due to the prior int() cast. PiperOrigin-RevId: 409910620 Change-Id: I43d27b837683cc575113e1a10cf5ab3da99da0ff --- jraph/_src/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jraph/_src/utils.py b/jraph/_src/utils.py index 4b54b9b..822be1d 100644 --- a/jraph/_src/utils.py +++ b/jraph/_src/utils.py @@ -615,10 +615,10 @@ def pad_with_graphs(graph: gn_graph.GraphsTuple, padding_graph = gn_graph.GraphsTuple( n_node=np.concatenate( - [np.array([pad_n_node]), + [np.array([pad_n_node], dtype=np.int32), np.zeros(pad_n_empty_graph, dtype=np.int32)]), n_edge=np.concatenate( - [np.array([pad_n_edge]), + [np.array([pad_n_edge], dtype=np.int32), np.zeros(pad_n_empty_graph, dtype=np.int32)]), nodes=tree.tree_map(tree_nodes_pad, graph.nodes), edges=tree.tree_map(tree_edges_pad, graph.edges),