Skip to content

Commit

Permalink
Add generics to HNSW
Browse files Browse the repository at this point in the history
  • Loading branch information
hicder committed Dec 2, 2024
1 parent 643c6e6 commit 0d9b101
Show file tree
Hide file tree
Showing 17 changed files with 157 additions and 95 deletions.
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions rs/cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ tonic.workspace = true
rand.workspace = true
index.workspace = true
anyhow.workspace = true
quantization.workspace = true
3 changes: 2 additions & 1 deletion rs/cli/src/index_viewer.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use clap::Parser;
use index::hnsw::reader::HnswReader;
use index::hnsw::utils::GraphTraversal;
use quantization::pq::ProductQuantizer;

#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
Expand Down Expand Up @@ -31,7 +32,7 @@ pub fn main() {
let points_per_layer_0 = arg.points_per_layer_0;

let reader = HnswReader::new(arg.index_path);
let hnsw = reader.read();
let hnsw = reader.read::<u8, ProductQuantizer>();

let header = hnsw.get_header();
println!("Header: {:?}", header);
Expand Down
69 changes: 32 additions & 37 deletions rs/index/src/hnsw/builder.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use std::cmp::min;
use std::collections::{BinaryHeap, HashMap, HashSet, VecDeque};
use std::vec;

use anyhow::{anyhow, Context, Result};
use bit_vec::BitVec;
use log::debug;
use ordered_float::NotNan;
use quantization::quantization::Quantizer;
use quantization::typing::VectorT;
use rand::Rng;
use utils::distance::l2::L2DistanceCalculatorImpl::StreamingSIMD;

use super::index::Hnsw;
use super::utils::{BuilderContext, GraphTraversal, PointAndDistance};
Expand Down Expand Up @@ -45,53 +46,53 @@ impl Layer {
}

/// The actual builder
pub struct HnswBuilder {
vectors: Box<dyn VectorStorage<u8>>,
pub struct HnswBuilder<T: VectorT<Q>, Q: Quantizer> {
vectors: Box<dyn VectorStorage<T>>,

max_neighbors: usize,
pub layers: Vec<Layer>,
pub current_top_layer: u8,
pub quantizer: Box<dyn Quantizer>,
pub quantizer: Q,
ef_contruction: u32,
pub entry_point: Vec<u32>,
max_layer: u8,
pub doc_id_mapping: Vec<u64>,
}

// TODO(hicder): support bare vector in addition to quantized one.
impl HnswBuilder {
impl<T: VectorT<Q>, Q: Quantizer> HnswBuilder<T, Q> {
pub fn new(
max_neighbors: usize,
max_layers: u8,
ef_construction: u32,
vector_storage_memory_size: usize,
vector_storage_file_size: usize,
num_features: usize,
quantizer: Box<dyn Quantizer>,
quantizer: Q,
base_directory: String,
) -> Self {
let vectors = Box::new(FileBackedAppendableVectorStorage::<u8>::new(
let vectors = Box::new(FileBackedAppendableVectorStorage::<T>::new(
base_directory.clone(),
vector_storage_memory_size,
vector_storage_file_size,
num_features,
));

Self {
vectors: vectors,
vectors,
max_neighbors,
max_layer: max_layers,
layers: vec![],
current_top_layer: 0,
quantizer: quantizer,
quantizer,
ef_contruction: ef_construction,
entry_point: vec![],
doc_id_mapping: Vec::new(),
}
}

pub fn from_hnsw(
hnsw: Hnsw,
hnsw: Hnsw<T, Q>,
output_directory: String,
vector_storage_config: VectorStorageConfig,
max_neighbors: usize,
Expand All @@ -108,7 +109,7 @@ impl HnswBuilder {
}

let tmp_vector_storage_dir = format!("{}/vector_storage_tmp", output_directory);
let mut vector_storage = Box::new(FileBackedAppendableVectorStorage::<u8>::new(
let mut vector_storage = Box::new(FileBackedAppendableVectorStorage::<T>::new(
tmp_vector_storage_dir,
vector_storage_config.memory_threshold,
vector_storage_config.file_size,
Expand All @@ -126,13 +127,12 @@ impl HnswBuilder {
loop {
let layer = &mut layers[current_top_layer as usize];
hnsw.visit(current_top_layer, |from: u32, to: u32| {
let distance = hnsw.quantizer.distance(
hnsw.vector_storage
.get(from as usize, &mut context)
.unwrap(),
hnsw.vector_storage.get(to as usize, &mut context).unwrap(),
utils::distance::l2::L2DistanceCalculatorImpl::StreamingSIMD,
);
let from_v = hnsw
.vector_storage
.get(from as usize, &mut context)
.unwrap();
let to_v = hnsw.vector_storage.get(to as usize, &mut context).unwrap();
let distance = T::distance(from_v, to_v, &hnsw.quantizer);
layer
.edges
.entry(from)
Expand Down Expand Up @@ -171,10 +171,6 @@ impl HnswBuilder {
}
}

pub fn append_vector_to_storage(&mut self, vector: &[u8]) -> Result<()> {
self.vectors.append(vector)
}

fn generate_id(&mut self, doc_id: u64) -> u32 {
let generated_id = self.doc_id_mapping.len() as u32;
self.doc_id_mapping.push(doc_id);
Expand Down Expand Up @@ -284,7 +280,7 @@ impl HnswBuilder {
}

let vector_storage_config = self.vectors.config();
let mut new_vector_storage = Box::new(FileBackedAppendableVectorStorage::<u8>::new(
let mut new_vector_storage = Box::new(FileBackedAppendableVectorStorage::<T>::new(
temp_dir.clone(),
vector_storage_config.memory_threshold,
vector_storage_config.file_size,
Expand All @@ -305,12 +301,12 @@ impl HnswBuilder {

/// Insert a vector into the index
pub fn insert(&mut self, doc_id: u64, vector: &[f32]) -> Result<()> {
let quantized_query = self.quantizer.quantize(vector);
let quantized_query = T::process_vector(vector, &self.quantizer);
let point_id = self.generate_id(doc_id);
let mut context = BuilderContext::new(point_id + 1);

let empty_graph = point_id == 0;
self.append_vector_to_storage(&quantized_query)?;
self.vectors.append(&quantized_query)?;
let layer = self.get_random_layer();

if empty_graph {
Expand Down Expand Up @@ -409,10 +405,10 @@ impl HnswBuilder {
fn distance_two_points(&self, a: u32, b: u32) -> f32 {
let a_vector = self.get_vector(a);
let b_vector = self.get_vector(b);
self.quantizer.distance(a_vector, b_vector, StreamingSIMD)
T::distance(a_vector, b_vector, &self.quantizer)
}

fn get_vector(&self, point_id: u32) -> &[u8] {
fn get_vector(&self, point_id: u32) -> &[T] {
self.vectors.get(point_id).unwrap()
}

Expand Down Expand Up @@ -475,7 +471,7 @@ impl HnswBuilder {
}
}

pub fn vectors(&mut self) -> &mut Box<dyn VectorStorage<u8>> {
pub fn vectors(&mut self) -> &mut Box<dyn VectorStorage<T>> {
&mut self.vectors
}

Expand Down Expand Up @@ -504,12 +500,12 @@ impl HnswBuilder {
}
}

impl GraphTraversal for HnswBuilder {
impl<T: VectorT<Q>, Q: Quantizer> GraphTraversal<T, Q> for HnswBuilder<T, Q> {
type ContextT = BuilderContext;

fn distance(&self, query: &[u8], point_id: u32, _context: &mut BuilderContext) -> f32 {
self.quantizer
.distance(query, self.get_vector(point_id), StreamingSIMD)
fn distance(&self, query: &[T], point_id: u32, _context: &mut BuilderContext) -> f32 {
let point = self.vectors.get(point_id).unwrap();
T::distance(query, point, &self.quantizer)
}

fn get_edges_for_point(&self, point_id: u32, layer: u8) -> Option<Vec<u32>> {
Expand Down Expand Up @@ -598,10 +594,8 @@ mod tests {
max_neighbors: 1,
layers: vec![layer],
current_top_layer: 0,
quantizer: Box::new(
ProductQuantizer::new(10, 2, 1, codebook, base_directory.clone())
.expect("ProductQuantizer should be created."),
),
quantizer: ProductQuantizer::new(10, 2, 1, codebook, base_directory.clone())
.expect("Can't create product quantizer"),
ef_contruction: 0,
entry_point: vec![0, 1],
max_layer: 0,
Expand Down Expand Up @@ -726,7 +720,8 @@ mod tests {

let vector_dir = format!("{}/vectors", base_directory);
fs::create_dir_all(vector_dir.clone()).unwrap();
let mut builder = HnswBuilder::new(5, 10, 20, 1024, 4096, 5, Box::new(pq), vector_dir);
let mut builder =
HnswBuilder::<u8, ProductQuantizer>::new(5, 10, 20, 1024, 4096, 5, pq, vector_dir);

for i in 0..100 {
builder
Expand Down
39 changes: 19 additions & 20 deletions rs/index/src/hnsw/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use std::fs::File;
use log::debug;
use memmap2::Mmap;
use num_traits::ToPrimitive;
use quantization::pq::ProductQuantizerReader;
use quantization::quantization::Quantizer;
use quantization::typing::VectorT;
use rand::Rng;

use super::utils::{GraphTraversal, TraversalContext};
Expand Down Expand Up @@ -36,13 +36,13 @@ impl TraversalContext for SearchContext {
}
}

pub struct Hnsw {
pub struct Hnsw<T: VectorT<Q>, Q: Quantizer> {
// Need this for mmap
#[allow(dead_code)]
backing_file: File,
mmap: Mmap,

pub vector_storage: FixedFileVectorStorage<u8>,
pub vector_storage: FixedFileVectorStorage<T>,

header: Header,
data_offset: usize,
Expand All @@ -52,13 +52,13 @@ pub struct Hnsw {
level_offsets_offset: usize,
doc_id_mapping_offset: usize,

pub quantizer: Box<dyn Quantizer + Send + Sync>,
pub quantizer: Q,
}

impl Hnsw {
impl<T: VectorT<Q>, Q: Quantizer> Hnsw<T, Q> {
pub fn new(
backing_file: File,
vector_storage: FixedFileVectorStorage<u8>,
vector_storage: FixedFileVectorStorage<T>,
header: Header,
data_offset: usize,
edges_offset: usize,
Expand All @@ -69,9 +69,11 @@ impl Hnsw {
base_directory: String,
) -> Self {
// Read quantizer
let pq_directory = format!("{}/quantizer", base_directory);
let pq_reader = ProductQuantizerReader::new(pq_directory);
let pq = pq_reader.read().unwrap();
let quantizer_directory = format!("{}/quantizer", base_directory);
// let pq_reader = ProductQuantizerReader::new(pq_directory);
// let pq = pq_reader.read().unwrap();

let quantizer = Q::read(quantizer_directory).unwrap();

let index_mmap = unsafe { Mmap::map(&backing_file).unwrap() };

Expand All @@ -86,7 +88,7 @@ impl Hnsw {
edge_offsets_offset,
level_offsets_offset,
doc_id_mapping_offset,
quantizer: Box::new(pq),
quantizer,
}
}

Expand All @@ -105,7 +107,7 @@ impl Hnsw {
ef: u32,
context: &mut SearchContext,
) -> Vec<IdWithScore> {
let quantized_query = self.quantizer.quantize(query);
let quantized_query = T::process_vector(query, &self.quantizer);
let mut current_layer: i32 = self.header.num_layers as i32 - 1;
let mut ep = self.get_entry_point_top_layer();
let mut working_set;
Expand Down Expand Up @@ -148,7 +150,7 @@ impl Hnsw {
self.data_offset
}

fn get_vector(&self, point_id: u32, context: &mut SearchContext) -> &[u8] {
fn get_vector(&self, point_id: u32, context: &mut SearchContext) -> &[T] {
self.vector_storage.get(point_id as usize, context).unwrap()
}

Expand Down Expand Up @@ -296,15 +298,12 @@ impl Hnsw {
}
}

impl GraphTraversal for Hnsw {
impl<T: VectorT<Q>, Q: Quantizer> GraphTraversal<T, Q> for Hnsw<T, Q> {
type ContextT = SearchContext;

fn distance(&self, query: &[u8], point_id: u32, context: &mut SearchContext) -> f32 {
self.quantizer.distance(
query,
self.get_vector(point_id, context),
utils::distance::l2::L2DistanceCalculatorImpl::StreamingSIMD,
)
fn distance(&self, query: &[T], point_id: u32, context: &mut SearchContext) -> f32 {
let point = self.get_vector(point_id, context);
T::distance(query, point, &self.quantizer)
}

fn get_edges_for_point(&self, point_id: u32, layer: u8) -> Option<Vec<u32>> {
Expand Down Expand Up @@ -412,7 +411,7 @@ impl GraphTraversal for Hnsw {
}
}

impl Index for Hnsw {
impl<T: VectorT<Q>, Q: Quantizer> Index for Hnsw<T, Q> {
fn search(
&self,
query: &[f32],
Expand Down
Loading

0 comments on commit 0d9b101

Please sign in to comment.