Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new: fully support parameterized db object #70

Merged
merged 5 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 1 addition & 9 deletions _nx_arangodb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,7 @@ def get_info():
for key in info_keys:
del d[key]

d["default_config"] = {
"host": None,
"username": None,
"password": None,
"db_name": None,
"read_parallelism": None,
"read_batch_size": None,
"use_gpu": True,
}
d["default_config"] = {"use_gpu": True}

return d

Expand Down
10 changes: 10 additions & 0 deletions nx_arangodb/classes/dict/adj.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def adjlist_outer_dict_factory(
db: StandardDatabase,
graph: Graph,
default_node_type: str,
read_parallelism: int,
read_batch_size: int,
edge_type_key: str,
edge_type_func: Callable[[str, str], str],
graph_type: str,
Expand All @@ -115,6 +117,8 @@ def adjlist_outer_dict_factory(
db,
graph,
default_node_type,
read_parallelism,
read_batch_size,
edge_type_key,
edge_type_func,
graph_type,
Expand Down Expand Up @@ -1467,6 +1471,8 @@ def __init__(
db: StandardDatabase,
graph: Graph,
default_node_type: str,
read_parallelism: int,
read_batch_size: int,
edge_type_key: str,
edge_type_func: Callable[[str, str], str],
graph_type: str,
Expand All @@ -1489,6 +1495,8 @@ def __init__(
self.edge_type_key = edge_type_key
self.edge_type_func = edge_type_func
self.default_node_type = default_node_type
self.read_parallelism = read_parallelism
self.read_batch_size = read_batch_size
self.adjlist_inner_dict_factory = adjlist_inner_dict_factory(
db,
graph,
Expand Down Expand Up @@ -1853,6 +1861,8 @@ def _fetch_all(self) -> None:
is_directed=True,
is_multigraph=self.is_multigraph,
symmetrize_edges_if_directed=self.symmetrize_edges_if_directed,
read_parallelism=self.read_parallelism,
read_batch_size=self.read_batch_size,
)

# Even if the Graph is undirected,
Expand Down
21 changes: 19 additions & 2 deletions nx_arangodb/classes/dict/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,20 @@


def node_dict_factory(
db: StandardDatabase, graph: Graph, default_node_type: str
db: StandardDatabase,
graph: Graph,
default_node_type: str,
read_parallelism: int,
read_batch_size: int,
) -> Callable[..., NodeDict]:
"""Factory function for creating a NodeDict."""
return lambda: NodeDict(db, graph, default_node_type)
return lambda: NodeDict(
db,
graph,
default_node_type,
read_parallelism,
read_batch_size,
)


def node_attr_dict_factory(
Expand Down Expand Up @@ -262,6 +272,8 @@ def __init__(
db: StandardDatabase,
graph: Graph,
default_node_type: str,
read_parallelism: int,
read_batch_size: int,
*args: Any,
**kwargs: Any,
):
Expand All @@ -271,6 +283,9 @@ def __init__(
self.db = db
self.graph = graph
self.default_node_type = default_node_type
self.read_parallelism = read_parallelism
self.read_batch_size = read_batch_size

self.node_attr_dict_factory = node_attr_dict_factory(self.db, self.graph)

self.FETCHED_ALL_DATA = False
Expand Down Expand Up @@ -472,6 +487,8 @@ def _fetch_all(self):
is_directed=False, # not used
is_multigraph=False, # not used
symmetrize_edges_if_directed=False, # not used
read_parallelism=self.read_parallelism,
read_batch_size=self.read_batch_size,
)

for node_id, node_data in node_dict.items():
Expand Down
23 changes: 12 additions & 11 deletions nx_arangodb/classes/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def get_arangodb_graph(
is_directed: bool,
is_multigraph: bool,
symmetrize_edges_if_directed: bool,
read_parallelism: int,
read_batch_size: int,
) -> Tuple[
NodeDict,
GraphAdjDict | DiGraphAdjDict | MultiGraphAdjDict | MultiDiGraphAdjDict,
Expand Down Expand Up @@ -142,11 +144,10 @@ def get_arangodb_graph(
if not load_adj_dict and not load_coo:
metagraph["edgeCollections"] = {}

config = nx.config.backends.arangodb
assert config.db_name
assert config.host
assert config.username
assert config.password
hosts = adb_graph._conn._hosts
hosts = hosts.split(",") if type(hosts) is str else hosts
db_name = adb_graph._conn._db_name
username, password = adb_graph._conn._auth

(
node_dict,
Expand All @@ -157,20 +158,20 @@ def get_arangodb_graph(
vertex_ids_to_index,
edge_values,
) = NetworkXLoader.load_into_networkx(
config.db_name,
database=db_name,
metagraph=metagraph,
hosts=[config.host],
username=config.username,
password=config.password,
hosts=hosts,
username=username,
password=password,
load_adj_dict=load_adj_dict,
load_coo=load_coo,
load_all_vertex_attributes=load_all_vertex_attributes,
load_all_edge_attributes=load_all_edge_attributes,
is_directed=is_directed,
is_multigraph=is_multigraph,
symmetrize_edges_if_directed=symmetrize_edges_if_directed,
parallelism=config.read_parallelism,
batch_size=config.read_batch_size,
parallelism=read_parallelism,
batch_size=read_batch_size,
)

return (
Expand Down
59 changes: 27 additions & 32 deletions nx_arangodb/classes/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,13 @@ def __init__(
self.use_nxcg_cache = True
self.nxcg_graph = None

self.edge_type_key = edge_type_key
self.read_parallelism = read_parallelism
self.read_batch_size = read_batch_size

# Does not apply to undirected graphs
self.symmetrize_edges = symmetrize_edges

self.edge_type_key = edge_type_key

# TODO: Consider this
# if not self.__graph_name:
# if incoming_graph_data is not None:
Expand All @@ -227,8 +229,8 @@ def __init__(

self._loaded_incoming_graph_data = False
if self.graph_exists_in_db:
self._set_factory_methods()
self.__set_arangodb_backend_config(read_parallelism, read_batch_size)
self._set_factory_methods(read_parallelism, read_batch_size)
self.__set_arangodb_backend_config()

if overwrite_graph:
logger.info("Overwriting graph...")
Expand Down Expand Up @@ -284,7 +286,7 @@ def __init__(
# Init helper methods #
#######################

def _set_factory_methods(self) -> None:
def _set_factory_methods(self, read_parallelism: int, read_batch_size: int) -> None:
"""Set the factory methods for the graph, _node, and _adj dictionaries.

The ArangoDB CRUD operations are handled by the modified dictionaries.
Expand All @@ -299,39 +301,29 @@ def _set_factory_methods(self) -> None:
"""

base_args = (self.db, self.adb_graph)

node_args = (*base_args, self.default_node_type)
adj_args = (
*node_args,
self.edge_type_key,
self.edge_type_func,
self.__class__.__name__,
node_args_with_read = (*node_args, read_parallelism, read_batch_size)

adj_args = (self.edge_type_key, self.edge_type_func, self.__class__.__name__)
adj_inner_args = (*node_args, *adj_args)
adj_outer_args = (
*node_args_with_read,
*adj_args,
self.symmetrize_edges,
)

self.graph_attr_dict_factory = graph_dict_factory(*base_args)

self.node_dict_factory = node_dict_factory(*node_args)
self.node_dict_factory = node_dict_factory(*node_args_with_read)
self.node_attr_dict_factory = node_attr_dict_factory(*base_args)

self.edge_attr_dict_factory = edge_attr_dict_factory(*base_args)
self.adjlist_inner_dict_factory = adjlist_inner_dict_factory(*adj_args)
self.adjlist_outer_dict_factory = adjlist_outer_dict_factory(
*adj_args, self.symmetrize_edges
)

def __set_arangodb_backend_config(
self, read_parallelism: int, read_batch_size: int
) -> None:
if not all([self._host, self._username, self._password, self._db_name]):
m = "Must set all environment variables to use the ArangoDB Backend with an existing graph" # noqa: E501
raise OSError(m)
self.adjlist_inner_dict_factory = adjlist_inner_dict_factory(*adj_inner_args)
self.adjlist_outer_dict_factory = adjlist_outer_dict_factory(*adj_outer_args)

def __set_arangodb_backend_config(self) -> None:
config = nx.config.backends.arangodb
config.host = self._host
config.username = self._username
config.password = self._password
config.db_name = self._db_name
config.read_parallelism = read_parallelism
config.read_batch_size = read_batch_size
config.use_gpu = True # Only used by default if nx-cugraph is available

def __set_edge_collections_attributes(self, attributes: set[str] | None) -> None:
Expand All @@ -345,7 +337,7 @@ def __set_edge_collections_attributes(self, attributes: set[str] | None) -> None
self._edge_collections_attributes.add("_id")

def __set_db(self, db: Any = None) -> None:
self._host = os.getenv("DATABASE_HOST")
self._hosts = os.getenv("DATABASE_HOST", "").split(",")
self._username = os.getenv("DATABASE_USERNAME")
self._password = os.getenv("DATABASE_PASSWORD")
self._db_name = os.getenv("DATABASE_NAME")
Expand All @@ -355,17 +347,20 @@ def __set_db(self, db: Any = None) -> None:
m = "arango.database.StandardDatabase"
raise TypeError(m)

db.version()
db.version() # make sure the connection is valid
self.__db = db
self._db_name = db.name
self._hosts = db._conn._hosts
self._username, self._password = db._conn._auth
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting environment variables (i.e DATABASE_HOST, DATABASE_USER, etc.) is still supported, but passing a custom db object will take priority over env vars

return

if not all([self._host, self._username, self._password, self._db_name]):
if not all([self._hosts, self._username, self._password, self._db_name]):
m = "Database environment variables not set. Can't connect to the database"
logger.warning(m)
self.__db = None
return

self.__db = ArangoClient(hosts=self._host, request_timeout=None).db(
self.__db = ArangoClient(hosts=self._hosts, request_timeout=None).db(
self._db_name, self._username, self._password, verify=True
)

Expand Down
4 changes: 2 additions & 2 deletions nx_arangodb/classes/multigraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,8 @@ def __init__(
# Init helper methods #
#######################

def _set_factory_methods(self) -> None:
super()._set_factory_methods()
def _set_factory_methods(self, read_parallelism: int, read_batch_size: int) -> None:
super()._set_factory_methods(read_parallelism, read_batch_size)
self.edge_key_dict_factory = edge_key_dict_factory(
self.db,
self.adb_graph,
Expand Down
4 changes: 4 additions & 0 deletions nx_arangodb/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,8 @@ def nxadb_to_nx(G: nxadb.Graph) -> nx.Graph:
is_directed=G.is_directed(),
is_multigraph=G.is_multigraph(),
symmetrize_edges_if_directed=G.symmetrize_edges if G.is_directed() else False,
read_parallelism=G.read_parallelism,
read_batch_size=G.read_batch_size,
)

logger.info(f"Graph '{G.adb_graph.name}' load took {time.time() - start_time}s")
Expand Down Expand Up @@ -337,6 +339,8 @@ def nxadb_to_nxcg(G: nxadb.Graph, as_directed: bool = False) -> nxcg.Graph:
symmetrize_edges_if_directed=(
G.symmetrize_edges if G.is_directed() else False
),
read_parallelism=G.read_parallelism,
read_batch_size=G.read_batch_size,
)

logger.info(f"Graph '{G.adb_graph.name}' load took {time.time() - start_time}s")
Expand Down
Loading