Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
noctillion committed Nov 13, 2024
1 parent 63f7716 commit 516e500
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 36 deletions.
8 changes: 2 additions & 6 deletions transcriptomics_data_service/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,7 @@ def _deserialize_gene_expression(self, rec: asyncpg.Record) -> GeneExpression:
# CRUD: gene_expression_norm
############################

async def fetch_gene_expressions_by_experiment_id(
self, experiment_result_id: str
) -> Tuple[GeneExpression, ...]:
async def fetch_gene_expressions_by_experiment_id(self, experiment_result_id: str) -> Tuple[GeneExpression, ...]:
"""
Fetch gene expressions for a specific experiment_result_id.
"""
Expand All @@ -147,9 +145,7 @@ async def fetch_gene_expressions_by_experiment_id(
res = await conn.fetch(query, experiment_result_id)
return tuple([self._deserialize_gene_expression(record) for record in res])

async def update_normalized_expressions(
self, expressions: List[GeneExpression], method: str
):
async def update_normalized_expressions(self, expressions: List[GeneExpression], method: str):
"""
Update the normalized expressions in the database using batch updates.
"""
Expand Down
1 change: 0 additions & 1 deletion transcriptomics_data_service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,3 @@ async def lifespan(_app: FastAPI):
app.include_router(ingest_router)
app.include_router(experiment_router)
app.include_router(normalization_router)

53 changes: 24 additions & 29 deletions transcriptomics_data_service/routers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,30 +29,29 @@ async def normalize(
Normalize gene expressions using the specified method for a given experiment_result_id.
"""
# method validation
if method not in ['tpm', 'tmm', 'getmm']:
if method not in ["tpm", "tmm", "getmm"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported normalization method: {method}"
status_code=status.HTTP_400_BAD_REQUEST, detail=f"Unsupported normalization method: {method}"
)

# load gene lengths
if method in ['tpm', 'getmm']:
if method in ["tpm", "getmm"]:
if gene_lengths_file is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Gene lengths file is required for {method.upper()} normalization."
detail=f"Gene lengths file is required for {method.upper()} normalization.",
)
gene_lengths = await _load_gene_lengths(gene_lengths_file)

raw_counts_df = await _fetch_raw_counts(db, experiment_result_id)

# normalization
if method == 'tpm':
if method == "tpm":
raw_counts_df, gene_lengths_series = _align_gene_lengths(raw_counts_df, gene_lengths)
normalized_df = read_counts2tpm(raw_counts_df, gene_lengths_series)
elif method == 'tmm':
elif method == "tmm":
normalized_df = tmm_normalization(raw_counts_df)
elif method == 'getmm':
elif method == "getmm":
raw_counts_df, gene_lengths_series = _align_gene_lengths(raw_counts_df, gene_lengths)
normalized_df = getmm_normalization(raw_counts_df, gene_lengths_series)

Expand All @@ -67,9 +66,9 @@ async def _load_gene_lengths(gene_lengths_file: UploadFile) -> pd.Series:
Load gene lengths from the uploaded file.
"""
content = await gene_lengths_file.read()
gene_lengths_df = pd.read_csv(StringIO(content.decode('utf-8')), index_col='GeneID')
gene_lengths_series = gene_lengths_df['GeneLength']
gene_lengths_series = gene_lengths_series.apply(pd.to_numeric, errors='raise')
gene_lengths_df = pd.read_csv(StringIO(content.decode("utf-8")), index_col="GeneID")
gene_lengths_series = gene_lengths_df["GeneLength"]
gene_lengths_series = gene_lengths_series.apply(pd.to_numeric, errors="raise")
return gene_lengths_series


Expand All @@ -84,15 +83,11 @@ async def _fetch_raw_counts(db, experiment_result_id: str) -> pd.DataFrame:

data = []
for expr in expressions:
data.append({
'GeneID': expr.gene_code,
'SampleID': expr.sample_id,
'RawCount': expr.raw_count
})
data.append({"GeneID": expr.gene_code, "SampleID": expr.sample_id, "RawCount": expr.raw_count})
df = pd.DataFrame(data)
raw_counts_df = df.pivot(index='GeneID', columns='SampleID', values='RawCount')
raw_counts_df = df.pivot(index="GeneID", columns="SampleID", values="RawCount")

raw_counts_df = raw_counts_df.apply(pd.to_numeric, errors='raise')
raw_counts_df = raw_counts_df.apply(pd.to_numeric, errors="raise")

return raw_counts_df

Expand All @@ -103,7 +98,9 @@ def _align_gene_lengths(raw_counts_df: pd.DataFrame, gene_lengths: pd.Series):
"""
common_genes = raw_counts_df.index.intersection(gene_lengths.index)
if common_genes.empty:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="No common genes between counts and gene lengths.")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="No common genes between counts and gene lengths."
)
raw_counts_df = raw_counts_df.loc[common_genes]
gene_lengths_series = gene_lengths.loc[common_genes]
return raw_counts_df, gene_lengths_series
Expand All @@ -115,34 +112,32 @@ async def _update_normalized_values(db, normalized_df: pd.DataFrame, experiment_
"""
# Fetch existing expressions to get raw_count values
existing_expressions = await db.fetch_gene_expressions_by_experiment_id(experiment_result_id)
raw_count_dict = {
(expr.gene_code, expr.sample_id): expr.raw_count for expr in existing_expressions
}
raw_count_dict = {(expr.gene_code, expr.sample_id): expr.raw_count for expr in existing_expressions}

normalized_df = normalized_df.reset_index().melt(
id_vars='GeneID', var_name='SampleID', value_name='NormalizedValue'
id_vars="GeneID", var_name="SampleID", value_name="NormalizedValue"
)

expressions = []
for _, row in normalized_df.iterrows():
gene_code = row['GeneID']
sample_id = row['SampleID']
gene_code = row["GeneID"]
sample_id = row["SampleID"]
raw_count = raw_count_dict.get((gene_code, sample_id))

if raw_count is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Raw count not found for gene {gene_code}, sample {sample_id}"
detail=f"Raw count not found for gene {gene_code}, sample {sample_id}",
)

gene_expression = GeneExpression(
gene_code=gene_code,
sample_id=sample_id,
experiment_result_id=experiment_result_id,
raw_count=raw_count,
tpm_count=row['NormalizedValue'] if method == 'tpm' else None,
tmm_count=row['NormalizedValue'] if method == 'tmm' else None,
getmm_count=row['NormalizedValue'] if method == 'getmm' else None,
tpm_count=row["NormalizedValue"] if method == "tpm" else None,
tmm_count=row["NormalizedValue"] if method == "tmm" else None,
getmm_count=row["NormalizedValue"] if method == "getmm" else None,
)
expressions.append(gene_expression)

Expand Down
1 change: 1 addition & 0 deletions transcriptomics_data_service/scripts/normalize.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pandas as pd


def read_counts2tpm(counts_df, gene_lengths, scale_library=1e6, scale_length=1e3):
"""
Convert raw read counts to TPM (Transcripts Per Million).
Expand Down

0 comments on commit 516e500

Please sign in to comment.