From a21c4d614ba984b7dcf3253099f54cca33add905 Mon Sep 17 00:00:00 2001 From: hofer Date: Fri, 6 Sep 2024 14:40:59 -0400 Subject: [PATCH] Get GFA import working with new data model --- src/imports.rs | 1 + src/imports/fasta.rs | 2 +- src/imports/gfa.rs | 197 +++++++++++++++++++++++++++++ src/main.rs | 223 --------------------------------- src/models/block_group.rs | 39 +++++- src/models/block_group_edge.rs | 6 +- src/models/change_log.rs | 2 +- src/models/edge.rs | 16 ++- src/models/path.rs | 2 +- src/models/path_edge.rs | 7 +- 10 files changed, 254 insertions(+), 241 deletions(-) create mode 100644 src/imports/gfa.rs diff --git a/src/imports.rs b/src/imports.rs index 62a9731..da35a12 100644 --- a/src/imports.rs +++ b/src/imports.rs @@ -1 +1,2 @@ pub mod fasta; +pub mod gfa; diff --git a/src/imports/fasta.rs b/src/imports/fasta.rs index aff6e6c..7e50377 100644 --- a/src/imports/fasta.rs +++ b/src/imports/fasta.rs @@ -57,7 +57,7 @@ pub fn import_fasta(fasta: &String, name: &str, shallow: bool, conn: &mut Connec 0, 0, ); - BlockGroupEdge::bulk_create(conn, block_group.id, vec![edge_into.id, edge_out_of.id]); + BlockGroupEdge::bulk_create(conn, block_group.id, &[edge_into.id, edge_out_of.id]); Path::create( conn, &name, diff --git a/src/imports/gfa.rs b/src/imports/gfa.rs new file mode 100644 index 0000000..684da1c --- /dev/null +++ b/src/imports/gfa.rs @@ -0,0 +1,197 @@ +use gfa_reader::Gfa; +use rusqlite::Connection; +use std::collections::{HashMap, HashSet}; + +use crate::models::{ + self, + block_group::BlockGroup, + block_group_edge::BlockGroupEdge, + edge::{Edge, EdgeData}, + path::Path, + sequence::Sequence, + strand::Strand, +}; + +fn import_gfa(gfa_path: &str, collection_name: &str, conn: &Connection) { + models::Collection::create(conn, collection_name); + let block_group = BlockGroup::create(conn, collection_name, None, ""); + let gfa: Gfa = Gfa::parse_gfa_file(gfa_path); + let mut sequences_by_segment_id: HashMap = HashMap::new(); + + for segment in &gfa.segments { + let input_sequence = segment.sequence.get_string(&gfa.sequence); + let sequence = Sequence::new() + .sequence_type("DNA") + .sequence(input_sequence) + .save(conn); + sequences_by_segment_id.insert(segment.id, sequence); + } + + let mut edges = HashSet::new(); + for link in &gfa.links { + let source = sequences_by_segment_id.get(&link.from).unwrap(); + let target = sequences_by_segment_id.get(&link.to).unwrap(); + edges.insert(edge_data_from_fields( + &source.hash, + source.length, + &target.hash, + )); + } + + for input_path in &gfa.paths { + let mut source_hash = Edge::PATH_START_HASH; + let mut source_coordinate = 0; + for segment_id in input_path.nodes.iter() { + let target = sequences_by_segment_id.get(segment_id).unwrap(); + edges.insert(edge_data_from_fields( + source_hash, + source_coordinate, + &target.hash, + )); + source_hash = &target.hash; + source_coordinate = target.length; + } + edges.insert(edge_data_from_fields( + source_hash, + source_coordinate, + Edge::PATH_END_HASH, + )); + } + + for input_walk in &gfa.walk { + let mut source_hash = Edge::PATH_START_HASH; + let mut source_coordinate = 0; + for segment_id in input_walk.walk_id.iter() { + let target = sequences_by_segment_id.get(segment_id).unwrap(); + edges.insert(edge_data_from_fields( + source_hash, + source_coordinate, + &target.hash, + )); + source_hash = &target.hash; + source_coordinate = target.length; + } + edges.insert(edge_data_from_fields( + source_hash, + source_coordinate, + Edge::PATH_END_HASH, + )); + } + + let edge_ids = Edge::bulk_create(conn, edges.into_iter().collect::>()); + BlockGroupEdge::bulk_create(conn, block_group.id, &edge_ids); + + let saved_edges = Edge::bulk_load(conn, &edge_ids); + let mut edge_ids_by_data = HashMap::new(); + for edge in saved_edges { + let key = + edge_data_from_fields(&edge.source_hash, edge.source_coordinate, &edge.target_hash); + edge_ids_by_data.insert(key, edge.id); + } + + for input_path in &gfa.paths { + let path_name = &input_path.name; + let mut source_hash = Edge::PATH_START_HASH; + let mut source_coordinate = 0; + let mut path_edge_ids = vec![]; + for segment_id in input_path.nodes.iter() { + let target = sequences_by_segment_id.get(segment_id).unwrap(); + let key = edge_data_from_fields(source_hash, source_coordinate, &target.hash); + let edge_id = *edge_ids_by_data.get(&key).unwrap(); + path_edge_ids.push(edge_id); + source_hash = &target.hash; + source_coordinate = target.length; + } + let key = edge_data_from_fields(source_hash, source_coordinate, Edge::PATH_END_HASH); + let edge_id = *edge_ids_by_data.get(&key).unwrap(); + path_edge_ids.push(edge_id); + Path::create(conn, path_name, block_group.id, path_edge_ids); + } + + for input_walk in &gfa.walk { + let path_name = &input_walk.sample_id; + let mut source_hash = Edge::PATH_START_HASH; + let mut source_coordinate = 0; + let mut path_edge_ids = vec![]; + for segment_id in input_walk.walk_id.iter() { + let target = sequences_by_segment_id.get(segment_id).unwrap(); + let key = edge_data_from_fields(source_hash, source_coordinate, &target.hash); + let edge_id = *edge_ids_by_data.get(&key).unwrap(); + path_edge_ids.push(edge_id); + source_hash = &target.hash; + source_coordinate = target.length; + } + let key = edge_data_from_fields(source_hash, source_coordinate, Edge::PATH_END_HASH); + let edge_id = *edge_ids_by_data.get(&key).unwrap(); + path_edge_ids.push(edge_id); + Path::create(conn, path_name, block_group.id, path_edge_ids); + } +} + +fn edge_data_from_fields(source_hash: &str, source_coordinate: i32, target_hash: &str) -> EdgeData { + EdgeData { + source_hash: source_hash.to_string(), + source_coordinate, + source_strand: Strand::Forward, + target_hash: target_hash.to_string(), + target_coordinate: 0, + target_strand: Strand::Forward, + chromosome_index: 0, + phased: 0, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_helpers::get_connection; + use rusqlite::{types::Value as SQLValue, Connection}; + use std::fs; + use std::path::PathBuf; + + #[test] + fn test_import_simple_gfa() { + let mut gfa_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + gfa_path.push("fixtures/simple.gfa"); + let collection_name = "test".to_string(); + let conn = &get_connection(None); + import_gfa(gfa_path.to_str().unwrap(), &collection_name, conn); + + let block_group_id = BlockGroup::get_id(conn, &collection_name, None, ""); + let path = Path::get_paths( + conn, + "select * from path where block_group_id = ?1 AND name = ?2", + vec![ + SQLValue::from(block_group_id), + SQLValue::from("124".to_string()), + ], + )[0] + .clone(); + + let result = Path::sequence(conn, path); + assert_eq!(result, "ATGGCATATTCGCAGCT"); + } + + #[test] + fn test_import_gfa_with_walk() { + let mut gfa_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + gfa_path.push("fixtures/walk.gfa"); + let collection_name = "walk".to_string(); + let conn = &mut get_connection(None); + import_gfa(gfa_path.to_str().unwrap(), &collection_name, conn); + + let block_group_id = BlockGroup::get_id(conn, &collection_name, None, ""); + let path = Path::get_paths( + conn, + "select * from path where block_group_id = ?1 AND name = ?2", + vec![ + SQLValue::from(block_group_id), + SQLValue::from("291344".to_string()), + ], + )[0] + .clone(); + + let result = Path::sequence(conn, path); + assert_eq!(result, "ACCTACAAATTCAAAC"); + } +} diff --git a/src/main.rs b/src/main.rs index c61d392..fe9b683 100644 --- a/src/main.rs +++ b/src/main.rs @@ -65,94 +65,6 @@ enum Commands { }, } -fn import_gfa(gfa_path: &str, collection_name: &String, conn: &mut Connection) { - run_migrations(conn); - - let gfa: Gfa = Gfa::parse_gfa_file(gfa_path); - - let collection = models::Collection::create(conn, collection_name); - - let mut blocks_by_segment_id: HashMap = HashMap::new(); - - for segment in &gfa.segments { - let sequence = segment.sequence.get_string(&gfa.sequence); - let seq_hash = - models::Sequence::create(conn, "DNA".to_string(), &sequence.to_string(), true); - let block = Block { - id: 0, - path_id: 0, - sequence_hash: seq_hash, - start: 0, - end: (sequence.len() as i32), - strand: "1".to_string(), - }; - let segment_id = segment.id; - blocks_by_segment_id.insert(segment_id, block); - } - - let mut created_blocks_by_segment_id: HashMap = HashMap::new(); - - for input_path in &gfa.paths { - let path_name = &input_path.name; - // TODO: Fix Some(1) - let path = models::Path::create(conn, &collection.name, None, path_name, Some(1)); - for segment_id in input_path.nodes.iter() { - let block = blocks_by_segment_id.get(segment_id).unwrap(); - let created_block = Block::create( - conn, - &block.sequence_hash, - path.id, - block.start, - block.end, - &block.strand, - ); - created_blocks_by_segment_id.insert(*segment_id, created_block); - } - } - - for input_walk in &gfa.walk { - // TODO: Is this what we want to use for the path name? - let walk_id = &input_walk.sample_id; - // TODO: Fix Some(1) - let path = models::Path::create(conn, &collection.name, None, walk_id, Some(1)); - for segment_id in input_walk.walk_id.iter() { - let block = blocks_by_segment_id.get(segment_id).unwrap(); - let created_block = Block::create( - conn, - &block.sequence_hash, - path.id, - block.start, - block.end, - &block.strand, - ); - created_blocks_by_segment_id.insert(*segment_id, created_block); - } - } - - let mut source_block_ids: HashSet = HashSet::new(); - let mut target_block_ids: HashSet = HashSet::new(); - - for link in &gfa.links { - let source_segment_id = link.from; - let target_segment_id = link.to; - let source_block = created_blocks_by_segment_id - .get(&source_segment_id) - .unwrap(); - let target_block = created_blocks_by_segment_id - .get(&target_segment_id) - .unwrap(); - models::Edge::create(conn, source_block.id, Some(target_block.id)); - source_block_ids.insert(source_block.id); - target_block_ids.insert(target_block.id); - } - - let end_block_ids = target_block_ids.difference(&source_block_ids); - - for end_block_id in end_block_ids { - models::Edge::create(conn, *end_block_id, None); - } -} - fn main() { let cli = Cli::parse(); @@ -197,138 +109,3 @@ fn main() { None => {} } } - -#[cfg(test)] -mod tests { - use rusqlite::Connection; - use std::fs; - // Note this useful idiom: importing names from outer (for mod tests) scope. - use super::*; - use gen::test_helpers::get_connection; - - #[test] - fn test_add_fasta() { - let mut fasta_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - fasta_path.push("fixtures/simple.fa"); - let mut conn = get_connection(None); - import_fasta( - &fasta_path.to_str().unwrap().to_string(), - "test", - false, - &mut conn, - ); - assert_eq!( - BlockGroup::get_all_sequences(&conn, 1), - HashSet::from_iter(vec!["ATCGATCGATCGATCGATCGGGAACACACAGAGA".to_string()]) - ); - - let path = Path::get(&conn, 1); - assert_eq!( - Path::sequence(&conn, path), - "ATCGATCGATCGATCGATCGGGAACACACAGAGA".to_string() - ); - } - - #[test] - fn test_update_fasta_with_vcf() { - let mut vcf_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - vcf_path.push("fixtures/simple.vcf"); - let mut fasta_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - fasta_path.push("fixtures/simple.fa"); - let conn = &mut get_connection("test.db"); - let collection = "test".to_string(); - import_fasta( - &fasta_path.to_str().unwrap().to_string(), - &collection, - false, - conn, - ); - update_with_vcf( - &vcf_path.to_str().unwrap().to_string(), - &collection, - "".to_string(), - "".to_string(), - conn, - ); - assert_eq!( - BlockGroup::get_all_sequences(conn, 1), - HashSet::from_iter(vec!["ATCGATCGATCGATCGATCGGGAACACACAGAGA".to_string()]) - ); - // A homozygous set of variants should only return 1 sequence - assert_eq!( - BlockGroup::get_all_sequences(conn, 2), - HashSet::from_iter(vec!["ATCATCGATAGAGATCGATCGGGAACACACAGAGA".to_string()]) - ); - // This individual is homozygous for the first variant and does not contain the second - assert_eq!( - BlockGroup::get_all_sequences(conn, 3), - HashSet::from_iter(vec!["ATCATCGATCGATCGATCGGGAACACACAGAGA".to_string()]) - ); - } - - #[test] - fn test_update_fasta_with_vcf_custom_genotype() { - let mut vcf_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - vcf_path.push("fixtures/general.vcf"); - let mut fasta_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - fasta_path.push("fixtures/simple.fa"); - let conn = &mut get_connection("test.db"); - let collection = "test".to_string(); - import_fasta( - &fasta_path.to_str().unwrap().to_string(), - &collection, - false, - conn, - ); - update_with_vcf( - &vcf_path.to_str().unwrap().to_string(), - &collection, - "0/1".to_string(), - "sample 1".to_string(), - conn, - ); - assert_eq!( - BlockGroup::get_all_sequences(conn, 1), - HashSet::from_iter(vec!["ATCGATCGATCGATCGATCGGGAACACACAGAGA".to_string()]) - ); - assert_eq!( - BlockGroup::get_all_sequences(conn, 2), - HashSet::from_iter( - [ - "ATCGATCGATAGAGATCGATCGGGAACACACAGAGA", - "ATCATCGATAGAGATCGATCGGGAACACACAGAGA", - "ATCGATCGATCGATCGATCGGGAACACACAGAGA", - "ATCATCGATCGATCGATCGGGAACACACAGAGA" - ] - .iter() - .map(|v| v.to_string()) - ) - ); - } - - #[test] - fn test_import_simple_gfa() { - let mut gfa_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - gfa_path.push("fixtures/simple.gfa"); - - let collection_name = "test".to_string(); - let conn = &mut get_connection(); - import_gfa(gfa_path.to_str().unwrap(), &collection_name, conn); - - let result = Path::sequence(conn, &collection_name, None, "124", 1); - assert_eq!(result, "ATGGCATATTCGCAGCT"); - } - - #[test] - fn test_import_gfa_with_walk() { - let mut gfa_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - gfa_path.push("fixtures/walk.gfa"); - - let collection_name = "walk".to_string(); - let conn = &mut get_connection(); - import_gfa(gfa_path.to_str().unwrap(), &collection_name, conn); - - let result = Path::sequence(conn, &collection_name, None, "291344", 1); - assert_eq!(result, "ACCTACAAATTCAAAC"); - } -} diff --git a/src/models/block_group.rs b/src/models/block_group.rs index 32f290d..0db8d89 100644 --- a/src/models/block_group.rs +++ b/src/models/block_group.rs @@ -154,8 +154,8 @@ impl BlockGroup { let edge_ids = BlockGroupEdge::edges_for_block_group(conn, source_block_group_id) .iter() .map(|edge| edge.id) - .collect(); - BlockGroupEdge::bulk_create(conn, target_block_group_id, edge_ids); + .collect::>(); + BlockGroupEdge::bulk_create(conn, target_block_group_id, &edge_ids); for path in existing_paths { let edge_ids = PathEdge::edges_for(conn, path.id) @@ -207,6 +207,35 @@ impl BlockGroup { new_bg_id.id } + pub fn get_id( + conn: &Connection, + collection_name: &str, + sample_name: Option<&str>, + group_name: &str, + ) -> i32 { + let result = if sample_name.is_some() { + conn.query_row( + "select id from block_group where collection_name = ?1 AND sample_name = ?2 AND name = ?3", + (collection_name, sample_name, group_name.clone()), + |row| row.get(0), + ) + } else { + conn.query_row( + "select id from block_group where collection_name = ?1 AND sample_name IS NULL AND name = ?2", + (collection_name, group_name.clone()), + |row| row.get(0), + ) + }; + + match result { + Ok(res) => res, + Err(rusqlite::Error::QueryReturnedNoRows) => 0, + Err(_e) => { + panic!("Error querying the database: {_e}"); + } + } + } + pub fn get_block_boundaries( source_edges: Option<&Vec<&Edge>>, target_edges: Option<&Vec<&Edge>>, @@ -446,7 +475,7 @@ impl BlockGroup { for (block_group_id, new_edges) in new_edges_by_block_group { let edge_ids = Edge::bulk_create(conn, new_edges); - BlockGroupEdge::bulk_create(conn, block_group_id, edge_ids); + BlockGroupEdge::bulk_create(conn, block_group_id, &edge_ids); } ChangeLog::bulk_create( conn, @@ -475,7 +504,7 @@ impl BlockGroup { ) { let new_edges = BlockGroup::set_up_new_edges(change, tree); let edge_ids = Edge::bulk_create(conn, new_edges); - BlockGroupEdge::bulk_create(conn, change.block_group_id, edge_ids); + BlockGroupEdge::bulk_create(conn, change.block_group_id, &edge_ids); ChangeLog::new( change.path.id, change.start, @@ -644,7 +673,7 @@ mod tests { BlockGroupEdge::bulk_create( conn, block_group.id, - vec![edge0.id, edge1.id, edge2.id, edge3.id, edge4.id], + &[edge0.id, edge1.id, edge2.id, edge3.id, edge4.id], ); let path = Path::create( conn, diff --git a/src/models/block_group_edge.rs b/src/models/block_group_edge.rs index 8a2fd7d..dd1bc63 100644 --- a/src/models/block_group_edge.rs +++ b/src/models/block_group_edge.rs @@ -10,7 +10,7 @@ pub struct BlockGroupEdge { } impl BlockGroupEdge { - pub fn bulk_create(conn: &Connection, block_group_id: i32, edge_ids: Vec) { + pub fn bulk_create(conn: &Connection, block_group_id: i32, edge_ids: &[i32]) { for chunk in edge_ids.chunks(100000) { let mut rows_to_insert = vec![]; for edge_id in chunk { @@ -37,8 +37,8 @@ impl BlockGroupEdge { let edge_ids = block_group_edges .into_iter() .map(|block_group_edge| block_group_edge.edge_id) - .collect(); - Edge::bulk_load(conn, edge_ids) + .collect::>(); + Edge::bulk_load(conn, &edge_ids) } pub fn query(conn: &Connection, query: &str, placeholders: Vec) -> Vec { diff --git a/src/models/change_log.rs b/src/models/change_log.rs index 43f9f35..e47ba8e 100644 --- a/src/models/change_log.rs +++ b/src/models/change_log.rs @@ -357,7 +357,7 @@ mod tests { BlockGroupEdge::bulk_create( conn, block_group.id, - vec![edge0.id, edge1.id, edge2.id, edge3.id, edge4.id], + &[edge0.id, edge1.id, edge2.id, edge3.id, edge4.id], ); let path = Path::create( conn, diff --git a/src/models/edge.rs b/src/models/edge.rs index 54b5754..91e1dfb 100644 --- a/src/models/edge.rs +++ b/src/models/edge.rs @@ -1,3 +1,4 @@ +use itertools::Itertools; use rusqlite::types::Value; use rusqlite::{params_from_iter, Connection}; use std::collections::HashSet; @@ -102,9 +103,9 @@ impl Edge { } } - pub fn bulk_load(conn: &Connection, edge_ids: Vec) -> Vec { + pub fn bulk_load(conn: &Connection, edge_ids: &[i32]) -> Vec { let formatted_edge_ids = edge_ids - .into_iter() + .iter() .map(|edge_id| edge_id.to_string()) .collect::>() .join(","); @@ -211,7 +212,12 @@ impl Edge { edge_ids.push(row.unwrap()); } - existing_edge_ids.extend(edge_ids); + existing_edge_ids.extend( + edge_ids + .into_iter() + .sorted_by(|c1, c2| Ord::cmp(&c1, &c2)) + .collect::>(), + ); } existing_edge_ids @@ -285,7 +291,7 @@ mod tests { let edge_ids = Edge::bulk_create(conn, vec![edge1, edge2, edge3]); assert_eq!(edge_ids.len(), 3); - let edges = Edge::bulk_load(conn, edge_ids); + let edges = Edge::bulk_load(conn, &edge_ids); assert_eq!(edges.len(), 3); let edges_by_source_hash = edges @@ -369,7 +375,7 @@ mod tests { let edge_ids = Edge::bulk_create(conn, vec![edge1, edge2, edge3]); assert_eq!(edge_ids.len(), 3); - let edges = Edge::bulk_load(conn, edge_ids); + let edges = Edge::bulk_load(conn, &edge_ids); assert_eq!(edges.len(), 3); let edges_by_source_hash = edges diff --git a/src/models/path.rs b/src/models/path.rs index 75a2895..aed9b2e 100644 --- a/src/models/path.rs +++ b/src/models/path.rs @@ -138,7 +138,7 @@ impl Path { ) -> NewBlock { if into.target_hash != out_of.source_hash { panic!( - "Consecutive edges in path {0} don't share the same block", + "Consecutive edges in path {0} don't share the same sequence", path.id ); } diff --git a/src/models/path_edge.rs b/src/models/path_edge.rs index 9934a4c..3bc4787 100644 --- a/src/models/path_edge.rs +++ b/src/models/path_edge.rs @@ -78,8 +78,11 @@ impl PathEdge { "select * from path_edges where path_id = ?1 order by index_in_path ASC", vec![Value::from(path_id)], ); - let edge_ids = path_edges.into_iter().map(|path_edge| path_edge.edge_id); - let edges = Edge::bulk_load(conn, edge_ids.clone().collect()); + let edge_ids = path_edges + .into_iter() + .map(|path_edge| path_edge.edge_id) + .collect::>(); + let edges = Edge::bulk_load(conn, &edge_ids); let edges_by_id = edges .into_iter() .map(|edge| (edge.id, edge))