Skip to content

Commit

Permalink
Add medoid centers (#35)
Browse files Browse the repository at this point in the history
* Add Medoid center calculation

* Added medoids + release

* update doc string

* trailing whitespace
  • Loading branch information
tom-whitehead authored Nov 13, 2024
1 parent c8d36c9 commit 35b51a2
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 12 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# Version 0.9.0 2024-11-13
## Changes
- Added `Center::Medoid` as a means of calculating the center of clusters. The medoid is the point in a cluster with
the minimum distance to all other points. Computationally more expensive than centroids as requires calculation of
pairwise distances (using the selected distance metric). The output will be an observed data point in the cluster.

# Version 0.8.3 2024-10-15
## Changes
- Fix for a bug that occurred when `allow_single_cluster` was set to true and the root cluster is the only one
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "hdbscan"
version = "0.8.3"
version = "0.9.0"
edition = "2021"
authors = [ "Tom Whitehead <[email protected]>", ]
description = "HDBSCAN clustering in pure Rust. A huge improvement on DBSCAN, capable of identifying clusters of varying densities."
Expand Down
58 changes: 55 additions & 3 deletions src/centers.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use num_traits::Float;
use std::cmp::Ordering;
use std::collections::HashSet;

/// Possible methodologies for calculating the center of clusters
Expand All @@ -7,22 +8,32 @@ pub enum Center {
/// The elementwise mean of all data points in a cluster.
/// The output is not guaranteed to be an observed data point.
Centroid,
/// Calculates the geographical centeroid for lat/lon coordinates.
/// Calculates the geographical centroid for lat/lon coordinates.
/// Assumes input coordinates are in degrees (latitude, longitude).
/// Output coordinates are also in degrees.
GeoCentroid,
/// The point in a cluster with the minimum distance to all other points. Computationally more
/// expensive than centroids as requires calculation of pairwise distances (using the selected
/// distance metric). The output will be an observed data point in the cluster.
Medoid,
}

impl Center {
pub(crate) fn calc_centers<T: Float>(&self, data: &[Vec<T>], labels: &[i32]) -> Vec<Vec<T>> {
pub(crate) fn calc_centers<T: Float, F: Fn(&[T], &[T]) -> T>(
&self,
data: &[Vec<T>],
labels: &[i32],
dist_func: F,
) -> Vec<Vec<T>> {
match self {
Center::Centroid => self.calc_centroids(data, labels),
Center::GeoCentroid => self.calc_geo_centroids(data, labels),
Center::Medoid => self.calc_medoids(data, labels, dist_func),
}
}

fn calc_centroids<T: Float>(&self, data: &[Vec<T>], labels: &[i32]) -> Vec<Vec<T>> {
// All points weighted equally
// All points weighted equally for now
let weights = vec![T::one(); data.len()];
Center::calc_weighted_centroids(data, labels, &weights)
}
Expand Down Expand Up @@ -121,4 +132,45 @@ impl Center {

centers
}

fn calc_medoids<T: Float, F: Fn(&[T], &[T]) -> T>(
&self,
data: &[Vec<T>],
labels: &[i32],
dist_func: F,
) -> Vec<Vec<T>> {
let n_clusters = labels
.iter()
.filter(|&&label| label != -1)
.collect::<HashSet<_>>()
.len();
let mut medoids = Vec::with_capacity(n_clusters);

for cluster_id in 0..n_clusters as i32 {
let cluster_data = data
.iter()
.zip(labels.iter())
.filter(|(_datapoint, &label)| label == cluster_id)
.map(|(datapoint, _label)| datapoint)
.collect::<Vec<&Vec<_>>>();

let n_samples = cluster_data.len();
let medoid_idx = (0..n_samples)
.map(|i| {
(0..n_samples)
.map(|j| dist_func(cluster_data[i], cluster_data[j]))
.fold(T::zero(), std::ops::Add::add)
})
.enumerate()
.min_by(|(_idx_a, sum_a), (_idx_b, sum_b)| {
sum_a.partial_cmp(sum_b).unwrap_or(Ordering::Equal)
})
.map(|(idx, _sum)| idx)
.unwrap_or(0);

medoids.push(cluster_data[medoid_idx].clone())
}

medoids
}
}
51 changes: 43 additions & 8 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,18 +218,14 @@ impl<'a, T: Float> Hdbscan<'a, T> {
/// Calculates the centers of the clusters just calculate.
///
/// # Parameters
/// * `center` - the type of center to calculate. Currently only centroid (the element wise mean
/// of all the data points in a cluster) is supported.
/// * `center` - the type of center to calculate.
/// * `labels` - a reference to the labels calculated by a call to `Hdbscan::cluster`.
///
/// # Returns
/// * A vector of the cluster centers, of shape num clusters by num dimensions/features. The
/// index of the centroid is the cluster label. For example, the centroid cluster of label 0
/// will be the first centroid in the vector of centroids.
///
/// # Panics
/// * If the labels are of different length to the data passed to the `Hdbscan` constructor
///
/// # Examples
/// ```
///use hdbscan::{Center, Hdbscan};
Expand Down Expand Up @@ -258,14 +254,22 @@ impl<'a, T: Float> Hdbscan<'a, T> {
center: Center,
labels: &[i32],
) -> Result<Vec<Vec<T>>, HdbscanError> {
assert_eq!(labels.len(), self.data.len());
if labels.len() != self.data.len() {
return Err(HdbscanError::WrongDimension(String::from(
"The length of the labels must equal the length of the original clustering data.",
)));
}
if self.hp.dist_metric != DistanceMetric::Haversine && center == Center::GeoCentroid {
// TODO: Implement a more appropriate error variant when doing a major version bump
return Err(HdbscanError::WrongDimension(String::from(
"Geographical centroids can only be used with geographical coordinates.",
)));
}
Ok(center.calc_centers(self.data, labels))
Ok(center.calc_centers(
self.data,
labels,
distance::get_dist_func(&self.hp.dist_metric),
))
}

fn validate_input_data(&self) -> Result<(), HdbscanError> {
Expand Down Expand Up @@ -1101,7 +1105,7 @@ mod tests {
}

#[test]
fn calc_centers() {
fn calc_centroids() {
let data = cluster_test_data();
let clusterer = Hdbscan::default_hyper_params(&data);
let labels = clusterer.cluster().unwrap();
Expand All @@ -1110,6 +1114,37 @@ mod tests {
assert!(centroids.contains(&vec![3.8, 4.0]) && centroids.contains(&vec![1.12, 1.34]));
}

#[test]
fn calc_medoids() {
let data: Vec<Vec<f32>> = vec![
vec![1.3, 1.2],
vec![1.2, 1.3],
vec![1.5, 1.5],
vec![1.6, 1.7],
vec![1.7, 1.6],
vec![6.3, 6.2],
vec![6.2, 6.3],
vec![6.5, 6.5],
vec![6.6, 6.7],
vec![6.7, 6.6],
];
let clusterer = Hdbscan::default_hyper_params(&data);
let result = clusterer.cluster().unwrap();
let centers = clusterer.calc_centers(Center::Medoid, &result).unwrap();

let unique_clusters = result
.iter()
.filter(|&&label| label != -1)
.collect::<HashSet<_>>();
assert_eq!(centers.len(), unique_clusters.len());

centers
.iter()
.for_each(|center| assert!(data.contains(center)));
assert_eq!(vec![1.5, 1.5], centers[0]);
assert_eq!(vec![6.5, 6.5], centers[1]);
}

fn cluster_test_data() -> Vec<Vec<f32>> {
vec![
vec![1.5, 2.2],
Expand Down

0 comments on commit 35b51a2

Please sign in to comment.