-
Notifications
You must be signed in to change notification settings - Fork 47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Brute k-NN #257
Brute k-NN #257
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,260 @@ | ||||||||||||||||||||||||||||||||
defmodule Scholar.Neighbors.BruteKNN do | ||||||||||||||||||||||||||||||||
@moduledoc """ | ||||||||||||||||||||||||||||||||
Brute-Force k-Nearest Neighbor Search Algorithm. | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
In order to find the k-nearest neighbors the algorithm calculates | ||||||||||||||||||||||||||||||||
the distance between the query point and each of the data samples. | ||||||||||||||||||||||||||||||||
Therefore, its time complexity is $O(MN)$ for $N$ samples and $M$ query points. | ||||||||||||||||||||||||||||||||
It uses $O(BN)$ memory for batch size $B$. | ||||||||||||||||||||||||||||||||
Larger batch sizes will lead to faster predictions, but will consume more memory. | ||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||
import Nx.Defn | ||||||||||||||||||||||||||||||||
import Scholar.Shared | ||||||||||||||||||||||||||||||||
require Nx | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
@derive {Nx.Container, keep: [:num_neighbors, :metric, :batch_size], containers: [:data]} | ||||||||||||||||||||||||||||||||
defstruct [:num_neighbors, :metric, :data, :batch_size] | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
opts = [ | ||||||||||||||||||||||||||||||||
num_neighbors: [ | ||||||||||||||||||||||||||||||||
required: true, | ||||||||||||||||||||||||||||||||
type: :pos_integer, | ||||||||||||||||||||||||||||||||
doc: "The number of nearest neighbors." | ||||||||||||||||||||||||||||||||
], | ||||||||||||||||||||||||||||||||
metric: [ | ||||||||||||||||||||||||||||||||
type: {:or, [{:custom, Scholar.Options, :metric, []}, {:fun, 2}]}, | ||||||||||||||||||||||||||||||||
default: {:minkowski, 2}, | ||||||||||||||||||||||||||||||||
doc: ~S""" | ||||||||||||||||||||||||||||||||
The function that measures distance between two points. Possible values: | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
* `{:minkowski, p}` - Minkowski metric. By changing value of `p` parameter (a positive number or `:infinity`) | ||||||||||||||||||||||||||||||||
we can set Manhattan (`1`), Euclidean (`2`), Chebyshev (`:infinity`), or any arbitrary $L_p$ metric. | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
* `:cosine` - Cosine metric. | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
* Anonymous function of arity 2 that takes two rank-1 tensors of same dimension and returns a scalar. | ||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||
], | ||||||||||||||||||||||||||||||||
batch_size: [ | ||||||||||||||||||||||||||||||||
type: :pos_integer, | ||||||||||||||||||||||||||||||||
doc: "The number of samples in a batch." | ||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
@opts_schema NimbleOptions.new!(opts) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
@doc """ | ||||||||||||||||||||||||||||||||
Fits a brute-force k-NN model. | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
## Options | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
#{NimbleOptions.docs(@opts_schema)} | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
## Examples | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
iex> data = Nx.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]]) | ||||||||||||||||||||||||||||||||
iex> model = Scholar.Neighbors.BruteKNN.fit(data, num_neighbors: 2) | ||||||||||||||||||||||||||||||||
iex> model.num_neighbors | ||||||||||||||||||||||||||||||||
2 | ||||||||||||||||||||||||||||||||
iex> model.data | ||||||||||||||||||||||||||||||||
#Nx.Tensor< | ||||||||||||||||||||||||||||||||
s64[5][2] | ||||||||||||||||||||||||||||||||
[ | ||||||||||||||||||||||||||||||||
[1, 2], | ||||||||||||||||||||||||||||||||
[2, 3], | ||||||||||||||||||||||||||||||||
[3, 4], | ||||||||||||||||||||||||||||||||
[4, 5], | ||||||||||||||||||||||||||||||||
[5, 6] | ||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||
> | ||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||
deftransform fit(data, opts) do | ||||||||||||||||||||||||||||||||
if Nx.rank(data) != 2 do | ||||||||||||||||||||||||||||||||
raise ArgumentError, | ||||||||||||||||||||||||||||||||
"expected input tensor to have shape {num_samples, num_features}, | ||||||||||||||||||||||||||||||||
got tensor with shape: #{inspect(Nx.shape(data))}" | ||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
opts = NimbleOptions.validate!(opts, @opts_schema) | ||||||||||||||||||||||||||||||||
k = opts[:num_neighbors] | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
if k > Nx.axis_size(data, 0) do | ||||||||||||||||||||||||||||||||
raise ArgumentError, | ||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||
expected num_neighbors to be less than or equal to \ | ||||||||||||||||||||||||||||||||
num_samples = #{Nx.axis_size(data, 0)}, got: #{k} | ||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
metric = | ||||||||||||||||||||||||||||||||
case opts[:metric] do | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Following up above: I would move this normalization to Scholar.Options.metric then. This way everywhere can rely on it being a 2-arity function. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You mean instead of returning the atom to return the anonymous function of arity 2? scholar/lib/scholar/neighbors/kd_tree.ex Lines 33 to 43 in 0d1bcc1
And here is how it is used there: scholar/lib/scholar/neighbors/kd_tree.ex Lines 283 to 286 in 0d1bcc1
I would honestly prefer metric to be stored as a function inside a field. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The main problem is that different algorithms support different metrics. Brute-force search works with literally any metric, while current implementation of random projection forest works only with the Euclidean distance. I am not sure about k-d tree; I think it works only with the three metrics specified in the docs. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is fine then to keep this logic only in this module then :) IMO There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed. I would still suggest converting atoms to functions and storing them as fields. Seems cleaner and doesn't require choosing the right function to call each time |
||||||||||||||||||||||||||||||||
{:minkowski, p} -> | ||||||||||||||||||||||||||||||||
&Scholar.Metrics.Distance.minkowski(&1, &2, p: p) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
:cosine -> | ||||||||||||||||||||||||||||||||
&Scholar.Metrics.Distance.cosine/2 | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
fun when is_function(fun, 2) -> | ||||||||||||||||||||||||||||||||
fun | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not 100% sure about this. Could there be issues with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, this is fine! |
||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
%__MODULE__{ | ||||||||||||||||||||||||||||||||
num_neighbors: k, | ||||||||||||||||||||||||||||||||
metric: metric, | ||||||||||||||||||||||||||||||||
data: data, | ||||||||||||||||||||||||||||||||
batch_size: opts[:batch_size] | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
@doc """ | ||||||||||||||||||||||||||||||||
Computes nearest neighbors of query tensor using brute-force search. | ||||||||||||||||||||||||||||||||
Returns the neighbors indices and distances from query points. | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
## Examples | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
iex> data = Nx.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]]) | ||||||||||||||||||||||||||||||||
iex> model = Scholar.Neighbors.BruteKNN.fit(data, num_neighbors: 2) | ||||||||||||||||||||||||||||||||
iex> query = Nx.tensor([[1, 3], [4, 2], [3, 6]]) | ||||||||||||||||||||||||||||||||
iex> {neighbors, distances} = Scholar.Neighbors.BruteKNN.predict(model, query) | ||||||||||||||||||||||||||||||||
iex> neighbors | ||||||||||||||||||||||||||||||||
#Nx.Tensor< | ||||||||||||||||||||||||||||||||
u64[3][2] | ||||||||||||||||||||||||||||||||
[ | ||||||||||||||||||||||||||||||||
[0, 1], | ||||||||||||||||||||||||||||||||
[1, 2], | ||||||||||||||||||||||||||||||||
[3, 2] | ||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||
> | ||||||||||||||||||||||||||||||||
iex> distances | ||||||||||||||||||||||||||||||||
#Nx.Tensor< | ||||||||||||||||||||||||||||||||
f32[3][2] | ||||||||||||||||||||||||||||||||
[ | ||||||||||||||||||||||||||||||||
[1.0, 1.0], | ||||||||||||||||||||||||||||||||
[2.2360680103302, 2.2360680103302], | ||||||||||||||||||||||||||||||||
[1.4142135381698608, 2.0] | ||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||
> | ||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||
deftransform predict(%__MODULE__{} = model, query) do | ||||||||||||||||||||||||||||||||
if Nx.rank(query) != 2 do | ||||||||||||||||||||||||||||||||
raise ArgumentError, | ||||||||||||||||||||||||||||||||
"expected query tensor to have shape {num_queries, num_features}, | ||||||||||||||||||||||||||||||||
got tensor with shape: #{inspect(Nx.shape(query))}" | ||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
if Nx.axis_size(model.data, 1) != Nx.axis_size(query, 1) do | ||||||||||||||||||||||||||||||||
raise ArgumentError, | ||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||
expected query tensor to have the same dimension as tensor used for fitting the model, \ | ||||||||||||||||||||||||||||||||
got #{inspect(Nx.axis_size(model.data, 1))} \ | ||||||||||||||||||||||||||||||||
and #{inspect(Nx.axis_size(query, 1))} | ||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
predict_n(model, query) | ||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
defn predict_n(%__MODULE__{} = model, query) do | ||||||||||||||||||||||||||||||||
k = model.num_neighbors | ||||||||||||||||||||||||||||||||
metric = model.metric | ||||||||||||||||||||||||||||||||
data = model.data | ||||||||||||||||||||||||||||||||
type = Nx.Type.merge(to_float_type(data), to_float_type(query)) | ||||||||||||||||||||||||||||||||
query_size = Nx.axis_size(query, 0) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
batch_size = | ||||||||||||||||||||||||||||||||
case model.batch_size do | ||||||||||||||||||||||||||||||||
nil -> query_size | ||||||||||||||||||||||||||||||||
_ -> min(model.batch_size, query_size) | ||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
{batches, leftover} = get_batches(query, batch_size: batch_size) | ||||||||||||||||||||||||||||||||
num_batches = Nx.axis_size(batches, 0) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
{neighbor_indices, neighbor_distances, _} = | ||||||||||||||||||||||||||||||||
while { | ||||||||||||||||||||||||||||||||
neighbor_indices = Nx.broadcast(Nx.u64(0), {query_size, k}), | ||||||||||||||||||||||||||||||||
neighbor_distances = | ||||||||||||||||||||||||||||||||
Nx.broadcast(Nx.as_type(:nan, type), {query_size, k}), | ||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||
data, | ||||||||||||||||||||||||||||||||
batches, | ||||||||||||||||||||||||||||||||
i = Nx.u64(0) | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
}, | ||||||||||||||||||||||||||||||||
i < num_batches do | ||||||||||||||||||||||||||||||||
batch = batches[i] | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
{batch_indices, batch_distances} = | ||||||||||||||||||||||||||||||||
brute_force_search(data, batch, num_neighbors: k, metric: metric) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
neighbor_indices = Nx.put_slice(neighbor_indices, [i * batch_size, 0], batch_indices) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
neighbor_distances = | ||||||||||||||||||||||||||||||||
Nx.put_slice(neighbor_distances, [i * batch_size, 0], batch_distances) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
{neighbor_indices, neighbor_distances, {data, batches, i + 1}} | ||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
{neighbor_indices, neighbor_distances} = | ||||||||||||||||||||||||||||||||
case leftover do | ||||||||||||||||||||||||||||||||
nil -> | ||||||||||||||||||||||||||||||||
{neighbor_indices, neighbor_distances} | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
_ -> | ||||||||||||||||||||||||||||||||
leftover_size = Nx.axis_size(leftover, 0) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
leftover = | ||||||||||||||||||||||||||||||||
Nx.slice_along_axis(query, query_size - leftover_size, leftover_size, axis: 0) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
{leftover_indices, leftover_distances} = | ||||||||||||||||||||||||||||||||
brute_force_search(data, leftover, num_neighbors: k, metric: metric) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
neighbor_indices = | ||||||||||||||||||||||||||||||||
Nx.put_slice(neighbor_indices, [num_batches * batch_size, 0], leftover_indices) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
neighbor_distances = | ||||||||||||||||||||||||||||||||
Nx.put_slice(neighbor_distances, [num_batches * batch_size, 0], leftover_distances) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
{neighbor_indices, neighbor_distances} | ||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
{neighbor_indices, neighbor_distances} | ||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
defn get_batches(tensor, opts) do | ||||||||||||||||||||||||||||||||
{size, dim} = Nx.shape(tensor) | ||||||||||||||||||||||||||||||||
batch_size = opts[:batch_size] | ||||||||||||||||||||||||||||||||
num_batches = div(size, batch_size) | ||||||||||||||||||||||||||||||||
leftover_size = rem(size, batch_size) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
batches = | ||||||||||||||||||||||||||||||||
tensor | ||||||||||||||||||||||||||||||||
|> Nx.slice_along_axis(0, num_batches * batch_size, axis: 0) | ||||||||||||||||||||||||||||||||
|> Nx.reshape({num_batches, batch_size, dim}) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
leftover = | ||||||||||||||||||||||||||||||||
if leftover_size > 0 do | ||||||||||||||||||||||||||||||||
Nx.slice_along_axis(tensor, num_batches * batch_size, leftover_size, axis: 0) | ||||||||||||||||||||||||||||||||
else | ||||||||||||||||||||||||||||||||
nil | ||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
{batches, leftover} | ||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
defnp brute_force_search(data, query, opts) do | ||||||||||||||||||||||||||||||||
k = opts[:num_neighbors] | ||||||||||||||||||||||||||||||||
metric = opts[:metric] | ||||||||||||||||||||||||||||||||
{m, d} = Nx.shape(data) | ||||||||||||||||||||||||||||||||
n = Nx.axis_size(query, 0) | ||||||||||||||||||||||||||||||||
x = query |> Nx.new_axis(1) |> Nx.broadcast({n, m, d}) |> Nx.vectorize([:query, :data]) | ||||||||||||||||||||||||||||||||
y = data |> Nx.new_axis(0) |> Nx.broadcast({n, m, d}) |> Nx.vectorize([:query, :data]) | ||||||||||||||||||||||||||||||||
distances = metric.(x, y) |> Nx.devectorize() |> Nx.rename(nil) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
neighbor_indices = | ||||||||||||||||||||||||||||||||
Nx.argsort(distances, axis: 1, type: :u64) |> Nx.slice_along_axis(0, k, axis: 1) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
neighbor_distances = Nx.take_along_axis(distances, neighbor_indices, axis: 1) | ||||||||||||||||||||||||||||||||
{neighbor_indices, neighbor_distances} | ||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
defmodule Scholar.Neighbors.BruteKNNTest do | ||
use ExUnit.Case, async: true | ||
alias Scholar.Neighbors.BruteKNN | ||
doctest BruteKNN | ||
|
||
defp data do | ||
Nx.tensor([ | ||
[10, 15], | ||
[46, 63], | ||
[68, 21], | ||
[40, 33], | ||
[25, 54], | ||
[15, 43], | ||
[44, 58], | ||
[45, 40], | ||
[62, 69], | ||
[53, 67] | ||
]) | ||
end | ||
|
||
defp query do | ||
Nx.tensor([ | ||
[12, 23], | ||
[55, 30], | ||
[41, 57], | ||
[64, 72], | ||
[26, 39] | ||
]) | ||
end | ||
|
||
defp result do | ||
neighbor_indices = | ||
Nx.tensor( | ||
[ | ||
[0, 5, 3], | ||
[7, 3, 2], | ||
[6, 1, 9], | ||
[8, 9, 1], | ||
[5, 4, 3] | ||
], | ||
type: :u64 | ||
) | ||
|
||
neighbor_distances = | ||
Nx.tensor([ | ||
[8.246211051940918, 20.2237491607666, 29.73213768005371], | ||
[14.142135620117188, 15.29705810546875, 15.81138801574707], | ||
[3.1622776985168457, 7.8102498054504395, 15.620499610900879], | ||
[3.605551242828369, 12.083045959472656, 20.124610900878906], | ||
[11.704699516296387, 15.033296585083008, 15.231546401977539] | ||
]) | ||
|
||
{neighbor_indices, neighbor_distances} | ||
end | ||
|
||
describe "fit" do | ||
test "default" do | ||
data = data() | ||
k = 3 | ||
model = BruteKNN.fit(data, num_neighbors: k) | ||
assert model.num_neighbors == 3 | ||
assert model.data == data | ||
assert model.batch_size == nil | ||
end | ||
|
||
test "custom metric and batch_size" do | ||
data = data() | ||
k = 3 | ||
metric = &Scholar.Metrics.Distance.minkowski/2 | ||
batch_size = 2 | ||
model = BruteKNN.fit(data, num_neighbors: k, metric: metric, batch_size: batch_size) | ||
assert model.num_neighbors == k | ||
assert model.metric == metric | ||
assert model.data == data | ||
assert model.batch_size == batch_size | ||
end | ||
end | ||
|
||
describe "predict" do | ||
test "batch_size = 1" do | ||
query = query() | ||
k = 3 | ||
model = BruteKNN.fit(data(), num_neighbors: k, batch_size: 1) | ||
{neighbors_true, distances_true} = result() | ||
{neighbors_pred, distances_pred} = BruteKNN.predict(model, query) | ||
assert neighbors_pred == neighbors_true | ||
assert distances_pred == distances_true | ||
end | ||
|
||
test "batch_size = 2" do | ||
query = query() | ||
k = 3 | ||
model = BruteKNN.fit(data(), num_neighbors: k, batch_size: 2) | ||
{neighbors_true, distances_true} = result() | ||
{neighbors_pred, distances_pred} = BruteKNN.predict(model, query) | ||
assert neighbors_pred == neighbors_true | ||
assert distances_pred == distances_true | ||
end | ||
|
||
test "batch_size = 5" do | ||
query = query() | ||
k = 3 | ||
model = BruteKNN.fit(data(), num_neighbors: k, batch_size: 5) | ||
{neighbors_true, distances_true} = result() | ||
{neighbors_pred, distances_pred} = BruteKNN.predict(model, query) | ||
assert neighbors_pred == neighbors_true | ||
assert distances_pred == distances_true | ||
end | ||
|
||
test "batch_size = 10" do | ||
query = query() | ||
k = 3 | ||
model = BruteKNN.fit(data(), num_neighbors: k, batch_size: 10) | ||
{neighbors_true, distances_true} = result() | ||
{neighbors_pred, distances_pred} = BruteKNN.predict(model, query) | ||
|
||
assert neighbors_pred == | ||
neighbors_true | ||
|
||
assert distances_pred == distances_true | ||
end | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps
Scholar.Options.metric
should be edited to support functions of arity 2?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If everywhere we accept a metric we also accept functions, then yes. The easiest is probably to make it so it always returns a 2-arity function and then we call it. :)