From d43814c7b02836c5913cbb28f4e7f326cfff4b35 Mon Sep 17 00:00:00 2001 From: Diego Martin Date: Tue, 28 Nov 2023 19:28:01 +0000 Subject: [PATCH] get_predicate implementation for vec --- src/engine/array.rs | 23 ++++++++++++++++++++--- tests/common/mod.rs | 1 + tests/get_predicate_test.rs | 24 ++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 3 deletions(-) create mode 100644 tests/get_predicate_test.rs diff --git a/src/engine/array.rs b/src/engine/array.rs index d4042fb..8b348f7 100644 --- a/src/engine/array.rs +++ b/src/engine/array.rs @@ -1,4 +1,4 @@ -use sprs::CsVec; +use sprs::{CsVec, CsMat, CsMatBase}; use crate::storage::ZarrArray; @@ -10,8 +10,25 @@ impl EngineStrategy> for ZarrArray { Ok(&self.transpose_view() * &selection) } - fn get_predicate(&self, index: usize) -> EngineResult> { - unimplemented!() + fn get_predicate(&self, index: usize) -> EngineResult> { + + + let mut rows: Vec = vec![]; + let mut cols: Vec = vec![]; + let mut values: Vec = vec![]; + let iterator = self.into_iter(); + for value in iterator { + if *value.0 == index as u8 { + rows.push(value.1.0); + cols.push(value.1.1); + values.push(*value.0); + } + + } + //CsMatBase, Vec, Vec> + let result = CsMat::new(self.shape(),rows, cols, values); + + Ok(result) } fn get_object(&self, index: usize) -> EngineResult> { diff --git a/tests/common/mod.rs b/tests/common/mod.rs index e50f97f..d876071 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -68,6 +68,7 @@ impl Predicate { pub fn get_idx(self, dictionary: &Dictionary) -> u8 { dictionary.get_predicate_idx_unchecked(self.into()) as u8 } + } impl From for &str { diff --git a/tests/get_predicate_test.rs b/tests/get_predicate_test.rs new file mode 100644 index 0000000..4780d18 --- /dev/null +++ b/tests/get_predicate_test.rs @@ -0,0 +1,24 @@ +use remote_hdt::{ + engine::EngineStrategy, + storage::{matrix::MatrixLayout, tabular::TabularLayout, ChunkingStrategy, LocalStorage}, +}; +use sprs::CsVec; + +mod common; + +#[test] +fn get_predicate_tabular_test() { + let mut storage = LocalStorage::new(TabularLayout); + common::setup(common::TABULAR_ZARR, &mut storage, ChunkingStrategy::Chunk); + + let actual = storage + .load_sparse(common::TABULAR_ZARR) + .unwrap() + .get_predicate(common::Predicate::InstanceOf.get_idx(&storage.get_dictionary()) as usize) + .unwrap(); + + assert_eq!( + actual, + CsVec::new(9, vec![0, 1, 2, 7, 8], vec![2, 4, 5, 7, 8]) + ) +}