From 9cb4cb4c363e78b807c2832a6582d87bbf7fd2e9 Mon Sep 17 00:00:00 2001 From: hofer Date: Wed, 18 Sep 2024 12:23:46 -0400 Subject: [PATCH] Refactor, add test, fix test --- src/exports/gfa.rs | 143 +++++++++++++++++++++++++--------------- src/models/path_edge.rs | 2 +- 2 files changed, 92 insertions(+), 53 deletions(-) diff --git a/src/exports/gfa.rs b/src/exports/gfa.rs index 1c69543..7e6d333 100644 --- a/src/exports/gfa.rs +++ b/src/exports/gfa.rs @@ -36,61 +36,33 @@ pub fn export_gfa(conn: &Connection, collection_name: &str, filename: &PathBuf) let file = File::create(filename).unwrap(); let mut writer = BufWriter::new(file); - let terminal_block_ids = write_blocks(&mut writer, blocks.clone()); + let mut terminal_block_ids = HashSet::new(); + for block in &blocks { + if block.sequence_hash == Sequence::PATH_START_HASH + || block.sequence_hash == Sequence::PATH_END_HASH + { + terminal_block_ids.insert(block.id); + continue; + } + } - write_edges( + write_segments(&mut writer, blocks.clone(), terminal_block_ids.clone()); + write_links( &mut writer, graph, edges_by_node_pair.clone(), terminal_block_ids, ); - - let paths = Path::get_paths_for_collection(conn, collection_name); - let edges_by_path_id = - PathEdge::edges_for_paths(conn, paths.iter().map(|path| path.id).collect()); - let node_pairs_by_edge_id = edges_by_node_pair - .iter() - .map(|(node_pair, edge)| (edge.id, *node_pair)) - .collect::>(); - - println!("here1"); - for path in paths { - println!("here2"); - println!("{}", path.name); - let edges_for_path = edges_by_path_id.get(&path.id).unwrap(); - let mut node_ids = vec![]; - let mut node_strands = vec![]; - // Edges actually have too much information, the target of one is the same as the source of - // the next, so just iterate and take the target node to get the path of segments. - for edge in edges_for_path[0..edges_for_path.len() - 1].iter() { - let (_, target) = node_pairs_by_edge_id.get(&edge.id).unwrap(); - node_ids.push(*target); - node_strands.push(edge.target_strand); - } - - writer - .write_all(&path_line(&path.name, &node_ids, &node_strands).into_bytes()) - .unwrap_or_else(|_| panic!("Error writing path {} to GFA stream", path.name)); - } -} - -fn path_line(path_name: &str, node_ids: &[i32], node_strands: &[Strand]) -> String { - let nodes = node_ids - .iter() - .zip(node_strands.iter()) - .map(|(node_id, node_strand)| format!("{}{}", node_id + 1, node_strand)) - .collect::>() - .join(","); - format!("P\t{}\t{}\n", path_name, nodes) + write_paths(&mut writer, conn, collection_name, edges_by_node_pair); } -fn write_blocks(writer: &mut BufWriter, blocks: Vec) -> HashSet { - let mut terminal_block_ids = HashSet::new(); +fn write_segments( + writer: &mut BufWriter, + blocks: Vec, + terminal_block_ids: HashSet, +) { for block in &blocks { - if block.sequence_hash == Sequence::PATH_START_HASH - || block.sequence_hash == Sequence::PATH_END_HASH - { - terminal_block_ids.insert(block.id); + if terminal_block_ids.contains(&block.id) { continue; } writer @@ -102,11 +74,13 @@ fn write_blocks(writer: &mut BufWriter, blocks: Vec) -> HashSe ) }); } +} - terminal_block_ids +fn segment_line(sequence: &str, index: usize) -> String { + format!("S\t{}\t{}\t{}\n", index + 1, sequence, "*") } -fn write_edges( +fn write_links( writer: &mut BufWriter, graph: DiGraphMap, edges_by_node_pair: HashMap<(i32, i32), Edge>, @@ -130,10 +104,6 @@ fn write_edges( } } -fn segment_line(sequence: &str, index: usize) -> String { - format!("S\t{}\t{}\t{}\n", index + 1, sequence, "*") -} - fn link_line( source_index: i32, source_strand: Strand, @@ -149,6 +119,48 @@ fn link_line( ) } +fn write_paths( + writer: &mut BufWriter, + conn: &Connection, + collection_name: &str, + edges_by_node_pair: HashMap<(i32, i32), Edge>, +) { + let paths = Path::get_paths_for_collection(conn, collection_name); + let edges_by_path_id = + PathEdge::edges_for_paths(conn, paths.iter().map(|path| path.id).collect()); + let node_pairs_by_edge_id = edges_by_node_pair + .iter() + .map(|(node_pair, edge)| (edge.id, *node_pair)) + .collect::>(); + + for path in paths { + let edges_for_path = edges_by_path_id.get(&path.id).unwrap(); + let mut node_ids = vec![]; + let mut node_strands = vec![]; + // Edges actually have too much information, the target of one is the same as the source of + // the next, so just iterate and take the target node to get the path of segments. + for edge in edges_for_path[0..edges_for_path.len() - 1].iter() { + let (_, target) = node_pairs_by_edge_id.get(&edge.id).unwrap(); + node_ids.push(*target); + node_strands.push(edge.target_strand); + } + + writer + .write_all(&path_line(&path.name, &node_ids, &node_strands).into_bytes()) + .unwrap_or_else(|_| panic!("Error writing path {} to GFA stream", path.name)); + } +} + +fn path_line(path_name: &str, node_ids: &[i32], node_strands: &[Strand]) -> String { + let nodes = node_ids + .iter() + .zip(node_strands.iter()) + .map(|(node_id, node_strand)| format!("{}{}", node_id + 1, node_strand)) + .collect::>() + .join(","); + format!("P\t{}\t{}\n", path_name, nodes) +} + mod tests { use rusqlite::Connection; // Note this useful idiom: importing names from outer (for mod tests) scope. @@ -277,4 +289,31 @@ mod tests { assert_eq!(paths.len(), 1); assert_eq!(Path::sequence(&conn, paths[0].clone()), "AAAATTTTGGGGCCCC"); } + + #[test] + fn test_simple_round_trip() { + setup_gen_dir(); + let mut gfa_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + gfa_path.push("fixtures/simple.gfa"); + let collection_name = "test".to_string(); + let conn = &get_connection(None); + import_gfa(&gfa_path, &collection_name, conn); + + let block_group_id = BlockGroup::get_id(conn, &collection_name, None, ""); + let all_sequences = BlockGroup::get_all_sequences(conn, block_group_id); + + let temp_dir = tempdir().expect("Couldn't get handle to temp directory"); + let mut gfa_path = PathBuf::from(temp_dir.path()); + gfa_path.push("intermediate.gfa"); + + export_gfa(conn, &collection_name, &gfa_path); + import_gfa(&gfa_path, "test collection 2", conn); + + let block_group2 = Collection::get_block_groups(conn, "test collection 2") + .pop() + .unwrap(); + let all_sequences2 = BlockGroup::get_all_sequences(conn, block_group2.id); + + assert_eq!(all_sequences, all_sequences2); + } } diff --git a/src/models/path_edge.rs b/src/models/path_edge.rs index 4cbf457..f5e8b80 100644 --- a/src/models/path_edge.rs +++ b/src/models/path_edge.rs @@ -100,7 +100,7 @@ impl PathEdge { let path_edges = PathEdge::query( conn, format!( - "select id, path_id, index_in_path, edge_id from path_edges where path_id in ({})", + "select id, path_id, index_in_path, edge_id from path_edges where path_id in ({}) ORDER BY path_id, index_in_path", placeholder_string ) .as_str(),