Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Brute k-NN #257

Merged
merged 6 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 260 additions & 0 deletions lib/scholar/neighbors/brute_knn.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
defmodule Scholar.Neighbors.BruteKNN do
@moduledoc """
Brute-Force k-Nearest Neighbor Search Algorithm.

In order to find the k-nearest neighbors the algorithm calculates
the distance between the query point and each of the data samples.
Therefore, its time complexity is $O(MN)$ for $N$ samples and $M$ query points.
It uses $O(BN)$ memory for batch size $B$.
Larger batch sizes will lead to faster predictions, but will consume more memory.
"""
import Nx.Defn
import Scholar.Shared
require Nx

@derive {Nx.Container, keep: [:num_neighbors, :metric, :batch_size], containers: [:data]}
defstruct [:num_neighbors, :metric, :data, :batch_size]

opts = [
num_neighbors: [
required: true,
type: :pos_integer,
doc: "The number of nearest neighbors."
],
metric: [
type: {:or, [{:custom, Scholar.Options, :metric, []}, {:fun, 2}]},
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps Scholar.Options.metric should be edited to support functions of arity 2?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If everywhere we accept a metric we also accept functions, then yes. The easiest is probably to make it so it always returns a 2-arity function and then we call it. :)

default: {:minkowski, 2},
doc: ~S"""
The function that measures distance between two points. Possible values:

* `{:minkowski, p}` - Minkowski metric. By changing value of `p` parameter (a positive number or `:infinity`)
we can set Manhattan (`1`), Euclidean (`2`), Chebyshev (`:infinity`), or any arbitrary $L_p$ metric.

* `:cosine` - Cosine metric.

* Anonymous function of arity 2 that takes two rank-1 tensors of same dimension and returns a scalar.
"""
],
batch_size: [
type: :pos_integer,
doc: "The number of samples in a batch."
]
]

@opts_schema NimbleOptions.new!(opts)

@doc """
Fits a brute-force k-NN model.

## Options

#{NimbleOptions.docs(@opts_schema)}

## Examples

iex> data = Nx.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]])
iex> model = Scholar.Neighbors.BruteKNN.fit(data, num_neighbors: 2)
iex> model.num_neighbors
2
iex> model.data
#Nx.Tensor<
s64[5][2]
[
[1, 2],
[2, 3],
[3, 4],
[4, 5],
[5, 6]
]
>
"""
deftransform fit(data, opts) do
if Nx.rank(data) != 2 do
raise ArgumentError,
"expected input tensor to have shape {num_samples, num_features},
got tensor with shape: #{inspect(Nx.shape(data))}"
end

opts = NimbleOptions.validate!(opts, @opts_schema)
k = opts[:num_neighbors]

if k > Nx.axis_size(data, 0) do
raise ArgumentError,
"""
expected num_neighbors to be less than or equal to \
num_samples = #{Nx.axis_size(data, 0)}, got: #{k}
"""
end

metric =
case opts[:metric] do
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following up above: I would move this normalization to Scholar.Options.metric then. This way everywhere can rely on it being a 2-arity function.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean instead of returning the atom to return the anonymous function of arity 2?
I agree, but some other modules need to be edited as well. For example k-d tree:

metric: [
type: {:custom, Scholar.Options, :metric, []},
default: {:minkowski, 2},
doc: ~S"""
Name of the metric. Possible values:
* `{:minkowski, p}` - Minkowski metric. By changing value of `p` parameter (a positive number or `:infinity`)
we can set Manhattan (`1`), Euclidean (`2`), Chebyshev (`:infinity`), or any arbitrary $L_p$ metric.
* `:cosine` - Cosine metric.
"""

And here is how it is used there:
case opts[:metric] do
{:minkowski, 2} -> Distance.squared_euclidean(x1, x2)
{:minkowski, p} -> Distance.minkowski(x1, x2, p: p)
:cosine -> Distance.cosine(x1, x2)

I would honestly prefer metric to be stored as a function inside a field.

Copy link
Member Author

@krstopro krstopro Apr 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main problem is that different algorithms support different metrics. Brute-force search works with literally any metric, while current implementation of random projection forest works only with the Euclidean distance. I am not sure about k-d tree; I think it works only with the three metrics specified in the docs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is fine then to keep this logic only in this module then :) IMO

Copy link
Member Author

