diff --git a/Cargo.lock b/Cargo.lock index 6ccc524..e6a3ecf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -312,6 +312,20 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "ndarray" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "rawpointer", + "rayon", +] + [[package]] name = "num-complex" version = "0.4.5" @@ -603,7 +617,7 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "webgestalt" -version = "0.2.1" +version = "0.3.0" dependencies = [ "bincode", "clap", @@ -613,10 +627,11 @@ dependencies = [ [[package]] name = "webgestalt_lib" -version = "0.2.1" +version = "0.3.0" dependencies = [ "ahash", "csv", + "ndarray", "pretty_assertions", "rand", "rayon", diff --git a/Cargo.toml b/Cargo.toml index c9fedfa..bba0022 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ [package] name = "webgestalt" -version = "0.2.1" +version = "0.3.0" authors = ["John Elizarraras"] edition = "2021" rust-version = "1.63.0" @@ -16,7 +16,7 @@ repository = "https://github.com/bzhanglab/webgestalt_rust" bincode = "1.3.3" clap = { version = "4.4.15", features = ["derive"] } owo-colors = { version = "4.0.0", features = ["supports-colors"] } -webgestalt_lib = { version = "0.2.0", path = "webgestalt_lib" } +webgestalt_lib = { version = "0.3.0", path = "webgestalt_lib" } [profile.release] opt-level = 3 diff --git a/webgestalt_lib/Cargo.toml b/webgestalt_lib/Cargo.toml index 51570cd..2fed8a0 100644 --- a/webgestalt_lib/Cargo.toml +++ b/webgestalt_lib/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "webgestalt_lib" -version = "0.2.1" +version = "0.3.0" authors = ["John Elizarraras"] edition = "2021" rust-version = "1.63.0" @@ -17,6 +17,7 @@ rand = { version = "0.8.5", features = ["small_rng"] } rayon = "1.8.0" statrs = "0.16.0" ahash = "0.8.6" +ndarray = { version = "0.15.6", features = ["rayon"] } [dev-dependencies] pretty_assertions = "1.4.0" diff --git a/webgestalt_lib/src/methods.rs b/webgestalt_lib/src/methods.rs index 69983e6..fdd11eb 100644 --- a/webgestalt_lib/src/methods.rs +++ b/webgestalt_lib/src/methods.rs @@ -1,3 +1,4 @@ pub mod gsea; pub mod multilist; +pub mod nta; pub mod ora; diff --git a/webgestalt_lib/src/methods/nta.rs b/webgestalt_lib/src/methods/nta.rs new file mode 100644 index 0000000..edf5b7d --- /dev/null +++ b/webgestalt_lib/src/methods/nta.rs @@ -0,0 +1,101 @@ +use ndarray::{Array2, Axis, Zip}; +use std::ops::Div; + +/// A struct representing the options for the NTA algorithm +pub struct NTAConfig { + /// A vector of vectors of strings representing the edge list of the graph + pub edge_list: Vec>, + /// A vector of strings representing the seeds + pub seeds: Vec, + /// A float representing the reset probability during random walk (default: 0.5) + pub reset_probability: f64, + /// A float representing the tolerance for probability calculation + pub tolerance: f64, +} + +impl Default for NTAConfig { + fn default() -> Self { + NTAConfig { + edge_list: vec![], + seeds: vec![], + reset_probability: 0.5, + tolerance: 0.000001, + } + } +} + +pub struct NTAResult { + pub neighborhood: Vec, + pub scores: Vec, + pub candidates: Vec, +} + +/// Uses random walk to calculate the neighborhood of a set of nodes +/// Returns [`Vec`]representing the nodes in the neighborhood +/// +/// # Parameters +/// - `config` - A [`NTAOptions`] struct containing the edge list, seeds, neighborhood size, reset probability, and tolerance +pub fn nta(config: NTAConfig) -> Vec<(String, f64)> { + println!("Building Graph"); + let unique_nodes = ahash::AHashSet::from_iter(config.edge_list.iter().flatten().cloned()); + let mut node_map: ahash::AHashMap = ahash::AHashMap::default(); + let mut reverse_map: ahash::AHashMap = ahash::AHashMap::default(); + for (i, node) in unique_nodes.iter().enumerate() { + node_map.insert(node.clone(), i); + reverse_map.insert(i, node.clone()); + } + let mut graph = Array2::::zeros((unique_nodes.len(), unique_nodes.len())); + for edge in config.edge_list.iter() { + let node1 = node_map.get(&edge[0]).unwrap(); + let node2 = node_map.get(&edge[1]).unwrap(); + graph[[*node1, *node2]] = 1.0; + graph[[*node2, *node1]] = 1.0; + } + println!("Calculating NTA"); + let node_indices: Vec = config + .seeds + .iter() + .map(|seed| *node_map.get(seed).unwrap()) + .collect(); + let walk_res = random_walk_probability( + &graph, + &node_indices, + config.reset_probability, + config.reset_probability, + ); + let mut walk = walk_res.iter().enumerate().collect::>(); + walk.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap()); + walk.iter() + .map(|(i, p)| (reverse_map.get(&i).unwrap().clone(), **p)) + .collect() +} + +fn random_walk_probability( + adj_matrix: &ndarray::Array2, + node_indices: &Vec, + r: f64, + tolerance: f64, +) -> ndarray::Array1 { + let num_nodes = node_indices.len() as f64; + let de = adj_matrix.sum_axis(Axis(0)); + // de to 2d array + let de = de.insert_axis(Axis(1)); + let temp = adj_matrix.t().div(de); + let w = temp.t(); + let mut p0 = ndarray::Array1::from_elem(w.shape()[0], 0.0); + for i in node_indices { + p0[*i] = 1.0 / num_nodes; + } + let mut pt = p0.clone(); + let mut pt1 = w.dot(&pt) * (1.0 - r) + (r * &p0); + while Zip::from(&pt1) + .and(&pt) + .par_map_collect(|a, b| (a - b).abs()) + .sum() + > tolerance + { + pt = pt1; + pt1 = w.dot(&pt) * (1.0 - r) + (r * &p0); + } + pt1 +}