diff --git a/crates/augurs-clustering/README.md b/crates/augurs-clustering/README.md index d6658e29..f7b47dd3 100644 --- a/crates/augurs-clustering/README.md +++ b/crates/augurs-clustering/README.md @@ -8,7 +8,7 @@ A crate such as [`augurs-dtw`] must be used to calculate the distance matrix for ## Usage ```rust -use augurs::clustering::{DbscanClusterer, DistanceMatrix}; +use augurs::clustering::{DbscanCluster, DbscanClusterer, DistanceMatrix}; # fn main() -> Result<(), Box> { // Start with a distance matrix. @@ -32,7 +32,16 @@ let min_cluster_size = 2; // Use DBSCAN to detect clusters of series. // Note that we don't need to specify the number of clusters in advance. let clusters = DbscanClusterer::new(epsilon, min_cluster_size).fit(&distance_matrix); -assert_eq!(clusters, vec![0, 0, 0, 1, 1]); +assert_eq!( + clusters, + vec![ + DbscanCluster::Cluster(1.try_into().unwrap()), + DbscanCluster::Cluster(1.try_into().unwrap()), + DbscanCluster::Cluster(1.try_into().unwrap()), + DbscanCluster::Cluster(2.try_into().unwrap()), + DbscanCluster::Cluster(2.try_into().unwrap()), + ], +); # Ok(()) # } ``` diff --git a/crates/augurs-clustering/src/lib.rs b/crates/augurs-clustering/src/lib.rs index b52b9055..8a02a151 100644 --- a/crates/augurs-clustering/src/lib.rs +++ b/crates/augurs-clustering/src/lib.rs @@ -1,9 +1,64 @@ #![doc = include_str!("../README.md")] -use std::collections::VecDeque; +use std::{collections::VecDeque, num::NonZeroU32}; pub use augurs_core::DistanceMatrix; +/// A cluster identified by the DBSCAN algorithm. +/// +/// This is either a noise cluster, or a cluster with a specific ID. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum DbscanCluster { + /// A noise cluster. + Noise, + /// A cluster with the given ID. + /// + /// The ID is not guaranteed to remain the same between runs of the algorithm. + /// + /// We use a `NonZeroU32` here to ensure that the ID is never zero. This is mostly + /// just a size optimization. + Cluster(NonZeroU32), +} + +impl DbscanCluster { + /// Returns true if this cluster is a noise cluster. + pub fn is_noise(&self) -> bool { + matches!(self, Self::Noise) + } + + /// Returns true if this cluster is a cluster with the given ID. + pub fn is_cluster(&self) -> bool { + matches!(self, Self::Cluster(_)) + } + + /// Returns the ID of the cluster, if it is a cluster, or `-1` if it is a noise cluster. + pub fn as_i32(&self) -> i32 { + match self { + Self::Noise => -1, + Self::Cluster(id) => id.get() as i32, + } + } + + fn increment(&mut self) { + match self { + Self::Noise => unreachable!(), + Self::Cluster(id) => *id = id.checked_add(1).expect("cluster ID overflow"), + } + } +} + +// Simplify tests by allowing comparisons with i32. +#[cfg(test)] +impl PartialEq for DbscanCluster { + fn eq(&self, other: &i32) -> bool { + if self.is_noise() { + *other == -1 + } else { + self.as_i32() == *other + } + } +} + /// DBSCAN clustering algorithm. #[derive(Debug)] pub struct DbscanClusterer { @@ -40,11 +95,11 @@ impl DbscanClusterer { /// Run the DBSCAN clustering algorithm. /// - /// The return value is a vector of cluster assignments, with `-1` indicating noise. - pub fn fit(&self, distance_matrix: &DistanceMatrix) -> Vec { + /// The return value is a vector of cluster assignments, with `DbscanCluster::Noise` indicating noise. + pub fn fit(&self, distance_matrix: &DistanceMatrix) -> Vec { let n = distance_matrix.shape().0; - let mut clusters = vec![-1; n]; - let mut cluster = 0; + let mut clusters = vec![DbscanCluster::Noise; n]; + let mut cluster = DbscanCluster::Cluster(NonZeroU32::new(1).unwrap()); let mut visited = vec![false; n]; let mut to_visit = VecDeque::with_capacity(n); @@ -53,7 +108,7 @@ impl DbscanClusterer { for (i, d) in distance_matrix.iter().enumerate() { // Skip if already assigned to a cluster. - if clusters[i] != -1 { + if clusters[i].is_cluster() { continue; } self.find_neighbours(i, d, &mut neighbours); @@ -67,7 +122,7 @@ impl DbscanClusterer { clusters[i] = cluster; // Mark all noise neighbours as visited and add them to the queue. for neighbour in neighbours.drain(..) { - if clusters[neighbour] == -1 { + if clusters[neighbour].is_noise() { visited[neighbour] = true; to_visit.push_back(neighbour); } @@ -87,7 +142,7 @@ impl DbscanClusterer { } } } - cluster += 1; + cluster.increment(); } clusters } @@ -123,19 +178,19 @@ mod test { assert_eq!(clusters, vec![-1, -1, -1, -1]); let clusters = DbscanClusterer::new(1.0, 2).fit(&distance_matrix); - assert_eq!(clusters, vec![0, 0, -1, -1]); + assert_eq!(clusters, vec![1, 1, -1, -1]); let clusters = DbscanClusterer::new(1.0, 3).fit(&distance_matrix); assert_eq!(clusters, vec![-1, -1, -1, -1]); let clusters = DbscanClusterer::new(2.0, 2).fit(&distance_matrix); - assert_eq!(clusters, vec![0, 0, 0, -1]); + assert_eq!(clusters, vec![1, 1, 1, -1]); let clusters = DbscanClusterer::new(2.0, 3).fit(&distance_matrix); - assert_eq!(clusters, vec![0, 0, 0, -1]); + assert_eq!(clusters, vec![1, 1, 1, -1]); let clusters = DbscanClusterer::new(3.0, 3).fit(&distance_matrix); - assert_eq!(clusters, vec![0, 0, 0, 0]); + assert_eq!(clusters, vec![1, 1, 1, 1]); } #[test] @@ -151,36 +206,36 @@ mod test { let distance_matrix = DistanceMatrix::try_from_square(distance_matrix).unwrap(); let clusters = DbscanClusterer::new(10.0, 3).fit(&distance_matrix); let expected = vec![ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 2, -1, 2, -1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 3, -1, 3, -1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ]; assert_eq!(clusters, expected); } diff --git a/crates/augurs/tests/integration.rs b/crates/augurs/tests/integration.rs index 068132aa..a8d992c0 100644 --- a/crates/augurs/tests/integration.rs +++ b/crates/augurs/tests/integration.rs @@ -18,6 +18,9 @@ fn test_changepoint() { #[cfg(feature = "clustering")] #[test] fn test_clustering() { + fn convert_clusters(clusters: Vec) -> Vec { + clusters.into_iter().map(|c| c.as_i32()).collect() + } use augurs::{clustering::DbscanClusterer, DistanceMatrix}; let distance_matrix = vec![ vec![0.0, 1.0, 2.0, 3.0], @@ -27,22 +30,22 @@ fn test_clustering() { ]; let distance_matrix = DistanceMatrix::try_from_square(distance_matrix).unwrap(); let clusters = DbscanClusterer::new(0.5, 2).fit(&distance_matrix); - assert_eq!(clusters, vec![-1, -1, -1, -1]); + assert_eq!(convert_clusters(clusters), vec![-1, -1, -1, -1]); let clusters = DbscanClusterer::new(1.0, 2).fit(&distance_matrix); - assert_eq!(clusters, vec![0, 0, -1, -1]); + assert_eq!(convert_clusters(clusters), vec![1, 1, -1, -1]); let clusters = DbscanClusterer::new(1.0, 3).fit(&distance_matrix); - assert_eq!(clusters, vec![-1, -1, -1, -1]); + assert_eq!(convert_clusters(clusters), vec![-1, -1, -1, -1]); let clusters = DbscanClusterer::new(2.0, 2).fit(&distance_matrix); - assert_eq!(clusters, vec![0, 0, 0, -1]); + assert_eq!(convert_clusters(clusters), vec![1, 1, 1, -1]); let clusters = DbscanClusterer::new(2.0, 3).fit(&distance_matrix); - assert_eq!(clusters, vec![0, 0, 0, -1]); + assert_eq!(convert_clusters(clusters), vec![1, 1, 1, -1]); let clusters = DbscanClusterer::new(3.0, 3).fit(&distance_matrix); - assert_eq!(clusters, vec![0, 0, 0, 0]); + assert_eq!(convert_clusters(clusters), vec![1, 1, 1, 1]); } #[cfg(feature = "dtw")] diff --git a/crates/pyaugurs/src/clustering.rs b/crates/pyaugurs/src/clustering.rs index adad947e..b96086aa 100644 --- a/crates/pyaugurs/src/clustering.rs +++ b/crates/pyaugurs/src/clustering.rs @@ -82,8 +82,15 @@ impl Dbscan { &self, py: Python<'_>, distance_matrix: InputDistanceMatrix<'_>, - ) -> PyResult>> { + ) -> PyResult>> { let distance_matrix = distance_matrix.try_into()?; - Ok(self.inner.fit(&distance_matrix).into_pyarray(py).into()) + Ok(self + .inner + .fit(&distance_matrix) + .into_iter() + .map(|x| x.as_i32()) + .collect::>() + .into_pyarray(py) + .into()) } } diff --git a/examples/clustering/examples/dbscan_clustering.rs b/examples/clustering/examples/dbscan_clustering.rs index 08fab215..fb93b1e3 100644 --- a/examples/clustering/examples/dbscan_clustering.rs +++ b/examples/clustering/examples/dbscan_clustering.rs @@ -7,7 +7,10 @@ //! The resulting clusters are assigned a label of -1 for noise, 0 for the first cluster, and 1 for //! the second cluster. -use augurs::{clustering::DbscanClusterer, dtw::Dtw}; +use augurs::{ + clustering::{DbscanCluster, DbscanClusterer}, + dtw::Dtw, +}; // This is a very trivial example dataset containing 5 time series which // form two obvious clusters, plus a noise cluster. @@ -49,7 +52,16 @@ fn main() { let min_cluster_size = 2; // Run DBSCAN clustering on the distance matrix. - let clusters: Vec = - DbscanClusterer::new(epsilon, min_cluster_size).fit(&distance_matrix); - assert_eq!(clusters, vec![0, 0, 1, 1, -1]); + let clusters = DbscanClusterer::new(epsilon, min_cluster_size).fit(&distance_matrix); + println!("Clusters: {:?}", clusters); + assert_eq!( + clusters, + vec![ + DbscanCluster::Cluster(1.try_into().unwrap()), + DbscanCluster::Cluster(1.try_into().unwrap()), + DbscanCluster::Cluster(2.try_into().unwrap()), + DbscanCluster::Cluster(2.try_into().unwrap()), + DbscanCluster::Noise, + ] + ); }