Skip to content

Commit

Permalink
band -> bound, clz -> log
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim committed Nov 6, 2023
1 parent 8a4ee41 commit ba86a27
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 65 deletions.
4 changes: 2 additions & 2 deletions benchmarks/kd_tree.exs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ key = Nx.Random.key(System.os_time())

Benchee.run(
%{
"unbanded" => fn -> Scholar.Neighbors.KDTree.unbanded(uniform) end,
"banded" => fn -> Scholar.Neighbors.KDTree.banded(uniform, 2) end
"unbound" => fn -> Scholar.Neighbors.KDTree.unbound(uniform) end,
"bound" => fn -> Scholar.Neighbors.KDTree.bound(uniform, 2) end
},
time: 10,
memory_time: 2
Expand Down
88 changes: 34 additions & 54 deletions lib/scholar/neighbors/kd_tree.ex
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ defmodule Scholar.Neighbors.KDTree do
Two construction modes are available:
* `banded/2` - the tensor has min and max values with an amplitude given by `max - min`.
* `bound/2` - the tensor has min and max values with an amplitude given by `max - min`.
It is also guaranteed that the `amplitude * levels(tensor) + 1` does not overflow
the tensor. See `amplitude/1` to verify if this holds. This implementation happens
fully within `defn`. This version is orders of magnitude faster than the `unbanded/2`
fully within `defn`. This version is orders of magnitude faster than the `unbound/2`
one.
* `unbanded/2` - there are no known bands (min and max values) to the tensor.
* `unbound/2` - there are no known bounds (min and max values) to the tensor.
This implementation is recursive and goes in and out of the `defn`, therefore
it cannot be called inside `defn`.
Expand All @@ -36,8 +36,8 @@ defmodule Scholar.Neighbors.KDTree do
@doc """
Builds a KDTree without known min-max bounds.
If your tensor has a known band (for example, -1 and 1),
consider using the `banded/2` version which is often orders of
If your tensor has known bounds (for example, -1 and 1),
consider using the `bound/2` version which is often orders of
magnitude more efficient.
## Options
Expand All @@ -46,21 +46,21 @@ defmodule Scholar.Neighbors.KDTree do
## Examples
iex> Scholar.Neighbors.KDTree.unbanded(Nx.iota({5, 2}), compiler: EXLA.Defn)
iex> Scholar.Neighbors.KDTree.unbound(Nx.iota({5, 2}), compiler: EXLA.Defn)
%Scholar.Neighbors.KDTree{
data: Nx.iota({5, 2}),
levels: 3,
indexes: Nx.u32([3, 1, 4, 0, 2])
}
"""
def unbanded(tensor, opts \\ []) do
def unbound(tensor, opts \\ []) do
levels = levels(tensor)
{size, _dims} = Nx.shape(tensor)

indexes =
if size > 2 do
subtree_size = unbanded_subtree_size(1, levels, size)
subtree_size = unbound_subtree_size(1, levels, size)
{left, mid, right} = Nx.Defn.jit_apply(&root_slice(&1, subtree_size), [tensor], opts)

acc = <<Nx.to_number(mid)::32-unsigned-native-integer>>
Expand Down Expand Up @@ -88,7 +88,7 @@ defmodule Scholar.Neighbors.KDTree do
defp recur([{i, indexes} | rest], next, acc, tensor, level, levels, opts) do
%Nx.Tensor{shape: {size, dims}} = tensor
k = rem(level, dims)
subtree_size = unbanded_subtree_size(left_child(i), levels, size)
subtree_size = unbound_subtree_size(left_child(i), levels, size)

{left, mid, right} =
Nx.Defn.jit_apply(&recur_slice(&1, &2, &3, subtree_size), [tensor, indexes, k], opts)
Expand Down Expand Up @@ -121,39 +121,42 @@ defmodule Scholar.Neighbors.KDTree do
Nx.slice(indexes, [subtree_size + 1], [Nx.size(indexes) - subtree_size - 1])}
end

defp unbanded_subtree_size(i, levels, size) do
defp unbound_subtree_size(i, levels, size) do
import Bitwise
diff = levels - unbanded_level(i) - 1
diff = levels - unbound_level(i) - 1
shifted = 1 <<< diff
fllc_s = (i <<< diff) + shifted - 1
shifted - 1 + min(max(0, size - fllc_s), shifted)
end

