From eb217fcca4b6f5d98415166e4f69444144d9eb6c Mon Sep 17 00:00:00 2001 From: "C. Titus Brown" Date: Mon, 11 Nov 2024 13:24:38 -0800 Subject: [PATCH] test & fix --- src/pairwise.rs | 19 +++++++++++++++++-- src/python/tests/test_pairwise.py | 24 ++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/src/pairwise.rs b/src/pairwise.rs index 0b903fc7..dbec61bd 100644 --- a/src/pairwise.rs +++ b/src/pairwise.rs @@ -22,8 +22,6 @@ pub fn pairwise( write_all: bool, output: Option, ) -> Result<(), Box> { - // @CTB test for heterogenous scaled. - // Load all sigs into memory at once. let collection = load_collection( &siglist, @@ -39,6 +37,23 @@ pub fn pairwise( ) } + // pull scaled from command line; if not specified, calculate max and + // use that. + let common_scaled = match selection.scaled() { + Some(s) => s, + None => { + let s = *collection.max_scaled().expect("no records!?") as u32; + eprintln!( + "Setting scaled={} based on max scaled in collection", + s + ); + s + } + }; + + let mut selection = selection; + selection.set_scaled(common_scaled); + let sketches = collection.load_sketches(&selection)?; // set up a multi-producer, single-consumer channel. diff --git a/src/python/tests/test_pairwise.py b/src/python/tests/test_pairwise.py index a87cf126..ffa9d96b 100644 --- a/src/python/tests/test_pairwise.py +++ b/src/python/tests/test_pairwise.py @@ -771,3 +771,27 @@ def test_simple_scaled_heterogenous(runtmp): df = pandas.read_csv(output) assert len(df) == 1 assert set(list(df["scaled"])) == {10_000} + + +def test_simple_scaled_heterogenous(runtmp): + # test basic execution w/heterogeneous scaled - specified on command line + query_list = runtmp.output("query.txt") + against_list = runtmp.output("against.txt") + + sig2 = get_test_data("2.fa.sig.gz") + sig47 = get_test_data("47.fa.sig.gz") + sig63 = get_test_data("63.fa.sig.gz") + sig47_ds = runtmp.output("47-10k.sig.zip") + + runtmp.sourmash("sig", "downsample", sig47, "-o", sig47_ds, "--scaled", "10_000") + + make_file_list(query_list, [sig2, sig47_ds, sig63]) + + output = runtmp.output("out.csv") + + runtmp.sourmash("scripts", "pairwise", query_list, "-o", output, + '--scaled=15_000') + assert os.path.exists(output) + df = pandas.read_csv(output) + assert len(df) == 1 + assert set(list(df["scaled"])) == {15_000}