Skip to content

Commit

Permalink
fix pytorch tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanhhughes committed Jul 19, 2024
1 parent 1d3e377 commit 6973a8d
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 24 deletions.
26 changes: 9 additions & 17 deletions python/src/spark_dsg/torch_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,14 @@
graphs. Note that `DynamicSceneGraph.to_torch()` calls into the relevant
homogeneous or heterogeneous conversion function.
"""
from spark_dsg._dsg_bindings import (
DynamicSceneGraph,
SceneGraphLayer,
SceneGraphNode,
SceneGraphEdge,
DsgLayers,
LayerView,
)
from typing import Callable, Optional, Dict, Union
import numpy as np
import importlib
from typing import Callable, Dict, Optional, Union

import numpy as np

from spark_dsg._dsg_bindings import (DsgLayers, DynamicSceneGraph, LayerView,
SceneGraphEdge, SceneGraphLayer,
SceneGraphNode)

NodeConversionFunc = Callable[[DynamicSceneGraph, SceneGraphNode], np.ndarray]
EdgeConversionFunc = Callable[[DynamicSceneGraph, SceneGraphEdge], np.ndarray]
Expand Down Expand Up @@ -136,17 +132,13 @@ def scene_graph_layer_to_torch(
scene graph layer.
"""
torch, torch_geometric = _get_torch()

# output torch tensor data types
if double_precision:
dtype_float = torch.float64
else:
dtype_float = torch.float32
dtype_float = torch.float64 if double_precision else torch.float32

N = G.num_nodes()

node_features = []
node_positions = torch.zeros((N, 3), dtype=torch.float64)
node_positions = torch.zeros((N, 3), dtype=dtype_float)
id_map = {}

for node in G.nodes:
Expand All @@ -168,7 +160,7 @@ def scene_graph_layer_to_torch(
edge_features.append(edge_converter(G, edge))

if edge_converter is not None:
edge_features = torch.tensor(np.array(edge_features), dtype_float)
edge_features = torch.tensor(np.array(edge_features), dtype=dtype_float)

if edge_index.size(dim=1) > 0:
if edge_converter is None:
Expand Down
14 changes: 7 additions & 7 deletions python/tests/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _check_interlayer_edges(G, data, to_check, has_edge_attrs=False):
edge_name = f"{source}_to_{target}"
assert (source, edge_name, target) in metadata[1]
assert data[source, edge_name, target].edge_index.size(dim=0) == 2
assert data[source, edge_name, target].edge_index.size(dim=1) >= 2
assert data[source, edge_name, target].edge_index.size(dim=1) >= 1
if has_edge_attrs:
assert data[source, edge_name, target].edge_attr.size(dim=1) == 20
assert data[source, edge_name, target].edge_attr.size(dim=0) == data[
Expand All @@ -86,7 +86,7 @@ def test_torch_layer(resource_dir, has_torch):
if not has_torch:
return pytest.skip(reason="requires pytorch and pytorch geometric")

G = dsg.DynamicSceneGraph.load(str(resource_dir / "apartment_igx_dsg.json"))
G = dsg.DynamicSceneGraph.load(str(resource_dir / "apartment_dsg.json"))
places = G.get_layer(dsg.DsgLayers.PLACES)
assert places.num_nodes() > 0
assert places.num_edges() > 0
Expand All @@ -104,7 +104,7 @@ def test_torch_layer_edge_features(resource_dir, has_torch):
if not has_torch:
return pytest.skip(reason="requires pytorch and pytorch geometric")

G = dsg.DynamicSceneGraph.load(str(resource_dir / "apartment_igx_dsg.json"))
G = dsg.DynamicSceneGraph.load(str(resource_dir / "apartment_dsg.json"))
places = G.get_layer(dsg.DsgLayers.PLACES)
assert places.num_nodes() > 0
assert places.num_edges() > 0
Expand All @@ -122,7 +122,7 @@ def test_torch_homogeneous(resource_dir, has_torch):
if not has_torch:
return pytest.skip(reason="requires pytorch and pytorch geometric")

G = dsg.DynamicSceneGraph.load(str(resource_dir / "apartment_igx_dsg.json"))
G = dsg.DynamicSceneGraph.load(str(resource_dir / "apartment_dsg.json"))
assert G.num_nodes() > 0
assert G.num_edges() > 0

Expand All @@ -139,7 +139,7 @@ def test_torch_homogeneous_edge_features(resource_dir, has_torch):
if not has_torch:
return pytest.skip(reason="requires pytorch and pytorch geometric")

G = dsg.DynamicSceneGraph.load(str(resource_dir / "apartment_igx_dsg.json"))
G = dsg.DynamicSceneGraph.load(str(resource_dir / "apartment_dsg.json"))
assert G.num_nodes() > 0
assert G.num_edges() > 0

Expand All @@ -160,7 +160,7 @@ def test_torch_hetereogeneous(resource_dir, has_torch):
if not has_torch:
return pytest.skip(reason="requires pytorch and pytorch geometric")

G = dsg.DynamicSceneGraph.load(str(resource_dir / "apartment_igx_dsg.json"))
G = dsg.DynamicSceneGraph.load(str(resource_dir / "apartment_dsg.json"))
assert G.num_nodes() > 0
assert G.num_edges() > 0

Expand Down Expand Up @@ -190,7 +190,7 @@ def test_torch_hetereogeneous_edge_features(resource_dir, has_torch):
if not has_torch:
return pytest.skip(reason="requires pytorch and pytorch geometric")

G = dsg.DynamicSceneGraph.load(str(resource_dir / "apartment_igx_dsg.json"))
G = dsg.DynamicSceneGraph.load(str(resource_dir / "apartment_dsg.json"))
assert G.num_nodes() > 0
assert G.num_edges() > 0

Expand Down

0 comments on commit 6973a8d

Please sign in to comment.