From 9f08253e2d9613e57151bc9db4c3c5bdb7f07244 Mon Sep 17 00:00:00 2001 From: John Elizarraras Date: Mon, 18 Mar 2024 14:36:40 -0500 Subject: [PATCH 1/4] add NTA to webgestalt --- webgestalt_lib/Cargo.toml | 1 + webgestalt_lib/src/methods.rs | 1 + webgestalt_lib/src/methods/nta.rs | 101 ++++++++++++++++++++++++++++++ 3 files changed, 103 insertions(+) create mode 100644 webgestalt_lib/src/methods/nta.rs diff --git a/webgestalt_lib/Cargo.toml b/webgestalt_lib/Cargo.toml index 51570cd..cf2fbac 100644 --- a/webgestalt_lib/Cargo.toml +++ b/webgestalt_lib/Cargo.toml @@ -17,6 +17,7 @@ rand = { version = "0.8.5", features = ["small_rng"] } rayon = "1.8.0" statrs = "0.16.0" ahash = "0.8.6" +ndarray = { version = "0.15.6", features = ["rayon"] } [dev-dependencies] pretty_assertions = "1.4.0" diff --git a/webgestalt_lib/src/methods.rs b/webgestalt_lib/src/methods.rs index 69983e6..fdd11eb 100644 --- a/webgestalt_lib/src/methods.rs +++ b/webgestalt_lib/src/methods.rs @@ -1,3 +1,4 @@ pub mod gsea; pub mod multilist; +pub mod nta; pub mod ora; diff --git a/webgestalt_lib/src/methods/nta.rs b/webgestalt_lib/src/methods/nta.rs new file mode 100644 index 0000000..15376b1 --- /dev/null +++ b/webgestalt_lib/src/methods/nta.rs @@ -0,0 +1,101 @@ +use ndarray::{Array2, Axis, Zip}; +use std::ops::Div; + +/// A struct representing the options for the NTA algorithm +pub struct NTAOptions { + /// A vector of vectors of strings representing the edge list of the graph + pub edge_list: Vec>, + /// A vector of strings representing the seeds + pub seeds: Vec, + /// An integer representing the neighborhood size + pub neighborhood_size: usize, + /// A float representing the reset probability during random walk (default: 0.5) + pub reset_probability: f64, + /// A float representing the tolerance for probability calculation + pub tolerance: f64, +} + +impl Default for NTAOptions { + fn default() -> Self { + NTAOptions { + edge_list: vec![], + seeds: vec![], + neighborhood_size: 50, + reset_probability: 0.5, + tolerance: 0.000001, + } + } +} + +/// Uses random walk to calculate the neighborhood of a set of nodes +/// Returns [`Vec`]representing the nodes in the neighborhood +/// +/// # Parameters +/// - `config` - A [`NTAOptions`] struct containing the edge list, seeds, neighborhood size, reset probability, and tolerance +pub fn nta(config: NTAOptions) -> Vec { + println!("Building Graph"); + let unique_nodes = ahash::AHashSet::from_iter(config.edge_list.iter().flatten().cloned()); + let mut node_map: ahash::AHashMap = ahash::AHashMap::default(); + let mut reverse_map: ahash::AHashMap = ahash::AHashMap::default(); + for (i, node) in unique_nodes.iter().enumerate() { + node_map.insert(node.clone(), i); + reverse_map.insert(i, node.clone()); + } + let mut graph = Array2::::zeros((unique_nodes.len(), unique_nodes.len())); + for edge in config.edge_list.iter() { + let node1 = node_map.get(&edge[0]).unwrap(); + let node2 = node_map.get(&edge[1]).unwrap(); + graph[[*node1, *node2]] = 1.0; + graph[[*node2, *node1]] = 1.0; + } + println!("Calculating NTA"); + let node_indices: Vec = config + .seeds + .iter() + .map(|seed| *node_map.get(seed).unwrap()) + .collect(); + let walk_res = random_walk_probability( + &graph, + &node_indices, + config.reset_probability, + config.reset_probability, + ); + let walk = walk_res.to_vec(); + let mut top_n = walk.iter().enumerate().collect::>(); + top_n.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap()); + top_n.truncate(config.neighborhood_size); + top_n + .iter() + .map(|(i, _p)| reverse_map.get(i).unwrap().clone()) + .collect() +} + +fn random_walk_probability( + adj_matrix: &ndarray::Array2, + node_indices: &Vec, + r: f64, + tolerance: f64, +) -> ndarray::Array1 { + let num_nodes = node_indices.len() as f64; + let de = adj_matrix.sum_axis(Axis(0)); + // de to 2d array + let de = de.insert_axis(Axis(1)); + let temp = adj_matrix.t().div(de); + let w = temp.t(); + let mut p0 = ndarray::Array1::from_elem(w.shape()[0], 0.0); + for i in node_indices { + p0[*i] = 1.0 / num_nodes; + } + let mut pt = p0.clone(); + let mut pt1 = w.dot(&pt) * (1.0 - r) + (r * &p0); + while Zip::from(&pt1) + .and(&pt) + .par_map_collect(|a, b| (a - b).abs()) + .sum() + > tolerance + { + pt = pt1; + pt1 = w.dot(&pt) * (1.0 - r) + (r * &p0); + } + pt1 +} From 3d8675b2e6bea6cf41e4075233a3fd4bf75840b3 Mon Sep 17 00:00:00 2001 From: John Elizarraras Date: Mon, 18 Mar 2024 15:01:03 -0500 Subject: [PATCH 2/4] no selection of top nodes, just the string and score --- webgestalt_lib/src/methods/nta.rs | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/webgestalt_lib/src/methods/nta.rs b/webgestalt_lib/src/methods/nta.rs index 15376b1..52fa072 100644 --- a/webgestalt_lib/src/methods/nta.rs +++ b/webgestalt_lib/src/methods/nta.rs @@ -7,8 +7,6 @@ pub struct NTAOptions { pub edge_list: Vec>, /// A vector of strings representing the seeds pub seeds: Vec, - /// An integer representing the neighborhood size - pub neighborhood_size: usize, /// A float representing the reset probability during random walk (default: 0.5) pub reset_probability: f64, /// A float representing the tolerance for probability calculation @@ -20,19 +18,24 @@ impl Default for NTAOptions { NTAOptions { edge_list: vec![], seeds: vec![], - neighborhood_size: 50, reset_probability: 0.5, tolerance: 0.000001, } } } +pub struct NTAResult { + pub neighborhood: Vec, + pub scores: Vec, + pub candidates: Vec, +} + /// Uses random walk to calculate the neighborhood of a set of nodes /// Returns [`Vec`]representing the nodes in the neighborhood /// /// # Parameters /// - `config` - A [`NTAOptions`] struct containing the edge list, seeds, neighborhood size, reset probability, and tolerance -pub fn nta(config: NTAOptions) -> Vec { +pub fn nta(config: NTAOptions) -> Vec<(String, f64)> { println!("Building Graph"); let unique_nodes = ahash::AHashSet::from_iter(config.edge_list.iter().flatten().cloned()); let mut node_map: ahash::AHashMap = ahash::AHashMap::default(); @@ -60,13 +63,10 @@ pub fn nta(config: NTAOptions) -> Vec { config.reset_probability, config.reset_probability, ); - let walk = walk_res.to_vec(); - let mut top_n = walk.iter().enumerate().collect::>(); - top_n.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap()); - top_n.truncate(config.neighborhood_size); - top_n - .iter() - .map(|(i, _p)| reverse_map.get(i).unwrap().clone()) + let mut walk = walk_res.iter().enumerate().collect::>(); + walk.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap()); + walk.iter() + .map(|(i, p)| (reverse_map.get(&i).unwrap().clone(), **p)) .collect() } From 800daa68e757a4c0e3f2095bb070a9c3478fe1ea Mon Sep 17 00:00:00 2001 From: John Elizarraras Date: Mon, 18 Mar 2024 15:03:56 -0500 Subject: [PATCH 3/4] update Option to Config --- webgestalt_lib/src/methods/nta.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/webgestalt_lib/src/methods/nta.rs b/webgestalt_lib/src/methods/nta.rs index 52fa072..edf5b7d 100644 --- a/webgestalt_lib/src/methods/nta.rs +++ b/webgestalt_lib/src/methods/nta.rs @@ -2,7 +2,7 @@ use ndarray::{Array2, Axis, Zip}; use std::ops::Div; /// A struct representing the options for the NTA algorithm -pub struct NTAOptions { +pub struct NTAConfig { /// A vector of vectors of strings representing the edge list of the graph pub edge_list: Vec>, /// A vector of strings representing the seeds @@ -13,9 +13,9 @@ pub struct NTAOptions { pub tolerance: f64, } -impl Default for NTAOptions { +impl Default for NTAConfig { fn default() -> Self { - NTAOptions { + NTAConfig { edge_list: vec![], seeds: vec![], reset_probability: 0.5, @@ -35,7 +35,7 @@ pub struct NTAResult { /// /// # Parameters /// - `config` - A [`NTAOptions`] struct containing the edge list, seeds, neighborhood size, reset probability, and tolerance -pub fn nta(config: NTAOptions) -> Vec<(String, f64)> { +pub fn nta(config: NTAConfig) -> Vec<(String, f64)> { println!("Building Graph"); let unique_nodes = ahash::AHashSet::from_iter(config.edge_list.iter().flatten().cloned()); let mut node_map: ahash::AHashMap = ahash::AHashMap::default(); From 9382f0a3a5402cdc94c3858be161a2cefc6573b7 Mon Sep 17 00:00:00 2001 From: John Elizarraras Date: Mon, 18 Mar 2024 16:37:12 -0500 Subject: [PATCH 4/4] update version --- Cargo.lock | 19 +++++++++++++++++-- Cargo.toml | 4 ++-- webgestalt_lib/Cargo.toml | 2 +- 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6ccc524..e6a3ecf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -312,6 +312,20 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "ndarray" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "rawpointer", + "rayon", +] + [[package]] name = "num-complex" version = "0.4.5" @@ -603,7 +617,7 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "webgestalt" -version = "0.2.1" +version = "0.3.0" dependencies = [ "bincode", "clap", @@ -613,10 +627,11 @@ dependencies = [ [[package]] name = "webgestalt_lib" -version = "0.2.1" +version = "0.3.0" dependencies = [ "ahash", "csv", + "ndarray", "pretty_assertions", "rand", "rayon", diff --git a/Cargo.toml b/Cargo.toml index c9fedfa..bba0022 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ [package] name = "webgestalt" -version = "0.2.1" +version = "0.3.0" authors = ["John Elizarraras"] edition = "2021" rust-version = "1.63.0" @@ -16,7 +16,7 @@ repository = "https://github.com/bzhanglab/webgestalt_rust" bincode = "1.3.3" clap = { version = "4.4.15", features = ["derive"] } owo-colors = { version = "4.0.0", features = ["supports-colors"] } -webgestalt_lib = { version = "0.2.0", path = "webgestalt_lib" } +webgestalt_lib = { version = "0.3.0", path = "webgestalt_lib" } [profile.release] opt-level = 3 diff --git a/webgestalt_lib/Cargo.toml b/webgestalt_lib/Cargo.toml index cf2fbac..2fed8a0 100644 --- a/webgestalt_lib/Cargo.toml +++ b/webgestalt_lib/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "webgestalt_lib" -version = "0.2.1" +version = "0.3.0" authors = ["John Elizarraras"] edition = "2021" rust-version = "1.63.0"