Skip to content

Commit

Permalink
refactor(clustering)!: use new DbscanCluster type instead of isize
Browse files Browse the repository at this point in the history
Add a new `DbscanCluster` type to represent the cluster labels, which
can be either a cluster ID or a special value indicating noise. The
`Cluster` variant is now a `NonZeroU32` instead of an `isize`.

We could use a `u32` but using a `NonZeroU32` allows us to make use
of the niche optimizations that `NonZeroU32` provides and halves the
size of the type. Questionable whether it's worth it though?

Closes #229.
  • Loading branch information
sd2k committed Jan 12, 2025
1 parent e098590 commit 5ee35f3
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 56 deletions.
13 changes: 11 additions & 2 deletions crates/augurs-clustering/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ A crate such as [`augurs-dtw`] must be used to calculate the distance matrix for
## Usage

```rust
use augurs::clustering::{DbscanClusterer, DistanceMatrix};
use augurs::clustering::{DbscanCluster, DbscanClusterer, DistanceMatrix};

# fn main() -> Result<(), Box<dyn std::error::Error>> {
// Start with a distance matrix.
Expand All @@ -32,7 +32,16 @@ let min_cluster_size = 2;
// Use DBSCAN to detect clusters of series.
// Note that we don't need to specify the number of clusters in advance.
let clusters = DbscanClusterer::new(epsilon, min_cluster_size).fit(&distance_matrix);
assert_eq!(clusters, vec![0, 0, 0, 1, 1]);
assert_eq!(
clusters,
vec![
DbscanCluster::Cluster(1.try_into().unwrap()),
DbscanCluster::Cluster(1.try_into().unwrap()),
DbscanCluster::Cluster(1.try_into().unwrap()),
DbscanCluster::Cluster(2.try_into().unwrap()),
DbscanCluster::Cluster(2.try_into().unwrap()),
],
);
# Ok(())
# }
```
Expand Down
139 changes: 97 additions & 42 deletions crates/augurs-clustering/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,64 @@
#![doc = include_str!("../README.md")]

use std::collections::VecDeque;
use std::{collections::VecDeque, num::NonZeroU32};

pub use augurs_core::DistanceMatrix;

/// A cluster identified by the DBSCAN algorithm.
///
/// This is either a noise cluster, or a cluster with a specific ID.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DbscanCluster {
/// A noise cluster.
Noise,
/// A cluster with the given ID.
///
/// The ID is not guaranteed to remain the same between runs of the algorithm.
///
/// We use a `NonZeroU32` here to ensure that the ID is never zero. This is mostly
/// just a size optimization.
Cluster(NonZeroU32),
}

impl DbscanCluster {
/// Returns true if this cluster is a noise cluster.
pub fn is_noise(&self) -> bool {
matches!(self, Self::Noise)
}

/// Returns true if this cluster is a cluster with the given ID.
pub fn is_cluster(&self) -> bool {
matches!(self, Self::Cluster(_))
}

/// Returns the ID of the cluster, if it is a cluster, or `-1` if it is a noise cluster.
pub fn as_i32(&self) -> i32 {
match self {
Self::Noise => -1,
Self::Cluster(id) => id.get() as i32,
}
}

fn increment(&mut self) {
match self {
Self::Noise => unreachable!(),
Self::Cluster(id) => *id = id.checked_add(1).expect("cluster ID overflow"),
}
}
}

// Simplify tests by allowing comparisons with i32.
#[cfg(test)]
impl PartialEq<i32> for DbscanCluster {
fn eq(&self, other: &i32) -> bool {
if self.is_noise() {
*other == -1
} else {
self.as_i32() == *other
}
}
}

