diff --git a/src/graphs.rs b/src/graphs.rs index 2640ba3..8dcbca3 100644 --- a/src/graphs.rs +++ b/src/graphs.rs @@ -65,6 +65,7 @@ pub struct Graph { pub edges_indexed_to: bool, pub cols_to_keys_to_inds: HashMap>, + pub cols_to_inds_to_keys: HashMap>, pub coo_by_from_edge_to: HashMap<(String, String, String), Vec>>, pub cols_to_features: HashMap>>>, } @@ -90,6 +91,7 @@ impl Graph { edges_indexed_to: false, cols_to_features: HashMap::new(), cols_to_keys_to_inds: HashMap::new(), + cols_to_inds_to_keys: HashMap::new(), coo_by_from_edge_to: HashMap::new(), })) } @@ -157,9 +159,22 @@ impl Graph { self.cols_to_keys_to_inds .insert(col_name.clone(), HashMap::new()); } - let col_inds = self.cols_to_keys_to_inds.get_mut(&col_name).unwrap(); - let cur_ind = col_inds.len(); - col_inds.insert(String::from_utf8(key.clone()).unwrap(), cur_ind); + + if !self.cols_to_inds_to_keys.contains_key(&col_name) { + self.cols_to_inds_to_keys + .insert(col_name.clone(), HashMap::new()); + } + + let keys_to_inds: &mut HashMap = + self.cols_to_keys_to_inds.get_mut(&col_name).unwrap(); + let inds_to_keys: &mut HashMap = + self.cols_to_inds_to_keys.get_mut(&col_name).unwrap(); + + let cur_ind = keys_to_inds.len(); + let cur_key_str = String::from_utf8(key.clone()).unwrap(); + + keys_to_inds.insert(cur_key_str.clone(), cur_ind); + inds_to_keys.insert(cur_ind, cur_key_str); if !self.cols_to_features.contains_key(&col_name) { self.cols_to_features diff --git a/src/lib.rs b/src/lib.rs index ddae9d8..4be0d64 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -46,7 +46,7 @@ fn graph_to_pyg_format(py: Python, request: DataLoadRequest) -> PyResult>, + input: HashMap>, py: Python, ) -> PyResult<&PyDict> { let dict = PyDict::new(py); - input.iter().for_each(|(col_name, inner_map)| { - let inner_dict = PyDict::new(py); - inner_map.iter().for_each(|(key, value)| { - inner_dict.set_item(value, key).unwrap(); - }); - dict.set_item(col_name, inner_dict).unwrap(); - }); + input + .iter() + .for_each(|item| dict.set_item(item.0, item.1).unwrap()); Ok(dict) }