@krstopro krstopro Apr 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. I would still suggest converting atoms to functions and storing them as fields. Seems cleaner and doesn't require choosing the right function to call each time predict is used.
Also, it might be worth considering removing Scholar.Options.metric given the differences in metrics supported between modules.

{:minkowski, p} ->
&Scholar.Metrics.Distance.minkowski(&1, &2, p: p)

:cosine ->
&Scholar.Metrics.Distance.cosine/2

fun when is_function(fun, 2) ->
fun
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not 100% sure about this. Could there be issues with Nx backends?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this is fine!

end

%__MODULE__{
num_neighbors: k,
metric: metric,
data: data,
batch_size: opts[:batch_size]
}
end

@doc """
Computes nearest neighbors of query tensor using brute-force search.
Returns the neighbors indices and distances from query points.

## Examples

iex> data = Nx.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]])
iex> model = Scholar.Neighbors.BruteKNN.fit(data, num_neighbors: 2)
iex> query = Nx.tensor([[1, 3], [4, 2], [3, 6]])
iex> {neighbors, distances} = Scholar.Neighbors.BruteKNN.predict(model, query)
iex> neighbors
#Nx.Tensor<
u64[3][2]
[
[0, 1],
[1, 2],
[3, 2]
]
>
iex> distances
#Nx.Tensor<
f32[3][2]
[
[1.0, 1.0],
[2.2360680103302, 2.2360680103302],
[1.4142135381698608, 2.0]
]
>
"""
deftransform predict(%__MODULE__{} = model, query) do
if Nx.rank(query) != 2 do
raise ArgumentError,
"expected query tensor to have shape {num_queries, num_features},
got tensor with shape: #{inspect(Nx.shape(query))}"
end

if Nx.axis_size(model.data, 1) != Nx.axis_size(query, 1) do
raise ArgumentError,
"""
expected query tensor to have the same dimension as tensor used for fitting the model, \
got #{inspect(Nx.axis_size(model.data, 1))} \
and #{inspect(Nx.axis_size(query, 1))}
"""
end

predict_n(model, query)
end

defn predict_n(%__MODULE__{} = model, query) do
k = model.num_neighbors
metric = model.metric
data = model.data
type = Nx.Type.merge(to_float_type(data), to_float_type(query))
query_size = Nx.axis_size(query, 0)

batch_size =
case model.batch_size do
nil -> query_size
_ -> min(model.batch_size, query_size)
end

{batches, leftover} = get_batches(query, batch_size: batch_size)
num_batches = Nx.axis_size(batches, 0)

{neighbor_indices, neighbor_distances, _} =
while {
neighbor_indices = Nx.broadcast(Nx.u64(0), {query_size, k}),
neighbor_distances =
Nx.broadcast(Nx.as_type(:nan, type), {query_size, k}),
{
data,
batches,
i = Nx.u64(0)
}
},
i < num_batches do
batch = batches[i]

{batch_indices, batch_distances} =
brute_force_search(data, batch, num_neighbors: k, metric: metric)

neighbor_indices = Nx.put_slice(neighbor_indices, [i * batch_size, 0], batch_indices)

neighbor_distances =
Nx.put_slice(neighbor_distances, [i * batch_size, 0], batch_distances)

{neighbor_indices, neighbor_distances, {data, batches, i + 1}}
end

{neighbor_indices, neighbor_distances} =
case leftover do
nil ->
{neighbor_indices, neighbor_distances}

_ ->
leftover_size = Nx.axis_size(leftover, 0)

leftover =
Nx.slice_along_axis(query, query_size - leftover_size, leftover_size, axis: 0)

{leftover_indices, leftover_distances} =
brute_force_search(data, leftover, num_neighbors: k, metric: metric)

neighbor_indices =
Nx.put_slice(neighbor_indices, [num_batches * batch_size, 0], leftover_indices)

neighbor_distances =
Nx.put_slice(neighbor_distances, [num_batches * batch_size, 0], leftover_distances)

{neighbor_indices, neighbor_distances}
end

{neighbor_indices, neighbor_distances}
end

defn get_batches(tensor, opts) do
{size, dim} = Nx.shape(tensor)
batch_size = opts[:batch_size]
num_batches = div(size, batch_size)
leftover_size = rem(size, batch_size)

batches =
tensor
|> Nx.slice_along_axis(0, num_batches * batch_size, axis: 0)
|> Nx.reshape({num_batches, batch_size, dim})

