Skip to content

Commit

Permalink
refactor(clustering)!: use new DbscanCluster type instead of isize (#233
Browse files Browse the repository at this point in the history
)
  • Loading branch information
sd2k authored Jan 12, 2025
1 parent e098590 commit 867f261
Show file tree
Hide file tree
Showing 9 changed files with 170 additions and 64 deletions.
3 changes: 3 additions & 0 deletions book/book.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@ language = "en"
multilingual = false
src = "src"
title = "augurs - a time series toolkit"

[rust]
edition = "2021"
18 changes: 15 additions & 3 deletions book/src/tutorials/clustering.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ Let's start with a simple example using DBSCAN clustering:

```rust
# extern crate augurs;
use augurs::{clustering::DbscanClusterer, dtw::Dtw};
use augurs::{
clustering::{DbscanCluster, DbscanClusterer},
dtw::Dtw,
};

// Sample time series data
const SERIES: &[&[f64]] = &[
Expand All @@ -39,11 +42,20 @@ fn main() {
let min_cluster_size = 2;

// Perform clustering
let clusters: Vec<isize> = DbscanClusterer::new(epsilon, min_cluster_size)
let clusters = DbscanClusterer::new(epsilon, min_cluster_size)
.fit(&distance_matrix);

// Clusters are labeled: -1 for noise, 0+ for cluster membership
assert_eq!(clusters, vec![0, 0, 1, 1, -1]);
assert_eq!(
clusters,
vec![
DbscanCluster::Cluster(1.try_into().unwrap()),
DbscanCluster::Cluster(1.try_into().unwrap()),
DbscanCluster::Cluster(2.try_into().unwrap()),
DbscanCluster::Cluster(2.try_into().unwrap()),
DbscanCluster::Noise,
]
);
}
```

Expand Down
13 changes: 11 additions & 2 deletions crates/augurs-clustering/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ A crate such as [`augurs-dtw`] must be used to calculate the distance matrix for
## Usage

```rust
use augurs::clustering::{DbscanClusterer, DistanceMatrix};
use augurs::clustering::{DbscanCluster, DbscanClusterer, DistanceMatrix};

# fn main() -> Result<(), Box<dyn std::error::Error>> {
// Start with a distance matrix.
Expand All @@ -32,7 +32,16 @@ let min_cluster_size = 2;
// Use DBSCAN to detect clusters of series.
// Note that we don't need to specify the number of clusters in advance.
let clusters = DbscanClusterer::new(epsilon, min_cluster_size).fit(&distance_matrix);
assert_eq!(clusters, vec![0, 0, 0, 1, 1]);
assert_eq!(
clusters,
vec![
DbscanCluster::Cluster(1.try_into().unwrap()),
DbscanCluster::Cluster(1.try_into().unwrap()),
DbscanCluster::Cluster(1.try_into().unwrap()),
DbscanCluster::Cluster(2.try_into().unwrap()),
DbscanCluster::Cluster(2.try_into().unwrap()),
],
);
# Ok(())
# }
```
Expand Down
139 changes: 97 additions & 42 deletions crates/augurs-clustering/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,64 @@
#![doc = include_str!("../README.md")]

use std::collections::VecDeque;
use std::{collections::VecDeque, num::NonZeroU32};

pub use augurs_core::DistanceMatrix;

/// A cluster identified by the DBSCAN algorithm.
///
/// This is either a noise cluster, or a cluster with a specific ID.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DbscanCluster {
/// A noise cluster.
Noise,
/// A cluster with the given ID.
///
/// The ID is not guaranteed to remain the same between runs of the algorithm.
///
/// We use a `NonZeroU32` here to ensure that the ID is never zero. This is mostly
/// just a size optimization.
Cluster(NonZeroU32),
}

impl DbscanCluster {
/// Returns true if this cluster is a noise cluster.
pub fn is_noise(&self) -> bool {
matches!(self, Self::Noise)
}

/// Returns true if this cluster is a cluster with the given ID.
pub fn is_cluster(&self) -> bool {
matches!(self, Self::Cluster(_))
}

/// Returns the ID of the cluster, if it is a cluster, or `-1` if it is a noise cluster.
pub fn as_i32(&self) -> i32 {
match self {
Self::Noise => -1,
Self::Cluster(id) => id.get() as i32,
}
}

fn increment(&mut self) {
match self {
Self::Noise => unreachable!(),
Self::Cluster(id) => *id = id.checked_add(1).expect("cluster ID overflow"),
}
}
}

// Simplify tests by allowing comparisons with i32.
#[cfg(test)]
impl PartialEq<i32> for DbscanCluster {
fn eq(&self, other: &i32) -> bool {
if self.is_noise() {
*other == -1
} else {
self.as_i32() == *other
}
}
}

