-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hierarchical Clustering #187
Hierarchical Clustering #187
Conversation
Hi Billy, thank you for the PR. As far as I scanned through your PR you mostly use |
@msluszniak Will do! I mostly wanted a gut check on the direction before I started to "tensor-ify" the operations. On that note, I may need some help as I get into the optimizations with making them |
@billylanchantin issue #189 contains a good discussion on the challenges to optimize an algorithm to run on Nx/defn. However, for HCA, it seems there exists an algorithm already for the GPU here: https://citeseerx.ist.psu.edu/document?repid=rep1&type=pdf&doi=94392889916911b13e8a40e9cffc1d5924e93fc2 I am not familiar with HCA but I believe the computation of the distance matrix is straight-forward and already done. So you need to implement the linkage algorithm in section 2.4. Step 1 seems to be an argsort (you may need to initialize the distance matrix with infinity/neginfinity to get the proper indexes). The other steps (2-5) would be done in a while loop, so not that parallelizable, but at least should work in Nx. You could even do them in Elixir and we can help convert to Nx later. The important is to rely only on tensors (no maps, etc). |
Oh that looks like a great resource, I'll give it a read!
I wasn't aware of this one. I'll take a look at it as well. It does appear to predate the work I was planning to implement. But it may be worth having an implementation that's less featureful, but is more GPU-friendly. I'll have to see how they compare.
Yeah unfortunately the computation of the distance function is the "easy" part. It's the various while loops that are the challenge. At each iteration of cluster analysis, we merge two clusters. This decreases number of clusters by one (which changes shape of the distance matrix) and alters all distances involving the two merged clusters. It's a lot of bookkeeping. Plus in the Müeller paper, he's using not-necessarily-tensor-friendly things like priority queues to speed up the lookups. |
@billylanchantin I may be very wrong, but inspired on the paper, I wrote a version of linking for minimum HCA in a Livebook. The linking is parallel but You can see it is recursive and it requires |
The approach above assumes binary trees. So only two elements at a time. My algorithm requires stable sorting. So we need to add a TODO because the default will change on Nx 0.7. |
@josevalim See my comment on your gist. I may have been mistaken, though. I'll take a closer look tomorrow! |
Oh, I thought the simple did not update the matrix, but you are right. I will work on it. In that case, it is worth noting that my algorithm will only work if points do not get closer to clusters as clusters are merged. According to this paper, this is called the reducibility property (section 2.2) and luckily it applies to single, average, complete, and minimum variance. For the other metrics, we will need to implement the paper I linked earlier. The important thing to highlight is that
There is an optimization in that we don't need to argmin again if the next entry in minimal distances is smaller than the updates to the pairwise matrix, this means we can skip steps 2, 3, and 4. This is probably the priority queue optimization done by other papers. This algorithm should work on all metrics, it is just slower, so the good news is that once we implement one version, we can test and benchmark against it. |
I am delayed a bit because we are adding new functions to Nx to make updating symmetrical matrices efficient. |
No rush! I've been finding it hard to get the long blocks of time I need for this PR anyhow. |
I have implemented the full algorithm for simple and I can confirm that we don't need to update the matrix. The simple version always updates the distances to the minimum distance of all points, which we already get argsort. So if we put my previous gist in a loop, it will be enough to compute the single-link minimal version. :) I will provide a gist with the full version of other algorithms later. If you want, I can also provide the loop version for minimal. |
That'd be fantastic! Thanks for looking into this :) (And sorry about the delay. It was actually the impetus for the Nx exercises PR: I kept getting stuck so I decided I needed some practice.) And I admit I'm curious how this works. Without updated distances, I don't think the dendrogram object will have the right tree heights. But your gist will answer that! |
@billylanchantin you need the distances between clusters as well? I can work on that. :) |
@josevalim Yeah there are two goals with HC:
For the clusters themselves, you don't necessarily need to maintain the distances. But for the dendrogram, you need them to build the heights of the tree. One of the advantages of HC is that you can use the dendrogram to help you choose post-hoc how many clusters you want. (I can explain this further if requested 😄) |
@billylanchantin here is the complete code for single links: https://gist.github.com/josevalim/00990eb921bc1f767e8870798b5e786f I can also make it work for average, complete, and minimum variance. Your goal would be to traverse the links and distances and build the clusters. I can optimize the code a bit which would require a tiny modification to the output, but it should be straight-forward to adapt a bit. So with the above, we should have everything in hand for a MVP of hierarchical clustering. For centroids and median, we will need the second half of this comment, which will also have a different output for links and distances (most likely two |
@josevalim @msluszniak I need access to the pairwise functions before I can drop the |
Any of rebase or merge is fine. Whatever you prefer! |
Ok I think we're basically good to go! One last thing: I'm a bit worried about the complexity. I think it's actually
Updated: 9a2db56 |
Co-authored-by: José Valim <[email protected]>
Co-authored-by: Mateusz Sluszniak <[email protected]>
Yeah, the time complexity seems legit to me ;) |
Co-authored-by: José Valim <[email protected]>
Beautifully done! |
💚 💙 💜 💛 ❤️ |
Description
This PR adds hierarchical clustering. It's one of the items with a
help wanted
label, so I thought I'd give it a go :)Background
Hierarchical clustering is a family of unsupervised learning methods that build clusters in a greedy fashion, either bottom-up (agglomerative) or top-down (divisive). If you're unfamiliar, I think the following tutorial is probably approachable:
https://www.kdnuggets.com/2019/09/hierarchical-clustering.html
This PR adds support for agglomerative clustering, which seems to be the most common:
Approach
The plan is to follow the approach outlined in the paper Modern hierarchical, agglomerative clustering algorithms by Daniel Müllner. It's still state of the art AFAIK, and it serves the basis for Scipy's and Scikit-learn's hierarchical clustering (though both of those have extra bells and whistles).
The core of hierarchical clustering, which the Müllner paper denotes the "primitive" algorithm, is this:
a. group the closest pair of clusters into a new cluster and
b. recalculate the distances from the new cluster to the rest of the clusters
The primitive algorithm is best case O(n^3). The paper's main contribution is to describe 3 optimizations to improve the complexity. Which optimization to use depends on which of the 7 update functions was chosen.
So far, I've only implemented the primitive algorithm, which isn't to be used in practice. If what I've described so far is the direction we want to take, I can start working on the 3 optimizations.
To-do
Pre-computedCentroidMedian"primitive" (not to be used in practice)"generic""minimum-spanning-tree""nearest-neighbor-chain"Livebook demo
I'm including a livebook demo since the dendrogram visualization is very useful for understanding hierarchical clusters.