Skip to content

Commit

Permalink
Complete idx assignment
Browse files Browse the repository at this point in the history
  • Loading branch information
BuildKite committed Dec 11, 2024
1 parent c532368 commit 982f089
Showing 1 changed file with 143 additions and 60 deletions.
203 changes: 143 additions & 60 deletions rs/index/src/ivf/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,29 +131,46 @@ impl PostingListWithStoppingPoints {

impl PartialOrd for PostingListWithStoppingPoints {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.stopping_points.first().and_then(|sp| {
other.stopping_points.first().map(|osp| {
sp.duplicated_vector_idx
.partial_cmp(&osp.duplicated_vector_idx)
})
})?
if let (Some(sp), Some(osp)) = (self.stopping_points.first(), other.stopping_points.first())
{
match sp
.duplicated_vector_idx
.partial_cmp(&osp.duplicated_vector_idx)
{
Some(Ordering::Equal) => self
.posting_list
.iter()
.partial_cmp(other.posting_list.iter()),
other => other,
}
} else {
None
}
}
}

impl Ord for PostingListWithStoppingPoints {
fn cmp(&self, other: &Self) -> Ordering {
match (self.stopping_points.first(), other.stopping_points.first()) {
(Some(sp), Some(osp)) => sp.duplicated_vector_idx.cmp(&osp.duplicated_vector_idx),
_ => panic!("Comparison is only valid when stopping_points is not empty"),
if let (Some(sp), Some(osp)) = (self.stopping_points.first(), other.stopping_points.first())
{
match sp.duplicated_vector_idx.cmp(&osp.duplicated_vector_idx) {
Ordering::Equal => self.posting_list.iter().cmp(other.posting_list.iter()),
other => other,
}
} else {
panic!("Comparison is only valid when stopping_points is not empty");
}
}
}

impl PartialEq for PostingListWithStoppingPoints {
fn eq(&self, other: &Self) -> bool {
match (self.stopping_points.first(), other.stopping_points.first()) {
(Some(sp), Some(osp)) => sp.duplicated_vector_idx == osp.duplicated_vector_idx,
_ => false,
if let (Some(sp), Some(osp)) = (self.stopping_points.first(), other.stopping_points.first())
{
sp.duplicated_vector_idx == osp.duplicated_vector_idx
&& self.posting_list.iter().eq(other.posting_list.iter())
} else {
false
}
}
}
Expand Down Expand Up @@ -478,10 +495,10 @@ impl IvfBuilder {

for list_index in 0..self.posting_lists.len() {
let posting_list = self.posting_lists.get(list_index as u32)?;
lists_with_stopping_points.push(PostingListWithStoppingPoints {
posting_list: posting_list.iter().collect::<Vec<_>>(),
stopping_points: SortedVec::new(),
});
lists_with_stopping_points.push(PostingListWithStoppingPoints::new(
posting_list.iter().collect::<Vec<_>>(),
SortedVec::new(),
));
for (index_in_list, vector_storage_index) in posting_list.iter().enumerate() {
let dup_vec_instance = DuplicatedVectorInstance {
posting_list_idx: list_index,
Expand Down Expand Up @@ -515,36 +532,7 @@ impl IvfBuilder {
Ok(filtered_lists)
}

// [11,12,13]
// [0,2,4,6,8,20]
// [9,18,20]
// [14,15,16,18]
// [1,3,5,7,18,20]
// [10,15,21]
//
// cur_idx = -1
// 15 -> -1 + 2 + 1 = 2
// 14 -> 0
// 10 -> 1
//
// cur_idx = 2
// 18 -> 2 + 6 + 1 = 9
// 9 -> 3
// 16 -> 4
// 1 -> 5
// 3 -> 6
// 5 -> 7
// 7 -> 8
// 18 -> 9
//
// cur_idx = 9
// 20 -> 9 + 5 + 1 = 15
// 0 -> 10
// 2 -> 11
// 4 -> 12
// 6 -> 13
// 8 -> 14
fn reassign_duplicated_vectors(&mut self, assigned_ids: &mut Vec<i32>) -> Result<()> {
fn assign_ids_until_last_stopping_point(&mut self, assigned_ids: &mut Vec<i32>) -> Result<i32> {
let mut min_heap: BinaryHeap<Reverse<PostingListWithStoppingPoints>> = BinaryHeap::from(
self.build_posting_lists_with_stopping_points()?
.into_iter()
Expand Down Expand Up @@ -579,14 +567,13 @@ impl IvfBuilder {
if list_with_stopping_points.stopping_points.len() > 1 {
// Remove reassigned vectors from posting list, remove the min
// duplicated vector idx from stopping point list
min_heap.push(Reverse(PostingListWithStoppingPoints {
posting_list: list_with_stopping_points.posting_list
[idx_in_posting_list + 1..]
min_heap.push(Reverse(PostingListWithStoppingPoints::new(
list_with_stopping_points.posting_list[idx_in_posting_list + 1..]
.to_vec(),
stopping_points: SortedVec::from_unsorted(
SortedVec::from_unsorted(
list_with_stopping_points.stopping_points[1..].to_vec(),
),
}));
)));
}
break;
}
Expand All @@ -603,7 +590,27 @@ impl IvfBuilder {
cur_idx += 1;
}

Ok(())
Ok(cur_idx)
}

/// Assign new ids to the vectors
fn get_reassigned_ids(&mut self) -> Result<Vec<i32>> {
let vector_length = self.vectors.len();
let mut assigned_ids = vec![-1; vector_length];

let mut cur_idx = self.assign_ids_until_last_stopping_point(&mut assigned_ids)?;

for list_index in 0..self.posting_lists.len() {
let posting_list = self.posting_lists.get(list_index as u32)?;
for original_vector_index in posting_list.iter() {
if assigned_ids[original_vector_index as usize] >= 0 {
continue;
}
assigned_ids[original_vector_index as usize] = cur_idx;
cur_idx += 1;
}
}
Ok(assigned_ids)
}

pub fn cleanup(&mut self) -> Result<()> {
Expand Down Expand Up @@ -821,8 +828,8 @@ mod tests {
}

#[test]
fn test_reassign_duplicated_vectors() {
let temp_dir = tempdir::TempDir::new("reassign_duplicated_vectors_test")
fn test_assign_ids_until_last_stopping_point() {
let temp_dir = tempdir::TempDir::new("assign_ids_until_last_stopping_point_test")
.expect("Failed to create temporary directory");
let base_directory = temp_dir
.path()
Expand Down Expand Up @@ -851,19 +858,22 @@ mod tests {
})
.expect("Failed to create builder");

assert!(builder.add_posting_list(&vec![1, 3, 5, 7, 18, 20]).is_ok());
assert!(builder.add_posting_list(&vec![9, 18, 20]).is_ok());
assert!(builder.add_posting_list(&vec![1, 3, 5, 7, 18, 20]).is_ok());
assert!(builder.add_posting_list(&vec![14, 15, 16, 18]).is_ok());
assert!(builder.add_posting_list(&vec![0, 2, 4, 6, 8, 20]).is_ok());
assert!(builder.add_posting_list(&vec![10, 15, 21]).is_ok());

let mut assigned_ids = vec![-1; 22];
assert!(builder
.reassign_duplicated_vectors(&mut assigned_ids)
.is_ok());
assert_eq!(
builder
.assign_ids_until_last_stopping_point(&mut assigned_ids)
.expect("Failed to reassign ids for duplicated vectors"),
16
);

assert_eq!(assigned_ids[14], 0);
assert_eq!(assigned_ids[10], 1);
assert_eq!(assigned_ids[10], 0);
assert_eq!(assigned_ids[14], 1);
assert_eq!(assigned_ids[15], 2);
assert_eq!(assigned_ids[1], 3);
assert_eq!(assigned_ids[3], 4);
Expand All @@ -884,6 +894,79 @@ mod tests {
assert_eq!(assigned_ids[13], -1);
}

#[test]
fn test_get_reassigned_ids() {
let temp_dir = tempdir::TempDir::new("get_reassigned_ids_test")
.expect("Failed to create temporary directory");
let base_directory = temp_dir
.path()
.to_str()
.expect("Failed to convert temporary directory path to string")
.to_string();
let num_clusters = 4;
let num_vectors = 22;
let num_features = 1;
let file_size = 4096;
let balance_factor = 0.0;
let max_posting_list_size = usize::MAX;
let mut builder = IvfBuilder::new(IvfBuilderConfig {
max_iteration: 1000,
batch_size: 4,
num_clusters,
num_data_points: num_vectors,
max_clusters_per_vector: 2,
distance_threshold: 0.1,
base_directory,
memory_size: 1024,
file_size,
num_features,
tolerance: balance_factor,
max_posting_list_size,
})
.expect("Failed to create builder");

for i in 0..num_vectors {
builder
.add_vector(i as u64, generate_random_vector(num_features))
.expect("Vector should be added");
}

assert!(builder.add_posting_list(&vec![11, 12, 13]).is_ok());
assert!(builder.add_posting_list(&vec![0, 2, 4, 6, 8, 20]).is_ok());
assert!(builder.add_posting_list(&vec![9, 18, 20]).is_ok());
assert!(builder.add_posting_list(&vec![14, 15, 16, 18]).is_ok());
assert!(builder.add_posting_list(&vec![1, 3, 5, 7, 18, 20]).is_ok());
assert!(builder.add_posting_list(&vec![10, 15, 21]).is_ok());
assert!(builder.add_posting_list(&vec![10, 15, 17, 19]).is_ok());

let assigned_ids = builder
.get_reassigned_ids()
.expect("Failed to reassign ids for duplicated vectors");

assert_eq!(assigned_ids[10], 0);
assert_eq!(assigned_ids[14], 1);
assert_eq!(assigned_ids[15], 2);
assert_eq!(assigned_ids[1], 3);
assert_eq!(assigned_ids[3], 4);
assert_eq!(assigned_ids[5], 5);
assert_eq!(assigned_ids[7], 6);
assert_eq!(assigned_ids[9], 7);
assert_eq!(assigned_ids[16], 8);
assert_eq!(assigned_ids[18], 9);
assert_eq!(assigned_ids[0], 10);
assert_eq!(assigned_ids[2], 11);
assert_eq!(assigned_ids[4], 12);
assert_eq!(assigned_ids[6], 13);
assert_eq!(assigned_ids[8], 14);
assert_eq!(assigned_ids[20], 15);
assert_eq!(assigned_ids[11], 16);
assert_eq!(assigned_ids[12], 17);
assert_eq!(assigned_ids[13], 18);
assert_eq!(assigned_ids[21], 19);
assert_eq!(assigned_ids[17], 20);
assert_eq!(assigned_ids[19], 21);
}

#[test]
fn test_ivf_builder() {
let temp_dir = tempdir::TempDir::new("ivf_builder_test")
Expand Down

0 comments on commit 982f089

Please sign in to comment.