Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch k-NN #253

Closed
wants to merge 2 commits into from
Closed

Batch k-NN #253

wants to merge 2 commits into from

Conversation

krstopro
Copy link
Member

@krstopro krstopro commented Apr 8, 2024

I am honestly not sure where should this be implemented. Right now I added linear_search (a better name might be needed, e.g. brute_force_search) inside Scholar.Neighbors.Utils. Like this it can be used inside other modules such as Scholar.Neighbors.KNearestNeighbors or dimensionality reduction algorithms (t-SNE, Trimap, PacMAP).

The function itself should be documented and tested. More distances are needed. I am just submitting a draft to get some feedback.

Closes #239.

@krstopro krstopro marked this pull request as draft April 8, 2024 19:53
@josevalim
Copy link
Contributor

It looks good to me. For now it is a private module, so we don't have to sweat too much about the name. Does this function have any relationship with find_neighbour? If so, maybe we call one linear_search and the other batch_linear_search?

@krstopro
Copy link
Member Author

krstopro commented Apr 9, 2024

I would rename find_neighbors to linear_search_with_candidates as that is what it really is: it performs linear search, but only on indices specified with candidates tensor.

|> Nx.subtract(Nx.take(data, candidate_indices))

@msluszniak
Copy link
Contributor

There is a function for handling multiple distances that are in t-SNE, NNDescent, and Trimap modules. I guess it might be worth to move it to Scholar.Shared module.

Comment on lines +70 to +72
distances = distance_fn.(data, leftover)
indices = Nx.argsort(distances, axis: 1, type: :u64) |> Nx.slice_along_axis(0, k, axis: 1)
distances = Nx.take_along_axis(distances, indices, axis: 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These lines are common for the main part of the search in the while loop, maybe it is worth moving this into a separate function, but it's up to you since it's just 3 lines

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm considering it. The code still needs some polishing.
I am also thinking of abstracting the whole thing as kind of map over batches. It might be useful for Task 1 of #246 as well.

{
data,
batches,
i = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
i = 0
i = Nx.u64(0)

{query_size, dim} = Nx.shape(query)
num_batches = div(query_size, batch_size)
leftover_size = rem(query_size, batch_size)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part might also need to be moved inside Scholar.Shared as a function that takes a tensor and batch size and returns {batches, leftover} where batches is a tensor of shape {num_batches, batch_size, dim} and leftover is a tensor of shape {leftover_size, dim}.
Might be relevant for #246, task 1.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, definitely. In Python, there is even a function divmod that do that

@josevalim
Copy link
Contributor

linear_search and linear_candidate_search (or linear_neighbour_search) are fine to me! As I said, it is your call, the whole API is private, so focus on what makes the code cleaner :)

@krstopro krstopro mentioned this pull request Apr 11, 2024
4 tasks
@krstopro
Copy link
Member Author

Closing this in favor of #257.

@krstopro krstopro closed this Apr 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement batch version of brute-force k-NN search
3 participants