From 5c1786fc2ba27ee9ba909ee1ad4f990f714afd8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krsto=20Prorokovi=C4=87?= Date: Tue, 16 Apr 2024 18:59:50 +0200 Subject: [PATCH] Brute k-NN (#257) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add BruteKNN * update docs and tests * mix format * move get_batches inside BruteKNN * raise error when k > n --------- Co-authored-by: Krsto Proroković --- lib/scholar/neighbors/brute_knn.ex | 260 ++++++++++++++++++++++ test/scholar/neighbors/brute_knn_test.exs | 123 ++++++++++ 2 files changed, 383 insertions(+) create mode 100644 lib/scholar/neighbors/brute_knn.ex create mode 100644 test/scholar/neighbors/brute_knn_test.exs diff --git a/lib/scholar/neighbors/brute_knn.ex b/lib/scholar/neighbors/brute_knn.ex new file mode 100644 index 00000000..bfb584af --- /dev/null +++ b/lib/scholar/neighbors/brute_knn.ex @@ -0,0 +1,260 @@ +defmodule Scholar.Neighbors.BruteKNN do + @moduledoc """ + Brute-Force k-Nearest Neighbor Search Algorithm. + + In order to find the k-nearest neighbors the algorithm calculates + the distance between the query point and each of the data samples. + Therefore, its time complexity is $O(MN)$ for $N$ samples and $M$ query points. + It uses $O(BN)$ memory for batch size $B$. + Larger batch sizes will lead to faster predictions, but will consume more memory. + """ + import Nx.Defn + import Scholar.Shared + require Nx + + @derive {Nx.Container, keep: [:num_neighbors, :metric, :batch_size], containers: [:data]} + defstruct [:num_neighbors, :metric, :data, :batch_size] + + opts = [ + num_neighbors: [ + required: true, + type: :pos_integer, + doc: "The number of nearest neighbors." + ], + metric: [ + type: {:or, [{:custom, Scholar.Options, :metric, []}, {:fun, 2}]}, + default: {:minkowski, 2}, + doc: ~S""" + The function that measures distance between two points. Possible values: + + * `{:minkowski, p}` - Minkowski metric. By changing value of `p` parameter (a positive number or `:infinity`) + we can set Manhattan (`1`), Euclidean (`2`), Chebyshev (`:infinity`), or any arbitrary $L_p$ metric. + + * `:cosine` - Cosine metric. + + * Anonymous function of arity 2 that takes two rank-1 tensors of same dimension and returns a scalar. + """ + ], + batch_size: [ + type: :pos_integer, + doc: "The number of samples in a batch." + ] + ] + + @opts_schema NimbleOptions.new!(opts) + + @doc """ + Fits a brute-force k-NN model. + + ## Options + + #{NimbleOptions.docs(@opts_schema)} + + ## Examples + + iex> data = Nx.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]]) + iex> model = Scholar.Neighbors.BruteKNN.fit(data, num_neighbors: 2) + iex> model.num_neighbors + 2 + iex> model.data + #Nx.Tensor< + s64[5][2] + [ + [1, 2], + [2, 3], + [3, 4], + [4, 5], + [5, 6] + ] + > + """ + deftransform fit(data, opts) do + if Nx.rank(data) != 2 do + raise ArgumentError, + "expected input tensor to have shape {num_samples, num_features}, + got tensor with shape: #{inspect(Nx.shape(data))}" + end + + opts = NimbleOptions.validate!(opts, @opts_schema) + k = opts[:num_neighbors] + + if k > Nx.axis_size(data, 0) do + raise ArgumentError, + """ + expected num_neighbors to be less than or equal to \ + num_samples = #{Nx.axis_size(data, 0)}, got: #{k} + """ + end + + metric = + case opts[:metric] do + {:minkowski, p} -> + &Scholar.Metrics.Distance.minkowski(&1, &2, p: p) + + :cosine -> + &Scholar.Metrics.Distance.cosine/2 + + fun when is_function(fun, 2) -> + fun + end + + %__MODULE__{ + num_neighbors: k, + metric: metric, + data: data, + batch_size: opts[:batch_size] + } + end + + @doc """ + Computes nearest neighbors of query tensor using brute-force search. + Returns the neighbors indices and distances from query points. + + ## Examples + + iex> data = Nx.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]]) + iex> model = Scholar.Neighbors.BruteKNN.fit(data, num_neighbors: 2) + iex> query = Nx.tensor([[1, 3], [4, 2], [3, 6]]) + iex> {neighbors, distances} = Scholar.Neighbors.BruteKNN.predict(model, query) + iex> neighbors + #Nx.Tensor< + u64[3][2] + [ + [0, 1], + [1, 2], + [3, 2] + ] + > + iex> distances + #Nx.Tensor< + f32[3][2] + [ + [1.0, 1.0], + [2.2360680103302, 2.2360680103302], + [1.4142135381698608, 2.0] + ] + > + """ + deftransform predict(%__MODULE__{} = model, query) do + if Nx.rank(query) != 2 do + raise ArgumentError, + "expected query tensor to have shape {num_queries, num_features}, + got tensor with shape: #{inspect(Nx.shape(query))}" + end + + if Nx.axis_size(model.data, 1) != Nx.axis_size(query, 1) do + raise ArgumentError, + """ + expected query tensor to have the same dimension as tensor used for fitting the model, \ + got #{inspect(Nx.axis_size(model.data, 1))} \ + and #{inspect(Nx.axis_size(query, 1))} + """ + end + + predict_n(model, query) + end + + defn predict_n(%__MODULE__{} = model, query) do + k = model.num_neighbors + metric = model.metric + data = model.data + type = Nx.Type.merge(to_float_type(data), to_float_type(query)) + query_size = Nx.axis_size(query, 0) + + batch_size = + case model.batch_size do + nil -> query_size + _ -> min(model.batch_size, query_size) + end + + {batches, leftover} = get_batches(query, batch_size: batch_size) + num_batches = Nx.axis_size(batches, 0) + + {neighbor_indices, neighbor_distances, _} = + while { + neighbor_indices = Nx.broadcast(Nx.u64(0), {query_size, k}), + neighbor_distances = + Nx.broadcast(Nx.as_type(:nan, type), {query_size, k}), + { + data, + batches, + i = Nx.u64(0) + } + }, + i < num_batches do + batch = batches[i] + + {batch_indices, batch_distances} = + brute_force_search(data, batch, num_neighbors: k, metric: metric) + + neighbor_indices = Nx.put_slice(neighbor_indices, [i * batch_size, 0], batch_indices) + + neighbor_distances = + Nx.put_slice(neighbor_distances, [i * batch_size, 0], batch_distances) + + {neighbor_indices, neighbor_distances, {data, batches, i + 1}} + end + + {neighbor_indices, neighbor_distances} = + case leftover do + nil -> + {neighbor_indices, neighbor_distances} + + _ -> + leftover_size = Nx.axis_size(leftover, 0) + + leftover = + Nx.slice_along_axis(query, query_size - leftover_size, leftover_size, axis: 0) + + {leftover_indices, leftover_distances} = + brute_force_search(data, leftover, num_neighbors: k, metric: metric) + + neighbor_indices = + Nx.put_slice(neighbor_indices, [num_batches * batch_size, 0], leftover_indices) + + neighbor_distances = + Nx.put_slice(neighbor_distances, [num_batches * batch_size, 0], leftover_distances) + + {neighbor_indices, neighbor_distances} + end + + {neighbor_indices, neighbor_distances} + end + + defn get_batches(tensor, opts) do + {size, dim} = Nx.shape(tensor) + batch_size = opts[:batch_size] + num_batches = div(size, batch_size) + leftover_size = rem(size, batch_size) + + batches = + tensor + |> Nx.slice_along_axis(0, num_batches * batch_size, axis: 0) + |> Nx.reshape({num_batches, batch_size, dim}) + + leftover = + if leftover_size > 0 do + Nx.slice_along_axis(tensor, num_batches * batch_size, leftover_size, axis: 0) + else + nil + end + + {batches, leftover} + end + + defnp brute_force_search(data, query, opts) do + k = opts[:num_neighbors] + metric = opts[:metric] + {m, d} = Nx.shape(data) + n = Nx.axis_size(query, 0) + x = query |> Nx.new_axis(1) |> Nx.broadcast({n, m, d}) |> Nx.vectorize([:query, :data]) + y = data |> Nx.new_axis(0) |> Nx.broadcast({n, m, d}) |> Nx.vectorize([:query, :data]) + distances = metric.(x, y) |> Nx.devectorize() |> Nx.rename(nil) + + neighbor_indices = + Nx.argsort(distances, axis: 1, type: :u64) |> Nx.slice_along_axis(0, k, axis: 1) + + neighbor_distances = Nx.take_along_axis(distances, neighbor_indices, axis: 1) + {neighbor_indices, neighbor_distances} + end +end diff --git a/test/scholar/neighbors/brute_knn_test.exs b/test/scholar/neighbors/brute_knn_test.exs new file mode 100644 index 00000000..052f02bb --- /dev/null +++ b/test/scholar/neighbors/brute_knn_test.exs @@ -0,0 +1,123 @@ +defmodule Scholar.Neighbors.BruteKNNTest do + use ExUnit.Case, async: true + alias Scholar.Neighbors.BruteKNN + doctest BruteKNN + + defp data do + Nx.tensor([ + [10, 15], + [46, 63], + [68, 21], + [40, 33], + [25, 54], + [15, 43], + [44, 58], + [45, 40], + [62, 69], + [53, 67] + ]) + end + + defp query do + Nx.tensor([ + [12, 23], + [55, 30], + [41, 57], + [64, 72], + [26, 39] + ]) + end + + defp result do + neighbor_indices = + Nx.tensor( + [ + [0, 5, 3], + [7, 3, 2], + [6, 1, 9], + [8, 9, 1], + [5, 4, 3] + ], + type: :u64 + ) + + neighbor_distances = + Nx.tensor([ + [8.246211051940918, 20.2237491607666, 29.73213768005371], + [14.142135620117188, 15.29705810546875, 15.81138801574707], + [3.1622776985168457, 7.8102498054504395, 15.620499610900879], + [3.605551242828369, 12.083045959472656, 20.124610900878906], + [11.704699516296387, 15.033296585083008, 15.231546401977539] + ]) + + {neighbor_indices, neighbor_distances} + end + + describe "fit" do + test "default" do + data = data() + k = 3 + model = BruteKNN.fit(data, num_neighbors: k) + assert model.num_neighbors == 3 + assert model.data == data + assert model.batch_size == nil + end + + test "custom metric and batch_size" do + data = data() + k = 3 + metric = &Scholar.Metrics.Distance.minkowski/2 + batch_size = 2 + model = BruteKNN.fit(data, num_neighbors: k, metric: metric, batch_size: batch_size) + assert model.num_neighbors == k + assert model.metric == metric + assert model.data == data + assert model.batch_size == batch_size + end + end + + describe "predict" do + test "batch_size = 1" do + query = query() + k = 3 + model = BruteKNN.fit(data(), num_neighbors: k, batch_size: 1) + {neighbors_true, distances_true} = result() + {neighbors_pred, distances_pred} = BruteKNN.predict(model, query) + assert neighbors_pred == neighbors_true + assert distances_pred == distances_true + end + + test "batch_size = 2" do + query = query() + k = 3 + model = BruteKNN.fit(data(), num_neighbors: k, batch_size: 2) + {neighbors_true, distances_true} = result() + {neighbors_pred, distances_pred} = BruteKNN.predict(model, query) + assert neighbors_pred == neighbors_true + assert distances_pred == distances_true + end + + test "batch_size = 5" do + query = query() + k = 3 + model = BruteKNN.fit(data(), num_neighbors: k, batch_size: 5) + {neighbors_true, distances_true} = result() + {neighbors_pred, distances_pred} = BruteKNN.predict(model, query) + assert neighbors_pred == neighbors_true + assert distances_pred == distances_true + end + + test "batch_size = 10" do + query = query() + k = 3 + model = BruteKNN.fit(data(), num_neighbors: k, batch_size: 10) + {neighbors_true, distances_true} = result() + {neighbors_pred, distances_pred} = BruteKNN.predict(model, query) + + assert neighbors_pred == + neighbors_true + + assert distances_pred == distances_true + end + end +end