diff --git a/Cargo.lock b/Cargo.lock index b28300a..9e5192c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -806,7 +806,7 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "phenolrs" -version = "0.4.0" +version = "0.4.1" dependencies = [ "anyhow", "bytes", diff --git a/Cargo.toml b/Cargo.toml index 1f95220..6262dfd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "phenolrs" -version = "0.4.0" +version = "0.4.1" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/src/graphs.rs b/src/graphs.rs index 2640ba3..8dcbca3 100644 --- a/src/graphs.rs +++ b/src/graphs.rs @@ -65,6 +65,7 @@ pub struct Graph { pub edges_indexed_to: bool, pub cols_to_keys_to_inds: HashMap>, + pub cols_to_inds_to_keys: HashMap>, pub coo_by_from_edge_to: HashMap<(String, String, String), Vec>>, pub cols_to_features: HashMap>>>, } @@ -90,6 +91,7 @@ impl Graph { edges_indexed_to: false, cols_to_features: HashMap::new(), cols_to_keys_to_inds: HashMap::new(), + cols_to_inds_to_keys: HashMap::new(), coo_by_from_edge_to: HashMap::new(), })) } @@ -157,9 +159,22 @@ impl Graph { self.cols_to_keys_to_inds .insert(col_name.clone(), HashMap::new()); } - let col_inds = self.cols_to_keys_to_inds.get_mut(&col_name).unwrap(); - let cur_ind = col_inds.len(); - col_inds.insert(String::from_utf8(key.clone()).unwrap(), cur_ind); + + if !self.cols_to_inds_to_keys.contains_key(&col_name) { + self.cols_to_inds_to_keys + .insert(col_name.clone(), HashMap::new()); + } + + let keys_to_inds: &mut HashMap = + self.cols_to_keys_to_inds.get_mut(&col_name).unwrap(); + let inds_to_keys: &mut HashMap = + self.cols_to_inds_to_keys.get_mut(&col_name).unwrap(); + + let cur_ind = keys_to_inds.len(); + let cur_key_str = String::from_utf8(key.clone()).unwrap(); + + keys_to_inds.insert(cur_key_str.clone(), cur_ind); + inds_to_keys.insert(cur_ind, cur_key_str); if !self.cols_to_features.contains_key(&col_name) { self.cols_to_features diff --git a/src/lib.rs b/src/lib.rs index ddae9d8..4be0d64 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -46,7 +46,7 @@ fn graph_to_pyg_format(py: Python, request: DataLoadRequest) -> PyResult>, + input: HashMap>, py: Python, ) -> PyResult<&PyDict> { let dict = PyDict::new(py); - input.iter().for_each(|(col_name, inner_map)| { - let inner_dict = PyDict::new(py); - inner_map.iter().for_each(|(key, value)| { - inner_dict.set_item(value, key).unwrap(); - }); - dict.set_item(col_name, inner_dict).unwrap(); - }); + input + .iter() + .for_each(|item| dict.set_item(item.0, item.1).unwrap()); Ok(dict) }