defp unbanded_level(i) when is_integer(i), do: 31 - clz32(i + 1)
defp unbound_level(i) when is_integer(i), do: floor(:math.log2(i + 1))

@doc """
Builds a KDTree with known min-max bounds entirely within `defn`.
This requires the amplitude `|max - min|` of the tensor to be given.
This requires the amplitude `|max - min|` of the tensor to be given
such that `max + (amplitude + 1) * (size - 1)` does not overflow the
maximum tensor type.
For example, a tensor where all values are between 0 and 1 has amplitude
1. Values between -1 and 1 has amplitude 2. If your tensor is normalized,
then you know the amplitude. Otherwise you can use `amplitude/1` to check
it.
1. Values between -1 and 1 has amplitude 2. If your tensor is normalized
to floating points, then it is most likely bound (given their high
precision). You can use `amplitude/1` to check your assumptions.
## Examples
iex> Scholar.Neighbors.KDTree.banded(Nx.iota({5, 2}), 10)
iex> Scholar.Neighbors.KDTree.bound(Nx.iota({5, 2}), 10)
%Scholar.Neighbors.KDTree{
data: Nx.iota({5, 2}),
levels: 3,
indexes: Nx.u32([3, 1, 4, 0, 2])
}
"""
deftransform banded(tensor, amplitude) do
%__MODULE__{levels: levels(tensor), indexes: banded_n(tensor, amplitude), data: tensor}
deftransform bound(tensor, amplitude) do
%__MODULE__{levels: levels(tensor), indexes: bound_n(tensor, amplitude), data: tensor}
end

defnp banded_n(tensor, amplitude) do
defnp bound_n(tensor, amplitude) do
levels = levels(tensor)
{size, dims} = Nx.shape(tensor)
band = amplitude + 1
Expand All @@ -175,8 +178,8 @@ defmodule Scholar.Neighbors.KDTree do
pos = Nx.argsort(indexes, type: :u32)

pivot =
banded_segment_begin(tags, levels, size) +
banded_subtree_size(left_child(tags), levels, size)
bound_segment_begin(tags, levels, size) +
bound_subtree_size(left_child(tags), levels, size)

Nx.select(
pos < (1 <<< level) - 1,
Expand All @@ -193,17 +196,17 @@ defmodule Scholar.Neighbors.KDTree do
)
end

defnp banded_subtree_size(i, levels, size) do
diff = levels - banded_level(i) - 1
defnp bound_subtree_size(i, levels, size) do
diff = levels - bound_level(i) - 1
shifted = 1 <<< diff
first_lowest_level = (i <<< diff) + shifted - 1
# Use select instead of max to deal with overflows
lowest_level = Nx.select(first_lowest_level > size, Nx.u32(0), size - first_lowest_level)
shifted - 1 + min(lowest_level, shifted)
end

defnp banded_segment_begin(i, levels, size) do
level = banded_level(i)
defnp bound_segment_begin(i, levels, size) do
level = bound_level(i)
top = (1 <<< level) - 1
diff = levels - level - 1
shifted = 1 <<< diff
Expand All @@ -214,15 +217,15 @@ defmodule Scholar.Neighbors.KDTree do
end

# Since this property relies on u32, let's check the tensor type.
deftransformp banded_level(%Nx.Tensor{type: {:u, 32}} = i) do
deftransformp bound_level(%Nx.Tensor{type: {:u, 32}} = i) do
Nx.subtract(31, Nx.count_leading_zeros(Nx.add(i, 1)))
end