leftover =
if leftover_size > 0 do
Nx.slice_along_axis(tensor, num_batches * batch_size, leftover_size, axis: 0)
else
nil
end

{batches, leftover}
end

defnp brute_force_search(data, query, opts) do
k = opts[:num_neighbors]
metric = opts[:metric]
{m, d} = Nx.shape(data)
n = Nx.axis_size(query, 0)
x = query |> Nx.new_axis(1) |> Nx.broadcast({n, m, d}) |> Nx.vectorize([:query, :data])
y = data |> Nx.new_axis(0) |> Nx.broadcast({n, m, d}) |> Nx.vectorize([:query, :data])
distances = metric.(x, y) |> Nx.devectorize() |> Nx.rename(nil)

neighbor_indices =
Nx.argsort(distances, axis: 1, type: :u64) |> Nx.slice_along_axis(0, k, axis: 1)

neighbor_distances = Nx.take_along_axis(distances, neighbor_indices, axis: 1)
{neighbor_indices, neighbor_distances}
end
end
123 changes: 123 additions & 0 deletions test/scholar/neighbors/brute_knn_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
defmodule Scholar.Neighbors.BruteKNNTest do
use ExUnit.Case, async: true
alias Scholar.Neighbors.BruteKNN
doctest BruteKNN

defp data do
Nx.tensor([
[10, 15],
[46, 63],
[68, 21],
[40, 33],
[25, 54],
[15, 43],
[44, 58],
[45, 40],
[62, 69],
[53, 67]
])
end

defp query do
Nx.tensor([
[12, 23],
[55, 30],
[41, 57],
[64, 72],
[26, 39]
])
end

defp result do
neighbor_indices =
Nx.tensor(
[
[0, 5, 3],
[7, 3, 2],
[6, 1, 9],
[8, 9, 1],
[5, 4, 3]
],
type: :u64
)

neighbor_distances =
Nx.tensor([
[8.246211051940918, 20.2237491607666, 29.73213768005371],
[14.142135620117188, 15.29705810546875, 15.81138801574707],
[3.1622776985168457, 7.8102498054504395, 15.620499610900879],
[3.605551242828369, 12.083045959472656, 20.124610900878906],
[11.704699516296387, 15.033296585083008, 15.231546401977539]
])

{neighbor_indices, neighbor_distances}
end

describe "fit" do
test "default" do
data = data()
k = 3
model = BruteKNN.fit(data, num_neighbors: k)
assert model.num_neighbors == 3
assert model.data == data
assert model.batch_size == nil
end

test "custom metric and batch_size" do
data = data()
k = 3
metric = &Scholar.Metrics.Distance.minkowski/2
batch_size = 2
model = BruteKNN.fit(data, num_neighbors: k, metric: metric, batch_size: batch_size)
assert model.num_neighbors == k
assert model.metric == metric
assert model.data == data
assert model.batch_size == batch_size
end
end

describe "predict" do
test "batch_size = 1" do
query = query()
k = 3
model = BruteKNN.fit(data(), num_neighbors: k, batch_size: 1)
{neighbors_true, distances_true} = result()
{neighbors_pred, distances_pred} = BruteKNN.predict(model, query)
assert neighbors_pred == neighbors_true
assert distances_pred == distances_true
end

test "batch_size = 2" do
query = query()
k = 3
model = BruteKNN.fit(data(), num_neighbors: k, batch_size: 2)
{neighbors_true, distances_true} = result()
{neighbors_pred, distances_pred} = BruteKNN.predict(model, query)
assert neighbors_pred == neighbors_true
assert distances_pred == distances_true
end

test "batch_size = 5" do
query = query()
k = 3
model = BruteKNN.fit(data(), num_neighbors: k, batch_size: 5)
{neighbors_true, distances_true} = result()
{neighbors_pred, distances_pred} = BruteKNN.predict(model, query)
assert neighbors_pred == neighbors_true
assert distances_pred == distances_true
end

test "batch_size = 10" do
query = query()
k = 3
model = BruteKNN.fit(data(), num_neighbors: k, batch_size: 10)
{neighbors_true, distances_true} = result()
{neighbors_pred, distances_pred} = BruteKNN.predict(model, query)

assert neighbors_pred ==
neighbors_true

assert distances_pred == distances_true
end
end
end
Loading