From 532c63c4155dc14674cb1eca31c281701ed47722 Mon Sep 17 00:00:00 2001 From: hofer Date: Thu, 1 Aug 2024 18:20:18 -0400 Subject: [PATCH] Add code to split blocks and update edges --- migrations/01-initial/up.sql | 2 +- src/models/block.rs | 454 +++++++++++++++++++++++++++++++++++ src/models/edge.rs | 86 +++++++ 3 files changed, 541 insertions(+), 1 deletion(-) diff --git a/migrations/01-initial/up.sql b/migrations/01-initial/up.sql index 20a84fd..d507ca8 100644 --- a/migrations/01-initial/up.sql +++ b/migrations/01-initial/up.sql @@ -38,7 +38,7 @@ CREATE TABLE block ( strand TEXT NOT NULL DEFAULT "1", FOREIGN KEY(sequence_hash) REFERENCES sequence(hash), FOREIGN KEY(block_group_id) REFERENCES block_group(id), - constraint chk_strand check (strand in ('-1', '1', '0', '.', '?')) + constraint chk_strand check (strand in ('-1', '1', '0', '.', '?', '+', '-')) ); CREATE UNIQUE INDEX block_uidx ON block(sequence_hash, block_group_id, start, end, strand); diff --git a/src/models/block.rs b/src/models/block.rs index 3bc3d85..0802f99 100644 --- a/src/models/block.rs +++ b/src/models/block.rs @@ -1,5 +1,7 @@ use rusqlite::Connection; +use crate::models::edge::{Edge, UpdatedEdge}; + #[derive(Debug)] pub struct Block { pub id: i32, @@ -58,4 +60,456 @@ impl Block { } } } + + pub fn edges_into(conn: &Connection, block_id: i32) -> Vec { + let edge_query = "select id, source_id, target_id, chromosome_index, phased from edges where target_id = ?1;"; + let mut stmt = conn.prepare_cached(edge_query).unwrap(); + + let mut edges: Vec = vec![]; + let mut it = stmt.query([block_id]).unwrap(); + let mut row = it.next().unwrap(); + while row.is_some() { + let edge = row.unwrap(); + let edge_id: i32 = edge.get(0).unwrap(); + let source_block_id: i32 = edge.get(1).unwrap(); + let target_block_id: i32 = edge.get(2).unwrap(); + let chromosome_index: i32 = edge.get(3).unwrap(); + let phased: i32 = edge.get(4).unwrap(); + edges.push(Edge { + id: edge_id, + source_id: source_block_id, + target_id: Some(target_block_id), + chromosome_index, + phased, + }); + row = it.next().unwrap(); + } + + edges + } + + pub fn edges_out_of(conn: &Connection, block_id: i32) -> Vec { + let edge_query = "select id, source_id, target_id, chromosome_index, phased from edges where source_id = ?1;"; + let mut stmt = conn.prepare_cached(edge_query).unwrap(); + + let mut edges: Vec = vec![]; + let mut it = stmt.query([block_id]).unwrap(); + let mut row = it.next().unwrap(); + while row.is_some() { + let edge = row.unwrap(); + let edge_id: i32 = edge.get(0).unwrap(); + let source_block_id: i32 = edge.get(1).unwrap(); + let target_block_id: i32 = edge.get(2).unwrap(); + let chromosome_index: i32 = edge.get(3).unwrap(); + let phased: i32 = edge.get(4).unwrap(); + edges.push(Edge { + id: edge_id, + source_id: source_block_id, + target_id: Some(target_block_id), + chromosome_index, + phased, + }); + row = it.next().unwrap(); + } + + edges + } + + pub fn split( + conn: &Connection, + block: Block, + coordinate: i32, + chromosome_index: i32, + phased: i32, + ) -> Option<(Block, Block)> { + if coordinate < block.start || coordinate >= block.end { + println!("Coordinate is out of block bounds"); + return None; + } + let new_left_block = Block::create( + conn, + &block.sequence_hash, + block.block_group_id, + block.start, + coordinate, + &block.strand, + ); + let new_right_block = Block::create( + conn, + &block.sequence_hash, + block.block_group_id, + coordinate, + block.end, + &block.strand, + ); + + let mut replacement_edges: Vec = vec![]; + + let edges_into = Block::edges_into(conn, block.id); + + for edge in edges_into.iter() { + replacement_edges.push(UpdatedEdge { + id: edge.id, + new_source_id: Some(edge.source_id), + new_target_id: Some(new_left_block.id), + }); + } + + let edges_out_of = Block::edges_out_of(conn, block.id); + + for edge in edges_out_of.iter() { + replacement_edges.push(UpdatedEdge { + id: edge.id, + new_source_id: Some(new_right_block.id), + new_target_id: edge.target_id, + }); + } + + Edge::create( + conn, + new_left_block.id, + Some(new_right_block.id), + chromosome_index, + phased, + ); + + Edge::bulk_update(conn, replacement_edges); + + // TODO: Delete existing block? + + Some((new_left_block, new_right_block)) + } +} + +#[cfg(test)] +mod tests { + use rusqlite::Connection; + use std::collections::HashSet; + // Note this useful idiom: importing names from outer (for mod tests) scope. + use super::*; + + use crate::migrations::run_migrations; + use crate::models::{sequence::Sequence, BlockGroup, Collection}; + + fn get_connection() -> Connection { + let mut conn = Connection::open_in_memory() + .unwrap_or_else(|_| panic!("Error opening in memory test db")); + rusqlite::vtab::array::load_module(&conn).unwrap(); + run_migrations(&mut conn); + conn + } + + #[test] + fn test_edges_into() { + let conn = &mut get_connection(); + Collection::create(conn, &"test collection".to_string()); + let block_group = BlockGroup::create( + conn, + &"test collection".to_string(), + None, + &"test block group".to_string(), + ); + let sequence1_hash = + Sequence::create(conn, "DNA".to_string(), &"ATCGATCG".to_string(), true); + let block1 = Block::create( + conn, + &sequence1_hash, + block_group.id, + 0, + 8, + &"+".to_string(), + ); + let sequence2_hash = + Sequence::create(conn, "DNA".to_string(), &"AAAAAAAA".to_string(), true); + let block2 = Block::create( + conn, + &sequence2_hash, + block_group.id, + 1, + 8, + &"+".to_string(), + ); + let sequence3_hash = + Sequence::create(conn, "DNA".to_string(), &"CCCCCCCC".to_string(), true); + let block3 = Block::create( + conn, + &sequence3_hash, + block_group.id, + 1, + 8, + &"+".to_string(), + ); + let sequence4_hash = + Sequence::create(conn, "DNA".to_string(), &"GGGGGGGG".to_string(), true); + let block4 = Block::create( + conn, + &sequence4_hash, + block_group.id, + 1, + 8, + &"+".to_string(), + ); + let edge1 = Edge::create(conn, block1.id, Some(block3.id), 0, 0); + let edge2 = Edge::create(conn, block2.id, Some(block3.id), 0, 0); + Edge::create(conn, block3.id, Some(block4.id), 0, 0); + + let edges_into_block3 = Block::edges_into(conn, block3.id); + assert_eq!(edges_into_block3.len(), 2); + + let mut actual_ids = HashSet::new(); + actual_ids.insert(edges_into_block3[0].id); + actual_ids.insert(edges_into_block3[1].id); + let mut expected_ids = HashSet::new(); + expected_ids.insert(edge1.id); + expected_ids.insert(edge2.id); + assert_eq!(actual_ids, expected_ids); + } + + #[test] + fn test_no_edges_into() { + let conn = &mut get_connection(); + Collection::create(conn, &"test collection".to_string()); + let block_group = BlockGroup::create( + conn, + &"test collection".to_string(), + None, + &"test block group".to_string(), + ); + let sequence1_hash = + Sequence::create(conn, "DNA".to_string(), &"ATCGATCG".to_string(), true); + let block1 = Block::create( + conn, + &sequence1_hash, + block_group.id, + 0, + 8, + &"+".to_string(), + ); + let sequence2_hash = + Sequence::create(conn, "DNA".to_string(), &"AAAAAAAA".to_string(), true); + let block2 = Block::create( + conn, + &sequence2_hash, + block_group.id, + 1, + 8, + &"+".to_string(), + ); + Edge::create(conn, block1.id, Some(block2.id), 0, 0); + + let edges_into_block1 = Block::edges_into(conn, block1.id); + assert_eq!(edges_into_block1.len(), 0); + } + + #[test] + fn test_edges_out_of() { + let conn = &mut get_connection(); + Collection::create(conn, &"test collection".to_string()); + let block_group = BlockGroup::create( + conn, + &"test collection".to_string(), + None, + &"test block group".to_string(), + ); + let sequence1_hash = + Sequence::create(conn, "DNA".to_string(), &"ATCGATCG".to_string(), true); + let block1 = Block::create( + conn, + &sequence1_hash, + block_group.id, + 0, + 8, + &"+".to_string(), + ); + let sequence2_hash = + Sequence::create(conn, "DNA".to_string(), &"AAAAAAAA".to_string(), true); + let block2 = Block::create( + conn, + &sequence2_hash, + block_group.id, + 1, + 8, + &"+".to_string(), + ); + let sequence3_hash = + Sequence::create(conn, "DNA".to_string(), &"CCCCCCCC".to_string(), true); + let block3 = Block::create( + conn, + &sequence3_hash, + block_group.id, + 1, + 8, + &"+".to_string(), + ); + let sequence4_hash = + Sequence::create(conn, "DNA".to_string(), &"GGGGGGGG".to_string(), true); + let block4 = Block::create( + conn, + &sequence4_hash, + block_group.id, + 1, + 8, + &"+".to_string(), + ); + Edge::create(conn, block1.id, Some(block2.id), 0, 0); + let edge1 = Edge::create(conn, block2.id, Some(block3.id), 0, 0); + let edge2 = Edge::create(conn, block2.id, Some(block4.id), 0, 0); + + let edges_out_of_block2 = Block::edges_out_of(conn, block2.id); + assert_eq!(edges_out_of_block2.len(), 2); + + let mut actual_ids = HashSet::new(); + actual_ids.insert(edges_out_of_block2[0].id); + actual_ids.insert(edges_out_of_block2[1].id); + let mut expected_ids = HashSet::new(); + expected_ids.insert(edge1.id); + expected_ids.insert(edge2.id); + assert_eq!(actual_ids, expected_ids); + } + + #[test] + fn test_no_edges_out_of() { + let conn = &mut get_connection(); + Collection::create(conn, &"test collection".to_string()); + let block_group = BlockGroup::create( + conn, + &"test collection".to_string(), + None, + &"test block group".to_string(), + ); + let sequence1_hash = + Sequence::create(conn, "DNA".to_string(), &"ATCGATCG".to_string(), true); + let block1 = Block::create( + conn, + &sequence1_hash, + block_group.id, + 0, + 8, + &"+".to_string(), + ); + let sequence2_hash = + Sequence::create(conn, "DNA".to_string(), &"AAAAAAAA".to_string(), true); + let block2 = Block::create( + conn, + &sequence2_hash, + block_group.id, + 1, + 8, + &"+".to_string(), + ); + Edge::create(conn, block1.id, Some(block2.id), 0, 0); + + let edges_out_of_block2 = Block::edges_out_of(conn, block2.id); + assert_eq!(edges_out_of_block2.len(), 0); + } + + #[test] + fn test_split_block() { + let conn = &mut get_connection(); + Collection::create(conn, &"test collection".to_string()); + let block_group = BlockGroup::create( + conn, + &"test collection".to_string(), + None, + &"test block group".to_string(), + ); + let sequence1_hash = + Sequence::create(conn, "DNA".to_string(), &"ATCGATCG".to_string(), true); + let block1 = Block::create( + conn, + &sequence1_hash, + block_group.id, + 0, + 8, + &"+".to_string(), + ); + let sequence2_hash = + Sequence::create(conn, "DNA".to_string(), &"AAAAAAAA".to_string(), true); + let block2 = Block::create( + conn, + &sequence2_hash, + block_group.id, + 1, + 8, + &"+".to_string(), + ); + let sequence3_hash = + Sequence::create(conn, "DNA".to_string(), &"CCCCCCCC".to_string(), true); + let block3 = Block::create( + conn, + &sequence3_hash, + block_group.id, + 1, + 8, + &"+".to_string(), + ); + let sequence4_hash = + Sequence::create(conn, "DNA".to_string(), &"GGGGGGGG".to_string(), true); + let block4 = Block::create( + conn, + &sequence4_hash, + block_group.id, + 1, + 8, + &"+".to_string(), + ); + let edge1 = Edge::create(conn, block1.id, Some(block3.id), 0, 0); + let edge2 = Edge::create(conn, block2.id, Some(block3.id), 0, 0); + let edge3 = Edge::create(conn, block3.id, Some(block4.id), 0, 0); + + let (left_block, right_block) = Block::split(conn, block3, 4, 0, 0).unwrap(); + + let edges_into_left_block = Block::edges_into(conn, left_block.id); + assert_eq!(edges_into_left_block.len(), 2); + + let mut actual_incoming_ids = HashSet::new(); + actual_incoming_ids.insert(edges_into_left_block[0].id); + actual_incoming_ids.insert(edges_into_left_block[1].id); + let mut expected_incoming_ids = HashSet::new(); + expected_incoming_ids.insert(edge1.id); + expected_incoming_ids.insert(edge2.id); + assert_eq!(actual_incoming_ids, expected_incoming_ids); + + let edges_out_of_right_block = Block::edges_out_of(conn, right_block.id); + assert_eq!(edges_out_of_right_block.len(), 1); + assert_eq!(edges_out_of_right_block[0].id, edge3.id); + + let new_edge = Edge::lookup(conn, Some(left_block.id), Some(right_block.id)); + assert!(new_edge.is_some()); + } + + #[test] + fn test_split_block_bad_coordinate() { + let conn = &mut get_connection(); + Collection::create(conn, &"test collection".to_string()); + let block_group = BlockGroup::create( + conn, + &"test collection".to_string(), + None, + &"test block group".to_string(), + ); + let sequence1_hash = + Sequence::create(conn, "DNA".to_string(), &"ATCGATCG".to_string(), true); + let block1 = Block::create( + conn, + &sequence1_hash, + block_group.id, + 0, + 8, + &"+".to_string(), + ); + let result = Block::split(conn, block1, -1, 0, 0); + assert!(result.is_none()); + + let block2 = Block::create( + conn, + &sequence1_hash, + block_group.id, + 0, + 8, + &"+".to_string(), + ); + let result = Block::split(conn, block2, 100, 0, 0); + assert!(result.is_none()); + } } diff --git a/src/models/edge.rs b/src/models/edge.rs index a8c395e..f54148e 100644 --- a/src/models/edge.rs +++ b/src/models/edge.rs @@ -66,4 +66,90 @@ impl Edge { } } } + + pub fn bulk_update(conn: &Connection, edges_to_update: Vec) { + for edge_to_update in edges_to_update { + let update_query; + let mut placeholders: Vec = vec![]; + if edge_to_update.new_source_id.is_some() && edge_to_update.new_target_id.is_some() { + update_query = "update edges set source_id = ?1, target_id = ?2 where id = ?3"; + placeholders.push(edge_to_update.new_source_id.unwrap()); + placeholders.push(edge_to_update.new_target_id.unwrap()); + } else if edge_to_update.new_source_id.is_some() { + update_query = "update edges set source_id = ?1 where id = ?2"; + placeholders.push(edge_to_update.new_source_id.unwrap()); + } else if edge_to_update.new_target_id.is_some() { + update_query = "update edges set target_id = ?1 where id = ?2"; + placeholders.push(edge_to_update.new_target_id.unwrap()); + } else { + continue; + } + + println!("{update_query} {placeholders:?}"); + + let edge = Edge::lookup( + conn, + edge_to_update.new_source_id, + edge_to_update.new_target_id, + ); + if edge.is_none() { + placeholders.push(edge_to_update.id); + println!("updating {update_query} {placeholders:?}"); + let mut stmt = conn.prepare_cached(update_query).unwrap(); + stmt.execute(params_from_iter(&placeholders)).unwrap(); + } else { + println!("edge exists"); + } + } + } + + pub fn lookup( + conn: &Connection, + source_id: Option, + target_id: Option, + ) -> Option { + let query; + let mut stmt; + let mut it; + if source_id.is_some() && target_id.is_some() { + query = "select id, source_id, target_id, chromosome_index, phased from edges where source_id = ?1 and target_id = ?2;"; + stmt = conn.prepare_cached(query).unwrap(); + it = stmt + .query([source_id.unwrap(), target_id.unwrap()]) + .unwrap(); + } else if source_id.is_some() { + query = "select id, source_id, target_id, chromosome_index, phased from edges where source_id = ?1 and target_id is null;"; + stmt = conn.prepare_cached(query).unwrap(); + it = stmt.query([source_id.unwrap()]).unwrap(); + } else if target_id.is_some() { + query = "select id, source_id, target_id, chromosome_index, phased from edges where target_id = ?1 and source_id is null;"; + stmt = conn.prepare_cached(query).unwrap(); + it = stmt.query([target_id.unwrap()]).unwrap(); + } else { + return None; + } + + let row = it.next().unwrap(); + if row.is_some() { + let edge = row.unwrap(); + let source_id: i32 = edge.get(1).unwrap(); + let target_id: Option = edge.get(2).unwrap(); + Some(Edge { + id: edge.get(0).unwrap(), + source_id, + target_id, + chromosome_index: edge.get(3).unwrap(), + phased: edge.get(4).unwrap(), + }) + } else { + None + } + } +} + +#[derive(Debug)] +pub struct UpdatedEdge { + pub id: i32, + pub new_source_id: Option, + pub new_target_id: Option, }