diff --git a/migrations/core/01-initial/up.sql b/migrations/core/01-initial/up.sql index 4df5d4d..bb6dff8 100644 --- a/migrations/core/01-initial/up.sql +++ b/migrations/core/01-initial/up.sql @@ -38,6 +38,13 @@ CREATE TABLE block_group ( CREATE UNIQUE INDEX block_group_uidx ON block_group(collection_name, sample_name, name) WHERE sample_name is not null; CREATE UNIQUE INDEX block_group_null_sample_uidx ON block_group(collection_name, name) WHERE sample_name is null; +CREATE TABLE block_group_lineage ( + id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + parent_id INTEGER NOT NULL, + child_id INTEGER NOT NULL +) STRICT; +CREATE UNIQUE INDEX block_group_lineage_uidx ON block_group_lineage(parent_id, child_id); + CREATE TABLE path ( id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, block_group_id INTEGER NOT NULL, diff --git a/src/models/block_group.rs b/src/models/block_group.rs index 3406059..91077ad 100644 --- a/src/models/block_group.rs +++ b/src/models/block_group.rs @@ -2,6 +2,8 @@ use std::collections::{HashMap, HashSet}; use intervaltree::IntervalTree; use itertools::Itertools; +use petgraph::graphmap::DiGraphMap; +use petgraph::visit::{depth_first_search, Dfs, DfsEvent, IntoEdgesDirected, Reversed}; use petgraph::Direction; use rusqlite::{params_from_iter, types::Value as SQLValue, Connection}; use serde::{Deserialize, Serialize}; @@ -163,6 +165,79 @@ impl BlockGroup { objs } + pub fn add_relation(conn: &Connection, source_block_group_id: i64, target_block_group_id: i64) { + let query = + "INSERT OR IGNORE INTO block_group_lineage (parent_id, child_id) values (?1, ?2)"; + let mut stmt = conn.prepare(query).unwrap(); + stmt.execute((source_block_group_id, target_block_group_id)) + .unwrap(); + } + + pub fn get_children(conn: &Connection, block_group_id: i64) -> Vec { + let query = "select child_id from block_group_lineage where parent_id = ?1;"; + let mut stmt = conn.prepare(query).unwrap(); + let mut children = vec![]; + for row in stmt.query_map((block_group_id,), |row| row.get(0)).unwrap() { + children.push(row.unwrap()); + } + children + } + + pub fn get_parents(conn: &Connection, block_group_id: i64) -> Vec { + let query = "select parent_id from block_group_lineage where child_id = ?1;"; + let mut stmt = conn.prepare(query).unwrap(); + let mut ids = vec![]; + for row in stmt.query_map((block_group_id,), |row| row.get(0)).unwrap() { + ids.push(row.unwrap()); + } + ids + } + + pub fn get_ancestors(conn: &Connection, block_group_id: i64) -> Vec> { + let query = "WITH RECURSIVE ancestors(parent_id, child_id, depth) AS ( \ + VALUES(?1, NULL, 0) UNION \ + select bgt.parent_id, bgt.child_id, ancestors.depth+1 from block_group_lineage bgt join ancestors ON bgt.child_id=ancestors.parent_id \ + ) SELECT parent_id, child_id, depth from ancestors where child_id is not null ORDER BY depth, parent_id DESC;"; + let mut stmt = conn.prepare(query).unwrap(); + + let mut graph: DiGraphMap = DiGraphMap::new(); + + for row in stmt + .query_map((block_group_id,), |row| { + Ok((row.get(0)?, row.get(1)?, row.get(2)?)) + }) + .unwrap() + { + let (parent_id, child_id, depth): (i64, i64, i64) = row.unwrap(); + graph.add_node(parent_id); + graph.add_node(child_id); + graph.add_edge(parent_id, child_id, ()); + } + + let rev = Reversed(&graph); + let mut paths = vec![]; + let mut current_path = vec![]; + depth_first_search(&rev, Some(block_group_id), |event| { + if let DfsEvent::TreeEdge(u, v) = event { + current_path.push(v); + if rev.edges_directed(v, Direction::Outgoing).next().is_none() { + paths.push(current_path.clone()); + current_path.clear(); + } + } else if let DfsEvent::CrossForwardEdge(u, v) = event { + if u != block_group_id { + current_path.push(u); + } + if rev.edges_directed(v, Direction::Outgoing).next().is_none() { + current_path.push(v); + paths.push(current_path.clone()); + current_path.clear(); + } + } + }); + paths + } + pub fn get_by_id(conn: &Connection, id: i64) -> BlockGroup { let query = "SELECT * FROM block_group WHERE id = ?1"; let mut stmt = conn.prepare(query).unwrap(); @@ -233,6 +308,8 @@ impl BlockGroup { &edges.iter().map(|ap| ap.edge_id).collect::>(), ); } + + BlockGroup::add_relation(conn, source_block_group_id, target_block_group_id); } pub fn get_or_create_sample_block_group( @@ -308,6 +385,15 @@ impl BlockGroup { } } + pub fn get_graph(conn: &Connection, block_group_id: i64) -> DiGraphMap { + let mut edges = BlockGroupEdge::edges_for_block_group(conn, block_group_id); + let blocks = Edge::blocks_from_edges(conn, &edges); + let boundary_edges = Edge::boundary_edges_from_sequences(&blocks); + edges.extend(boundary_edges.clone()); + let (graph, _) = Edge::build_graph(&edges, &blocks); + graph + } + pub fn get_all_sequences(conn: &Connection, block_group_id: i64) -> HashSet { let mut edges = BlockGroupEdge::edges_for_block_group(conn, block_group_id); let blocks = Edge::blocks_from_edges(conn, &edges); @@ -1534,4 +1620,44 @@ mod tests { ]) ); } + + #[test] + fn test_adds_relation_on_clone() { + let conn = &get_connection(None); + let (block_group_id, path) = setup_block_group(conn); + let new_bg = BlockGroup::create(conn, "test", None, "test2"); + let new_bg_id = new_bg.id; + BlockGroup::clone(conn, block_group_id, new_bg_id); + assert_eq!( + BlockGroup::get_children(conn, block_group_id), + vec![new_bg_id] + ); + assert_eq!( + BlockGroup::get_parents(conn, new_bg_id), + vec![block_group_id] + ); + } + + #[test] + fn test_finds_all_ancestors() { + let conn = &get_connection(None); + let collection = Collection::create(conn, "test"); + let bg1 = BlockGroup::create(conn, "test", None, "test1"); + let bg2 = BlockGroup::create(conn, "test", None, "test2"); + let bg3 = BlockGroup::create(conn, "test", None, "test3"); + let bg4 = BlockGroup::create(conn, "test", None, "test4"); + BlockGroup::add_relation(conn, bg1.id, bg2.id); + BlockGroup::add_relation(conn, bg2.id, bg3.id); + BlockGroup::add_relation(conn, bg3.id, bg4.id); + BlockGroup::add_relation(conn, bg1.id, bg4.id); + BlockGroup::add_relation(conn, bg1.id, bg3.id); + assert_eq!( + BlockGroup::get_ancestors(conn, bg4.id), + vec![ + vec![bg3.id, bg2.id, bg1.id], + vec![bg3.id, bg1.id], + vec![bg1.id] + ] + ); + } }