diff --git a/migrations/01-initial/up.sql b/migrations/01-initial/up.sql index 2dc058c..8cf0b3c 100644 --- a/migrations/01-initial/up.sql +++ b/migrations/01-initial/up.sql @@ -38,15 +38,13 @@ CREATE UNIQUE INDEX block_uidx ON block(sequence_hash, block_group_id, start, en CREATE TABLE edges ( id INTEGER PRIMARY KEY NOT NULL, - source_id INTEGER NOT NULL, + source_id INTEGER, target_id INTEGER, - origin INTEGER NOT NULL, chromosome_index INTEGER NOT NULL, phased INTEGER NOT NULL, FOREIGN KEY(source_id) REFERENCES block(id), FOREIGN KEY(target_id) REFERENCES block(id), constraint chk_phased check (phased in (0, 1)) - constraint chk_origin check (origin in (0, 1)) ); CREATE UNIQUE INDEX edge_uidx ON edges(source_id, target_id, chromosome_index, phased); diff --git a/src/main.rs b/src/main.rs index b820135..b24dc2d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -80,7 +80,7 @@ fn import_fasta(fasta: &String, name: &String, shallow: bool, conn: &mut Connect (sequence.len() as i32), &"1".to_string(), ); - let edge = models::Edge::create(conn, block.id, None, 1, 0, 0); + let edge = models::Edge::create(conn, Some(block.id), None, 0, 0); } println!("Created it"); } else { diff --git a/src/models.rs b/src/models.rs index e47368e..a57b8bc 100644 --- a/src/models.rs +++ b/src/models.rs @@ -178,9 +178,8 @@ impl Block { #[derive(Debug)] pub struct Edge { pub id: i32, - pub source_id: i32, + pub source_id: Option, pub target_id: Option, - pub origin: i32, pub chromosome_index: i32, pub phased: i32, } @@ -188,28 +187,31 @@ pub struct Edge { impl Edge { pub fn create( conn: &Connection, - source_id: i32, + source_id: Option, target_id: Option, - origin: i32, chromosome_index: i32, phased: i32, ) -> Edge { let mut query; let mut id_query; let mut placeholders: Vec = vec![]; - if target_id.is_some() { - query = "INSERT INTO edges (source_id, target_id, origin, chromosome_index, phased) VALUES (?1, ?2, ?3, ?4, ?5) RETURNING *"; + if target_id.is_some() && source_id.is_some() { + query = "INSERT INTO edges (source_id, target_id, chromosome_index, phased) VALUES (?1, ?2, ?3, ?4) RETURNING *"; id_query = "select id from edges where source_id = ?1 and target_id = ?2 and chromosome_index = ?3 and phased = ?4"; placeholders.push(Value::from(source_id)); placeholders.push(target_id.unwrap().into()); - placeholders.push(origin.into()); + placeholders.push(chromosome_index.into()); + placeholders.push(phased.into()); + } else if target_id.is_some() { + id_query = "select id from edges where target_id = ?1 and source_id is null and chromosome_index = ?2 and phased = ?3"; + query = "INSERT INTO edges (target_id, chromosome_index, phased) VALUES (?1, ?2, ?3) RETURNING *"; + placeholders.push(target_id.into()); placeholders.push(chromosome_index.into()); placeholders.push(phased.into()); } else { id_query = "select id from edges where source_id = ?1 and target_id is null and chromosome_index = ?2 and phased = ?3"; - query = "INSERT INTO edges (source_id, origin, chromosome_index, phased) VALUES (?1, ?2, ?3, ?4) RETURNING *"; + query = "INSERT INTO edges (source_id, chromosome_index, phased) VALUES (?1, ?2, ?3) RETURNING *"; placeholders.push(source_id.into()); - placeholders.push(origin.into()); placeholders.push(chromosome_index.into()); placeholders.push(phased.into()); } @@ -219,9 +221,8 @@ impl Edge { id: row.get(0)?, source_id: row.get(1)?, target_id: row.get(2)?, - origin: row.get(3)?, - chromosome_index: row.get(4)?, - phased: row.get(5)?, + chromosome_index: row.get(3)?, + phased: row.get(4)?, }) }) { Ok(edge) => edge, @@ -234,7 +235,6 @@ impl Edge { .unwrap(), source_id, target_id, - origin, chromosome_index, phased, } @@ -258,7 +258,7 @@ pub struct Path { } impl Path { - pub fn create(conn: &mut Connection, name: &str, block_group_id: i32, edges: Vec) -> Path { + pub fn create(conn: &Connection, name: &str, block_group_id: i32, edges: Vec) -> Path { let query = "INSERT INTO path (name, block_group_id, edges) VALUES (?1, ?2, ?3) RETURNING (id)"; let mut stmt = conn.prepare(query).unwrap(); @@ -300,6 +300,29 @@ impl Path { rows.next().unwrap().unwrap() } + pub fn get_paths(conn: &Connection, query: &str, placeholders: Vec) -> Vec { + let mut stmt = conn.prepare(query).unwrap(); + let mut rows = stmt + .query_map(params_from_iter(placeholders), |row| { + let mut edge_str: String = row.get(3).unwrap(); + Ok(Path { + id: row.get(0)?, + block_group_id: row.get(1)?, + name: row.get(2)?, + edges: edge_str + .split(',') + .map(|v| v.parse::().unwrap()) + .collect::>(), + }) + }) + .unwrap(); + let mut paths = vec![]; + for row in rows { + paths.push(row.unwrap()); + } + paths + } + pub fn edges_to_graph(conn: &mut Connection, edges: &Vec) -> DiGraphMap<(u32), ()> { let edge_str = (*edges) .iter() @@ -310,15 +333,19 @@ impl Path { let mut stmt = conn.prepare(&query).unwrap(); let mut rows = stmt .query_map([], |row| { - let source_id: u32 = row.get(0).unwrap(); - let target_id: u32 = row.get(1).unwrap(); + let source_id: Option = row.get(0).unwrap(); + let target_id: Option = row.get(1).unwrap(); Ok((source_id, target_id)) }) .unwrap(); let mut graph = DiGraphMap::new(); for edge in rows { let (source, target) = edge.unwrap(); - graph.add_edge(source, target, ()); + if let Some(source_value) = source { + if let Some(target_value) = target { + graph.add_edge(source_value, target_value, ()); + } + } } graph } @@ -375,14 +402,14 @@ impl BlockGroup { } } - pub fn clone(conn: &mut Connection, source_id: i32, target_id: i32) { + pub fn clone(conn: &mut Connection, source_block_group_id: i32, target_block_group_id: i32) { let mut stmt = conn .prepare_cached( "SELECT id, sequence_hash, start, end, strand from block where block_group_id = ?1", ) .unwrap(); let mut block_map: HashMap = HashMap::new(); - let mut it = stmt.query([source_id]).unwrap(); + let mut it = stmt.query([source_block_group_id]).unwrap(); let mut row = it.next().unwrap(); while row.is_some() { let block = row.unwrap(); @@ -391,16 +418,14 @@ impl BlockGroup { let start = block.get(2).unwrap(); let end = block.get(3).unwrap(); let strand: String = block.get(4).unwrap(); - let new_block = Block::create(conn, &hash, target_id, start, end, &strand); + let new_block = Block::create(conn, &hash, target_block_group_id, start, end, &strand); block_map.insert(block_id, new_block.id); row = it.next().unwrap(); } // todo: figure out rusqlite's rarray let mut stmt = conn - .prepare_cached( - "SELECT source_id, target_id, origin from edges where source_id IN (?1)", - ) + .prepare_cached("SELECT id, source_id, target_id from edges where source_id IN (?1)") .unwrap(); let block_keys = block_map .keys() @@ -409,34 +434,62 @@ impl BlockGroup { .join(", "); let mut it = stmt.query([block_keys]).unwrap(); let mut row = it.next().unwrap(); + let mut edge_map = HashMap::new(); while row.is_some() { let edge = row.unwrap(); - let source_id: i32 = edge.get(0).unwrap(); - let target_id: Option = edge.get(1).unwrap(); - let origin: i32 = edge.get(2).unwrap(); - if (target_id.is_some()) { + let edge_id: i32 = edge.get(0).unwrap(); + let source_id: Option = edge.get(1).unwrap(); + let target_id: Option = edge.get(2).unwrap(); + let mut new_edge; + if target_id.is_some() && source_id.is_some() { let target_id = target_id.unwrap(); - Edge::create( + let source_id = source_id.unwrap(); + new_edge = Edge::create( conn, - *block_map.get(&source_id).unwrap_or(&source_id), + Some(*block_map.get(&source_id).unwrap_or(&source_id)), Some(*block_map.get(&target_id).unwrap_or(&target_id)), - origin, 0, 0, ); - } else { - Edge::create( + } else if (target_id.is_some()) { + let target_id = target_id.unwrap(); + new_edge = Edge::create( + conn, + None, + Some(*block_map.get(&target_id).unwrap_or(&target_id)), + 0, + 0, + ); + } else if source_id.is_some() { + let source_id = source_id.unwrap(); + new_edge = Edge::create( conn, - *block_map.get(&source_id).unwrap_or(&source_id), + Some(*block_map.get(&source_id).unwrap_or(&source_id)), None, - origin, 0, 0, ); + } else { + panic!("no source and target specified."); } + edge_map.insert(edge_id, new_edge.id); row = it.next().unwrap(); } + + let existing_paths = Path::get_paths( + conn, + "SELECT * from path where block_group_id = ?1", + vec![Value::from(source_block_group_id)], + ); + + for path in existing_paths { + let mut new_edges = vec![]; + for edge in path.edges { + new_edges.push(*edge_map.get(&edge).unwrap()); + } + Path::create(conn, &path.name, target_block_group_id, new_edges); + } } pub fn get_or_create_sample_block_group( @@ -474,12 +527,44 @@ impl BlockGroup { } let new_bg_id = BlockGroup::create(conn, collection_name, Some(sample_name), group_name); - // clone parent blocks/edges + // clone parent blocks/edges/path BlockGroup::clone(conn, bg_id, new_bg_id.id); new_bg_id.id } + pub fn get_all_sequences(conn: &Connection, block_group_id: i32) -> Vec { + let blocks_query = "WITH RECURSIVE traverse(block_id, sequence, block_start, block_end, block_strand, depth, global_start, global_end) AS ( + SELECT edges.target_id as start_block, seq.sequence, block.start, block.end, block.strand, 0 as depth, 0 as global_start, block.end - block.start as global_end FROM block_group left join block on (block_group.id = block.block_group_id) left join sequence seq on (seq.hash = block.sequence_hash) left join edges on (block.id = edges.target_id) WHERE block.block_group_id = ?1 AND edges.source_id is null + UNION ALL + SELECT e2.target_id, seq2.sequence, b2.start, b2.end, b2.strand, t2.depth + 1, t2.global_end, t2.global_end + b2.end - b2.start FROM edges e2 left join block b2 on (b2.id = e2.target_id) left join sequence seq2 on (seq2.hash = b2.sequence_hash) JOIN traverse t2 ON e2.source_id = t2.block_id where e2.target_id is not null order by depth desc + ) SELECT sequence, block_start, block_end, block_strand, depth FROM traverse;"; + let mut stmt = conn.prepare_cached(blocks_query).unwrap(); + let mut sequences = vec![]; + let rows = stmt + .query_map([block_group_id], |row| { + Ok(( + row.get(0)?, + row.get(1)?, + row.get(2)?, + row.get(3)?, + row.get(4)?, + )) + }) + .unwrap(); + let mut seq_index = 0; + for row in rows { + let (seq, start, end, strand, depth): (String, usize, usize, String, u32) = + row.unwrap(); + if depth == 0 { + seq_index = sequences.len(); + sequences.push("".to_string()); + } + sequences[seq_index].push_str(&seq[start..end]); + } + sequences + } + #[allow(clippy::ptr_arg)] #[allow(clippy::too_many_arguments)] pub fn insert_change( @@ -532,13 +617,13 @@ impl BlockGroup { ); row = it.next().unwrap(); } + // TODO: probably don't need the graph, just get vector of source_ids. let mut dfs = Dfs::new(&graph, path.edges[0] as u32); let mut path_start = 0; let mut path_end = 0; let mut new_edges = vec![]; let mut previous_block: Option<&Block> = None; let mut next_node = dfs.next(&graph); - // while let Some(nx) = dfs.next(&graph) { while next_node.is_some() { let nx = next_node.unwrap(); let block = blocks.get(&(nx as i32)).unwrap(); @@ -550,6 +635,7 @@ impl BlockGroup { let contains_start = path_start <= start && start < path_end; let contains_end = path_start <= end && end < path_end; + let overlap = path_start <= end && start <= path_end; if contains_start && contains_end { // our range is fully contained w/in the block @@ -560,7 +646,7 @@ impl BlockGroup { &block.sequence_hash, block_group_id, block.start, - start - block.start, + start - path_start, &block.strand, ); let right_block = Block::create( @@ -572,10 +658,10 @@ impl BlockGroup { &block.strand, ); if let Some(value) = previous_block { - new_edges.push((value.id, left_block.id, 0)) + new_edges.push((Some(value.id), Some(left_block.id))) } - new_edges.push((left_block.id, new_block_id, 0)); - new_edges.push((new_block_id, right_block.id, 0)); + new_edges.push((Some(left_block.id), Some(new_block_id))); + new_edges.push((Some(new_block_id), Some(right_block.id))); } else if contains_start { // our range is overlapping the end of the block // |----block---| @@ -589,9 +675,11 @@ impl BlockGroup { &block.strand, ); if let Some(value) = previous_block { - new_edges.push((value.id, left_block.id, 0)); + new_edges.push((Some(value.id), Some(left_block.id))); + } else { + new_edges.push((None, Some(left_block.id))); } - new_edges.push((left_block.id, new_block_id, 0)); + new_edges.push((Some(left_block.id), Some(new_block_id))); } else if contains_end { // our range is overlapping the beginning of the block // |----block---| @@ -600,22 +688,24 @@ impl BlockGroup { conn, &block.sequence_hash, block_group_id, - end, + path_end - end, block.end, &block.strand, ); // what stuff went to this block? - new_edges.push((new_block_id, right_block.id, 0)); + new_edges.push((Some(new_block_id), Some(right_block.id))); let last_node = dfs.next(&graph); if last_node.is_some() { let next_block = blocks.get(&(last_node.unwrap() as i32)).unwrap(); - new_edges.push((right_block.id, next_block.id, 0)); + new_edges.push((Some(right_block.id), Some(next_block.id))); } break; - } else { + } else if overlap { // our range is the whole block, ignore it // |--block---| // |-----range------| + } else { + // not yet at the range } path_start += block_length; @@ -629,14 +719,7 @@ impl BlockGroup { println!("change is {path:?} {graph:?} {blocks:?} {new_edges:?}"); for new_edge in new_edges { - Edge::create( - conn, - new_edge.0, - Some(new_edge.1), - new_edge.2, - chromosome_index, - phased, - ); + Edge::create(conn, new_edge.0, new_edge.1, chromosome_index, phased); } } @@ -718,14 +801,16 @@ mod tests { let t_block = Block::create(conn, &t_seq_hash, block_group.id, 0, 10, &"1".to_string()); let c_block = Block::create(conn, &c_seq_hash, block_group.id, 0, 10, &"1".to_string()); let g_block = Block::create(conn, &g_seq_hash, block_group.id, 0, 10, &"1".to_string()); - let edge_1 = Edge::create(conn, a_block.id, Some(t_block.id), 1, 0, 0); - let edge_2 = Edge::create(conn, t_block.id, Some(c_block.id), 0, 0, 0); - let edge_3 = Edge::create(conn, c_block.id, Some(g_block.id), 0, 0, 0); + let edge_0 = Edge::create(conn, None, Some(a_block.id), 0, 0); + let edge_1 = Edge::create(conn, Some(a_block.id), Some(t_block.id), 0, 0); + let edge_2 = Edge::create(conn, Some(t_block.id), Some(c_block.id), 0, 0); + let edge_3 = Edge::create(conn, Some(c_block.id), Some(g_block.id), 0, 0); + let edge_4 = Edge::create(conn, Some(g_block.id), None, 0, 0); let path = Path::create( conn, "chr1", block_group.id, - vec![edge_1.id, edge_2.id, edge_3.id], + vec![edge_0.id, edge_1.id, edge_2.id, edge_3.id, edge_4.id], ); (block_group.id, path.id) } @@ -747,38 +832,13 @@ mod tests { ); BlockGroup::insert_change(&mut conn, path_id, 7, 15, insert.id, 1, 0); - let blocks_query = "WITH RECURSIVE traverse(block_id, sequence, block_start, block_end, depth, global_start, global_end) AS ( - SELECT edges.source_id, seq.sequence, block.start, block.end, 0 as depth, 0 as global_start, block.end - block.start as global_end FROM block_group left join block on (block_group.id = block.block_group_id) left join sequence seq on (seq.hash = block.sequence_hash) left join edges on (block.id = edges.source_id) WHERE block.block_group_id = ?1 AND edges.origin = 1 - UNION ALL - SELECT e2.target_id, seq2.sequence, b2.start, b2.end, t2.depth + 1, t2.global_end, t2.global_end + b2.end - b2.start FROM edges e2 left join block b2 on (b2.id = e2.target_id) left join sequence seq2 on (seq2.hash = b2.sequence_hash) JOIN traverse t2 ON e2.source_id = t2.block_id where e2.target_id is not null order by depth desc - ) SELECT block_id, sequence, block_start, block_end, depth, global_start, global_end FROM traverse;"; - let mut stmt = conn.prepare_cached(blocks_query).unwrap(); - - #[derive(Debug)] - struct BlockInfo { - id: i32, - sequence: String, - block_start: i32, - block_end: i32, - depth: i32, - global_start: i32, - global_end: i32, - } - let rows = stmt - .query_map([block_group_id], |row| { - Ok(BlockInfo { - id: row.get(0).unwrap(), - sequence: row.get(1).unwrap(), - block_start: row.get(2).unwrap(), - block_end: row.get(3).unwrap(), - depth: row.get(4).unwrap(), - global_start: row.get(5).unwrap(), - global_end: row.get(6).unwrap(), - }) - }) - .unwrap(); - for block in rows { - println!("{block:?}"); - } + let all_sequences = BlockGroup::get_all_sequences(&conn, block_group_id); + assert_eq!( + all_sequences, + [ + "AAAAAAAAAATTTTTTTTTTCCCCCCCCCCGGGGGGGGGG", + "AAAAAAANNNNTTTTTCCCCCCCCCCGGGGGGGGGG" + ] + ) } }