Skip to content

Commit

Permalink
new: support use_gpu algorithm parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
aMahanna committed Jan 3, 2025
1 parent 81aba1a commit ed34107
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 12 deletions.
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,16 @@ import nx_arangodb as nxadb

G = nxadb.Graph(name="MyGraph")

# Option 1: Use Global Config
nx.config.backends.arangodb.use_gpu = False

nx.pagerank(G)
nx.betweenness_centrality(G)
# ...

nx.config.backends.arangodb.use_gpu = True

# Option 2: Use Local Config
nx.pagerank(G, use_gpu=False)
nx.betweenness_centrality(G, use_gpu=False)
```

<p align="center">
Expand Down
7 changes: 5 additions & 2 deletions doc/algorithms/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,17 @@ You can also force-run algorithms on CPU even if ``nx-cugraph`` is installed:
G = nxadb.Graph(name="MyGraph")
# Option 1: Use Global Config
nx.config.backends.arangodb.use_gpu = False
nx.pagerank(G)
nx.betweenness_centrality(G)
# ...
nx.config.backends.arangodb.use_gpu = True
# Option 2: Use Local Config
nx.pagerank(G, use_gpu=False)
nx.betweenness_centrality(G, use_gpu=False)
.. image:: ../_static/dispatch.png
:align: center
Expand Down
6 changes: 1 addition & 5 deletions doc/nx_arangodb.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,7 @@
"outputs": [],
"source": [
"# 5. Run an algorithm (CPU)\n",
"nx.config.backends.arangodb.use_gpu = False # Optional\n",
"\n",
"res = nx.pagerank(G)"
"res = nx.pagerank(G, use_gpu=False)"
]
},
{
Expand Down Expand Up @@ -357,8 +355,6 @@
"source": [
"# 4. Run an algorithm (GPU)\n",
"# See *Package Installation* to install nx-cugraph ^\n",
"nx.config.backends.arangodb.use_gpu = True\n",
"\n",
"res = nx.pagerank(G)"
]
},
Expand Down
4 changes: 3 additions & 1 deletion nx_arangodb/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ def _auto_func(func_name: str, /, *args: Any, **kwargs: Any) -> Any:
dfunc = _registered_algorithms[func_name]

backend_priority: list[str] = []
if nxadb.convert.GPU_AVAILABLE and nx.config.backends.arangodb.use_gpu:

use_gpu = bool(kwargs.pop("use_gpu", nx.config.backends.arangodb.use_gpu))
if nxadb.convert.GPU_AVAILABLE and use_gpu:
backend_priority.append("cugraph")

for backend in backend_priority:
Expand Down
11 changes: 9 additions & 2 deletions tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,12 @@ def test_gpu_pagerank(graph_cls: type[nxadb.Graph]) -> None:
assert gpu_cached_time < gpu_no_cache_time
assert_pagerank(res_gpu_cached, res_gpu_no_cache, 10)

# 4. CPU
# 4. CPU (with use_gpu=False)
start_cpu_force_no_gpu = time.time()
res_cpu_force_no_gpu = nx.pagerank(graph, use_gpu=False)
cpu_force_no_gpu_time = time.time() - start_cpu_force_no_gpu

# 5. CPU
assert graph.nxcg_graph is not None
graph.clear_nxcg_cache()
assert graph.nxcg_graph is None
Expand All @@ -456,12 +461,14 @@ def test_gpu_pagerank(graph_cls: type[nxadb.Graph]) -> None:
start_cpu = time.time()
res_cpu = nx.pagerank(graph)
cpu_time = time.time() - start_cpu
assert_pagerank(res_cpu, res_cpu_force_no_gpu, 10)

assert graph.nxcg_graph is None

m = "GPU execution should be faster than CPU execution"
assert gpu_time < cpu_time, m
assert gpu_time < cpu_force_no_gpu_time, m
assert gpu_no_cache_time < cpu_time, m
assert gpu_no_cache_time < cpu_force_no_gpu_time, m
assert_pagerank(res_gpu_no_cache, res_cpu, 10)


Expand Down

0 comments on commit ed34107

Please sign in to comment.