Skip to content

Commit

Permalink
PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim committed Nov 6, 2023
1 parent ba86a27 commit 81337bd
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion lib/scholar/neighbors/kd_tree.ex
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ defmodule Scholar.Neighbors.KDTree do
## Examples
iex> Scholar.Neighbors.KDTree.unbound(Nx.iota({5, 2}), compiler: EXLA.Defn)
iex> Scholar.Neighbors.KDTree.unbound(Nx.iota({5, 2}), compiler: EXLA)
%Scholar.Neighbors.KDTree{
data: Nx.iota({5, 2}),
levels: 3,
Expand Down
8 changes: 4 additions & 4 deletions test/scholar/neighbors/kd_tree_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,28 @@ defmodule Scholar.Neighbors.KDTreeTest do
describe "unbound" do
test "sample" do
assert %Scholar.Neighbors.KDTree{levels: 4, indexes: indexes} =
Scholar.Neighbors.KDTree.unbound(example(), compiler: EXLA.Defn)
Scholar.Neighbors.KDTree.unbound(example(), compiler: EXLA)

assert Nx.to_flat_list(indexes) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4]
end

test "float" do
assert %Scholar.Neighbors.KDTree{levels: 4, indexes: indexes} =
Scholar.Neighbors.KDTree.unbound(example() |> Nx.as_type(:f32),
compiler: EXLA.Defn
compiler: EXLA
)

assert Nx.to_flat_list(indexes) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4]
end

test "corner cases" do
assert %Scholar.Neighbors.KDTree{levels: 1, indexes: indexes} =
Scholar.Neighbors.KDTree.unbound(Nx.iota({1, 2}), compiler: EXLA.Defn)
Scholar.Neighbors.KDTree.unbound(Nx.iota({1, 2}), compiler: EXLA)

assert indexes == Nx.u32([0])

assert %Scholar.Neighbors.KDTree{levels: 2, indexes: indexes} =
Scholar.Neighbors.KDTree.unbound(Nx.iota({2, 2}), compiler: EXLA.Defn)
Scholar.Neighbors.KDTree.unbound(Nx.iota({2, 2}), compiler: EXLA)

assert indexes == Nx.u32([1, 0])
end
Expand Down

0 comments on commit 81337bd

Please sign in to comment.