/// DBSCAN clustering algorithm.
#[derive(Debug)]
pub struct DbscanClusterer {
Expand Down Expand Up @@ -40,11 +95,11 @@ impl DbscanClusterer {

/// Run the DBSCAN clustering algorithm.
///
/// The return value is a vector of cluster assignments, with `-1` indicating noise.
pub fn fit(&self, distance_matrix: &DistanceMatrix) -> Vec<isize> {
/// The return value is a vector of cluster assignments, with `DbscanCluster::Noise` indicating noise.
pub fn fit(&self, distance_matrix: &DistanceMatrix) -> Vec<DbscanCluster> {
let n = distance_matrix.shape().0;
let mut clusters = vec![-1; n];
let mut cluster = 0;
let mut clusters = vec![DbscanCluster::Noise; n];
let mut cluster = DbscanCluster::Cluster(NonZeroU32::new(1).unwrap());
let mut visited = vec![false; n];
let mut to_visit = VecDeque::with_capacity(n);

Expand All @@ -53,7 +108,7 @@ impl DbscanClusterer {

for (i, d) in distance_matrix.iter().enumerate() {
// Skip if already assigned to a cluster.
if clusters[i] != -1 {
if clusters[i].is_cluster() {
continue;
}
self.find_neighbours(i, d, &mut neighbours);
Expand All @@ -67,7 +122,7 @@ impl DbscanClusterer {
clusters[i] = cluster;
// Mark all noise neighbours as visited and add them to the queue.
for neighbour in neighbours.drain(..) {
if clusters[neighbour] == -1 {
if clusters[neighbour].is_noise() {
visited[neighbour] = true;
to_visit.push_back(neighbour);
}
Expand All @@ -87,7 +142,7 @@ impl DbscanClusterer {
}
}
}
cluster += 1;
cluster.increment();
}
clusters
}
Expand Down Expand Up @@ -123,19 +178,19 @@ mod test {
assert_eq!(clusters, vec![-1, -1, -1, -1]);

let clusters = DbscanClusterer::new(1.0, 2).fit(&distance_matrix);
assert_eq!(clusters, vec![0, 0, -1, -1]);
assert_eq!(clusters, vec![1, 1, -1, -1]);

let clusters = DbscanClusterer::new(1.0, 3).fit(&distance_matrix);
assert_eq!(clusters, vec![-1, -1, -1, -1]);

let clusters = DbscanClusterer::new(2.0, 2).fit(&distance_matrix);
assert_eq!(clusters, vec![0, 0, 0, -1]);
assert_eq!(clusters, vec![1, 1, 1, -1]);

let clusters = DbscanClusterer::new(2.0, 3).fit(&distance_matrix);
assert_eq!(clusters, vec![0, 0, 0, -1]);
assert_eq!(clusters, vec![1, 1, 1, -1]);

let clusters = DbscanClusterer::new(3.0, 3).fit(&distance_matrix);
assert_eq!(clusters, vec![0, 0, 0, 0]);
assert_eq!(clusters, vec![1, 1, 1, 1]);
}

#[test]
Expand All @@ -151,36 +206,36 @@ mod test {
let distance_matrix = DistanceMatrix::try_from_square(distance_matrix).unwrap();
let clusters = DbscanClusterer::new(10.0, 3).fit(&distance_matrix);
let expected = vec![
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 2, -1, 2, -1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 3, -1, 3, -1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
];
assert_eq!(clusters, expected);
}
Expand Down
15 changes: 9 additions & 6 deletions crates/augurs/tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ fn test_changepoint() {
#[cfg(feature = "clustering")]
#[test]
fn test_clustering() {
fn convert_clusters(clusters: Vec<augurs_clustering::DbscanCluster>) -> Vec<i32> {
clusters.into_iter().map(|c| c.as_i32()).collect()
}
use augurs::{clustering::DbscanClusterer, DistanceMatrix};
let distance_matrix = vec![
vec![0.0, 1.0, 2.0, 3.0],
Expand All @@ -27,22 +30,22 @@ fn test_clustering() {
];
let distance_matrix = DistanceMatrix::try_from_square(distance_matrix).unwrap();
let clusters = DbscanClusterer::new(0.5, 2).fit(&distance_matrix);
assert_eq!(clusters, vec![-1, -1, -1, -1]);
assert_eq!(convert_clusters(clusters), vec![-1, -1, -1, -1]);

let clusters = DbscanClusterer::new(1.0, 2).fit(&distance_matrix);
assert_eq!(clusters, vec![0, 0, -1, -1]);
assert_eq!(convert_clusters(clusters), vec![1, 1, -1, -1]);

let clusters = DbscanClusterer::new(1.0, 3).fit(&distance_matrix);
assert_eq!(clusters, vec![-1, -1, -1, -1]);
assert_eq!(convert_clusters(clusters), vec![-1, -1, -1, -1]);

let clusters = DbscanClusterer::new(2.0, 2).fit(&distance_matrix);
assert_eq!(clusters, vec![0, 0, 0, -1]);
assert_eq!(convert_clusters(clusters), vec![1, 1, 1, -1]);

let clusters = DbscanClusterer::new(2.0, 3).fit(&distance_matrix);
assert_eq!(clusters, vec![0, 0, 0, -1]);
assert_eq!(convert_clusters(clusters), vec![1, 1, 1, -1]);

let clusters = DbscanClusterer::new(3.0, 3).fit(&distance_matrix);
assert_eq!(clusters, vec![0, 0, 0, 0]);
assert_eq!(convert_clusters(clusters), vec![1, 1, 1, 1]);
}

#[cfg(feature = "dtw")]
Expand Down
11 changes: 9 additions & 2 deletions crates/pyaugurs/src/clustering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,15 @@ impl Dbscan {
&self,
py: Python<'_>,
distance_matrix: InputDistanceMatrix<'_>,
) -> PyResult<Py<PyArray1<isize>>> {
) -> PyResult<Py<PyArray1<i32>>> {
let distance_matrix = distance_matrix.try_into()?;
Ok(self.inner.fit(&distance_matrix).into_pyarray(py).into())
Ok(self
.inner
.fit(&distance_matrix)
.into_iter()
.map(|x| x.as_i32())
.collect::<Vec<_>>()
.into_pyarray(py)
.into())
}
}
20 changes: 16 additions & 4 deletions examples/clustering/examples/dbscan_clustering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
//! The resulting clusters are assigned a label of -1 for noise, 0 for the first cluster, and 1 for
//! the second cluster.
use augurs::{clustering::DbscanClusterer, dtw::Dtw};
use augurs::{
clustering::{DbscanCluster, DbscanClusterer},
dtw::Dtw,
};

// This is a very trivial example dataset containing 5 time series which
// form two obvious clusters, plus a noise cluster.
Expand Down Expand Up @@ -49,7 +52,16 @@ fn main() {
let min_cluster_size = 2;

// Run DBSCAN clustering on the distance matrix.
let clusters: Vec<isize> =
DbscanClusterer::new(epsilon, min_cluster_size).fit(&distance_matrix);
assert_eq!(clusters, vec![0, 0, 1, 1, -1]);
let clusters = DbscanClusterer::new(epsilon, min_cluster_size).fit(&distance_matrix);
println!("Clusters: {:?}", clusters);
assert_eq!(
clusters,
vec![
DbscanCluster::Cluster(1.try_into().unwrap()),
DbscanCluster::Cluster(1.try_into().unwrap()),
DbscanCluster::Cluster(2.try_into().unwrap()),
DbscanCluster::Cluster(2.try_into().unwrap()),
DbscanCluster::Noise,
]
);
}
9 changes: 7 additions & 2 deletions js/augurs-clustering-js/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,12 @@ impl DbscanClusterer {
/// The return value is an `Int32Array` of cluster IDs, with `-1` indicating noise.
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn fit(&self, distanceMatrix: VecVecF64) -> Result<Vec<isize>, JsError> {
Ok(self.inner.fit(&DistanceMatrix::new(distanceMatrix)?.into()))
pub fn fit(&self, distanceMatrix: VecVecF64) -> Result<Vec<i32>, JsError> {
Ok(self
.inner
.fit(&DistanceMatrix::new(distanceMatrix)?.into())
.into_iter()
.map(|x| x.as_i32())
.collect::<Vec<_>>())
}
}
6 changes: 3 additions & 3 deletions js/testpkg/clustering.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ describe('clustering', () => {
[2, 3, 0, 4],
[3, 3, 4, 0],
]);
expect(labels).toEqual(new Int32Array([0, 0, -1, -1]));
expect(labels).toEqual(new Int32Array([1, 1, -1, -1]));
});

it('can be fit with a raw distance matrix of typed arrays', () => {
Expand All @@ -33,7 +33,7 @@ describe('clustering', () => {
new Float64Array([2, 3, 0, 4]),
new Float64Array([3, 3, 4, 0]),
]);
expect(labels).toEqual(new Int32Array([0, 0, -1, -1]));
expect(labels).toEqual(new Int32Array([1, 1, -1, -1]));
});

it('can be fit with a distance matrix from augurs', () => {
Expand All @@ -46,6 +46,6 @@ describe('clustering', () => {
]);
const clusterer = new DbscanClusterer({ epsilon: 0.5, minClusterSize: 2 });
const labels = clusterer.fit(distanceMatrix);
expect(labels).toEqual(new Int32Array([0, 0, 0, -1]));
expect(labels).toEqual(new Int32Array([1, 1, 1, -1]));
})
});

0 comments on commit 867f261

Please sign in to comment.