Skip to content

Commit

Permalink
Plug IntSeqDecoderIterator into Ivf (#269)
Browse files Browse the repository at this point in the history
  • Loading branch information
tyb0807 authored Jan 3, 2025
1 parent 478837c commit 4dda8b6
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 49 deletions.
11 changes: 9 additions & 2 deletions rs/compression/src/compression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,19 @@ pub trait IntSeqEncoder {
fn write(&self, writer: &mut BufWriter<&mut File>) -> Result<usize>;
}

pub trait IntSeqDecoderIterator: Iterator {
pub trait IntSeqDecoder {
type IteratorType: Iterator<Item = Self::Item>;
type Item;

/// Creates a decoder
fn new_decoder(encoded_data: &[u8]) -> Self
where
Self: Sized;

/// Creates an iterator that iterates the encoded data and decodes one element at a time on the
/// fly
fn get_iterator(&self) -> Self::IteratorType;

/// Returns the number of elements in the sequence
fn len(&self) -> usize;
fn num_elem(&self) -> usize;
}
40 changes: 29 additions & 11 deletions rs/compression/src/noc/noc.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use std::fs::File;
use std::io::{BufWriter, Write};
use std::ptr::NonNull;

use anyhow::{anyhow, Result};
use utils::io::wrap_write;
use utils::mem::get_ith_val_from_raw_ptr;

use crate::compression::{IntSeqDecoderIterator, IntSeqEncoder};
use crate::compression::{IntSeqDecoder, IntSeqEncoder};

pub struct PlainEncoder {
num_elem: usize,
Expand Down Expand Up @@ -60,32 +61,49 @@ impl IntSeqEncoder for PlainEncoder {
}
}

pub struct PlainDecoderIterator {
pub struct PlainDecoder {
size: usize,
cur_index: usize,
encoded_data_ptr: *const u64,
encoded_data_ptr: NonNull<u64>,
}

impl IntSeqDecoderIterator for PlainDecoderIterator {
impl IntSeqDecoder for PlainDecoder {
type IteratorType = PlainDecodingIterator;
type Item = u64;

fn new_decoder(encoded_data: &[u8]) -> Self {
let encoded_data_ptr = NonNull::new(encoded_data.as_ptr() as *mut u64)
.expect("Encoded data pointer should not be null");
Self {
size: encoded_data.len(),
encoded_data_ptr,
}
}

fn get_iterator(&self) -> Self::IteratorType {
PlainDecodingIterator {
num_elem: self.num_elem(),
cur_index: 0,
encoded_data_ptr: encoded_data.as_ptr() as *const u64,
encoded_data_ptr: self.encoded_data_ptr,
}
}

fn len(&self) -> usize {
self.size
fn num_elem(&self) -> usize {
self.size / std::mem::size_of::<Self::Item>()
}
}

impl Iterator for PlainDecoderIterator {
pub struct PlainDecodingIterator {
num_elem: usize,
cur_index: usize,
encoded_data_ptr: NonNull<u64>,
}

impl Iterator for PlainDecodingIterator {
type Item = u64;

fn next(&mut self) -> Option<Self::Item> {
if self.cur_index < self.size {
let value = get_ith_val_from_raw_ptr(self.encoded_data_ptr, self.cur_index);
if self.cur_index < self.num_elem {
let value = get_ith_val_from_raw_ptr(self.encoded_data_ptr.as_ptr(), self.cur_index);
self.cur_index += 1;
Some(value)
} else {
Expand Down
48 changes: 28 additions & 20 deletions rs/index/src/ivf/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::collections::BinaryHeap;
use std::marker::PhantomData;

use anyhow::{Context, Result};
use compression::compression::IntSeqDecoder;
use quantization::quantization::Quantizer;
use quantization::typing::VectorOps;
use utils::distance::l2::L2DistanceCalculatorImpl::StreamingSIMD;
Expand All @@ -13,7 +14,7 @@ use crate::posting_list::combined_file::FixedIndexFile;
use crate::utils::{IdWithScore, SearchContext};
use crate::vector::fixed_file::FixedFileVectorStorage;

pub struct Ivf<Q: Quantizer, D: DistanceCalculator> {
pub struct Ivf<Q: Quantizer, DC: DistanceCalculator, D: IntSeqDecoder<Item = u64>> {
// The dataset.
pub vector_storage: FixedFileVectorStorage<Q::QuantizedT>,

Expand All @@ -30,10 +31,11 @@ pub struct Ivf<Q: Quantizer, D: DistanceCalculator> {

pub quantizer: Q,

_marker: PhantomData<D>,
_distance_calculator_marker: PhantomData<DC>,
_decoder_marker: PhantomData<D>,
}

impl<Q: Quantizer, D: DistanceCalculator> Ivf<Q, D> {
impl<Q: Quantizer, DC: DistanceCalculator, D: IntSeqDecoder<Item = u64>> Ivf<Q, DC, D> {
pub fn new(
vector_storage: FixedFileVectorStorage<Q::QuantizedT>,
index_storage: FixedIndexFile,
Expand All @@ -45,7 +47,8 @@ impl<Q: Quantizer, D: DistanceCalculator> Ivf<Q, D> {
index_storage,
num_clusters,
quantizer,
_marker: PhantomData,
_distance_calculator_marker: PhantomData,
_decoder_marker: PhantomData,
}
}

Expand All @@ -59,7 +62,7 @@ impl<Q: Quantizer, D: DistanceCalculator> Ivf<Q, D> {
let centroid = index_storage
.get_centroid(i as usize)
.with_context(|| format!("Failed to get centroid at index {}", i))?;
let dist = D::calculate(&vector, &centroid);
let dist = DC::calculate(&vector, &centroid);
distances.push((i as usize, dist));
}
distances.select_nth_unstable_by(num_probes - 1, |a, b| a.1.total_cmp(&b.1));
Expand All @@ -78,15 +81,16 @@ impl<Q: Quantizer, D: DistanceCalculator> Ivf<Q, D> {
if let Ok(byte_slice) = self.index_storage.get_posting_list(centroid) {
let quantized_query = Q::QuantizedT::process_vector(query, &self.quantizer);
let mut results: Vec<IdWithScore> = Vec::new();
for idx in transmute_u8_to_slice::<u64>(byte_slice).iter() {
match self.vector_storage.get(*idx as usize, context) {
let decoder = D::new_decoder(&byte_slice);
for idx in decoder.get_iterator() {
match self.vector_storage.get(idx as usize, context) {
Some(vector) => {
let distance =
self.quantizer
.distance(&quantized_query, vector, StreamingSIMD);
results.push(IdWithScore {
score: distance,
id: *idx,
id: idx,
});
}
None => {}
Expand Down Expand Up @@ -149,7 +153,9 @@ impl<Q: Quantizer, D: DistanceCalculator> Ivf<Q, D> {
}
}

impl<Q: Quantizer, D: DistanceCalculator> Searchable for Ivf<Q, D> {
impl<Q: Quantizer, DC: DistanceCalculator, D: IntSeqDecoder<Item = u64>> Searchable
for Ivf<Q, DC, D>
{
fn search(
&self,
query: &[f32],
Expand Down Expand Up @@ -180,11 +186,12 @@ mod tests {
use std::io::Write;

use anyhow::anyhow;
use compression::noc::noc::PlainDecoder;
use num_traits::ops::bytes::ToBytes;
use quantization::noq::noq::NoQuantizer;
use quantization::pq::pq::ProductQuantizer;
use utils::distance::l2::L2DistanceCalculator;
use utils::mem::transmute_slice_to_u8;
use utils::mem::{transmute_slice_to_u8, transmute_u8_to_slice};

use super::*;

Expand Down Expand Up @@ -341,7 +348,7 @@ mod tests {
let num_clusters = 2;

let quantizer = NoQuantizer::new(3);
let ivf = Ivf::<NoQuantizer, L2DistanceCalculator>::new(
let ivf = Ivf::<_, L2DistanceCalculator, PlainDecoder>::new(
storage,
index_storage,
num_clusters,
Expand Down Expand Up @@ -392,12 +399,13 @@ mod tests {
FixedIndexFile::new(file_path).expect("FixedIndexFile should be created");
let num_probes = 2;

let nearest = Ivf::<NoQuantizer, L2DistanceCalculator>::find_nearest_centroids(
&vector,
&index_storage,
num_probes,
)
.expect("Nearest centroids should be found");
let nearest =
Ivf::<NoQuantizer, L2DistanceCalculator, PlainDecoder>::find_nearest_centroids(
&vector,
&index_storage,
num_probes,
)
.expect("Nearest centroids should be found");

assert_eq!(nearest[0], 1);
assert_eq!(nearest[1], 0);
Expand Down Expand Up @@ -443,7 +451,7 @@ mod tests {
let num_probes = 2;

let quantizer = NoQuantizer::new(num_features);
let ivf: Ivf<NoQuantizer, L2DistanceCalculator> =
let ivf: Ivf<_, L2DistanceCalculator, PlainDecoder> =
Ivf::new(storage, index_storage, num_clusters, quantizer);

let query = vec![2.0, 3.0, 4.0];
Expand Down Expand Up @@ -518,7 +526,7 @@ mod tests {
let num_clusters = 2;
let num_probes = 2;

let ivf: Ivf<ProductQuantizer, L2DistanceCalculator> =
let ivf: Ivf<_, L2DistanceCalculator, PlainDecoder> =
Ivf::new(storage, index_storage, num_clusters, quantizer);

let query = vec![2.0, 3.0, 4.0];
Expand Down Expand Up @@ -571,7 +579,7 @@ mod tests {
let num_probes = 1;

let quantizer = NoQuantizer::new(num_features);
let ivf: Ivf<NoQuantizer, L2DistanceCalculator> =
let ivf: Ivf<_, L2DistanceCalculator, PlainDecoder> =
Ivf::new(storage, index_storage, num_clusters, quantizer);

let query = vec![1.0, 2.0, 3.0];
Expand Down
23 changes: 16 additions & 7 deletions rs/index/src/ivf/reader.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use anyhow::Result;
use compression::compression::IntSeqDecoder;
use quantization::quantization::Quantizer;
use utils::DistanceCalculator;

Expand All @@ -15,7 +16,9 @@ impl IvfReader {
Self { base_directory }
}

pub fn read<Q: Quantizer, D: DistanceCalculator>(&self) -> Result<Ivf<Q, D>> {
pub fn read<Q: Quantizer, DC: DistanceCalculator, D: IntSeqDecoder<Item = u64>>(
&self,
) -> Result<Ivf<Q, DC, D>> {
let index_storage = FixedIndexFile::new(format!("{}/index", self.base_directory))?;

let vector_storage_path = format!("{}/vectors", self.base_directory);
Expand All @@ -30,7 +33,7 @@ impl IvfReader {
let quantizer_directory = format!("{}/quantizer", self.base_directory);
let quantizer = Q::read(quantizer_directory).unwrap();

Ok(Ivf::new(
Ok(Ivf::<_, DC, D>::new(
vector_storage,
index_storage,
num_clusters,
Expand All @@ -43,7 +46,7 @@ impl IvfReader {
mod tests {
use std::fs;

use compression::noc::noc::PlainEncoder;
use compression::noc::noc::{PlainDecoder, PlainEncoder};
use quantization::noq::noq::{NoQuantizer, NoQuantizerWriter};
use tempdir::TempDir;
use utils::distance::l2::L2DistanceCalculator;
Expand All @@ -70,7 +73,10 @@ mod tests {
let num_features = 4;
let file_size = 4096;
let quantizer = NoQuantizer::new(num_features);
let writer = IvfWriter::<_, PlainEncoder, L2DistanceCalculator>::new(base_directory.clone(), quantizer);
let writer = IvfWriter::<_, PlainEncoder, L2DistanceCalculator>::new(
base_directory.clone(),
quantizer,
);

let mut builder: IvfBuilder<L2DistanceCalculator> = IvfBuilder::new(IvfBuilderConfig {
max_iteration: 1000,
Expand Down Expand Up @@ -107,7 +113,7 @@ mod tests {

let reader = IvfReader::new(base_directory.clone());
let index = reader
.read::<NoQuantizer, L2DistanceCalculator>()
.read::<NoQuantizer, L2DistanceCalculator, PlainDecoder>()
.expect("Failed to read index file");

// Check if files were created
Expand Down Expand Up @@ -214,7 +220,10 @@ mod tests {
let noq_writer = NoQuantizerWriter::new(quantizer_directory);
assert!(noq_writer.write(&quantizer).is_ok());

let writer = IvfWriter::<_, PlainEncoder, L2DistanceCalculator>::new(base_directory.clone(), quantizer);
let writer = IvfWriter::<_, PlainEncoder, L2DistanceCalculator>::new(
base_directory.clone(),
quantizer,
);

let mut builder: IvfBuilder<L2DistanceCalculator> = IvfBuilder::new(IvfBuilderConfig {
max_iteration: 1000,
Expand Down Expand Up @@ -243,7 +252,7 @@ mod tests {

let reader = IvfReader::new(base_directory.clone());
let index = reader
.read::<NoQuantizer, L2DistanceCalculator>()
.read::<NoQuantizer, L2DistanceCalculator, PlainDecoder>()
.expect("Failed to read index file");

let num_centroids = index.num_clusters;
Expand Down
4 changes: 0 additions & 4 deletions rs/index/src/posting_list/combined_file.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::marker::PhantomData;
use std::mem::size_of;

use anyhow::{anyhow, Result};
Expand Down Expand Up @@ -26,8 +25,6 @@ pub struct Header {
}

pub struct FixedIndexFile {
_marker: PhantomData<u64>,

mmap: Mmap,
header: Header,
doc_id_mapping_offset: usize,
Expand All @@ -52,7 +49,6 @@ impl FixedIndexFile {
Self::align_to_next_boundary(centroid_offset + header.centroids_len as usize, 8)
+ size_of::<u64>(); // FileBackedAppendablePostingListStorage's first u64 encodes num_clusters
Ok(Self {
_marker: PhantomData,
mmap,
header,
doc_id_mapping_offset,
Expand Down
2 changes: 2 additions & 0 deletions rs/index/src/segment/immutable_segment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,5 @@ impl Searchable for ImmutableSegment {
}

impl SegmentSearchable for ImmutableSegment {}
unsafe impl Send for ImmutableSegment {}
unsafe impl Sync for ImmutableSegment {}
7 changes: 4 additions & 3 deletions rs/index/src/spann/index.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::cmp::Ordering;

use compression::noc::noc::PlainDecoder;
use log::debug;
use quantization::noq::noq::NoQuantizer;
use utils::distance::l2::L2DistanceCalculator;
Expand All @@ -10,13 +11,13 @@ use crate::ivf::index::Ivf;

pub struct Spann {
centroids: Hnsw<NoQuantizer>,
posting_lists: Ivf<NoQuantizer, L2DistanceCalculator>,
posting_lists: Ivf<NoQuantizer, L2DistanceCalculator, PlainDecoder>,
}

impl Spann {
pub fn new(
centroids: Hnsw<NoQuantizer>,
posting_lists: Ivf<NoQuantizer, L2DistanceCalculator>,
posting_lists: Ivf<NoQuantizer, L2DistanceCalculator, PlainDecoder>,
) -> Self {
Self {
centroids,
Expand All @@ -28,7 +29,7 @@ impl Spann {
&self.centroids
}

pub fn get_posting_lists(&self) -> &Ivf<NoQuantizer, L2DistanceCalculator> {
pub fn get_posting_lists(&self) -> &Ivf<NoQuantizer, L2DistanceCalculator, PlainDecoder> {
&self.posting_lists
}
}
Expand Down
Loading

0 comments on commit 4dda8b6

Please sign in to comment.