From e5cdc36b8eb9975134912cd929d170a4b503a0b3 Mon Sep 17 00:00:00 2001 From: Tessa Pierce Ward Date: Thu, 8 Feb 2024 15:58:13 -0800 Subject: [PATCH] MRG: re-establish `tax` gather reading flexibility (#2986) A while back, I introduced `GatherRow` to handle checking for required gather columns for us. However, it ended up being overly restrictive -- any extra columns cause `gather_csv` reading to fail. Here, I add a filtration step that lets us ignore unspecified columns entirely before reading a GatherRow. Initializing the GatherRow after this filtration continues to handle the checks for all required columns while restoring flexibility. As a consequence, we can actually delete all the `non-essential` names in `GatherRow` and avoid carrying them around (saving some memory) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/sourmash/tax/tax_utils.py | 51 +++++++++++++++++------------------ tests/test_tax_utils.py | 19 ++++++++++++- 2 files changed, 42 insertions(+), 28 deletions(-) diff --git a/src/sourmash/tax/tax_utils.py b/src/sourmash/tax/tax_utils.py index 55b30a540e..d1827c2aad 100644 --- a/src/sourmash/tax/tax_utils.py +++ b/src/sourmash/tax/tax_utils.py @@ -6,7 +6,7 @@ from collections import abc, defaultdict from itertools import zip_longest from typing import NamedTuple -from dataclasses import dataclass, field, replace, asdict +from dataclasses import dataclass, field, replace, asdict, fields import gzip from sourmash import sqlite_utils, sourmash_args @@ -742,7 +742,10 @@ def load_gather_results( for n, row in enumerate(r): # try reading each gather row into a TaxResult try: - gatherRow = GatherRow(**row) + filt_row = filter_row( + row, GatherRow + ) # filter row first to allow extra (unused) columns in csv + gatherRow = GatherRow(**filt_row) except TypeError as exc: raise ValueError( f"'{gather_csv}' is missing columns needed for taxonomic summarization. Please run gather with sourmash >= 4.4." @@ -1675,6 +1678,20 @@ def load(cls, locations, **kwargs): return tax_assign +def filter_row(row, dataclass_type): + """ + Filter the row to only include keys that exist in the dataclass fields. + This allows extra columns to be passed in with the gather csv while still + taking advantage of the checks for required columns that come with dataclass + initialization. + """ + valid_keys = {field.name for field in fields(dataclass_type)} + # 'match_name' and 'name' should be interchangeable (sourmash 4.x) + if "match_name" in row.keys() and "name" not in row.keys(): + row["name"] = row.pop("match_name") + return {k: v for k, v in row.items() if k in valid_keys} + + @dataclass class GatherRow: """ @@ -1689,7 +1706,8 @@ class GatherRow: with sourmash_args.FileInputCSV(gather_csv) as r: for row in enumerate(r): - gatherRow = GatherRow(**row) + filt_row = filter_row(row, GatherRow) # filter first to allow extra columns + gatherRow = GatherRow(**filt_row) """ # essential columns @@ -1706,32 +1724,10 @@ class GatherRow: ksize: int scaled: int - # non-essential - intersect_bp: int = None - f_orig_query: float = None - f_match: float = None - average_abund: float = None - median_abund: float = None - std_abund: float = None - filename: str = None - md5: str = None - f_match_orig: float = None - gather_result_rank: str = None - moltype: str = None + # non-essential, but used if available query_n_hashes: int = None - query_abundance: int = None - query_containment_ani: float = None - match_containment_ani: float = None - average_containment_ani: float = None - max_containment_ani: float = None - potential_false_negative: bool = None - n_unique_weighted_found: int = None sum_weighted_found: int = None total_weighted_hashes: int = None - query_containment_ani_low: float = None - query_containment_ani_high: float = None - match_containment_ani_low: float = None - match_containment_ani_high: float = None @dataclass @@ -1854,7 +1850,8 @@ class TaxResult(BaseTaxResult): with sourmash_args.FileInputCSV(gather_csv) as r: for row in enumerate(r): - gatherRow = GatherRow(**row) + filt_row = filter_row(row, GatherRow) # this filters any extra columns + gatherRow = GatherRow(**filt_row) # this checks for required columns and raises TypeError for any missing # initialize TaxResult tax_res = TaxResult(raw=gatherRow) diff --git a/tests/test_tax_utils.py b/tests/test_tax_utils.py index a362984532..bd0060b65a 100644 --- a/tests/test_tax_utils.py +++ b/tests/test_tax_utils.py @@ -37,6 +37,7 @@ LineageDB, LineageDB_Sqlite, MultiLineageDB, + filter_row, ) @@ -93,7 +94,8 @@ def make_GatherRow(gather_dict=None, exclude_cols=[]): gatherD.update(gather_dict) for col in exclude_cols: gatherD.pop(col) - gatherRaw = GatherRow(**gatherD) + fgatherD = filter_row(gatherD, GatherRow) + gatherRaw = GatherRow(**fgatherD) return gatherRaw @@ -807,6 +809,21 @@ def test_GatherRow_old_gather(): assert "__init__() missing 1 required positional argument: 'query_bp'" in str(exc) +def test_GatherRow_match_name_not_name(): + # gather contains match_name but not name column + gA = {"match_name": "gA.1 name"} + grow = make_GatherRow(gA, exclude_cols=["name"]) + print(grow) + assert grow.name == "gA.1 name" + + +def test_GatherRow_extra_cols(): + # gather contains extra columns + gA = {"not-a-col": "nope"} + grow = make_GatherRow(gA) + assert isinstance(grow, GatherRow) + + def test_get_ident_default(): ident = "GCF_001881345.1" n_id = get_ident(ident)