@doc """
Returns the amplitude of a tensor for banding.
Returns the amplitude of a bounded tensor.
If -1 is returned, it means the tensor cannot use the `banded` algorithm
to generate a KDTree and `unbanded/2` must be used instead.
If -1 is returned, it means the tensor cannot use the `bound` algorithm
to generate a KDTree and `unbound/2` must be used instead.
This cannot be invoked inside a `defn`.
Expand Down Expand Up @@ -261,7 +264,7 @@ defmodule Scholar.Neighbors.KDTree do
"""
deftransform levels(%Nx.Tensor{} = tensor) do
case Nx.shape(tensor) do
{size, _dims} -> 32 - clz32(size)
{size, _dims} -> ceil(:math.log2(size + 1))
_ -> raise ArgumentError, "KDTrees requires a tensor of rank 2"
end
end
Expand Down Expand Up @@ -344,27 +347,4 @@ defmodule Scholar.Neighbors.KDTree do
"""
deftransform right_child(i) when is_integer(i), do: 2 * i + 2
deftransform right_child(%Nx.Tensor{} = t), do: Nx.add(Nx.multiply(2, t), 2)

@clz_lookup {32, 31, 30, 30, 29, 29, 29, 29, 28, 28, 28, 28, 28, 28, 28, 28}

defp clz32(x) when is_integer(x) do
import Bitwise

n =
if x >= 1 <<< 16 do
if x >= 1 <<< 24 do
if x >= 1 <<< 28, do: 28, else: 24
else
if x >= 1 <<< 20, do: 20, else: 16
end
else
if x >= 1 <<< 8 do
if x >= 1 <<< 12, do: 12, else: 8
else
if x >= 1 <<< 4, do: 4, else: 0
end
end

elem(@clz_lookup, x >>> n) - n
end
end
18 changes: 9 additions & 9 deletions test/scholar/neighbors/kd_tree_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@ defmodule Scholar.Neighbors.KDTreeTest do
])
end

describe "unbanded" do
describe "unbound" do
test "sample" do
assert %Scholar.Neighbors.KDTree{levels: 4, indexes: indexes} =
Scholar.Neighbors.KDTree.unbanded(example(), compiler: EXLA.Defn)
Scholar.Neighbors.KDTree.unbound(example(), compiler: EXLA.Defn)

assert Nx.to_flat_list(indexes) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4]
end

test "float" do
assert %Scholar.Neighbors.KDTree{levels: 4, indexes: indexes} =
Scholar.Neighbors.KDTree.unbanded(example() |> Nx.as_type(:f32),
Scholar.Neighbors.KDTree.unbound(example() |> Nx.as_type(:f32),
compiler: EXLA.Defn
)

Expand All @@ -36,35 +36,35 @@ defmodule Scholar.Neighbors.KDTreeTest do

test "corner cases" do
assert %Scholar.Neighbors.KDTree{levels: 1, indexes: indexes} =
Scholar.Neighbors.KDTree.unbanded(Nx.iota({1, 2}), compiler: EXLA.Defn)
Scholar.Neighbors.KDTree.unbound(Nx.iota({1, 2}), compiler: EXLA.Defn)

assert indexes == Nx.u32([0])

assert %Scholar.Neighbors.KDTree{levels: 2, indexes: indexes} =
Scholar.Neighbors.KDTree.unbanded(Nx.iota({2, 2}), compiler: EXLA.Defn)
Scholar.Neighbors.KDTree.unbound(Nx.iota({2, 2}), compiler: EXLA.Defn)

assert indexes == Nx.u32([1, 0])
end
end

describe "banded" do
describe "bound" do
test "iota" do
assert %Scholar.Neighbors.KDTree{levels: 3, indexes: indexes} =
Scholar.Neighbors.KDTree.banded(Nx.iota({5, 2}), 10)
Scholar.Neighbors.KDTree.bound(Nx.iota({5, 2}), 10)

assert indexes == Nx.u32([3, 1, 4, 0, 2])
end

test "float" do
assert %Scholar.Neighbors.KDTree{levels: 4, indexes: indexes} =
Scholar.Neighbors.KDTree.banded(example() |> Nx.as_type(:f32), 100)
Scholar.Neighbors.KDTree.bound(example() |> Nx.as_type(:f32), 100)

assert Nx.to_flat_list(indexes) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4]
end

test "sample" do
assert %Scholar.Neighbors.KDTree{levels: 4, indexes: indexes} =
Scholar.Neighbors.KDTree.banded(example(), 100)
Scholar.Neighbors.KDTree.bound(example(), 100)

assert Nx.to_flat_list(indexes) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4]
end
Expand Down

0 comments on commit ba86a27

Please sign in to comment.