Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug in RandomProjectionForest due to Nx.dot with EXLA Backend #230

Closed
krstopro opened this issue Jan 10, 2024 · 5 comments · Fixed by #231
Closed

Bug in RandomProjectionForest due to Nx.dot with EXLA Backend #230

krstopro opened this issue Jan 10, 2024 · 5 comments · Fixed by #231

Comments

@krstopro
Copy link
Member

krstopro commented Jan 10, 2024

Currently, there is a bug in RandomProjectionForest that I am not exactly sure how to resolve. Namely, when querying the tree, the projection value might not be the same as when fitting the tree. This happens due to inner workings of XLA, as mentioned here. For example, if we have

Nx.global_default_backend(EXLA.Backend)
key = Nx.Random.key(12)
{tensor, key} = Nx.Random.uniform(key, shape: {10, 7})
forest = Scholar.Neighbors.RandomProjectionForest.fit(tensor, num_neighbors: 1, num_trees: 1, key: key)

then doing

Nx.dot(forest.hyperplanes[[.., 1]], [1], tensor, [1])[0][3]

gives

#Nx.Tensor<
  f32
  EXLA.Backend<host:0, 0.1762387501.1096941588.214884>
  1.3310179710388184
>

On the other hand, doing

query = Nx.new_axis(tensor[3], 0)
Nx.dot(forest.hyperplanes[[.., 1]], [1], query, [1])[0][0]

gives

#Nx.Tensor<
  f32
  EXLA.Backend<host:0, 0.1762387501.1096941588.214899>
  1.331018090248108
>

Also, forest.medians[0][1] gives

#Nx.Tensor<
  f32
  EXLA.Backend<host:0, 0.1762387501.1096941588.214903>
  1.3310179710388184
>

which is the same as the first result.
This is a problem when comparing projections to medians, as done in the module.

One way to resolve this could be to use :f64 for both hyperplanes and medians, albeit more memory expensive.

Any thoughts on this?

@krstopro krstopro changed the title Bug in RandomProjectionForest due to Bug in RandomProjectionForest due to Nx.dot with EXLA Backend Jan 10, 2024
@josevalim
Copy link
Contributor

Keep in mind that Nx shows the wrong representation for f32. There is an open issue about that. I am not sure if it affects this, but it is something to keep in mind. Other than that, I am not sure how much we should sweat about this. Perhaps add a warning that we recommend f64 input tensor for better precision and we just make sure to preserve the input precision?

@polvalente
Copy link
Contributor

I wouldn't call this a bug. Those numbers are basically the same within 1.0e-5 precision. This is most likely due to some difference in the underlying calculation chosen by the EXLA compiler.

Try comparing calculations in the CPU and GPU, for instance.

@krstopro
Copy link
Member Author

krstopro commented Jan 10, 2024

I wouldn't call this a bug. Those numbers are basically the same within 1.0e-5 precision. This is most likely due to some difference in the underlying calculation chosen by the EXLA compiler.

Try comparing calculations in the CPU and GPU, for instance.

By bug, I mean bug in RandomProjectionForest, not Nx.dot.

@polvalente
Copy link
Contributor

I wouldn't call this a bug. Those numbers are basically the same within 1.0e-5 precision. This is most likely due to some difference in the underlying calculation chosen by the EXLA compiler.

Try comparing calculations in the CPU and GPU, for instance.

By bug, I mean bug in RandomProjectionForest, not Nx.dot.

I would still compare values. Because Nx.dot can change the return value slightly depending on the shape and the device

@krstopro
Copy link
Member Author

krstopro commented Jan 10, 2024

@josevalim

Other than that, I am not sure how much we should sweat about this.

The problem is, when constructing the k-NN graph I want x[i] to end up in the leaf with x[i] (itself). If the projection of x[i] at some level happens to be one of the medians, then we might get that the projection when querying the tree is not equal to the median, but larger. The point will then be routed to the right subtree instead of left and I don't like it.

Switching to :f64 could resolve this. Adding a small number to medians could be another solution, though I don't like it as the data itself can be small as well.
Perhaps another solution is adding the edge from point to itself when constructing the graph.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants