Skip to content

Commit

Permalink
Refactor, add test, fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
dkhofer committed Sep 18, 2024
1 parent 5ec11d5 commit 9cb4cb4
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 53 deletions.
143 changes: 91 additions & 52 deletions src/exports/gfa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,61 +36,33 @@ pub fn export_gfa(conn: &Connection, collection_name: &str, filename: &PathBuf)
let file = File::create(filename).unwrap();
let mut writer = BufWriter::new(file);

let terminal_block_ids = write_blocks(&mut writer, blocks.clone());
let mut terminal_block_ids = HashSet::new();
for block in &blocks {
if block.sequence_hash == Sequence::PATH_START_HASH
|| block.sequence_hash == Sequence::PATH_END_HASH
{
terminal_block_ids.insert(block.id);
continue;
}
}

write_edges(
write_segments(&mut writer, blocks.clone(), terminal_block_ids.clone());
write_links(
&mut writer,
graph,
edges_by_node_pair.clone(),
terminal_block_ids,
);

let paths = Path::get_paths_for_collection(conn, collection_name);
let edges_by_path_id =
PathEdge::edges_for_paths(conn, paths.iter().map(|path| path.id).collect());
let node_pairs_by_edge_id = edges_by_node_pair
.iter()
.map(|(node_pair, edge)| (edge.id, *node_pair))
.collect::<HashMap<i32, (i32, i32)>>();

println!("here1");
for path in paths {
println!("here2");
println!("{}", path.name);
let edges_for_path = edges_by_path_id.get(&path.id).unwrap();
let mut node_ids = vec![];
let mut node_strands = vec![];
// Edges actually have too much information, the target of one is the same as the source of
// the next, so just iterate and take the target node to get the path of segments.
for edge in edges_for_path[0..edges_for_path.len() - 1].iter() {
let (_, target) = node_pairs_by_edge_id.get(&edge.id).unwrap();
node_ids.push(*target);
node_strands.push(edge.target_strand);
}

writer
.write_all(&path_line(&path.name, &node_ids, &node_strands).into_bytes())
.unwrap_or_else(|_| panic!("Error writing path {} to GFA stream", path.name));
}
}

fn path_line(path_name: &str, node_ids: &[i32], node_strands: &[Strand]) -> String {
let nodes = node_ids
.iter()
.zip(node_strands.iter())
.map(|(node_id, node_strand)| format!("{}{}", node_id + 1, node_strand))
.collect::<Vec<String>>()
.join(",");
format!("P\t{}\t{}\n", path_name, nodes)
write_paths(&mut writer, conn, collection_name, edges_by_node_pair);
}

fn write_blocks(writer: &mut BufWriter<File>, blocks: Vec<GroupBlock>) -> HashSet<i32> {
let mut terminal_block_ids = HashSet::new();
fn write_segments(
writer: &mut BufWriter<File>,
blocks: Vec<GroupBlock>,
terminal_block_ids: HashSet<i32>,
) {
for block in &blocks {
if block.sequence_hash == Sequence::PATH_START_HASH
|| block.sequence_hash == Sequence::PATH_END_HASH
{
terminal_block_ids.insert(block.id);
if terminal_block_ids.contains(&block.id) {
continue;
}
writer
Expand All @@ -102,11 +74,13 @@ fn write_blocks(writer: &mut BufWriter<File>, blocks: Vec<GroupBlock>) -> HashSe
)
});
}
}

terminal_block_ids
fn segment_line(sequence: &str, index: usize) -> String {
format!("S\t{}\t{}\t{}\n", index + 1, sequence, "*")
}

fn write_edges(
fn write_links(
writer: &mut BufWriter<File>,
graph: DiGraphMap<i32, ()>,
edges_by_node_pair: HashMap<(i32, i32), Edge>,
Expand All @@ -130,10 +104,6 @@ fn write_edges(
}
}

fn segment_line(sequence: &str, index: usize) -> String {
format!("S\t{}\t{}\t{}\n", index + 1, sequence, "*")
}

fn link_line(
source_index: i32,
source_strand: Strand,
Expand All @@ -149,6 +119,48 @@ fn link_line(
)
}

fn write_paths(
writer: &mut BufWriter<File>,
conn: &Connection,
collection_name: &str,
edges_by_node_pair: HashMap<(i32, i32), Edge>,
) {
let paths = Path::get_paths_for_collection(conn, collection_name);
let edges_by_path_id =
PathEdge::edges_for_paths(conn, paths.iter().map(|path| path.id).collect());
let node_pairs_by_edge_id = edges_by_node_pair
.iter()
.map(|(node_pair, edge)| (edge.id, *node_pair))
.collect::<HashMap<i32, (i32, i32)>>();

for path in paths {
let edges_for_path = edges_by_path_id.get(&path.id).unwrap();
let mut node_ids = vec![];
let mut node_strands = vec![];
// Edges actually have too much information, the target of one is the same as the source of
// the next, so just iterate and take the target node to get the path of segments.
for edge in edges_for_path[0..edges_for_path.len() - 1].iter() {
let (_, target) = node_pairs_by_edge_id.get(&edge.id).unwrap();
node_ids.push(*target);
node_strands.push(edge.target_strand);
}

writer
.write_all(&path_line(&path.name, &node_ids, &node_strands).into_bytes())
.unwrap_or_else(|_| panic!("Error writing path {} to GFA stream", path.name));
}
}

fn path_line(path_name: &str, node_ids: &[i32], node_strands: &[Strand]) -> String {
let nodes = node_ids
.iter()
.zip(node_strands.iter())
.map(|(node_id, node_strand)| format!("{}{}", node_id + 1, node_strand))
.collect::<Vec<String>>()
.join(",");
format!("P\t{}\t{}\n", path_name, nodes)
}

mod tests {
use rusqlite::Connection;
// Note this useful idiom: importing names from outer (for mod tests) scope.
Expand Down Expand Up @@ -277,4 +289,31 @@ mod tests {
assert_eq!(paths.len(), 1);
assert_eq!(Path::sequence(&conn, paths[0].clone()), "AAAATTTTGGGGCCCC");
}

#[test]
fn test_simple_round_trip() {
setup_gen_dir();
let mut gfa_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
gfa_path.push("fixtures/simple.gfa");
let collection_name = "test".to_string();
let conn = &get_connection(None);
import_gfa(&gfa_path, &collection_name, conn);

let block_group_id = BlockGroup::get_id(conn, &collection_name, None, "");
let all_sequences = BlockGroup::get_all_sequences(conn, block_group_id);

let temp_dir = tempdir().expect("Couldn't get handle to temp directory");
let mut gfa_path = PathBuf::from(temp_dir.path());
gfa_path.push("intermediate.gfa");

export_gfa(conn, &collection_name, &gfa_path);
import_gfa(&gfa_path, "test collection 2", conn);

let block_group2 = Collection::get_block_groups(conn, "test collection 2")
.pop()
.unwrap();
let all_sequences2 = BlockGroup::get_all_sequences(conn, block_group2.id);

assert_eq!(all_sequences, all_sequences2);
}
}
2 changes: 1 addition & 1 deletion src/models/path_edge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ impl PathEdge {
let path_edges = PathEdge::query(
conn,
format!(
"select id, path_id, index_in_path, edge_id from path_edges where path_id in ({})",
"select id, path_id, index_in_path, edge_id from path_edges where path_id in ({}) ORDER BY path_id, index_in_path",
placeholder_string
)
.as_str(),
Expand Down

0 comments on commit 9cb4cb4

Please sign in to comment.