Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hierarchical Clustering #187

Merged
merged 61 commits into from
Nov 28, 2023

Conversation

billylanchantin
Copy link
Contributor

@billylanchantin billylanchantin commented Oct 10, 2023

Description

This PR adds hierarchical clustering. It's one of the items with a help wanted label, so I thought I'd give it a go :)

Background

Hierarchical clustering is a family of unsupervised learning methods that build clusters in a greedy fashion, either bottom-up (agglomerative) or top-down (divisive). If you're unfamiliar, I think the following tutorial is probably approachable:

https://www.kdnuggets.com/2019/09/hierarchical-clustering.html

This PR adds support for agglomerative clustering, which seems to be the most common:

Approach

The plan is to follow the approach outlined in the paper Modern hierarchical, agglomerative clustering algorithms by Daniel Müllner. It's still state of the art AFAIK, and it serves the basis for Scipy's and Scikit-learn's hierarchical clustering (though both of those have extra bells and whistles).

The core of hierarchical clustering, which the Müllner paper denotes the "primitive" algorithm, is this:

  1. Start by assuming each data point is in a cluster by itself.
  2. Build a (square) pairwise dissimilarity matrix from the data points.
  3. Choose from among 7 dissimilarity update functions.
  4. Use the update function to
    a. group the closest pair of clusters into a new cluster and
    b. recalculate the distances from the new cluster to the rest of the clusters
  5. Repeat step 4 until there is a single cluster containing all data points.

The primitive algorithm is best case O(n^3). The paper's main contribution is to describe 3 optimizations to improve the complexity. Which optimization to use depends on which of the 7 update functions was chosen.

So far, I've only implemented the primitive algorithm, which isn't to be used in practice. If what I've described so far is the direction we want to take, I can start working on the 3 optimizations.

To-do

  • Dissimilarity functions
    • Euclidean
    • Pre-computed
      • Future work
  • Dissimilarity update functions
    • Average
    • Centroid
      • Future work
    • Complete
    • Median
      • Future work
    • Single
    • Ward
    • Weighted
  • Cluster functions
    • "parallel-nearest-neighbor"
    • "primitive" (not to be used in practice)
      • (Available in a commit)
    • "generic"
      • Future work
    • "minimum-spanning-tree"
    • "nearest-neighbor-chain"
      • (Available in a commit)

Livebook demo

I'm including a livebook demo since the dendrogram visualization is very useful for understanding hierarchical clusters.

Screen Shot 2023-11-20 at 5 51 51 PM

@msluszniak
Copy link
Contributor

Hi Billy, thank you for the PR. As far as I scanned through your PR you mostly use def functions. It is very important that all the functionalities are implemented with defn. It gives orders of magnitude improvements to performance. The other reason is that only other defns and deftransforms can be used inside other defn.

@billylanchantin
Copy link
Contributor Author

As far as I scanned through your PR you mostly use def functions. It is very important that all the functionalities are implemented with defn.

@msluszniak Will do! I mostly wanted a gut check on the direction before I started to "tensor-ify" the operations.

On that note, I may need some help as I get into the optimizations with making them Nx-friendly. But I'll call out any parts I get stuck on in the PR.

@josevalim
Copy link
Contributor

@billylanchantin issue #189 contains a good discussion on the challenges to optimize an algorithm to run on Nx/defn. However, for HCA, it seems there exists an algorithm already for the GPU here: https://citeseerx.ist.psu.edu/document?repid=rep1&type=pdf&doi=94392889916911b13e8a40e9cffc1d5924e93fc2

I am not familiar with HCA but I believe the computation of the distance matrix is straight-forward and already done. So you need to implement the linkage algorithm in section 2.4. Step 1 seems to be an argsort (you may need to initialize the distance matrix with infinity/neginfinity to get the proper indexes). The other steps (2-5) would be done in a while loop, so not that parallelizable, but at least should work in Nx. You could even do them in Elixir and we can help convert to Nx later. The important is to rely only on tensors (no maps, etc).

@billylanchantin
Copy link
Contributor Author

issue #189 contains a good discussion on the challenges to optimize an algorithm to run on Nx/defn.

Oh that looks like a great resource, I'll give it a read!

However, for HCA, it seems there exists an algorithm already for the GPU here: https://citeseerx.ist.psu.edu/document?repid=rep1&type=pdf&doi=94392889916911b13e8a40e9cffc1d5924e93fc2

I wasn't aware of this one. I'll take a look at it as well.

It does appear to predate the work I was planning to implement. But it may be worth having an implementation that's less featureful, but is more GPU-friendly. I'll have to see how they compare.

I am not familiar with HCA but I believe the computation of the distance matrix is straight-forward and already done.

Yeah unfortunately the computation of the distance function is the "easy" part. It's the various while loops that are the challenge.

At each iteration of cluster analysis, we merge two clusters. This decreases number of clusters by one (which changes shape of the distance matrix) and alters all distances involving the two merged clusters. It's a lot of bookkeeping.

Plus in the Müeller paper, he's using not-necessarily-tensor-friendly things like priority queues to speed up the lookups.

@josevalim
Copy link
Contributor

josevalim commented Nov 5, 2023

@billylanchantin I may be very wrong, but inspired on the paper, I wrote a version of linking for minimum HCA in a Livebook. The linking is parallel but log2(n) so it seems better than the paper: https://gist.github.com/josevalim/00990eb921bc1f767e8870798b5e786f

You can see it is recursive and it requires ceiling(:math.log2(size+1)) operations (or 32 - Nx.count_leading_zeros(Nx.u32(size)) if you prefer integer arithmetic). I can convert it to a proper defn tomorrow if you don't beat me to it. The most complex part will be the logic to update the pairwise matrix but maybe the existing papers have better ideas.

