From fd4d1d943adbb119a0e9efe5e477874d8e9c21f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Sun, 5 Nov 2023 17:31:29 +0100 Subject: [PATCH] Banded works --- lib/scholar/neighbors/kd_tree.ex | 31 ++++++++----------------------- 1 file changed, 8 insertions(+), 23 deletions(-) diff --git a/lib/scholar/neighbors/kd_tree.ex b/lib/scholar/neighbors/kd_tree.ex index d53ee1ff..01e2202d 100644 --- a/lib/scholar/neighbors/kd_tree.ex +++ b/lib/scholar/neighbors/kd_tree.ex @@ -26,7 +26,6 @@ defmodule Scholar.Neighbors.KDTree do import Nx.Defn # TODO: Benchmark - # TODO: Add tagged/amplitude version @derive {Nx.Container, keep: [:levels], containers: [:indexes]} @enforce_keys [:levels, :indexes] @@ -181,31 +180,17 @@ defmodule Scholar.Neighbors.KDTree do shifted - 1 + min(lowest_level, shifted) end - defn banded_segment_begin(t, levels, size) do - while t, j <- 0..(size - 1) do - s = t[j] - i = (1 <<< banded_level(s)) - 1 - - {_, _, acc} = - while {i, s, acc = i}, i + 1 <= s do - {i + 1, s, acc + banded_subtree_size(i, levels, size)} - end + defn banded_segment_begin(i, levels, size) do + level = banded_level(i) + top = (1 <<< level) - 1 + diff = levels - level - 1 + shifted = 1 <<< diff + left_siblings = i - top - Nx.put_slice(t, [j], Nx.stack(acc)) - end + top + left_siblings * (shifted - 1) + + min(left_siblings * shifted, size - (1 <<< (levels - 1)) + 1) end - # defn banded_segment_begin(i, levels, size) do - # level = banded_level(i) - # top = (1 <<< level) - 1 - # diff = levels - level - 1 - # shifted = 1 <<< diff - # left_siblings = i - top - - # top + left_siblings * (shifted - 1) + - # min(left_siblings * shifted, size - (1 <<< (levels - 1)) - 1) - # end - # Since this property relies on u32, let's check the tensor type. deftransformp banded_level(%Nx.Tensor{type: {:u, 32}} = i) do Nx.subtract(31, Nx.count_leading_zeros(Nx.add(i, 1)))