Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Normalization #11

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
881 changes: 434 additions & 447 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ jsonschema = "^4.21.1"
pydantic-settings = "^2.1.0"
asyncpg = "^0.29.0"
pandas = "^2.2.3"
conorm = "^1.2.0"
noctillion marked this conversation as resolved.
Show resolved Hide resolved

[tool.poetry.group.dev.dependencies]
aioresponses = "^0.7.6"
Expand Down
114 changes: 95 additions & 19 deletions transcriptomics_data_service/db.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Annotated, AsyncIterator
from typing import Annotated, AsyncIterator, List, Tuple
import asyncpg
from bento_lib.db.pg_async import PgAsyncDatabase
from contextlib import asynccontextmanager
Expand Down Expand Up @@ -83,26 +83,28 @@ async def create_gene_expressions(self, expressions: list[GeneExpression], trans
Rows on gene_expressions can only be created as part of an RCM ingestion.
Ingestion is all-or-nothing, hence the transaction.
"""
async with transaction_conn.transaction():
# sub-transaction
for gene_expression in expressions:
await self._create_gene_expression(gene_expression, transaction_conn)
# Prepare data for bulk insertion
records = [
noctillion marked this conversation as resolved.
Show resolved Hide resolved
(
expr.gene_code,
expr.sample_id,
expr.experiment_result_id,
expr.raw_count,
expr.tpm_count,
expr.tmm_count,
expr.getmm_count,
)
for expr in expressions
]

async def _create_gene_expression(self, expression: GeneExpression, transaction_conn: asyncpg.Connection):
# Creates a row on gene_expressions within a transaction.
query = """
INSERT INTO gene_expressions (gene_code, sample_id, experiment_result_id, raw_count, tpm_count, tmm_count)
VALUES ($1, $2, $3, $4, $5, $6)
INSERT INTO gene_expressions (
gene_code, sample_id, experiment_result_id, raw_count, tpm_count, tmm_count, getmm_count
) VALUES ($1, $2, $3, $4, $5, $6, $7)
"""
await transaction_conn.execute(
query,
expression.gene_code,
expression.sample_id,
expression.experiment_result_id,
expression.raw_count,
expression.tpm_count,
expression.tmm_count,
)

await transaction_conn.executemany(query, records)
self.logger.info(f"Inserted {len(records)} gene expression records.")

async def fetch_expressions(self) -> tuple[GeneExpression, ...]:
return tuple([r async for r in self._select_expressions(None)])
Expand All @@ -112,7 +114,7 @@ async def _select_expressions(self, exp_id: str | None) -> AsyncIterator[GeneExp
where_clause = "WHERE experiment_result_id = $1" if exp_id is not None else ""
query = f"SELECT * FROM gene_expressions {where_clause}"
async with self.connect() as conn:
res = await conn.fetch(query, *((exp_id) if exp_id is not None else ()))
res = await conn.fetch(query, *(exp_id,) if exp_id is not None else ())
for r in map(lambda g: self._deserialize_gene_expression(g), res):
yield r

Expand All @@ -124,8 +126,82 @@ def _deserialize_gene_expression(self, rec: asyncpg.Record) -> GeneExpression:
raw_count=rec["raw_count"],
tpm_count=rec["tpm_count"],
tmm_count=rec["tmm_count"],
getmm_count=rec["getmm_count"],
)

############################
# CRUD: gene_expression_norm
############################

async def fetch_gene_expressions_by_experiment_id(self, experiment_result_id: str) -> Tuple[GeneExpression, ...]:
noctillion marked this conversation as resolved.
Show resolved Hide resolved
"""
Fetch gene expressions for a specific experiment_result_id.
"""
conn: asyncpg.Connection
async with self.connect() as conn:
query = """
SELECT * FROM gene_expressions WHERE experiment_result_id = $1
"""
res = await conn.fetch(query, experiment_result_id)
return tuple([self._deserialize_gene_expression(record) for record in res])

async def update_normalized_expressions(self, expressions: List[GeneExpression], method: str):
noctillion marked this conversation as resolved.
Show resolved Hide resolved
"""
Update the normalized expressions in the database using batch updates.
"""
conn: asyncpg.Connection
async with self.connect() as conn:
noctillion marked this conversation as resolved.
Show resolved Hide resolved
async with conn.transaction():
if method == "tpm":
column = "tpm_count"
elif method == "tmm":
column = "tmm_count"
elif method == "getmm":
column = "getmm_count"
noctillion marked this conversation as resolved.
Show resolved Hide resolved
else:
raise ValueError(f"Unsupported normalization method: {method}")

# Prepare data for bulk update
records = [
(
getattr(expr, column),
expr.experiment_result_id,
expr.gene_code,
expr.sample_id,
)
for expr in expressions
]

await conn.execute(
f"""
CREATE TEMPORARY TABLE temp_updates (
value DOUBLE PRECISION,
experiment_result_id VARCHAR(255),
gene_code VARCHAR(255),
sample_id VARCHAR(255)
) ON COMMIT DROP
"""
)

await conn.copy_records_to_table(
"temp_updates",
records=records,
columns=["value", "experiment_result_id", "gene_code", "sample_id"],
)

# Update the main table
await conn.execute(
f"""
UPDATE gene_expressions
SET {column} = temp_updates.value
FROM temp_updates
WHERE gene_expressions.experiment_result_id = temp_updates.experiment_result_id
AND gene_expressions.gene_code = temp_updates.gene_code
AND gene_expressions.sample_id = temp_updates.sample_id
"""
)
self.logger.info(f"Updated normalized values for method '{method}'.")

@asynccontextmanager
async def transaction_connection(self):
conn: asyncpg.Connection
Expand Down
2 changes: 2 additions & 0 deletions transcriptomics_data_service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from transcriptomics_data_service.routers.experiment_results import experiment_router
from transcriptomics_data_service.routers.expressions import expression_router
from transcriptomics_data_service.routers.ingest import ingest_router
from transcriptomics_data_service.routers.normalization import normalization_router
from . import __version__
from .config import get_config
from .constants import BENTO_SERVICE_KIND, SERVICE_TYPE
Expand Down Expand Up @@ -44,3 +45,4 @@ async def lifespan(_app: FastAPI):
app.include_router(expression_router)
app.include_router(ingest_router)
app.include_router(experiment_router)
app.include_router(normalization_router)
1 change: 1 addition & 0 deletions transcriptomics_data_service/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ class GeneExpression(BaseModel):
raw_count: int
tpm_count: float | None = None
tmm_count: float | None = None
getmm_count: float | None = None
15 changes: 0 additions & 15 deletions transcriptomics_data_service/routers/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,18 +81,3 @@ def _load_csv(file_bytes: bytes, logger: Logger) -> pd.DataFrame:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Error parsing CSV: {e}")
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Value error in CSV data: {e}")


@ingest_router.post("/normalize/{experiment_result_id}")
async def normalize(
db: DatabaseDependency,
experiment_result_id: str,
features_lengths_file: UploadFile = File(...),
status_code=status.HTTP_200_OK,
):
features_lengths = json.load(features_lengths_file.file)
# TODO validate shape
# TODO validate experiment_result_id exists
# TODO algorithm selection argument?
# TODO perform the normalization in a transaction
return
153 changes: 153 additions & 0 deletions transcriptomics_data_service/routers/normalization.py
noctillion marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
from fastapi import APIRouter, HTTPException, UploadFile, File, status
import pandas as pd
from io import StringIO

from transcriptomics_data_service.db import DatabaseDependency
from transcriptomics_data_service.models import GeneExpression
from transcriptomics_data_service.scripts.normalize import (
read_counts2tpm,
tmm_normalization,
getmm_normalization,
)

# Constants for normalization methods
NORM_TPM = "tpm"
NORM_TMM = "tmm"
NORM_GETMM = "getmm"

# List of all valid normalization methods
VALID_METHODS = [NORM_TPM, NORM_TMM, NORM_GETMM]

__all__ = ["normalization_router"]

normalization_router = APIRouter(prefix="/normalize")


@normalization_router.post(
"/{method}/{experiment_result_id}",
noctillion marked this conversation as resolved.
Show resolved Hide resolved
status_code=status.HTTP_200_OK,
)
async def normalize(
method: str,
experiment_result_id: str,
db: DatabaseDependency,
gene_lengths_file: UploadFile = File(None),
):
"""
Normalize gene expressions using the specified method for a given experiment_result_id.
"""
# method validation
if method not in VALID_METHODS:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=f"Unsupported normalization method: {method}"
)

# load gene lengths
if method in [NORM_TPM, NORM_GETMM]:
if gene_lengths_file is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Gene lengths file is required for {method.upper()} normalization.",
)
gene_lengths = await _load_gene_lengths(gene_lengths_file)

raw_counts_df = await _fetch_raw_counts(db, experiment_result_id)

# normalization
if method == NORM_TPM:
raw_counts_df, gene_lengths_series = _align_gene_lengths(raw_counts_df, gene_lengths)
normalized_df = read_counts2tpm(raw_counts_df, gene_lengths_series)
elif method == NORM_TMM:
normalized_df = tmm_normalization(raw_counts_df)
elif method == NORM_GETMM:
raw_counts_df, gene_lengths_series = _align_gene_lengths(raw_counts_df, gene_lengths)
normalized_df = getmm_normalization(raw_counts_df, gene_lengths_series)

# database update using normalized values
await _update_normalized_values(db, normalized_df, experiment_result_id, method=method)

return {"message": f"{method.upper()} normalization completed successfully"}


async def _load_gene_lengths(gene_lengths_file: UploadFile) -> pd.Series:
"""
Load gene lengths from the uploaded file.
"""
content = await gene_lengths_file.read()
gene_lengths_df = pd.read_csv(StringIO(content.decode("utf-8")), index_col="GeneID")
gene_lengths_series = gene_lengths_df["GeneLength"]
gene_lengths_series = gene_lengths_series.apply(pd.to_numeric, errors="raise")
return gene_lengths_series


async def _fetch_raw_counts(db, experiment_result_id: str) -> pd.DataFrame:
"""
Fetch raw counts from the database for the given experiment_result_id.
Returns a DataFrame with genes as rows and samples as columns.
"""
expressions = await db.fetch_gene_expressions_by_experiment_id(experiment_result_id)
if not expressions:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Experiment result not found.")

data = []
for expr in expressions:
data.append({"GeneID": expr.gene_code, "SampleID": expr.sample_id, "RawCount": expr.raw_count})
df = pd.DataFrame(data)
raw_counts_df = df.pivot(index="GeneID", columns="SampleID", values="RawCount")

raw_counts_df = raw_counts_df.apply(pd.to_numeric, errors="raise")

return raw_counts_df


def _align_gene_lengths(raw_counts_df: pd.DataFrame, gene_lengths: pd.Series):
"""
Align the gene lengths with the raw counts DataFrame based on GeneID.
"""
common_genes = raw_counts_df.index.intersection(gene_lengths.index)
if common_genes.empty:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="No common genes between counts and gene lengths."
)
raw_counts_df = raw_counts_df.loc[common_genes]
gene_lengths_series = gene_lengths.loc[common_genes]
return raw_counts_df, gene_lengths_series


async def _update_normalized_values(db, normalized_df: pd.DataFrame, experiment_result_id: str, method: str):
"""
Update the normalized values in the database
"""
# Fetch existing expressions to get raw_count values
existing_expressions = await db.fetch_gene_expressions_by_experiment_id(experiment_result_id)
raw_count_dict = {(expr.gene_code, expr.sample_id): expr.raw_count for expr in existing_expressions}

normalized_df = normalized_df.reset_index().melt(
id_vars="GeneID", var_name="SampleID", value_name="NormalizedValue"
)

expressions = []
for _, row in normalized_df.iterrows():
gene_code = row["GeneID"]
sample_id = row["SampleID"]
raw_count = raw_count_dict.get((gene_code, sample_id))

if raw_count is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Raw count not found for gene {gene_code}, sample {sample_id}",
)

gene_expression = GeneExpression(
gene_code=gene_code,
sample_id=sample_id,
experiment_result_id=experiment_result_id,
raw_count=raw_count,
tpm_count=row["NormalizedValue"] if method == NORM_TPM else None,
tmm_count=row["NormalizedValue"] if method == NORM_TMM else None,
getmm_count=row["NormalizedValue"] if method == NORM_GETMM else None,
noctillion marked this conversation as resolved.
Show resolved Hide resolved
)
expressions.append(gene_expression)

# Update expressions in the database
await db.update_normalized_expressions(expressions, method)
Loading