diff --git a/bento_reference_service/db.py b/bento_reference_service/db.py index d6048ec..f838982 100644 --- a/bento_reference_service/db.py +++ b/bento_reference_service/db.py @@ -90,10 +90,28 @@ def deserialize_genome(self, rec: asyncpg.Record, external_resource_uris: bool) taxon=OntologyTerm(id=rec["taxon_id"], label=rec["taxon_label"]), ) - async def _select_genomes(self, g_id: str | None, external_resource_uris: bool) -> AsyncIterator[GenomeWithURIs]: + async def _select_genomes( + self, + g_ids: list[str] | None, + taxon_id: str | None = None, + external_resource_uris: bool = False, + ) -> AsyncIterator[GenomeWithURIs]: + where_items: list[str] = [] + q_params: list[str | int] = [] + + def _q_param(pv: str | int) -> str: + q_params.append(pv) + return f"${len(q_params)}" + + if g_ids: + g_id_ors = " OR ".join(f"g.id = {_q_param(g_id)}" for g_id in g_ids) + where_items.append(f"({g_id_ors})") + + if taxon_id: + where_items.append(f"taxon_id = {_q_param(taxon_id)}") + conn: asyncpg.Connection async with self.connect() as conn: - where_clause = "WHERE g.id = $1" if g_id is not None else "" res = await conn.fetch( f""" SELECT @@ -122,19 +140,21 @@ async def _select_genomes(self, g_id: str | None, external_resource_uris: bool) ) SELECT jsonb_agg(contigs_tmp.*) FROM contigs_tmp ) contigs - FROM genomes g {where_clause} + FROM genomes g {('WHERE ' + ' AND '.join(where_items)) if where_items else ''} """, - *((g_id,) if g_id is not None else ()), + *q_params, ) for r in map(lambda g: self.deserialize_genome(g, external_resource_uris), res): yield r - async def get_genomes(self, external_resource_uris: bool = False) -> tuple[GenomeWithURIs, ...]: - return tuple([r async for r in self._select_genomes(None, external_resource_uris)]) + async def get_genomes( + self, g_ids: list[str] | None = None, taxon_id: str | None = None, external_resource_uris: bool = False + ) -> tuple[GenomeWithURIs, ...]: + return tuple([r async for r in self._select_genomes(g_ids, taxon_id, external_resource_uris)]) - async def get_genome(self, g_id: str, external_resource_uris: bool = False) -> GenomeWithURIs | None: - return await anext(self._select_genomes(g_id, external_resource_uris), None) + async def get_genome(self, g_id: str, *, external_resource_uris: bool = False) -> GenomeWithURIs | None: + return await anext(self._select_genomes([g_id], external_resource_uris=external_resource_uris), None) async def delete_genome(self, g_id: str) -> None: conn: asyncpg.Connection @@ -165,7 +185,7 @@ async def get_genome_and_contig_by_checksum_str( chk_norm, ) - genome_res = (await anext(self._select_genomes(contig_res["genome_id"], False), None)) if contig_res else None + genome_res = (await anext(self._select_genomes([contig_res["genome_id"]]), None)) if contig_res else None if genome_res is None or contig_res is None: return None return genome_res, self.deserialize_contig(contig_res) diff --git a/bento_reference_service/routers/genomes.py b/bento_reference_service/routers/genomes.py index 65fef29..094f1ce 100644 --- a/bento_reference_service/routers/genomes.py +++ b/bento_reference_service/routers/genomes.py @@ -33,9 +33,12 @@ async def get_genome_or_raise_404( @genome_router.get("", dependencies=[authz_middleware.dep_public_endpoint()]) async def genomes_list( - db: DatabaseDependency, response_format: str | None = None + db: DatabaseDependency, + ids: Annotated[list[str] | None, Query()] = None, + taxon_id: str | None = None, + response_format: str | None = None, ) -> tuple[m.GenomeWithURIs, ...] | tuple[str, ...]: - genomes = await db.get_genomes(external_resource_uris=True) + genomes = await db.get_genomes(ids, taxon_id, external_resource_uris=True) if response_format == "id_list": return tuple(g.id for g in genomes) # else, format as full response diff --git a/bento_reference_service/sql/schema.sql b/bento_reference_service/sql/schema.sql index 58abaea..e3289f4 100644 --- a/bento_reference_service/sql/schema.sql +++ b/bento_reference_service/sql/schema.sql @@ -11,6 +11,7 @@ CREATE TABLE IF NOT EXISTS genomes ( taxon_id VARCHAR(31) NOT NULL, -- e.g., NCBITaxon:9606 taxon_label TEXT NOT NULL -- e.g., Homo sapiens ); +CREATE INDEX IF NOT EXISTS genomes_id_trgm_idx ON genomes USING GIN (id gin_trgm_ops); -- Migration (v0.2.0): add genomes.gff3_uri and genomes.gff3_tbi_uri if they do not exist: ALTER TABLE genomes @@ -25,6 +26,7 @@ CREATE TABLE IF NOT EXISTS genome_aliases ( PRIMARY KEY (genome_id, alias) ); CREATE INDEX IF NOT EXISTS genome_aliases_genome_idx ON genome_aliases (genome_id); +CREATE INDEX IF NOT EXISTS genome_aliases_alias_trgm_idx ON genome_aliases USING GIN (alias gin_trgm_ops); CREATE TABLE IF NOT EXISTS genome_contigs ( genome_id VARCHAR(31) NOT NULL REFERENCES genomes ON DELETE CASCADE, diff --git a/tests/test_db.py b/tests/test_db.py index 81985ef..c3e3885 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -30,6 +30,22 @@ async def test_create_genome(db: Database, db_cleanup): await _set_up_hg38_subset_genome(db) +async def test_get_genomes(db: Database, db_cleanup): + # start with two genomes, so we validate that we get the right one(s) + await _set_up_sars_cov_2_genome(db) + await _set_up_hg38_subset_genome(db) + + assert len(await db.get_genomes(g_ids=[SARS_COV_2_GENOME_ID, TEST_GENOME_HG38_CHR1_F100K_OBJ.id])) == 2 + + res = await db.get_genomes(g_ids=[SARS_COV_2_GENOME_ID]) + assert len(res) == 1 + assert res[0].id == SARS_COV_2_GENOME_ID + + res = await db.get_genomes(taxon_id="NCBITaxon:9606") + assert len(res) == 1 + assert res[0].id == TEST_GENOME_HG38_CHR1_F100K_OBJ.id + + @pytest.mark.parametrize( "checksum,genome_id,contig_name", [