diff --git a/transcriptomics_data_service/db.py b/transcriptomics_data_service/db.py index f60fbd1..0b12878 100644 --- a/transcriptomics_data_service/db.py +++ b/transcriptomics_data_service/db.py @@ -133,9 +133,7 @@ def _deserialize_gene_expression(self, rec: asyncpg.Record) -> GeneExpression: # CRUD: gene_expression_norm ############################ - async def fetch_gene_expressions_by_experiment_id( - self, experiment_result_id: str - ) -> Tuple[GeneExpression, ...]: + async def fetch_gene_expressions_by_experiment_id(self, experiment_result_id: str) -> Tuple[GeneExpression, ...]: """ Fetch gene expressions for a specific experiment_result_id. """ @@ -147,9 +145,7 @@ async def fetch_gene_expressions_by_experiment_id( res = await conn.fetch(query, experiment_result_id) return tuple([self._deserialize_gene_expression(record) for record in res]) - async def update_normalized_expressions( - self, expressions: List[GeneExpression], method: str - ): + async def update_normalized_expressions(self, expressions: List[GeneExpression], method: str): """ Update the normalized expressions in the database using batch updates. """ diff --git a/transcriptomics_data_service/main.py b/transcriptomics_data_service/main.py index 4b58b7d..e76ccc0 100644 --- a/transcriptomics_data_service/main.py +++ b/transcriptomics_data_service/main.py @@ -46,4 +46,3 @@ async def lifespan(_app: FastAPI): app.include_router(ingest_router) app.include_router(experiment_router) app.include_router(normalization_router) - diff --git a/transcriptomics_data_service/routers/normalization.py b/transcriptomics_data_service/routers/normalization.py index a92bcd8..d36729e 100644 --- a/transcriptomics_data_service/routers/normalization.py +++ b/transcriptomics_data_service/routers/normalization.py @@ -29,30 +29,29 @@ async def normalize( Normalize gene expressions using the specified method for a given experiment_result_id. """ # method validation - if method not in ['tpm', 'tmm', 'getmm']: + if method not in ["tpm", "tmm", "getmm"]: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Unsupported normalization method: {method}" + status_code=status.HTTP_400_BAD_REQUEST, detail=f"Unsupported normalization method: {method}" ) # load gene lengths - if method in ['tpm', 'getmm']: + if method in ["tpm", "getmm"]: if gene_lengths_file is None: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Gene lengths file is required for {method.upper()} normalization." + detail=f"Gene lengths file is required for {method.upper()} normalization.", ) gene_lengths = await _load_gene_lengths(gene_lengths_file) raw_counts_df = await _fetch_raw_counts(db, experiment_result_id) # normalization - if method == 'tpm': + if method == "tpm": raw_counts_df, gene_lengths_series = _align_gene_lengths(raw_counts_df, gene_lengths) normalized_df = read_counts2tpm(raw_counts_df, gene_lengths_series) - elif method == 'tmm': + elif method == "tmm": normalized_df = tmm_normalization(raw_counts_df) - elif method == 'getmm': + elif method == "getmm": raw_counts_df, gene_lengths_series = _align_gene_lengths(raw_counts_df, gene_lengths) normalized_df = getmm_normalization(raw_counts_df, gene_lengths_series) @@ -67,9 +66,9 @@ async def _load_gene_lengths(gene_lengths_file: UploadFile) -> pd.Series: Load gene lengths from the uploaded file. """ content = await gene_lengths_file.read() - gene_lengths_df = pd.read_csv(StringIO(content.decode('utf-8')), index_col='GeneID') - gene_lengths_series = gene_lengths_df['GeneLength'] - gene_lengths_series = gene_lengths_series.apply(pd.to_numeric, errors='raise') + gene_lengths_df = pd.read_csv(StringIO(content.decode("utf-8")), index_col="GeneID") + gene_lengths_series = gene_lengths_df["GeneLength"] + gene_lengths_series = gene_lengths_series.apply(pd.to_numeric, errors="raise") return gene_lengths_series @@ -84,15 +83,11 @@ async def _fetch_raw_counts(db, experiment_result_id: str) -> pd.DataFrame: data = [] for expr in expressions: - data.append({ - 'GeneID': expr.gene_code, - 'SampleID': expr.sample_id, - 'RawCount': expr.raw_count - }) + data.append({"GeneID": expr.gene_code, "SampleID": expr.sample_id, "RawCount": expr.raw_count}) df = pd.DataFrame(data) - raw_counts_df = df.pivot(index='GeneID', columns='SampleID', values='RawCount') + raw_counts_df = df.pivot(index="GeneID", columns="SampleID", values="RawCount") - raw_counts_df = raw_counts_df.apply(pd.to_numeric, errors='raise') + raw_counts_df = raw_counts_df.apply(pd.to_numeric, errors="raise") return raw_counts_df @@ -103,7 +98,9 @@ def _align_gene_lengths(raw_counts_df: pd.DataFrame, gene_lengths: pd.Series): """ common_genes = raw_counts_df.index.intersection(gene_lengths.index) if common_genes.empty: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="No common genes between counts and gene lengths.") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="No common genes between counts and gene lengths." + ) raw_counts_df = raw_counts_df.loc[common_genes] gene_lengths_series = gene_lengths.loc[common_genes] return raw_counts_df, gene_lengths_series @@ -115,24 +112,22 @@ async def _update_normalized_values(db, normalized_df: pd.DataFrame, experiment_ """ # Fetch existing expressions to get raw_count values existing_expressions = await db.fetch_gene_expressions_by_experiment_id(experiment_result_id) - raw_count_dict = { - (expr.gene_code, expr.sample_id): expr.raw_count for expr in existing_expressions - } + raw_count_dict = {(expr.gene_code, expr.sample_id): expr.raw_count for expr in existing_expressions} normalized_df = normalized_df.reset_index().melt( - id_vars='GeneID', var_name='SampleID', value_name='NormalizedValue' + id_vars="GeneID", var_name="SampleID", value_name="NormalizedValue" ) expressions = [] for _, row in normalized_df.iterrows(): - gene_code = row['GeneID'] - sample_id = row['SampleID'] + gene_code = row["GeneID"] + sample_id = row["SampleID"] raw_count = raw_count_dict.get((gene_code, sample_id)) if raw_count is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Raw count not found for gene {gene_code}, sample {sample_id}" + detail=f"Raw count not found for gene {gene_code}, sample {sample_id}", ) gene_expression = GeneExpression( @@ -140,9 +135,9 @@ async def _update_normalized_values(db, normalized_df: pd.DataFrame, experiment_ sample_id=sample_id, experiment_result_id=experiment_result_id, raw_count=raw_count, - tpm_count=row['NormalizedValue'] if method == 'tpm' else None, - tmm_count=row['NormalizedValue'] if method == 'tmm' else None, - getmm_count=row['NormalizedValue'] if method == 'getmm' else None, + tpm_count=row["NormalizedValue"] if method == "tpm" else None, + tmm_count=row["NormalizedValue"] if method == "tmm" else None, + getmm_count=row["NormalizedValue"] if method == "getmm" else None, ) expressions.append(gene_expression) diff --git a/transcriptomics_data_service/scripts/normalize.py b/transcriptomics_data_service/scripts/normalize.py index cca47bc..a1135ca 100644 --- a/transcriptomics_data_service/scripts/normalize.py +++ b/transcriptomics_data_service/scripts/normalize.py @@ -1,5 +1,6 @@ import pandas as pd + def read_counts2tpm(counts_df, gene_lengths, scale_library=1e6, scale_length=1e3): """ Convert raw read counts to TPM (Transcripts Per Million).