diff --git a/transcriptomics_data_service/db.py b/transcriptomics_data_service/db.py index 92f793f..c839f4f 100644 --- a/transcriptomics_data_service/db.py +++ b/transcriptomics_data_service/db.py @@ -1,5 +1,5 @@ import logging -from typing import Annotated, AsyncIterator, List, Tuple +from typing import Annotated, AsyncIterator, List, Tuple, Optional import asyncpg from bento_lib.db.pg_async import PgAsyncDatabase from contextlib import asynccontextmanager @@ -111,9 +111,6 @@ async def create_gene_expressions(self, expressions: list[GeneExpression], trans await transaction_conn.executemany(query, records) self.logger.info(f"Inserted {len(records)} gene expression records.") - async def fetch_expressions(self) -> tuple[GeneExpression, ...]: - return tuple([r async for r in self._select_expressions(None)]) - async def _select_expressions(self, exp_id: str | None) -> AsyncIterator[GeneExpression]: conn: asyncpg.Connection where_clause = "WHERE experiment_result_id = $1" if exp_id is not None else "" @@ -123,28 +120,6 @@ async def _select_expressions(self, exp_id: str | None) -> AsyncIterator[GeneExp for r in map(lambda g: self._deserialize_gene_expression(g), res): yield r - async def fetch_gene_expressions_by_experiment_id(self, experiment_result_id: str) -> Tuple[GeneExpression, ...]: - """ - Fetch gene expressions for a specific experiment_result_id. - """ - conn: asyncpg.Connection - async with self.connect() as conn: - query = """ - SELECT * FROM gene_expressions WHERE experiment_result_id = $1 - """ - res = await conn.fetch(query, experiment_result_id) - return tuple([self._deserialize_gene_expression(record) for record in res]) - - async def fetch_gene_expressions( - self, experiments: list[str], method: str = "raw", paginate: bool = False - ) -> Tuple[Tuple[GeneExpression, ...], int]: - if not experiments: - return (), 0 - # TODO: refactor this fetch_gene_expressions_by_experiment_id and implement pagination - experiment_result_id = experiments[0] - expressions = await self.fetch_gene_expressions_by_experiment_id(experiment_result_id) - return expressions, len(expressions) - def _deserialize_gene_expression(self, rec: asyncpg.Record) -> GeneExpression: return GeneExpression( gene_code=rec["gene_code"], @@ -219,6 +194,83 @@ async def transaction_connection(self): # operations must be made using this connection for the transaction to apply yield conn + async def fetch_gene_expressions( + self, + genes: Optional[List[str]] = None, + experiments: Optional[List[str]] = None, + sample_ids: Optional[List[str]] = None, + method: str = "raw", + page: int = 1, + page_size: int = 100, + paginate: bool = True, + ) -> Tuple[List[GeneExpression], int]: + """ + Fetch gene expressions based on genes, experiments, sample_ids, and method, with optional pagination. + Returns a tuple of (expressions list, total_records count). + """ + conn: asyncpg.Connection + async with self.connect() as conn: + # Query builder + base_query = """ + SELECT gene_code, sample_id, experiment_result_id, raw_count, tpm_count, tmm_count, getmm_count + FROM gene_expressions + """ + count_query = "SELECT COUNT(*) FROM gene_expressions" + conditions = [] + params = [] + param_counter = 1 + + if genes: + conditions.append(f"gene_code = ANY(${param_counter}::text[])") + params.append(genes) + param_counter += 1 + + if experiments: + conditions.append(f"experiment_result_id = ANY(${param_counter}::text[])") + params.append(experiments) + param_counter += 1 + + if sample_ids: + conditions.append(f"sample_id = ANY(${param_counter}::text[])") + params.append(sample_ids) + param_counter += 1 + + if method != "raw": + conditions.append(f"{method}_count IS NOT NULL") + + where_clause = " WHERE " + " AND ".join(conditions) if conditions else "" + + order_clause = " ORDER BY gene_code, sample_id" + + query = base_query + where_clause + order_clause + count_query += where_clause + + # Pagination + if paginate: + limit_offset_clause = f" LIMIT ${param_counter} OFFSET ${param_counter + 1}" + params.extend([page_size, (page - 1) * page_size]) + query += limit_offset_clause + + total_records_params = params[:-2] if paginate else params + total_records = await conn.fetchval(count_query, *total_records_params) + + res = await conn.fetch(query, *params) + + expressions = [ + GeneExpression( + gene_code=record["gene_code"], + sample_id=record["sample_id"], + experiment_result_id=record["experiment_result_id"], + raw_count=record["raw_count"], + tpm_count=record["tpm_count"], + tmm_count=record["tmm_count"], + getmm_count=record["getmm_count"], + ) + for record in res + ] + + return expressions, total_records + @lru_cache() def get_db(config: ConfigDependency, logger: LoggerDependency) -> Database: diff --git a/transcriptomics_data_service/main.py b/transcriptomics_data_service/main.py index e76ccc0..783462b 100644 --- a/transcriptomics_data_service/main.py +++ b/transcriptomics_data_service/main.py @@ -4,9 +4,9 @@ from transcriptomics_data_service.db import get_db from transcriptomics_data_service.routers.experiment_results import experiment_router -from transcriptomics_data_service.routers.expressions import expression_router from transcriptomics_data_service.routers.ingest import ingest_router from transcriptomics_data_service.routers.normalization import normalization_router +from transcriptomics_data_service.routers.query import query_router from . import __version__ from .config import get_config from .constants import BENTO_SERVICE_KIND, SERVICE_TYPE @@ -42,7 +42,7 @@ async def lifespan(_app: FastAPI): lifespan=lifespan, ) -app.include_router(expression_router) app.include_router(ingest_router) app.include_router(experiment_router) app.include_router(normalization_router) +app.include_router(query_router) diff --git a/transcriptomics_data_service/models.py b/transcriptomics_data_service/models.py index ab4317b..3b031ed 100644 --- a/transcriptomics_data_service/models.py +++ b/transcriptomics_data_service/models.py @@ -1,22 +1,73 @@ -from pydantic import BaseModel +from pydantic import BaseModel, Field, validator +from typing import List, Optional +from enum import Enum __all__ = [ "ExperimentResult", "GeneExpression", + "GeneExpressionData", + "PaginationMeta", + "GeneExpressionResponse", + "MethodEnum", + "QueryParameters", ] +class PaginatedRequest(BaseModel): + page: int = Field(1, ge=1, description="Current page number") + page_size: int = Field(100, ge=1, le=1000, description="Number of records per page") + + +class PaginatedResponse(PaginatedRequest): + total_records: int = Field(..., ge=0, description="Total number of records") + total_pages: int = Field(..., ge=1, description="Total number of pages") + + class ExperimentResult(BaseModel): - experiment_result_id: str - assembly_id: str | None = None - assembly_name: str | None = None + experiment_result_id: str = Field(..., min_length=1, max_length=255) + assembly_id: Optional[str] = Field(None, max_length=255) + assembly_name: Optional[str] = Field(None, max_length=255) class GeneExpression(BaseModel): - gene_code: str - sample_id: str - experiment_result_id: str + gene_code: str = Field(..., min_length=1, max_length=255) + sample_id: str = Field(..., min_length=1, max_length=255) + experiment_result_id: str = Field(..., min_length=1, max_length=255) raw_count: int - tpm_count: float | None = None - tmm_count: float | None = None - getmm_count: float | None = None + tpm_count: Optional[float] = None + tmm_count: Optional[float] = None + getmm_count: Optional[float] = None + + +class GeneExpressionData(BaseModel): + gene_code: str = Field(..., min_length=1, max_length=255, description="Gene code") + sample_id: str = Field(..., min_length=1, max_length=255, description="Sample ID") + experiment_result_id: str = Field(..., min_length=1, max_length=255, description="Experiment result ID") + count: float = Field(..., description="Expression count") + method: str = Field(..., description="Method used to calculate the expression count") + + +class GeneExpressionResponse(PaginatedResponse): + expressions: List[GeneExpressionData] + + +class MethodEnum(str, Enum): + raw = "raw" + tpm = "tpm" + tmm = "tmm" + getmm = "getmm" + + +class QueryParameters(PaginatedRequest): + genes: Optional[List[str]] = Field(None, description="List of gene codes to retrieve") + experiments: Optional[List[str]] = Field(None, description="List of experiment result IDs to retrieve data from") + sample_ids: Optional[List[str]] = Field(None, description="List of sample IDs to retrieve data from") + method: MethodEnum = Field(MethodEnum.raw, description="Data method to retrieve: 'raw', 'tpm', 'tmm', 'getmm'") + + @validator("genes", "experiments", "sample_ids", each_item=True) + def validate_identifiers(cls, value): + if not (1 <= len(value) <= 255): + raise ValueError("Each identifier must be between 1 and 255 characters long.") + if not value.replace("_", "").isalnum(): + raise ValueError("Identifiers must contain only alphanumeric characters and underscores.") + return value diff --git a/transcriptomics_data_service/routers/query.py b/transcriptomics_data_service/routers/query.py new file mode 100644 index 0000000..8700b23 --- /dev/null +++ b/transcriptomics_data_service/routers/query.py @@ -0,0 +1,87 @@ +from fastapi import APIRouter, HTTPException, status, Query + +from transcriptomics_data_service.db import DatabaseDependency +from transcriptomics_data_service.logger import LoggerDependency +from transcriptomics_data_service.models import ( + GeneExpressionData, + GeneExpressionResponse, + MethodEnum, + QueryParameters, +) + +query_router = APIRouter() + + +async def get_expressions_handler( + params: QueryParameters, + db: DatabaseDependency, + logger: LoggerDependency, +): + """ + Handler for fetching and returning gene expression data. + """ + logger.info(f"Received query parameters: {params}") + + expressions, total_records = await db.fetch_gene_expressions( + genes=params.genes, + experiments=params.experiments, + sample_ids=params.sample_ids, + method=params.method.value, + page=params.page, + page_size=params.page_size, + ) + + if not expressions: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="No gene expression data found for the given parameters.", + ) + + response_data = [] + method_field = f"{params.method.value}_count" if params.method != MethodEnum.raw else "raw_count" + for expr in expressions: + count = getattr(expr, method_field) + response_item = GeneExpressionData( + gene_code=expr.gene_code, + sample_id=expr.sample_id, + experiment_result_id=expr.experiment_result_id, + count=count, + method=method_field, + ) + response_data.append(response_item) + + total_pages = (total_records + params.page_size - 1) // params.page_size + + return GeneExpressionResponse( + expressions=response_data, + total_records=total_records, + page=params.page, + page_size=params.page_size, + total_pages=total_pages, + ) + + +@query_router.post( + "/expressions", + status_code=status.HTTP_200_OK, + response_model=GeneExpressionResponse, +) +async def get_expressions_post( + params: QueryParameters, + db: DatabaseDependency, + logger: LoggerDependency, +): + """ + Retrieve gene expression data via POST request. + + Example JSON body: + { + "genes": ["gene1", "gene2"], + "experiments": ["exp1"], + "sample_ids": ["sample1"], + "method": "tmm", + "page": 1, + "page_size": 100 + } + """ + return await get_expressions_handler(params, db, logger) diff --git a/transcriptomics_data_service/scripts/normalize.py b/transcriptomics_data_service/scripts/normalize.py index 9bcd402..dd7f169 100644 --- a/transcriptomics_data_service/scripts/normalize.py +++ b/transcriptomics_data_service/scripts/normalize.py @@ -108,7 +108,9 @@ def tmm_normalization(counts_df, logratio_trim=0.3, sum_trim=0.05, weighting=Tru return normalized_data -def getmm_normalization(counts_df, gene_lengths, logratio_trim=0.3, sum_trim=0.05, scaling_factor=1e3, weighting=True, n_jobs=-1): +def getmm_normalization( + counts_df, gene_lengths, logratio_trim=0.3, sum_trim=0.05, scaling_factor=1e3, weighting=True, n_jobs=-1 +): """Perform GeTMM normalization on counts data.""" counts_df, gene_lengths = prepare_counts_and_lengths(counts_df, gene_lengths) rpk = counts_df.mul(scaling_factor).div(gene_lengths, axis=0)