diff --git a/Cargo.lock b/Cargo.lock index b631a48..be95ac1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -591,6 +591,7 @@ dependencies = [ "bio", "clap", "include_dir", + "itertools", "noodles", "petgraph", "rusqlite", diff --git a/Cargo.toml b/Cargo.toml index 31ce0d3..8230f76 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" bio = "2.0.0" clap = { version = "4.5.8", features = ["derive"] } include_dir = "0.7.4" +itertools = "0.13.0" rusqlite = { version = "0.31.0", features = ["bundled", "array"] } rusqlite_migration = { version = "1.2.0" , features = ["from-directory"]} sha2 = "0.10.8" diff --git a/migrations/01-initial/up.sql b/migrations/01-initial/up.sql index 3227709..7c39843 100644 --- a/migrations/01-initial/up.sql +++ b/migrations/01-initial/up.sql @@ -98,10 +98,9 @@ CREATE UNIQUE INDEX new_edge_uidx ON new_edges(source_hash, source_coordinate, t CREATE TABLE path_edges ( id INTEGER PRIMARY KEY NOT NULL, path_id INTEGER NOT NULL, - source_edge_id INTEGER, - target_edge_id INTEGER, - FOREIGN KEY(source_edge_id) REFERENCES new_edges(id), - FOREIGN KEY(target_edge_id) REFERENCES new_edges(id), + index_in_path INTEGER NOT NULL, + edge_id INTEGER NOT NULL, + FOREIGN KEY(edge_id) REFERENCES new_edges(id), FOREIGN KEY(path_id) REFERENCES path(id) ); -CREATE UNIQUE INDEX path_edges_uidx ON path_edges(path_id, source_edge_id, target_edge_id); +CREATE UNIQUE INDEX path_edges_uidx ON path_edges(path_id, edge_id); diff --git a/src/main.rs b/src/main.rs index bea2e9d..5c76689 100644 --- a/src/main.rs +++ b/src/main.rs @@ -302,7 +302,7 @@ mod tests { HashSet::from_iter(vec!["ATCGATCGATCGATCGATCGGGAACACACAGAGA".to_string()]) ); assert_eq!( - Path::sequence(&conn, 1), + Path::get_sequence(&conn, 1), "ATCGATCGATCGATCGATCGGGAACACACAGAGA".to_string() ); } diff --git a/src/models.rs b/src/models.rs index 86d6294..1d4abb4 100644 --- a/src/models.rs +++ b/src/models.rs @@ -277,16 +277,8 @@ impl BlockGroup { let sequence_hashes = block_map .values() .map(|block| format!("\"{id}\"", id = block.sequence_hash)) - .collect::>() - .join(","); - let mut sequence_map = HashMap::new(); - for sequence in Sequence::get_sequences( - conn, - &format!("select * from sequence where hash in ({sequence_hashes})"), - vec![], - ) { - sequence_map.insert(sequence.hash, sequence.sequence); - } + .collect::>(); + let sequence_map = Sequence::get_sequences_by_hash(conn, sequence_hashes); let block_ids = block_map .keys() .map(|id| format!("{id}")) @@ -323,7 +315,8 @@ impl BlockGroup { let block = block_map.get(&start_node).unwrap(); let block_sequence = sequence_map.get(&block.sequence_hash).unwrap(); sequences.insert( - block_sequence[(block.start as usize)..(block.end as usize)].to_string(), + block_sequence.sequence[(block.start as usize)..(block.end as usize)] + .to_string(), ); } else { for path in all_simple_paths(&graph, start_node, *end_node) { @@ -332,7 +325,8 @@ impl BlockGroup { let block = block_map.get(&node).unwrap(); let block_sequence = sequence_map.get(&block.sequence_hash).unwrap(); current_sequence.push_str( - &block_sequence[(block.start as usize)..(block.end as usize)], + &block_sequence.sequence + [(block.start as usize)..(block.end as usize)], ); } sequences.insert(current_sequence); @@ -428,20 +422,19 @@ impl BlockGroup { // |----range---| let start_split_point = block.start + start - path_start; let end_split_point = block.start + end - path_start; - let mut next_block; - if start_split_point == block.start { + let next_block = if start_split_point == block.start { if let Some(pb) = previous_block { new_edges.push((Some(pb.id), Some(new_block_id))); } - next_block = block.clone(); + block.clone() } else { let (left_block, right_block) = Block::split(conn, block, start_split_point, chromosome_index, phased) .unwrap(); Block::delete(conn, block.id); new_edges.push((Some(left_block.id), Some(new_block_id))); - next_block = right_block.clone(); - } + right_block.clone() + }; if end_split_point == next_block.start { new_edges.push((Some(new_block_id), Some(next_block.id))); @@ -585,7 +578,6 @@ impl ChangeLog { mod tests { use super::*; use crate::migrations::run_migrations; - use std::hash::Hash; fn get_connection() -> Connection { let mut conn = Connection::open_in_memory() diff --git a/src/models/new_edge.rs b/src/models/new_edge.rs index b31c16f..7bd833c 100644 --- a/src/models/new_edge.rs +++ b/src/models/new_edge.rs @@ -1,7 +1,7 @@ use rusqlite::types::Value; use rusqlite::{params_from_iter, Connection}; -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct NewEdge { pub id: i32, pub source_hash: Option, @@ -85,4 +85,32 @@ impl NewEdge { } } } + + pub fn load(conn: &Connection, edge_ids: Vec) -> Vec { + let formatted_edge_ids = edge_ids + .into_iter() + .map(|edge_id| edge_id.to_string()) + .collect::>() + .join(","); + let query = format!("select id, source_hash, source_coordinate, target_hash, target_coordinate, chromosome_index, phased from new_edges where id in ({});", formatted_edge_ids); + let mut stmt = conn.prepare_cached(&query).unwrap(); + let rows = stmt + .query_map([], |row| { + Ok(NewEdge { + id: row.get(0)?, + source_hash: row.get(1)?, + source_coordinate: row.get(2)?, + target_hash: row.get(3)?, + target_coordinate: row.get(4)?, + chromosome_index: row.get(5)?, + phased: row.get(6)?, + }) + }) + .unwrap(); + let mut objs = vec![]; + for row in rows { + objs.push(row.unwrap()); + } + objs + } } diff --git a/src/models/path.rs b/src/models/path.rs index f2339d9..cf7a538 100644 --- a/src/models/path.rs +++ b/src/models/path.rs @@ -1,9 +1,12 @@ -use crate::models::{block::Block, edge::Edge, path_edge::PathEdge}; +use crate::models::{block::Block, new_edge::NewEdge, path_edge::PathEdge, sequence::Sequence}; use petgraph::graphmap::DiGraphMap; use petgraph::prelude::Dfs; use petgraph::Direction; use rusqlite::types::Value; use rusqlite::{params_from_iter, Connection}; +use std::collections::{HashMap, HashSet}; + +use itertools::Itertools; #[derive(Debug)] pub struct Path { @@ -49,10 +52,12 @@ pub fn revcomp(seq: &str) -> String { #[derive(Clone, Debug)] pub struct NewBlock { pub id: i32, - pub sequence_hash: String, + pub sequence: Sequence, pub block_sequence: String, - pub start: i32, - pub end: i32, + pub sequence_start: i32, + pub sequence_end: i32, + pub path_start: i32, + pub path_end: i32, pub strand: String, } @@ -84,6 +89,33 @@ impl Path { path } + pub fn new_create( + conn: &Connection, + name: &str, + block_group_id: i32, + edge_ids: Vec, + ) -> Path { + let query = "INSERT INTO path (name, block_group_id) VALUES (?1, ?2) RETURNING (id)"; + let mut stmt = conn.prepare(query).unwrap(); + let mut rows = stmt + .query_map((name, block_group_id), |row| { + Ok(Path { + id: row.get(0)?, + name: name.to_string(), + block_group_id, + blocks: vec![], + }) + }) + .unwrap(); + let path = rows.next().unwrap().unwrap(); + + for (index, edge_id) in edge_ids.iter().enumerate() { + PathEdge::create(conn, path.id, index.try_into().unwrap(), *edge_id); + } + + path + } + pub fn get(conn: &mut Connection, path_id: i32) -> Path { let query = "SELECT id, block_group_id, name from path where id = ?1;"; let mut stmt = conn.prepare(query).unwrap(); @@ -120,7 +152,7 @@ impl Path { paths } - pub fn sequence(conn: &Connection, path_id: i32) -> String { + pub fn get_sequence(conn: &Connection, path_id: i32) -> String { let block_ids = PathBlock::get_blocks(conn, path_id); let mut sequence = "".to_string(); for block_id in block_ids { @@ -134,10 +166,98 @@ impl Path { sequence } - pub fn get_new_blocks(conn: &Connection, path_id: i32) -> Vec { - let mut new_blocks = vec![]; - let edges = PathEdge::edges_for(conn, path_id); - new_blocks + pub fn new_get_sequence(conn: &Connection, path: Path) -> String { + let blocks = Path::blocks_for(conn, path); + blocks + .into_iter() + .map(|block| block.block_sequence) + .collect::>() + .join("") + } + + pub fn edge_pairs_to_block( + block_id: i32, + path: &Path, + into: NewEdge, + out_of: NewEdge, + sequences_by_hash: &HashMap, + current_path_length: i32, + ) -> NewBlock { + if into.target_hash.is_none() || out_of.source_hash.is_none() { + panic!( + "Consecutive edges in path {} have None as internal block sequence", + path.id + ); + } + + if into.target_hash != out_of.source_hash { + panic!( + "Consecutive edges in path {0} don't share the same block", + path.id + ); + } + + let sequence = sequences_by_hash.get(&into.target_hash.unwrap()).unwrap(); + let start = into.target_coordinate.unwrap(); + let end = out_of.source_coordinate.unwrap(); + + let strand; + let block_sequence; + + if end >= start { + strand = "+"; + block_sequence = sequence.sequence[start as usize..end as usize].to_string(); + } else { + strand = "-"; + block_sequence = revcomp(&sequence.sequence[end as usize..start as usize + 1]); + } + + NewBlock { + id: block_id, + sequence: sequence.clone(), + block_sequence, + sequence_start: start, + sequence_end: end, + path_start: current_path_length, + path_end: current_path_length + end, + strand: strand.to_string(), + } + } + + pub fn blocks_for(conn: &Connection, path: Path) -> Vec { + let edges = PathEdge::edges_for(conn, path.id); + let mut sequence_hashes = HashSet::new(); + for edge in &edges { + if edge.source_hash.is_some() { + sequence_hashes.insert(edge.source_hash.clone().unwrap()); + } + if edge.target_hash.is_some() { + sequence_hashes.insert(edge.target_hash.clone().unwrap()); + } + } + let sequences_by_hash = Sequence::get_sequences_by_hash( + conn, + sequence_hashes + .into_iter() + .map(|hash| format!("\"{hash}\"")) + .collect(), + ); + + let mut blocks = vec![]; + let mut path_length = 0; + for (index, (into, out_of)) in edges.into_iter().tuple_windows().enumerate() { + let block = Path::edge_pairs_to_block( + index as i32, + &path, + into, + out_of, + &sequences_by_hash, + path_length, + ); + path_length += block.block_sequence.len() as i32; + blocks.push(block); + } + blocks } } @@ -290,7 +410,7 @@ mod tests { use super::*; use crate::migrations::run_migrations; - use crate::models::{sequence::Sequence, BlockGroup, Collection}; + use crate::models::{sequence::Sequence, BlockGroup, Collection, Edge}; fn get_connection() -> Connection { let mut conn = Connection::open_in_memory() @@ -323,7 +443,7 @@ mod tests { block_group.id, vec![block1.id, block2.id, block3.id], ); - assert_eq!(Path::sequence(conn, path.id), "ATCGATCGAAAAAAACCCCCCC"); + assert_eq!(Path::get_sequence(conn, path.id), "ATCGATCGAAAAAAACCCCCCC"); } #[test] @@ -349,7 +469,7 @@ mod tests { block_group.id, vec![block3.id, block2.id, block1.id], ); - assert_eq!(Path::sequence(conn, path.id), "GGGGGGGTTTTTTTCGATCGAT"); + assert_eq!(Path::get_sequence(conn, path.id), "GGGGGGGTTTTTTTCGATCGAT"); } #[test] diff --git a/src/models/path_edge.rs b/src/models/path_edge.rs index 06da710..1413611 100644 --- a/src/models/path_edge.rs +++ b/src/models/path_edge.rs @@ -1,32 +1,28 @@ -use crate::models::new_edge::NewEdge; +use crate::models::{new_edge::NewEdge, path::Path}; use rusqlite::types::Value; use rusqlite::{params_from_iter, Connection}; +use std::collections::HashMap; -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct PathEdge { pub id: i32, pub path_id: i32, - pub source_edge_id: Option, - pub target_edge_id: Option, + pub index_in_path: i32, + pub edge_id: i32, } impl PathEdge { - pub fn create( - conn: &Connection, - path_id: i32, - source_edge_id: Option, - target_edge_id: Option, - ) -> PathEdge { + pub fn create(conn: &Connection, path_id: i32, index_in_path: i32, edge_id: i32) -> PathEdge { let query = - "INSERT INTO path_edges (path_id, source_edge_id, target_edge_id) VALUES (?1, ?2, ?3) RETURNING (id)"; + "INSERT INTO path_edges (path_id, index_in_path, edge_id) VALUES (?1, ?2, ?3) RETURNING (id)"; let mut stmt = conn.prepare(query).unwrap(); let mut rows = stmt - .query_map((path_id, source_edge_id, target_edge_id), |row| { + .query_map((path_id, index_in_path, edge_id), |row| { Ok(PathEdge { id: row.get(0)?, path_id, - source_edge_id, - target_edge_id, + index_in_path, + edge_id, }) }) .unwrap(); @@ -35,31 +31,17 @@ impl PathEdge { Err(rusqlite::Error::SqliteFailure(err, details)) => { if err.code == rusqlite::ErrorCode::ConstraintViolation { println!("{err:?} {details:?}"); - let query; let mut placeholders = vec![path_id]; - if let Some(s) = source_edge_id { - if let Some(t) = target_edge_id { - query = "SELECT id from path_edges where path_id = ?1 AND source_edge_id = ?2 AND target_edge_id = ?3;"; - placeholders.push(s); - placeholders.push(t); - } else { - query = "SELECT id from path_edges where path_id = ?1 AND source_edge_id = ?2 AND target_edge_id is null;"; - placeholders.push(s); - } - } else if let Some(t) = target_edge_id { - query = "SELECT id from path_edges where path_id = ?1 AND source_edge_id is null AND target_edge_id = ?2;"; - placeholders.push(t); - } else { - panic!("No edge ids passed"); - } + let query = "SELECT id from path_edges where path_id = ?1 AND edge_id = ?2;"; + placeholders.push(edge_id); println!("{query} {placeholders:?}"); PathEdge { id: conn .query_row(query, params_from_iter(&placeholders), |row| row.get(0)) .unwrap(), path_id, - source_edge_id, - target_edge_id, + index_in_path, + edge_id, } } else { panic!("something bad happened querying the database") @@ -78,8 +60,8 @@ impl PathEdge { Ok(PathEdge { id: row.get(0)?, path_id: row.get(1)?, - source_edge_id: row.get(2)?, - target_edge_id: row.get(3)?, + index_in_path: row.get(2)?, + edge_id: row.get(3)?, }) }) .unwrap(); @@ -91,12 +73,171 @@ impl PathEdge { } pub fn edges_for(conn: &Connection, path_id: i32) -> Vec { - let edges = vec![]; let path_edges = PathEdge::query( conn, - "select * from path_edges where path_id = ?1", + "select * from path_edges where path_id = ?1 order by index_in_path ASC", vec![Value::from(path_id)], ); - edges + let edge_ids = path_edges.into_iter().map(|path_edge| path_edge.edge_id); + let edges = NewEdge::load(conn, edge_ids.clone().collect()); + let edges_by_id = edges + .into_iter() + .map(|edge| (edge.id, edge)) + .collect::>(); + edge_ids + .into_iter() + .map(|edge_id| edges_by_id[&edge_id].clone()) + .collect::>() + } +} + +mod tests { + use rusqlite::Connection; + // Note this useful idiom: importing names from outer (for mod tests) scope. + use super::*; + + use crate::migrations::run_migrations; + use crate::models::{sequence::Sequence, BlockGroup, Collection}; + + fn get_connection() -> Connection { + let mut conn = Connection::open_in_memory() + .unwrap_or_else(|_| panic!("Error opening in memory test db")); + rusqlite::vtab::array::load_module(&conn).unwrap(); + run_migrations(&mut conn); + conn + } + + #[test] + fn test_gets_sequence() { + let conn = &mut get_connection(); + Collection::create(conn, "test collection"); + let block_group = BlockGroup::create(conn, "test collection", None, "test block group"); + let sequence1_hash = Sequence::create(conn, "DNA", "ATCGATCG", true); + let edge1 = NewEdge::create( + conn, + None, + None, + Some(sequence1_hash.clone()), + Some(0), + 0, + 0, + ); + let sequence2_hash = Sequence::create(conn, "DNA", "AAAAAAAA", true); + let edge2 = NewEdge::create( + conn, + Some(sequence1_hash.clone()), + Some(8), + Some(sequence2_hash.clone()), + Some(1), + 0, + 0, + ); + let sequence3_hash = Sequence::create(conn, "DNA", "CCCCCCCC", true); + let edge3 = NewEdge::create( + conn, + Some(sequence2_hash.clone()), + Some(8), + Some(sequence3_hash.clone()), + Some(1), + 0, + 0, + ); + let sequence4_hash = Sequence::create(conn, "DNA", "GGGGGGGG", true); + let edge4 = NewEdge::create( + conn, + Some(sequence3_hash.clone()), + Some(8), + Some(sequence4_hash.clone()), + Some(1), + 0, + 0, + ); + let edge5 = NewEdge::create( + conn, + Some(sequence4_hash.clone()), + Some(8), + None, + None, + 0, + 0, + ); + + let path = Path::new_create( + conn, + "chr1", + block_group.id, + vec![edge1.id, edge2.id, edge3.id, edge4.id, edge5.id], + ); + assert_eq!( + Path::new_get_sequence(conn, path), + "ATCGATCGAAAAAAACCCCCCCGGGGGGG" + ); + } + + #[test] + fn test_gets_sequence_with_rc() { + let conn = &mut get_connection(); + Collection::create(conn, "test collection"); + let block_group = BlockGroup::create(conn, "test collection", None, "test block group"); + let sequence1_hash = Sequence::create(conn, "DNA", "ATCGATCG", true); + let edge5 = NewEdge::create( + conn, + Some(sequence1_hash.clone()), + Some(0), + None, + None, + 0, + 0, + ); + let sequence2_hash = Sequence::create(conn, "DNA", "AAAAAAAA", true); + let edge4 = NewEdge::create( + conn, + Some(sequence2_hash.clone()), + Some(1), + Some(sequence1_hash.clone()), + Some(7), + 0, + 0, + ); + let sequence3_hash = Sequence::create(conn, "DNA", "CCCCCCCC", true); + let edge3 = NewEdge::create( + conn, + Some(sequence3_hash.clone()), + Some(1), + Some(sequence2_hash.clone()), + Some(7), + 0, + 0, + ); + let sequence4_hash = Sequence::create(conn, "DNA", "GGGGGGGG", true); + let edge2 = NewEdge::create( + conn, + Some(sequence4_hash.clone()), + Some(1), + Some(sequence3_hash.clone()), + Some(7), + 0, + 0, + ); + let edge1 = NewEdge::create( + conn, + None, + None, + Some(sequence4_hash.clone()), + Some(7), + 0, + 0, + ); + + let path = Path::new_create( + conn, + "chr1", + block_group.id, + vec![edge1.id, edge2.id, edge3.id, edge4.id, edge5.id], + ); + assert_eq!( + Path::new_get_sequence(conn, path), + "CCCCCCCGGGGGGGTTTTTTTCGATCGAT" + ); } } diff --git a/src/models/sequence.rs b/src/models/sequence.rs index 1854e54..57a3ef3 100644 --- a/src/models/sequence.rs +++ b/src/models/sequence.rs @@ -1,8 +1,9 @@ use rusqlite::types::Value; use rusqlite::{params_from_iter, Connection}; use sha2::{Digest, Sha256}; +use std::collections::HashMap; -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct Sequence { pub hash: String, pub sequence_type: String, @@ -67,4 +68,21 @@ impl Sequence { } objs } + + pub fn get_sequences_by_hash( + conn: &Connection, + hashes: Vec, + ) -> HashMap { + let mut sequence_map = HashMap::new(); + let joined_hashes = &hashes.join(","); + for sequence in Sequence::get_sequences( + conn, + &format!("select * from sequence where hash in ({0})", joined_hashes), + vec![], + ) { + sequence_map.insert(sequence.hash.clone(), sequence); + } + + sequence_map + } }