Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Random Projection Forest Update #231

Merged
merged 4 commits into from
Jan 16, 2024
Merged

Random Projection Forest Update #231

merged 4 commits into from
Jan 16, 2024

Conversation

krstopro
Copy link
Member

@krstopro krstopro commented Jan 15, 2024

Changelog:

  • Fixed the bug inside predict/2 that assumed that query is of the same size as tensor used to grow the forest.
  • Predict now returns the indices of (approximate) k-nearest neigbors as well as distances from the query points.
  • Hyperplanes and medians are now f64.
  • Incorporated some feedback from Random Projection Forest #215 I didn't manage to before the pull request got closed. E.g. left_child and right_child are private now.
  • Fixes Bug in RandomProjectionForest due to Nx.dot with EXLA Backend #230.

One thing I am still not sure about. num_neighbors is passed as an argument to fit/2. I prefer it like this, but it also makes sense to pass it as an argument to predict/2. There is also a small inconsistency with the option name used for the number of neighbors: KNearestNeighbors uses num_neighbors, while KDTree uses k.

Krsto Proroković added 3 commits January 15, 2024 22:37
{size, dim} = Nx.shape(tensor)
num_nodes = 2 ** depth - 1

{hyperplanes, _key} =
Nx.Random.normal(key, type: type, shape: {num_trees, depth, dim})
Nx.Random.normal(key, type: :f64, shape: {num_trees, depth, dim})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should default to f64, What if we allow the type as an argument? So we do: type = opts[:type] || to_float_type(tensor)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that initializing hyperplanes and medians to f64 might not be necessary. I changed the code in predict/2 from

proj = Nx.dot(query, hyperplanes[level])

(hyperplanes were vectorized by first coordinates which corresponds to trees) to

proj = Nx.dot(hyperplanes[[.., level]], [1], query, [1]) |> Nx.vectorize(:trees)

which is kinda the same as done in fit/2. Now the projections are the same inside fit/2 and predict/2, even with f32; it seems the order of the arguments was important as well. Should have done this before, sorry about that!
It will work for constructing the k-NN graph, i.e. when query = tensor, where tensor is used to grow the forest. Might not work for arbitrary query (i.e. there might be some issues with projections as described in #230). Perhaps adding a note in the docs might be a good idea?

Copy link
Contributor

@josevalim josevalim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beautiful, just one tiny comment about the type.

@josevalim josevalim merged commit db9efd7 into elixir-nx:main Jan 16, 2024
2 checks passed
@josevalim
Copy link
Contributor

More docs are always welcome too! 💚 💙 💜 💛 ❤️

@msluszniak
Copy link
Contributor

I didn't make it on time but great job!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Bug in RandomProjectionForest due to Nx.dot with EXLA Backend
3 participants