Skip to content

Commit

Permalink
MRG: re-establish tax gather reading flexibility (#2986)
Browse files Browse the repository at this point in the history
A while back, I introduced `GatherRow` to handle checking for required
gather columns for us. However, it ended up being overly restrictive --
any extra columns cause `gather_csv` reading to fail.

Here, I add a filtration step that lets us ignore unspecified columns
entirely before reading a GatherRow. Initializing the GatherRow after
this filtration continues to handle the checks for all required columns
while restoring flexibility.

As a consequence, we can actually delete all the `non-essential` names
in `GatherRow` and avoid carrying them around (saving some memory)

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
bluegenes and pre-commit-ci[bot] authored Feb 8, 2024
1 parent 427712c commit e5cdc36
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 28 deletions.
51 changes: 24 additions & 27 deletions src/sourmash/tax/tax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections import abc, defaultdict
from itertools import zip_longest
from typing import NamedTuple
from dataclasses import dataclass, field, replace, asdict
from dataclasses import dataclass, field, replace, asdict, fields
import gzip

from sourmash import sqlite_utils, sourmash_args
Expand Down Expand Up @@ -742,7 +742,10 @@ def load_gather_results(
for n, row in enumerate(r):
# try reading each gather row into a TaxResult
try:
gatherRow = GatherRow(**row)
filt_row = filter_row(
row, GatherRow
) # filter row first to allow extra (unused) columns in csv
gatherRow = GatherRow(**filt_row)
except TypeError as exc:
raise ValueError(
f"'{gather_csv}' is missing columns needed for taxonomic summarization. Please run gather with sourmash >= 4.4."
Expand Down Expand Up @@ -1675,6 +1678,20 @@ def load(cls, locations, **kwargs):
return tax_assign


def filter_row(row, dataclass_type):
"""
Filter the row to only include keys that exist in the dataclass fields.
This allows extra columns to be passed in with the gather csv while still
taking advantage of the checks for required columns that come with dataclass
initialization.
"""
valid_keys = {field.name for field in fields(dataclass_type)}
# 'match_name' and 'name' should be interchangeable (sourmash 4.x)
if "match_name" in row.keys() and "name" not in row.keys():
row["name"] = row.pop("match_name")
return {k: v for k, v in row.items() if k in valid_keys}


@dataclass
class GatherRow:
"""
Expand All @@ -1689,7 +1706,8 @@ class GatherRow:
with sourmash_args.FileInputCSV(gather_csv) as r:
for row in enumerate(r):
gatherRow = GatherRow(**row)
filt_row = filter_row(row, GatherRow) # filter first to allow extra columns
gatherRow = GatherRow(**filt_row)
"""

# essential columns
Expand All @@ -1706,32 +1724,10 @@ class GatherRow:
ksize: int
scaled: int

# non-essential
intersect_bp: int = None
f_orig_query: float = None
f_match: float = None
average_abund: float = None
median_abund: float = None
std_abund: float = None
filename: str = None
md5: str = None
f_match_orig: float = None
gather_result_rank: str = None
moltype: str = None
# non-essential, but used if available
query_n_hashes: int = None
query_abundance: int = None
query_containment_ani: float = None
match_containment_ani: float = None
average_containment_ani: float = None
max_containment_ani: float = None
potential_false_negative: bool = None
n_unique_weighted_found: int = None
sum_weighted_found: int = None
total_weighted_hashes: int = None
query_containment_ani_low: float = None
query_containment_ani_high: float = None
match_containment_ani_low: float = None
match_containment_ani_high: float = None


@dataclass
Expand Down Expand Up @@ -1854,7 +1850,8 @@ class TaxResult(BaseTaxResult):
with sourmash_args.FileInputCSV(gather_csv) as r:
for row in enumerate(r):
gatherRow = GatherRow(**row)
filt_row = filter_row(row, GatherRow) # this filters any extra columns
gatherRow = GatherRow(**filt_row) # this checks for required columns and raises TypeError for any missing
# initialize TaxResult
tax_res = TaxResult(raw=gatherRow)
Expand Down
19 changes: 18 additions & 1 deletion tests/test_tax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
LineageDB,
LineageDB_Sqlite,
MultiLineageDB,
filter_row,
)


Expand Down Expand Up @@ -93,7 +94,8 @@ def make_GatherRow(gather_dict=None, exclude_cols=[]):
gatherD.update(gather_dict)
for col in exclude_cols:
gatherD.pop(col)
gatherRaw = GatherRow(**gatherD)
fgatherD = filter_row(gatherD, GatherRow)
gatherRaw = GatherRow(**fgatherD)
return gatherRaw


Expand Down Expand Up @@ -807,6 +809,21 @@ def test_GatherRow_old_gather():
assert "__init__() missing 1 required positional argument: 'query_bp'" in str(exc)


def test_GatherRow_match_name_not_name():
# gather contains match_name but not name column
gA = {"match_name": "gA.1 name"}
grow = make_GatherRow(gA, exclude_cols=["name"])
print(grow)
assert grow.name == "gA.1 name"


def test_GatherRow_extra_cols():
# gather contains extra columns
gA = {"not-a-col": "nope"}
grow = make_GatherRow(gA)
assert isinstance(grow, GatherRow)


def test_get_ident_default():
ident = "GCF_001881345.1"
n_id = get_ident(ident)
Expand Down

0 comments on commit e5cdc36

Please sign in to comment.