-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Random Projection Forest Update #231
Conversation
{size, dim} = Nx.shape(tensor) | ||
num_nodes = 2 ** depth - 1 | ||
|
||
{hyperplanes, _key} = | ||
Nx.Random.normal(key, type: type, shape: {num_trees, depth, dim}) | ||
Nx.Random.normal(key, type: :f64, shape: {num_trees, depth, dim}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should default to f64, What if we allow the type as an argument? So we do: type = opts[:type] || to_float_type(tensor)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that initializing hyperplanes and medians to f64
might not be necessary. I changed the code in predict/2
from
proj = Nx.dot(query, hyperplanes[level])
(hyperplanes were vectorized by first coordinates which corresponds to trees) to
proj = Nx.dot(hyperplanes[[.., level]], [1], query, [1]) |> Nx.vectorize(:trees)
which is kinda the same as done in fit/2
. Now the projections are the same inside fit/2
and predict/2
, even with f32
; it seems the order of the arguments was important as well. Should have done this before, sorry about that!
It will work for constructing the k-NN graph, i.e. when query = tensor
, where tensor
is used to grow the forest. Might not work for arbitrary query (i.e. there might be some issues with projections as described in #230). Perhaps adding a note in the docs might be a good idea?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Beautiful, just one tiny comment about the type.
More docs are always welcome too! 💚 💙 💜 💛 ❤️ |
I didn't make it on time but great job! |
Changelog:
predict/2
that assumed that query is of the same size as tensor used to grow the forest.Hyperplanes and medians are nowf64
.left_child
andright_child
are private now.One thing I am still not sure about.
num_neighbors
is passed as an argument tofit/2
. I prefer it like this, but it also makes sense to pass it as an argument topredict/2
. There is also a small inconsistency with the option name used for the number of neighbors: KNearestNeighbors uses num_neighbors, while KDTree uses k.