@josevalim
Copy link
Contributor

The approach above assumes binary trees. So only two elements at a time. My algorithm requires stable sorting. So we need to add a TODO because the default will change on Nx 0.7.

@billylanchantin
Copy link
Contributor Author

@josevalim See my comment on your gist. I may have been mistaken, though. I'll take a closer look tomorrow!

@josevalim
Copy link
Contributor

josevalim commented Nov 6, 2023

Oh, I thought the simple did not update the matrix, but you are right. I will work on it. In that case, it is worth noting that my algorithm will only work if points do not get closer to clusters as clusters are merged. According to this paper, this is called the reducibility property (section 2.2) and luckily it applies to single, average, complete, and minimum variance.


For the other metrics, we will need to implement the paper I linked earlier. The important thing to highlight is that argsort/argmin is your priority queue and if you access their results, you get their values. The algorithm is roughly like this:

  1. Make the pairwise matrix an upper triangular matrix where the diagonal and the lower part are made of infinity
  2. Compute the argmin of the matrix, that will give the closest point to any point i
  3. Take the values from the argmin matrix, these are the minimal_distances
  4. argmin the minimal distance, the first element is the closest cluster that you have
  5. Set the value of this cluster to infinity in the minimal_distance, update the pairwise matrix, and then minimal_distances, go back to step 2 (n-1 times)

There is an optimization in that we don't need to argmin again if the next entry in minimal distances is smaller than the updates to the pairwise matrix, this means we can skip steps 2, 3, and 4. This is probably the priority queue optimization done by other papers.

This algorithm should work on all metrics, it is just slower, so the good news is that once we implement one version, we can test and benchmark against it.

@josevalim
Copy link
Contributor

I am delayed a bit because we are adding new functions to Nx to make updating symmetrical matrices efficient.

@billylanchantin
Copy link
Contributor Author

No rush! I've been finding it hard to get the long blocks of time I need for this PR anyhow.

@josevalim
Copy link
Contributor

Oh, I thought the simple did not update the matrix, but you are right.

I have implemented the full algorithm for simple and I can confirm that we don't need to update the matrix. The simple version always updates the distances to the minimum distance of all points, which we already get argsort. So if we put my previous gist in a loop, it will be enough to compute the single-link minimal version. :)

I will provide a gist with the full version of other algorithms later. If you want, I can also provide the loop version for minimal.

@billylanchantin
Copy link
Contributor Author

I will provide a gist with the full version of other algorithms later. If you want, I can also provide the loop version for minimal.

That'd be fantastic! Thanks for looking into this :)

(And sorry about the delay. It was actually the impetus for the Nx exercises PR: I kept getting stuck so I decided I needed some practice.)

And I admit I'm curious how this works. Without updated distances, I don't think the dendrogram object will have the right tree heights. But your gist will answer that!

@josevalim
Copy link
Contributor

@billylanchantin you need the distances between clusters as well? I can work on that. :)

@billylanchantin
Copy link
Contributor Author

@josevalim Yeah there are two goals with HC:

  1. The clusters themselves
  2. The dendrogram visualization

For the clusters themselves, you don't necessarily need to maintain the distances. But for the dendrogram, you need them to build the heights of the tree.

One of the advantages of HC is that you can use the dendrogram to help you choose post-hoc how many clusters you want. (I can explain this further if requested 😄)

@josevalim
Copy link
Contributor

@billylanchantin here is the complete code for single links:

https://gist.github.com/josevalim/00990eb921bc1f767e8870798b5e786f

I can also make it work for average, complete, and minimum variance. Your goal would be to traverse the links and distances and build the clusters. I can optimize the code a bit which would require a tiny modification to the output, but it should be straight-forward to adapt a bit. So with the above, we should have everything in hand for a MVP of hierarchical clustering.

For centroids and median, we will need the second half of this comment, which will also have a different output for links and distances (most likely two {n} sized vectors). So we can consider it something completely separate.

@billylanchantin
Copy link
Contributor Author

billylanchantin commented Nov 26, 2023

@josevalim @msluszniak I need access to the pairwise functions before I can drop the euclidean functions. I was gonna merge main, but José said rebase. Which do y'all prefer?

@josevalim
Copy link
Contributor

Any of rebase or merge is fine. Whatever you prefer!

@billylanchantin
Copy link
Contributor Author

billylanchantin commented Nov 27, 2023

Ok I think we're basically good to go!

One last thing: I'm a bit worried about the complexity. I think it's actually $O(\frac{n^2}{p} \cdot \log(n))$. But I'm not confident.

I'd prefer no docs over incorrect docs. Unless anyone can say definitively what the complexity is, I want to drop the complexity discussion from the moduledoc.

Updated: 9a2db56

lib/scholar/cluster/hierarchical.ex Outdated Show resolved Hide resolved
@msluszniak
Copy link
Contributor

Ok I think we're basically good to go!

One last thing: I'm a bit worried about the complexity. I think it's actually O(n2p⋅log⁡(n)). But I'm not confident.

I'd prefer no docs over incorrect docs. Unless anyone can say definitively what the complexity is, I want to drop the complexity discussion from the moduledoc.

Updated: 9a2db56

Yeah, the time complexity seems legit to me ;)

@josevalim josevalim merged commit 8826219 into elixir-nx:main Nov 28, 2023
2 checks passed
@josevalim
Copy link
Contributor

Beautifully done!

@josevalim
Copy link
Contributor

💚 💙 💜 💛 ❤️

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants