diff --git a/rustworkx/__init__.py b/rustworkx/__init__.py index d2511b330e..951aedeb91 100644 --- a/rustworkx/__init__.py +++ b/rustworkx/__init__.py @@ -570,15 +570,14 @@ def minimum_cycle_basis(graph, edge_cost_fn): [2] de Pina, J. 1995. Applications of shortest path methods. Ph.D. thesis, University of Amsterdam, Netherlands - :param graph: The input graph to use. Can either be a + :param graph: The input graph to use. Can be either a :class:`~rustworkx.PyGraph` or :class:`~rustworkx.PyDiGraph` :param edge_cost_fn: A callable object that acts as a weight function for an edge. It will accept a single positional argument, the edge's weight object and will return a float which will be used to represent the weight/cost of the edge - :return: A list of cycles where each cycle is a list of node indices - + :returns: A list of cycles where each cycle is a list of node indices :rtype: list """ raise TypeError("Invalid Input Type %s for graph" % type(graph)) diff --git a/rustworkx/__init__.pyi b/rustworkx/__init__.pyi index 140757a265..c571369780 100644 --- a/rustworkx/__init__.pyi +++ b/rustworkx/__init__.pyi @@ -84,6 +84,7 @@ from .rustworkx import digraph_core_number as digraph_core_number from .rustworkx import graph_core_number as graph_core_number from .rustworkx import stoer_wagner_min_cut as stoer_wagner_min_cut from .rustworkx import graph_minimum_cycle_basis as graph_minimum_cycle_basis +from .rustworkx import digraph_minimum_cycle_basis as digraph_minimum_cycle_basis from .rustworkx import simple_cycles as simple_cycles from .rustworkx import digraph_isolates as digraph_isolates from .rustworkx import graph_isolates as graph_isolates diff --git a/rustworkx/rustworkx.pyi b/rustworkx/rustworkx.pyi index 7482f8fb05..398d5c678c 100644 --- a/rustworkx/rustworkx.pyi +++ b/rustworkx/rustworkx.pyi @@ -15,6 +15,7 @@ from typing import ( Callable, Iterable, Iterator, + Union, final, Sequence, Any, @@ -248,8 +249,15 @@ def stoer_wagner_min_cut( weight_fn: Callable[[_T], float] | None = ..., ) -> tuple[float, NodeIndices] | None: ... def graph_minimum_cycle_basis( - graph: PyGraph[_S, _T], /, weight_fn: Callable[[_T], float] | None = ... -) -> list[list[NodeIndices]] | None: ... + graph: PyGraph[_S, _T], + edge_cost: Callable[[_T], float], + /, +) -> list[list[NodeIndices]]: ... +def digraph_minimum_cycle_basis( + graph: PyDiGraph[_S, _T], + edge_cost: Callable[[_T], float], + /, +) -> list[list[NodeIndices]]: ... def simple_cycles(graph: PyDiGraph, /) -> Iterator[NodeIndices]: ... def graph_isolates(graph: PyGraph) -> NodeIndices: ... def digraph_isolates(graph: PyDiGraph) -> NodeIndices: ... diff --git a/src/connectivity/minimum_cycle_basis.rs b/src/connectivity/min_cycle_basis.rs similarity index 69% rename from src/connectivity/minimum_cycle_basis.rs rename to src/connectivity/min_cycle_basis.rs index 1349706939..880ef0af21 100644 --- a/src/connectivity/minimum_cycle_basis.rs +++ b/src/connectivity/min_cycle_basis.rs @@ -4,18 +4,17 @@ use pyo3::exceptions::PyIndexError; use pyo3::prelude::*; use pyo3::Python; -use petgraph::graph::NodeIndex; +use crate::iterators::NodeIndices; +use crate::{CostFn, StablePyGraph}; use petgraph::prelude::*; use petgraph::visit::EdgeIndexable; use petgraph::EdgeType; -use crate::{CostFn, StablePyGraph}; - -pub fn minimum_cycle_basis_map( +pub fn minimum_cycle_basis( py: Python, graph: &StablePyGraph, edge_cost_fn: PyObject, -) -> PyResult>> { +) -> PyResult>> { if graph.node_count() == 0 || graph.edge_count() == 0 { return Ok(vec![]); } @@ -35,5 +34,17 @@ pub fn minimum_cycle_basis_map( } }; let cycle_basis = minimal_cycle_basis(graph, |e| edge_cost(e.id())).unwrap(); - Ok(cycle_basis) + // Convert the cycle basis to a list of lists of node indices + let result: Vec> = cycle_basis + .into_iter() + .map(|cycle| { + cycle + .into_iter() + .map(|node| NodeIndices { + nodes: vec![node.index()], + }) + .collect() + }) + .collect(); + Ok(result) } diff --git a/src/connectivity/mod.rs b/src/connectivity/mod.rs index 3ed475934c..3d2aff850c 100644 --- a/src/connectivity/mod.rs +++ b/src/connectivity/mod.rs @@ -14,7 +14,7 @@ mod all_pairs_all_simple_paths; mod johnson_simple_cycles; -mod minimum_cycle_basis; +mod min_cycle_basis; mod subgraphs; use super::{ @@ -919,6 +919,20 @@ pub fn stoer_wagner_min_cut( })) } +/// Find a minimum cycle basis of an undirected graph. +/// All weights must be nonnegative. If the input graph does not have +/// any nodes or edges, this function returns ``None``. +/// If the input graph does not any weight, this function will find the +/// minimum cycle basis with the weight of 1.0 for all edges. +/// +/// :param PyGraph: The undirected graph to be used +/// :param Callable edge_cost_fn: An optional callable object (function, lambda, etc) which +/// will be passed the edge object and expected to return a ``float``. +/// Edges with ``NaN`` weights will be considered to have 1.0 weight. +/// If ``edge_cost_fn`` is not specified a default value of ``1.0`` will be used for all edges. +/// +/// :returns: A list of cycles, where each cycle is a list of node indices +/// :rtype: list #[pyfunction] #[pyo3(text_signature = "(graph, edge_cost_fn, /)")] pub fn graph_minimum_cycle_basis( @@ -926,18 +940,33 @@ pub fn graph_minimum_cycle_basis( graph: &graph::PyGraph, edge_cost_fn: PyObject, ) -> PyResult>> { - let basis = minimum_cycle_basis::minimum_cycle_basis_map(py, &graph.graph, edge_cost_fn); - Ok(basis - .into_iter() - .map(|cycle| { - cycle - .into_iter() - .map(|node| NodeIndices { - nodes: node.iter().map(|nx| nx.index()).collect(), - }) - .collect() - }) - .collect()) + min_cycle_basis::minimum_cycle_basis(py, &graph.graph, edge_cost_fn) +} + +/// Find a minimum cycle basis of a directed graph (which is not of interest in the context +/// of minimum cycle basis). This function will return the minimum cycle basis of the +/// underlying undirected graph of the input directed graph. +/// All weights must be nonnegative. If the input graph does not have +/// any nodes or edges, this function returns ``None``. +/// If the input graph does not any weight, this function will find the +/// minimum cycle basis with the weight of 1.0 for all edges. +/// +/// :param PyDiGraph: The directed graph to be used +/// :param Callable edge_cost_fn: An optional callable object (function, lambda, etc) which +/// will be passed the edge object and expected to return a ``float``. +/// Edges with ``NaN`` weights will be considered to have 1.0 weight. +/// If ``edge_cost_fn`` is not specified a default value of ``1.0`` will be used for all edges. +/// +/// :returns: A list of cycles, where each cycle is a list of node indices +/// :rtype: list +#[pyfunction] +#[pyo3(text_signature = "(graph, edge_cost_fn, /)")] +pub fn digraph_minimum_cycle_basis( + py: Python, + graph: &digraph::PyDiGraph, + edge_cost_fn: PyObject, +) -> PyResult>> { + min_cycle_basis::minimum_cycle_basis(py, &graph.graph, edge_cost_fn) } /// Return the articulation points of an undirected graph. diff --git a/src/lib.rs b/src/lib.rs index 3d9eb63e01..bdf404f6fa 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -571,6 +571,7 @@ fn rustworkx(py: Python<'_>, m: &Bound) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(metric_closure))?; m.add_wrapped(wrap_pyfunction!(stoer_wagner_min_cut))?; m.add_wrapped(wrap_pyfunction!(graph_minimum_cycle_basis))?; + m.add_wrapped(wrap_pyfunction!(digraph_minimum_cycle_basis))?; m.add_wrapped(wrap_pyfunction!(steiner_tree::steiner_tree))?; m.add_wrapped(wrap_pyfunction!(digraph_dfs_search))?; m.add_wrapped(wrap_pyfunction!(graph_dfs_search))?;