diff --git a/README.md b/README.md index b7cc7fd..b0aa5da 100644 --- a/README.md +++ b/README.md @@ -166,13 +166,16 @@ import nx_arangodb as nxadb G = nxadb.Graph(name="MyGraph") +# Option 1: Use Global Config nx.config.backends.arangodb.use_gpu = False - nx.pagerank(G) nx.betweenness_centrality(G) # ... - nx.config.backends.arangodb.use_gpu = True + +# Option 2: Use Local Config +nx.pagerank(G, use_gpu=False) +nx.betweenness_centrality(G, use_gpu=False) ```
diff --git a/_nx_arangodb/__init__.py b/_nx_arangodb/__init__.py index 616e961..9f6c4d4 100644 --- a/_nx_arangodb/__init__.py +++ b/_nx_arangodb/__init__.py @@ -74,15 +74,7 @@ def get_info(): for key in info_keys: del d[key] - d["default_config"] = { - "host": None, - "username": None, - "password": None, - "db_name": None, - "read_parallelism": None, - "read_batch_size": None, - "use_gpu": True, - } + d["default_config"] = {"use_gpu": True} return d diff --git a/doc/algorithms/index.rst b/doc/algorithms/index.rst index 9adf6f7..d041b60 100644 --- a/doc/algorithms/index.rst +++ b/doc/algorithms/index.rst @@ -43,14 +43,17 @@ You can also force-run algorithms on CPU even if ``nx-cugraph`` is installed: G = nxadb.Graph(name="MyGraph") + # Option 1: Use Global Config nx.config.backends.arangodb.use_gpu = False - nx.pagerank(G) nx.betweenness_centrality(G) # ... - nx.config.backends.arangodb.use_gpu = True + # Option 2: Use Local Config + nx.pagerank(G, use_gpu=False) + nx.betweenness_centrality(G, use_gpu=False) + .. image:: ../_static/dispatch.png :align: center diff --git a/doc/nx_arangodb.ipynb b/doc/nx_arangodb.ipynb index 003524e..71249b5 100644 --- a/doc/nx_arangodb.ipynb +++ b/doc/nx_arangodb.ipynb @@ -236,9 +236,7 @@ "outputs": [], "source": [ "# 5. Run an algorithm (CPU)\n", - "nx.config.backends.arangodb.use_gpu = False # Optional\n", - "\n", - "res = nx.pagerank(G)" + "res = nx.pagerank(G, use_gpu=False)" ] }, { @@ -357,8 +355,6 @@ "source": [ "# 4. Run an algorithm (GPU)\n", "# See *Package Installation* to install nx-cugraph ^\n", - "nx.config.backends.arangodb.use_gpu = True\n", - "\n", "res = nx.pagerank(G)" ] }, diff --git a/nx_arangodb/classes/dict/adj.py b/nx_arangodb/classes/dict/adj.py index b1b2ddc..c290742 100644 --- a/nx_arangodb/classes/dict/adj.py +++ b/nx_arangodb/classes/dict/adj.py @@ -105,6 +105,8 @@ def adjlist_outer_dict_factory( db: StandardDatabase, graph: Graph, default_node_type: str, + read_parallelism: int, + read_batch_size: int, edge_type_key: str, edge_type_func: Callable[[str, str], str], graph_type: str, @@ -115,6 +117,8 @@ def adjlist_outer_dict_factory( db, graph, default_node_type, + read_parallelism, + read_batch_size, edge_type_key, edge_type_func, graph_type, @@ -1454,6 +1458,12 @@ class AdjListOuterDict(UserDict[str, AdjListInnerDict]): symmetrize_edges_if_directed : bool Whether to add the reverse edge if the graph is directed. + read_parallelism : int + The number of parallel threads to use for reading data in _fetch_all. + + read_batch_size : int + The number of documents to read in each batch in _fetch_all. + Example ------- >>> g = nxadb.Graph(name="MyGraph") @@ -1467,6 +1477,8 @@ def __init__( db: StandardDatabase, graph: Graph, default_node_type: str, + read_parallelism: int, + read_batch_size: int, edge_type_key: str, edge_type_func: Callable[[str, str], str], graph_type: str, @@ -1489,6 +1501,8 @@ def __init__( self.edge_type_key = edge_type_key self.edge_type_func = edge_type_func self.default_node_type = default_node_type + self.read_parallelism = read_parallelism + self.read_batch_size = read_batch_size self.adjlist_inner_dict_factory = adjlist_inner_dict_factory( db, graph, @@ -1853,6 +1867,8 @@ def _fetch_all(self) -> None: is_directed=True, is_multigraph=self.is_multigraph, symmetrize_edges_if_directed=self.symmetrize_edges_if_directed, + read_parallelism=self.read_parallelism, + read_batch_size=self.read_batch_size, ) # Even if the Graph is undirected, diff --git a/nx_arangodb/classes/dict/node.py b/nx_arangodb/classes/dict/node.py index 872b158..587014d 100644 --- a/nx_arangodb/classes/dict/node.py +++ b/nx_arangodb/classes/dict/node.py @@ -40,10 +40,20 @@ def node_dict_factory( - db: StandardDatabase, graph: Graph, default_node_type: str + db: StandardDatabase, + graph: Graph, + default_node_type: str, + read_parallelism: int, + read_batch_size: int, ) -> Callable[..., NodeDict]: """Factory function for creating a NodeDict.""" - return lambda: NodeDict(db, graph, default_node_type) + return lambda: NodeDict( + db, + graph, + default_node_type, + read_parallelism, + read_batch_size, + ) def node_attr_dict_factory( @@ -250,6 +260,12 @@ class NodeDict(UserDict[str, NodeAttrDict]): default_node_type : str The default node type for the graph. + read_parallelism : int + The number of parallel threads to use for reading data in _fetch_all. + + read_batch_size : int + The number of documents to read in each batch in _fetch_all. + Example ------- >>> G = nxadb.Graph("MyGraph") @@ -262,6 +278,8 @@ def __init__( db: StandardDatabase, graph: Graph, default_node_type: str, + read_parallelism: int, + read_batch_size: int, *args: Any, **kwargs: Any, ): @@ -271,6 +289,9 @@ def __init__( self.db = db self.graph = graph self.default_node_type = default_node_type + self.read_parallelism = read_parallelism + self.read_batch_size = read_batch_size + self.node_attr_dict_factory = node_attr_dict_factory(self.db, self.graph) self.FETCHED_ALL_DATA = False @@ -472,6 +493,8 @@ def _fetch_all(self): is_directed=False, # not used is_multigraph=False, # not used symmetrize_edges_if_directed=False, # not used + read_parallelism=self.read_parallelism, + read_batch_size=self.read_batch_size, ) for node_id, node_data in node_dict.items(): diff --git a/nx_arangodb/classes/function.py b/nx_arangodb/classes/function.py index 491c0cd..11316f9 100644 --- a/nx_arangodb/classes/function.py +++ b/nx_arangodb/classes/function.py @@ -47,6 +47,8 @@ def get_arangodb_graph( is_directed: bool, is_multigraph: bool, symmetrize_edges_if_directed: bool, + read_parallelism: int, + read_batch_size: int, ) -> Tuple[ NodeDict, GraphAdjDict | DiGraphAdjDict | MultiGraphAdjDict | MultiDiGraphAdjDict, @@ -142,11 +144,10 @@ def get_arangodb_graph( if not load_adj_dict and not load_coo: metagraph["edgeCollections"] = {} - config = nx.config.backends.arangodb - assert config.db_name - assert config.host - assert config.username - assert config.password + hosts = adb_graph._conn._hosts + hosts = hosts.split(",") if type(hosts) is str else hosts + db_name = adb_graph._conn._db_name + username, password = adb_graph._conn._auth ( node_dict, @@ -157,11 +158,11 @@ def get_arangodb_graph( vertex_ids_to_index, edge_values, ) = NetworkXLoader.load_into_networkx( - config.db_name, + database=db_name, metagraph=metagraph, - hosts=[config.host], - username=config.username, - password=config.password, + hosts=hosts, + username=username, + password=password, load_adj_dict=load_adj_dict, load_coo=load_coo, load_all_vertex_attributes=load_all_vertex_attributes, @@ -169,8 +170,8 @@ def get_arangodb_graph( is_directed=is_directed, is_multigraph=is_multigraph, symmetrize_edges_if_directed=symmetrize_edges_if_directed, - parallelism=config.read_parallelism, - batch_size=config.read_batch_size, + parallelism=read_parallelism, + batch_size=read_batch_size, ) return ( diff --git a/nx_arangodb/classes/graph.py b/nx_arangodb/classes/graph.py index 336cc3b..ed7cbda 100644 --- a/nx_arangodb/classes/graph.py +++ b/nx_arangodb/classes/graph.py @@ -214,11 +214,13 @@ def __init__( self.use_nxcg_cache = True self.nxcg_graph = None + self.edge_type_key = edge_type_key + self.read_parallelism = read_parallelism + self.read_batch_size = read_batch_size + # Does not apply to undirected graphs self.symmetrize_edges = symmetrize_edges - self.edge_type_key = edge_type_key - # TODO: Consider this # if not self.__graph_name: # if incoming_graph_data is not None: @@ -227,8 +229,8 @@ def __init__( self._loaded_incoming_graph_data = False if self.graph_exists_in_db: - self._set_factory_methods() - self.__set_arangodb_backend_config(read_parallelism, read_batch_size) + self._set_factory_methods(read_parallelism, read_batch_size) + self.__set_arangodb_backend_config() if overwrite_graph: logger.info("Overwriting graph...") @@ -284,7 +286,7 @@ def __init__( # Init helper methods # ####################### - def _set_factory_methods(self) -> None: + def _set_factory_methods(self, read_parallelism: int, read_batch_size: int) -> None: """Set the factory methods for the graph, _node, and _adj dictionaries. The ArangoDB CRUD operations are handled by the modified dictionaries. @@ -299,39 +301,29 @@ def _set_factory_methods(self) -> None: """ base_args = (self.db, self.adb_graph) + node_args = (*base_args, self.default_node_type) - adj_args = ( - *node_args, - self.edge_type_key, - self.edge_type_func, - self.__class__.__name__, + node_args_with_read = (*node_args, read_parallelism, read_batch_size) + + adj_args = (self.edge_type_key, self.edge_type_func, self.__class__.__name__) + adj_inner_args = (*node_args, *adj_args) + adj_outer_args = ( + *node_args_with_read, + *adj_args, + self.symmetrize_edges, ) self.graph_attr_dict_factory = graph_dict_factory(*base_args) - self.node_dict_factory = node_dict_factory(*node_args) + self.node_dict_factory = node_dict_factory(*node_args_with_read) self.node_attr_dict_factory = node_attr_dict_factory(*base_args) self.edge_attr_dict_factory = edge_attr_dict_factory(*base_args) - self.adjlist_inner_dict_factory = adjlist_inner_dict_factory(*adj_args) - self.adjlist_outer_dict_factory = adjlist_outer_dict_factory( - *adj_args, self.symmetrize_edges - ) - - def __set_arangodb_backend_config( - self, read_parallelism: int, read_batch_size: int - ) -> None: - if not all([self._host, self._username, self._password, self._db_name]): - m = "Must set all environment variables to use the ArangoDB Backend with an existing graph" # noqa: E501 - raise OSError(m) + self.adjlist_inner_dict_factory = adjlist_inner_dict_factory(*adj_inner_args) + self.adjlist_outer_dict_factory = adjlist_outer_dict_factory(*adj_outer_args) + def __set_arangodb_backend_config(self) -> None: config = nx.config.backends.arangodb - config.host = self._host - config.username = self._username - config.password = self._password - config.db_name = self._db_name - config.read_parallelism = read_parallelism - config.read_batch_size = read_batch_size config.use_gpu = True # Only used by default if nx-cugraph is available def __set_edge_collections_attributes(self, attributes: set[str] | None) -> None: @@ -345,7 +337,7 @@ def __set_edge_collections_attributes(self, attributes: set[str] | None) -> None self._edge_collections_attributes.add("_id") def __set_db(self, db: Any = None) -> None: - self._host = os.getenv("DATABASE_HOST") + self._hosts = os.getenv("DATABASE_HOST", "").split(",") self._username = os.getenv("DATABASE_USERNAME") self._password = os.getenv("DATABASE_PASSWORD") self._db_name = os.getenv("DATABASE_NAME") @@ -355,17 +347,20 @@ def __set_db(self, db: Any = None) -> None: m = "arango.database.StandardDatabase" raise TypeError(m) - db.version() + db.version() # make sure the connection is valid self.__db = db + self._db_name = db.name + self._hosts = db._conn._hosts + self._username, self._password = db._conn._auth return - if not all([self._host, self._username, self._password, self._db_name]): + if not all([self._hosts, self._username, self._password, self._db_name]): m = "Database environment variables not set. Can't connect to the database" logger.warning(m) self.__db = None return - self.__db = ArangoClient(hosts=self._host, request_timeout=None).db( + self.__db = ArangoClient(hosts=self._hosts, request_timeout=None).db( self._db_name, self._username, self._password, verify=True ) diff --git a/nx_arangodb/classes/multigraph.py b/nx_arangodb/classes/multigraph.py index c494d34..7d7db59 100644 --- a/nx_arangodb/classes/multigraph.py +++ b/nx_arangodb/classes/multigraph.py @@ -229,8 +229,8 @@ def __init__( # Init helper methods # ####################### - def _set_factory_methods(self) -> None: - super()._set_factory_methods() + def _set_factory_methods(self, read_parallelism: int, read_batch_size: int) -> None: + super()._set_factory_methods(read_parallelism, read_batch_size) self.edge_key_dict_factory = edge_key_dict_factory( self.db, self.adb_graph, diff --git a/nx_arangodb/convert.py b/nx_arangodb/convert.py index 8eda47b..33df773 100644 --- a/nx_arangodb/convert.py +++ b/nx_arangodb/convert.py @@ -256,6 +256,8 @@ def nxadb_to_nx(G: nxadb.Graph) -> nx.Graph: is_directed=G.is_directed(), is_multigraph=G.is_multigraph(), symmetrize_edges_if_directed=G.symmetrize_edges if G.is_directed() else False, + read_parallelism=G.read_parallelism, + read_batch_size=G.read_batch_size, ) logger.info(f"Graph '{G.adb_graph.name}' load took {time.time() - start_time}s") @@ -337,6 +339,8 @@ def nxadb_to_nxcg(G: nxadb.Graph, as_directed: bool = False) -> nxcg.Graph: symmetrize_edges_if_directed=( G.symmetrize_edges if G.is_directed() else False ), + read_parallelism=G.read_parallelism, + read_batch_size=G.read_batch_size, ) logger.info(f"Graph '{G.adb_graph.name}' load took {time.time() - start_time}s") diff --git a/nx_arangodb/interface.py b/nx_arangodb/interface.py index c756a32..ef2d2a6 100644 --- a/nx_arangodb/interface.py +++ b/nx_arangodb/interface.py @@ -63,7 +63,9 @@ def _auto_func(func_name: str, /, *args: Any, **kwargs: Any) -> Any: dfunc = _registered_algorithms[func_name] backend_priority: list[str] = [] - if nxadb.convert.GPU_AVAILABLE and nx.config.backends.arangodb.use_gpu: + + use_gpu = bool(kwargs.pop("use_gpu", nx.config.backends.arangodb.use_gpu)) + if nxadb.convert.GPU_AVAILABLE and use_gpu: backend_priority.append("cugraph") for backend in backend_priority: diff --git a/tests/conftest.py b/tests/conftest.py index 7f2eec3..1868390 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,6 @@ import logging import os -import sys -from io import StringIO -from typing import Any +from typing import Any, Dict import networkx as nx import pytest @@ -15,6 +13,8 @@ logger.setLevel(logging.INFO) +con: Dict[str, Any] +client: ArangoClient db: StandardDatabase run_gpu_tests: bool @@ -30,6 +30,7 @@ def pytest_addoption(parser: Any) -> None: def pytest_configure(config: Any) -> None: + global con con = { "url": config.getoption("url"), "username": config.getoption("username"), @@ -43,10 +44,11 @@ def pytest_configure(config: Any) -> None: print("Password: " + con["password"]) print("Database: " + con["dbName"]) + global client + client = ArangoClient(hosts=con["url"]) + global db - db = ArangoClient(hosts=con["url"]).db( - con["dbName"], con["username"], con["password"], verify=True - ) + db = client.db(con["dbName"], con["username"], con["password"], verify=True) print("Version: " + db.version()) print("----------------------------------------") @@ -99,6 +101,12 @@ def load_two_relation_graph() -> None: ) +def get_db(db_name: str) -> StandardDatabase: + global con + global client + return client.db(db_name, con["username"], con["password"], verify=True) + + def create_line_graph(load_attributes: set[str]) -> nxadb.Graph: G = nx.Graph() G.add_edge(1, 2, my_custom_weight=1) diff --git a/tests/test.py b/tests/test.py index db1e7e8..4b694fb 100644 --- a/tests/test.py +++ b/tests/test.py @@ -16,7 +16,7 @@ from nx_arangodb.classes.dict.graph import GRAPH_FIELD from nx_arangodb.classes.dict.node import NodeAttrDict, NodeDict -from .conftest import create_grid_graph, create_line_graph, db, run_gpu_tests +from .conftest import create_grid_graph, create_line_graph, db, get_db, run_gpu_tests G_NX: nx.Graph = nx.karate_club_graph() G_NX_digraph = nx.DiGraph(G_NX) @@ -88,6 +88,39 @@ def test_adb_graph_init(graph_cls: type[nxadb.Graph]) -> None: G.name = "RenamedTestGraph" +def test_multiple_graph_sessions(): + db_1_name = "test_db_1" + db_2_name = "test_db_2" + + db.delete_database(db_1_name, ignore_missing=True) + db.delete_database(db_2_name, ignore_missing=True) + + db.create_database(db_1_name) + db.create_database(db_2_name) + + db_1 = get_db(db_1_name) + db_2 = get_db(db_2_name) + + G_1 = nxadb.Graph(name="TestGraph", db=db_1) + G_2 = nxadb.Graph(name="TestGraph", db=db_2) + + G_1.add_node(1, foo="bar") + G_1.add_node(2) + G_1.add_edge(1, 2) + + G_2.add_node(1) + G_2.add_node(2) + G_2.add_node(3) + G_2.add_edge(1, 2) + G_2.add_edge(2, 3) + + res_1 = nx.pagerank(G_1) + res_2 = nx.pagerank(G_2) + + assert len(res_1) == 2 + assert len(res_2) == 3 + + def test_load_graph_from_nxadb(): graph_name = "KarateGraph" @@ -447,7 +480,12 @@ def test_gpu_pagerank(graph_cls: type[nxadb.Graph]) -> None: assert gpu_cached_time < gpu_no_cache_time assert_pagerank(res_gpu_cached, res_gpu_no_cache, 10) - # 4. CPU + # 4. CPU (with use_gpu=False) + start_cpu_force_no_gpu = time.time() + res_cpu_force_no_gpu = nx.pagerank(graph, use_gpu=False) + cpu_force_no_gpu_time = time.time() - start_cpu_force_no_gpu + + # 5. CPU assert graph.nxcg_graph is not None graph.clear_nxcg_cache() assert graph.nxcg_graph is None @@ -456,12 +494,14 @@ def test_gpu_pagerank(graph_cls: type[nxadb.Graph]) -> None: start_cpu = time.time() res_cpu = nx.pagerank(graph) cpu_time = time.time() - start_cpu + assert_pagerank(res_cpu, res_cpu_force_no_gpu, 10) assert graph.nxcg_graph is None - m = "GPU execution should be faster than CPU execution" assert gpu_time < cpu_time, m + assert gpu_time < cpu_force_no_gpu_time, m assert gpu_no_cache_time < cpu_time, m + assert gpu_no_cache_time < cpu_force_no_gpu_time, m assert_pagerank(res_gpu_no_cache, res_cpu, 10)