Skip to content

Commit

Permalink
Banded works
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim committed Nov 5, 2023
1 parent 24aca45 commit fd4d1d9
Showing 1 changed file with 8 additions and 23 deletions.
31 changes: 8 additions & 23 deletions lib/scholar/neighbors/kd_tree.ex
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ defmodule Scholar.Neighbors.KDTree do
import Nx.Defn

# TODO: Benchmark
# TODO: Add tagged/amplitude version

@derive {Nx.Container, keep: [:levels], containers: [:indexes]}
@enforce_keys [:levels, :indexes]
Expand Down Expand Up @@ -181,31 +180,17 @@ defmodule Scholar.Neighbors.KDTree do
shifted - 1 + min(lowest_level, shifted)
end

defn banded_segment_begin(t, levels, size) do
while t, j <- 0..(size - 1) do
s = t[j]
i = (1 <<< banded_level(s)) - 1

{_, _, acc} =
while {i, s, acc = i}, i + 1 <= s do
{i + 1, s, acc + banded_subtree_size(i, levels, size)}
end
defn banded_segment_begin(i, levels, size) do
level = banded_level(i)
top = (1 <<< level) - 1
diff = levels - level - 1
shifted = 1 <<< diff
left_siblings = i - top

Nx.put_slice(t, [j], Nx.stack(acc))
end
top + left_siblings * (shifted - 1) +
min(left_siblings * shifted, size - (1 <<< (levels - 1)) + 1)
end

# defn banded_segment_begin(i, levels, size) do
# level = banded_level(i)
# top = (1 <<< level) - 1
# diff = levels - level - 1
# shifted = 1 <<< diff
# left_siblings = i - top

# top + left_siblings * (shifted - 1) +
# min(left_siblings * shifted, size - (1 <<< (levels - 1)) - 1)
# end

# Since this property relies on u32, let's check the tensor type.
deftransformp banded_level(%Nx.Tensor{type: {:u, 32}} = i) do
Nx.subtract(31, Nx.count_leading_zeros(Nx.add(i, 1)))
Expand Down

0 comments on commit fd4d1d9

Please sign in to comment.