Skip to content

Commit

Permalink
MLP-667 | initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
aMahanna committed May 20, 2024
1 parent d0a86bf commit c61baef
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 12 deletions.
21 changes: 18 additions & 3 deletions src/graphs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ pub struct Graph {
pub edges_indexed_to: bool,

pub cols_to_keys_to_inds: HashMap<String, HashMap<String, usize>>,
pub cols_to_inds_to_keys: HashMap<String, HashMap<usize, String>>,
pub coo_by_from_edge_to: HashMap<(String, String, String), Vec<Vec<usize>>>,
pub cols_to_features: HashMap<String, HashMap<String, Vec<Vec<f64>>>>,
}
Expand All @@ -90,6 +91,7 @@ impl Graph {
edges_indexed_to: false,
cols_to_features: HashMap::new(),
cols_to_keys_to_inds: HashMap::new(),
cols_to_inds_to_keys: HashMap::new(),
coo_by_from_edge_to: HashMap::new(),
}))
}
Expand Down Expand Up @@ -157,9 +159,22 @@ impl Graph {
self.cols_to_keys_to_inds
.insert(col_name.clone(), HashMap::new());
}
let col_inds = self.cols_to_keys_to_inds.get_mut(&col_name).unwrap();
let cur_ind = col_inds.len();
col_inds.insert(String::from_utf8(key.clone()).unwrap(), cur_ind);

if !self.cols_to_inds_to_keys.contains_key(&col_name) {
self.cols_to_inds_to_keys
.insert(col_name.clone(), HashMap::new());
}

let keys_to_inds: &mut HashMap<String, usize> =
self.cols_to_keys_to_inds.get_mut(&col_name).unwrap();
let inds_to_keys: &mut HashMap<usize, String> =
self.cols_to_inds_to_keys.get_mut(&col_name).unwrap();

let cur_ind = keys_to_inds.len();
let cur_key_str = String::from_utf8(key.clone()).unwrap();

keys_to_inds.insert(cur_key_str.clone(), cur_ind);
inds_to_keys.insert(cur_ind, cur_key_str);

if !self.cols_to_features.contains_key(&col_name) {
self.cols_to_features
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ fn graph_to_pyg_format(py: Python, request: DataLoadRequest) -> PyResult<PygComp
construct::construct_cols_to_keys_to_inds(graph.cols_to_keys_to_inds.clone(), py)?;

let cols_to_inds_to_keys =
construct::construct_cols_to_inds_to_keys(graph.cols_to_keys_to_inds, py)?;
construct::construct_cols_to_inds_to_keys(graph.cols_to_inds_to_keys, py)?;

println!("Finished retrieval!");

Expand Down
12 changes: 4 additions & 8 deletions src/output/construct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,12 @@ pub fn construct_cols_to_keys_to_inds(

#[cfg(not(test))]
pub fn construct_cols_to_inds_to_keys(
input: HashMap<String, HashMap<String, usize>>,
input: HashMap<String, HashMap<usize, String>>,
py: Python,
) -> PyResult<&PyDict> {
let dict = PyDict::new(py);
input.iter().for_each(|(col_name, inner_map)| {
let inner_dict = PyDict::new(py);
inner_map.iter().for_each(|(key, value)| {
inner_dict.set_item(value, key).unwrap();
});
dict.set_item(col_name, inner_dict).unwrap();
});
input
.iter()
.for_each(|item| dict.set_item(item.0, item.1).unwrap());
Ok(dict)
}

0 comments on commit c61baef

Please sign in to comment.