/// DBSCAN clustering algorithm.
#[derive(Debug)]
pub struct DbscanClusterer {
Expand Down Expand Up @@ -40,11 +95,11 @@ impl DbscanClusterer {

/// Run the DBSCAN clustering algorithm.
///
/// The return value is a vector of cluster assignments, with `-1` indicating noise.
pub fn fit(&self, distance_matrix: &DistanceMatrix) -> Vec<isize> {
/// The return value is a vector of cluster assignments, with `DbscanCluster::Noise` indicating noise.
pub fn fit(&self, distance_matrix: &DistanceMatrix) -> Vec<DbscanCluster> {
let n = distance_matrix.shape().0;
let mut clusters = vec![-1; n];
let mut cluster = 0;
let mut clusters = vec![DbscanCluster::Noise; n];
let mut cluster = DbscanCluster::Cluster(NonZeroU32::new(1).unwrap());
let mut visited = vec![false; n];
let mut to_visit = VecDeque::with_capacity(n);

Expand All @@ -53,7 +108,7 @@ impl DbscanClusterer {

for (i, d) in distance_matrix.iter().enumerate() {
// Skip if already assigned to a cluster.
if clusters[i] != -1 {
if clusters[i].is_cluster() {
continue;
}
self.find_neighbours(i, d, &mut neighbours);
Expand All @@ -67,7 +122,7 @@ impl DbscanClusterer {
clusters[i] = cluster;
// Mark all noise neighbours as visited and add them to the queue.
for neighbour in neighbours.drain(..) {
if clusters[neighbour] == -1 {
if clusters[neighbour].is_noise() {
visited[neighbour] = true;
to_visit.push_back(neighbour);
}
Expand All @@ -87,7 +142,7 @@ impl DbscanClusterer {
}
}
}
cluster += 1;
cluster.increment();
}
clusters
}
Expand Down Expand Up @@ -123,19 +178,19 @@ mod test {
assert_eq!(clusters, vec![-1, -1, -1, -1]);

let clusters = DbscanClusterer::new(1.0, 2).fit(&distance_matrix);
assert_eq!(clusters, vec![0, 0, -1, -1]);
assert_eq!(clusters, vec![1, 1, -1, -1]);

let clusters = DbscanClusterer::new(1.0, 3).fit(&distance_matrix);
assert_eq!(clusters, vec![-1, -1, -1, -1]);

let clusters = DbscanClusterer::new(2.0, 2).fit(&distance_matrix);
assert_eq!(clusters, vec![0, 0, 0, -1]);
assert_eq!(clusters, vec![1, 1, 1, -1]);

let clusters = DbscanClusterer::new(2.0, 3).fit(&distance_matrix);
assert_eq!(clusters, vec![0, 0, 0, -1]);
assert_eq!(clusters, vec![1, 1, 1, -1]);

let clusters = DbscanClusterer::new(3.0, 3).fit(&distance_matrix);
assert_eq!(clusters, vec![0, 0, 0, 0]);
assert_eq!(clusters, vec![1, 1, 1, 1]);
}

#[test]
Expand All @@ -151,36 +206,36 @@ mod test {
let distance_matrix = DistanceMatrix::try_from_square(distance_matrix).unwrap();
let clusters = DbscanClusterer::new(10.0, 3).fit(&distance_matrix);
let expected = vec![
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 2, -1, 2, -1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 3, -1, 3, -1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
];
assert_eq!(clusters, expected);
}
Expand Down
15 changes: 9 additions & 6 deletions crates/augurs/tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ fn test_changepoint() {
#[cfg(feature = "clustering")]
#[test]
fn test_clustering() {
fn convert_clusters(clusters: Vec<augurs_clustering::DbscanCluster>) -> Vec<i32> {
clusters.into_iter().map(|c| c.as_i32()).collect()
}
use augurs::{clustering::DbscanClusterer, DistanceMatrix};
let distance_matrix = vec![
vec![0.0, 1.0, 2.0, 3.0],
Expand All @@ -27,22 +30,22 @@ fn test_clustering() {
];
let distance_matrix = DistanceMatrix::try_from_square(distance_matrix).unwrap();
let clusters = DbscanClusterer::new(0.5, 2).fit(&distance_matrix);
assert_eq!(clusters, vec![-1, -1, -1, -1]);
assert_eq!(convert_clusters(clusters), vec![-1, -1, -1, -1]);

let clusters = DbscanClusterer::new(1.0, 2).fit(&distance_matrix);
assert_eq!(clusters, vec![0, 0, -1, -1]);
assert_eq!(convert_clusters(clusters), vec![1, 1, -1, -1]);

let clusters = DbscanClusterer::new(1.0, 3).fit(&distance_matrix);
assert_eq!(clusters, vec![-1, -1, -1, -1]);
assert_eq!(convert_clusters(clusters), vec![-1, -1, -1, -1]);

let clusters = DbscanClusterer::new(2.0, 2).fit(&distance_matrix);
assert_eq!(clusters, vec![0, 0, 0, -1]);
assert_eq!(convert_clusters(clusters), vec![1, 1, 1, -1]);

let clusters = DbscanClusterer::new(2.0, 3).fit(&distance_matrix);
assert_eq!(clusters, vec![0, 0, 0, -1]);
assert_eq!(convert_clusters(clusters), vec![1, 1, 1, -1]);

let clusters = DbscanClusterer::new(3.0, 3).fit(&distance_matrix);
assert_eq!(clusters, vec![0, 0, 0, 0]);
assert_eq!(convert_clusters(clusters), vec![1, 1, 1, 1]);
}

#[cfg(feature = "dtw")]
Expand Down
11 changes: 9 additions & 2 deletions crates/pyaugurs/src/clustering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,15 @@ impl Dbscan {
&self,
py: Python<'_>,
distance_matrix: InputDistanceMatrix<'_>,
) -> PyResult<Py<PyArray1<isize>>> {
) -> PyResult<Py<PyArray1<i32>>> {
let distance_matrix = distance_matrix.try_into()?;
Ok(self.inner.fit(&distance_matrix).into_pyarray(py).into())
Ok(self
.inner
.fit(&distance_matrix)
.into_iter()
.map(|x| x.as_i32())
.collect::<Vec<_>>()
.into_pyarray(py)
.into())
}
}
20 changes: 16 additions & 4 deletions examples/clustering/examples/dbscan_clustering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
//! The resulting clusters are assigned a label of -1 for noise, 0 for the first cluster, and 1 for
//! the second cluster.
use augurs::{clustering::DbscanClusterer, dtw::Dtw};
use augurs::{
clustering::{DbscanCluster, DbscanClusterer},
dtw::Dtw,
};

// This is a very trivial example dataset containing 5 time series which
// form two obvious clusters, plus a noise cluster.
Expand Down Expand Up @@ -49,7 +52,16 @@ fn main() {
let min_cluster_size = 2;

// Run DBSCAN clustering on the distance matrix.
let clusters: Vec<isize> =
DbscanClusterer::new(epsilon, min_cluster_size).fit(&distance_matrix);
assert_eq!(clusters, vec![0, 0, 1, 1, -1]);
let clusters = DbscanClusterer::new(epsilon, min_cluster_size).fit(&distance_matrix);
println!("Clusters: {:?}", clusters);
assert_eq!(
clusters,
vec![
DbscanCluster::Cluster(1.try_into().unwrap()),
DbscanCluster::Cluster(1.try_into().unwrap()),
DbscanCluster::Cluster(2.try_into().unwrap()),
DbscanCluster::Cluster(2.try_into().unwrap()),
DbscanCluster::Noise,
]
);
}

0 comments on commit 5ee35f3

Please sign in to comment.