From c0a8428bd1c9f629b815d5ecc241168a105532c0 Mon Sep 17 00:00:00 2001 From: Julian Date: Tue, 10 Dec 2024 17:50:29 -0500 Subject: [PATCH 1/7] remove conorm --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a2a1218..3bb5a45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ jsonschema = "^4.21.1" pydantic-settings = "^2.1.0" asyncpg = "^0.29.0" pandas = "^2.2.3" -conorm = "^1.2.0" +joblib = "^1.4.2" [tool.poetry.group.dev.dependencies] aioresponses = "^0.7.6" From befc03c918ef4ff5751388b3378f904b08c17502 Mon Sep 17 00:00:00 2001 From: Julian Date: Thu, 12 Dec 2024 13:26:46 -0500 Subject: [PATCH 2/7] poetry --- poetry.lock | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/poetry.lock b/poetry.lock index a556a57..4927ed7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -529,20 +529,6 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] -[[package]] -name = "conorm" -version = "1.2.0" -description = "Normalization methods for RNA-seq count data." -optional = false -python-versions = ">=3.7" -files = [ - {file = "conorm-1.2.0.tar.gz", hash = "sha256:b4fe4a5d27b9e8c797c4db56f292a80f6e8bca401cbe92a62d9e4688864a986e"}, -] - -[package.dependencies] -numpy = "*" -pandas = "*" - [[package]] name = "coverage" version = "7.6.4" @@ -1062,6 +1048,17 @@ MarkupSafe = ">=2.0" [package.extras] i18n = ["Babel (>=2.7)"] +[[package]] +name = "joblib" +version = "1.4.2" +description = "Lightweight pipelining with Python functions" +optional = false +python-versions = ">=3.8" +files = [ + {file = "joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6"}, + {file = "joblib-1.4.2.tar.gz", hash = "sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e"}, +] + [[package]] name = "jsonschema" version = "4.23.0" @@ -2965,4 +2962,4 @@ propcache = ">=0.2.0" [metadata] lock-version = "2.0" python-versions = "^3.10.0" -content-hash = "8dcd26f85316220b00d12be77315972eb165cddd78ffb31fb5ce16d2fea05061" +content-hash = "e657110014e82d15e964ada3e9273d562b968bbfd50e7efcad372322b91fc327" From 3c3d9f4754c09ec730ab8aa61e1e9c9bfa4ba596 Mon Sep 17 00:00:00 2001 From: Julian Date: Thu, 12 Dec 2024 13:27:31 -0500 Subject: [PATCH 3/7] temporal fetch_gene_expressions change --- transcriptomics_data_service/db.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/transcriptomics_data_service/db.py b/transcriptomics_data_service/db.py index 91a4cad..636677a 100644 --- a/transcriptomics_data_service/db.py +++ b/transcriptomics_data_service/db.py @@ -135,6 +135,14 @@ async def fetch_gene_expressions_by_experiment_id(self, experiment_result_id: st res = await conn.fetch(query, experiment_result_id) return tuple([self._deserialize_gene_expression(record) for record in res]) + async def fetch_gene_expressions(self, experiments: list[str], method: str = "raw", paginate: bool = False) -> Tuple[Tuple[GeneExpression, ...], int]: + if not experiments: + return (), 0 + # TODO: refactor this fetch_gene_expressions_by_experiment_id and implement pagination + experiment_result_id = experiments[0] + expressions = await self.fetch_gene_expressions_by_experiment_id(experiment_result_id) + return expressions, len(expressions) + def _deserialize_gene_expression(self, rec: asyncpg.Record) -> GeneExpression: return GeneExpression( gene_code=rec["gene_code"], From 9e9e5c45b48cd97aef537e30225404fd360d23bf Mon Sep 17 00:00:00 2001 From: Julian Date: Thu, 12 Dec 2024 13:29:29 -0500 Subject: [PATCH 4/7] implement normalization methods --- .../scripts/normalize.py | 193 ++++++++++++------ 1 file changed, 126 insertions(+), 67 deletions(-) diff --git a/transcriptomics_data_service/scripts/normalize.py b/transcriptomics_data_service/scripts/normalize.py index a1135ca..d211dbf 100644 --- a/transcriptomics_data_service/scripts/normalize.py +++ b/transcriptomics_data_service/scripts/normalize.py @@ -1,72 +1,131 @@ import pandas as pd +import numpy as np +from joblib import Parallel, delayed +def filter_counts(counts_df): + """Filter out genes (rows) and samples (columns) with zero total counts.""" + row_filter = counts_df.sum(axis=1) > 0 + col_filter = counts_df.sum(axis=0) > 0 + return counts_df.loc[row_filter, col_filter] -def read_counts2tpm(counts_df, gene_lengths, scale_library=1e6, scale_length=1e3): - """ - Convert raw read counts to TPM (Transcripts Per Million). - - Parameters: - counts_df (DataFrame): DataFrame with genes as rows and samples as columns. - gene_lengths (Series): Series with gene lengths, index matches counts_df.index. - scale_library (int or float): Scaling factor for library size normalization (default 1e6). - scale_length (int or float): Scaling factor for gene length scaling (default 1e3). - - Returns: - DataFrame: TPM-normalized values. - """ - # Ensure counts_df and gene_lengths are aligned +def prepare_counts_and_lengths(counts_df, gene_lengths, scale_length=None): + """Align counts and gene_lengths, drop zeros, and optionally scale gene lengths.""" counts_df = counts_df.loc[gene_lengths.index] - - # Scale gene lengths - gene_lengths_scaled = gene_lengths / scale_length - - # Calculate Reads Per Scaled Kilobase (RPK) - rpk = counts_df.div(gene_lengths_scaled, axis=0) - - # Calculate scaling factors - scaling_factors = rpk.sum(axis=0) / scale_library - - # Calculate TPM - tpm = rpk.div(scaling_factors, axis=1) - + valid_lengths = gene_lengths.replace(0, pd.NA).dropna() + counts_df = counts_df.loc[valid_lengths.index] + gene_lengths = valid_lengths + if scale_length is not None: + gene_lengths = gene_lengths / scale_length + return filter_counts(counts_df), gene_lengths + +def parallel_apply(columns, func, n_jobs=-1): + """Apply a function to each column in parallel and combine results.""" + results = Parallel(n_jobs=n_jobs)(delayed(func)(col) for col in columns) + return pd.concat(results, axis=1) + +def trim_values(log_ratio, log_mean, w, logratio_trim, sum_trim): + """Perform log ratio and sum trimming.""" + n = len(log_ratio) + loL = int(np.floor(n * logratio_trim / 2)) + hiL = n - loL + lr_order = np.argsort(log_ratio) + trimmed_idx = lr_order[loL:hiL] + + lr_t = log_ratio[trimmed_idx] + w_t = w[trimmed_idx] + mean_t = log_mean[trimmed_idx] + + n_t = len(mean_t) + loS = int(np.floor(n_t * sum_trim / 2)) + hiS = n_t - loS + mean_order = np.argsort(mean_t) + final_idx = mean_order[loS:hiS] + + return lr_t[final_idx], w_t[final_idx] + +def compute_TMM_normalization_factors(counts_df, logratio_trim=0.3, sum_trim=0.05, weighting=True, n_jobs=-1): + """Compute TMM normalization factors for counts data.""" + lib_sizes = counts_df.sum(axis=0) + median_lib = lib_sizes.median() + ref_sample = (lib_sizes - median_lib).abs().idxmin() + + ref_counts = counts_df[ref_sample].values + sample_names = counts_df.columns + data_values = counts_df.values + + norm_factors = pd.Series(index=sample_names, dtype='float64') + norm_factors[ref_sample] = 1.0 + + def compute_norm_factor(sample): + if sample == ref_sample: + return sample, 1.0 + + i = sample_names.get_loc(sample) + data_i = data_values[:, i] + + mask = (data_i > 0) & (ref_counts > 0) + data_i_masked = data_i[mask] + data_r_masked = ref_counts[mask] + + N_i = data_i_masked.sum() + N_r = data_r_masked.sum() + + data_i_norm = data_i_masked / N_i + data_r_norm = data_r_masked / N_r + + log_ratio = np.log2(data_i_norm) - np.log2(data_r_norm) + log_mean = 0.5 * (np.log2(data_i_norm) + np.log2(data_r_norm)) + + w = 1.0 / (data_i_norm + data_r_norm) if weighting else np.ones_like(log_ratio) + + lr_final, w_final = trim_values(log_ratio, log_mean, w, logratio_trim, sum_trim) + + mean_M = np.sum(w_final * lr_final) / np.sum(w_final) + norm_factor = 2 ** mean_M + return sample, norm_factor + + samples = [s for s in sample_names if s != ref_sample] + results = Parallel(n_jobs=n_jobs)(delayed(compute_norm_factor)(s) for s in samples) + + for sample, nf in results: + norm_factors[sample] = nf + + norm_factors = norm_factors / np.exp(np.mean(np.log(norm_factors))) + return norm_factors + +def tmm_normalization(counts_df, logratio_trim=0.3, sum_trim=0.05, weighting=True, n_jobs=-1): + """Perform TMM normalization on counts data.""" + counts_df = filter_counts(counts_df) + norm_factors = compute_TMM_normalization_factors(counts_df, logratio_trim, sum_trim, weighting, n_jobs) + lib_sizes = counts_df.sum(axis=0) + normalized_data = counts_df.div(lib_sizes, axis=1).div(norm_factors, axis=1) * lib_sizes.mean() + return normalized_data + +def getmm_normalization(counts_df, gene_lengths, logratio_trim=0.3, sum_trim=0.05, weighting=True, n_jobs=-1): + """Perform GeTMM normalization on counts data.""" + counts_df, gene_lengths = prepare_counts_and_lengths(counts_df, gene_lengths) + rpk = counts_df.mul(1e3).div(gene_lengths, axis=0) + return tmm_normalization(rpk, logratio_trim, sum_trim, weighting, n_jobs) + +def compute_rpk(counts_df, gene_lengths_scaled, n_jobs=-1): + """Compute RPK values in parallel.""" + columns = counts_df.columns + def rpk_col(col): + return counts_df[col] / gene_lengths_scaled + rpk = parallel_apply(columns, rpk_col, n_jobs) + rpk.columns = columns + return rpk + +def tpm_normalization(counts_df, gene_lengths, scale_library=1e6, scale_length=1e3, n_jobs=-1): + """Convert raw read counts to TPM in parallel.""" + counts_df, gene_lengths_scaled = prepare_counts_and_lengths(counts_df, gene_lengths, scale_length=scale_length) + rpk = compute_rpk(counts_df, gene_lengths_scaled, n_jobs) + scaling_factors = rpk.sum(axis=0).replace(0, pd.NA) + scaling_factors_norm = scaling_factors / scale_library + + def tpm_col(col): + return rpk[col] / scaling_factors_norm[col] + + tpm = parallel_apply(rpk.columns, tpm_col, n_jobs) + tpm.columns = rpk.columns return tpm - - -def tmm_normalization(counts_df): - """ - Perform TMM normalization on counts data. - - Parameters: - counts_df (DataFrame): DataFrame with genes as rows and samples as columns. - - Returns: - DataFrame: TMM-normalized values. - """ - try: - import conorm - except ImportError: - raise ImportError("The 'conorm' package is required for this function but is not installed.") - normalized_array = conorm.tmm(counts_df) - normalized_df = pd.DataFrame(normalized_array, columns=counts_df.columns, index=counts_df.index) - return normalized_df - - -def getmm_normalization(counts_df, gene_lengths): - """ - Perform GeTMM normalization on counts data. - - Parameters: - counts_df (DataFrame): DataFrame with genes as rows and samples as columns. - gene_lengths (Series): Series with gene lengths, index matches counts_df.index. - - Returns: - DataFrame: GeTMM-normalized values. - """ - try: - import conorm - except ImportError: - raise ImportError("The 'conorm' package is required for this function but is not installed.") - - normalized_array = conorm.getmm(counts_df, gene_lengths) - normalized_df = pd.DataFrame(normalized_array, columns=counts_df.columns, index=counts_df.index) - return normalized_df From 8c8fd2c8e7840358604eaefbdc3ac294b9b9f2c8 Mon Sep 17 00:00:00 2001 From: Julian Date: Thu, 12 Dec 2024 13:31:24 -0500 Subject: [PATCH 5/7] refactor normalization process --- .../routers/normalization.py | 47 ++++++++++++------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/transcriptomics_data_service/routers/normalization.py b/transcriptomics_data_service/routers/normalization.py index a507208..d845ce7 100644 --- a/transcriptomics_data_service/routers/normalization.py +++ b/transcriptomics_data_service/routers/normalization.py @@ -5,7 +5,7 @@ from transcriptomics_data_service.db import DatabaseDependency from transcriptomics_data_service.models import GeneExpression from transcriptomics_data_service.scripts.normalize import ( - read_counts2tpm, + tpm_normalization, tmm_normalization, getmm_normalization, ) @@ -36,35 +36,38 @@ async def normalize( """ Normalize gene expressions using the specified method for a given experiment_result_id. """ - # method validation - if method not in VALID_METHODS: + # Method validation + if method.lower() not in VALID_METHODS: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Unsupported normalization method: {method}" ) - # load gene lengths - if method in [NORM_TPM, NORM_GETMM]: + # Load gene lengths if required + if method.lower() in [NORM_TPM, NORM_GETMM]: if gene_lengths_file is None: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Gene lengths file is required for {method.upper()} normalization.", ) gene_lengths = await _load_gene_lengths(gene_lengths_file) + else: + gene_lengths = None + # Fetch raw counts from the database raw_counts_df = await _fetch_raw_counts(db, experiment_result_id) - # normalization - if method == NORM_TPM: + # Perform normalization + if method.lower() == NORM_TPM: raw_counts_df, gene_lengths_series = _align_gene_lengths(raw_counts_df, gene_lengths) - normalized_df = read_counts2tpm(raw_counts_df, gene_lengths_series) - elif method == NORM_TMM: + normalized_df = tpm_normalization(raw_counts_df, gene_lengths_series) + elif method.lower() == NORM_TMM: normalized_df = tmm_normalization(raw_counts_df) - elif method == NORM_GETMM: + elif method.lower() == NORM_GETMM: raw_counts_df, gene_lengths_series = _align_gene_lengths(raw_counts_df, gene_lengths) normalized_df = getmm_normalization(raw_counts_df, gene_lengths_series) - # database update using normalized values - await _update_normalized_values(db, normalized_df, experiment_result_id, method=method) + # Update database with normalized values + await _update_normalized_values(db, normalized_df, experiment_result_id, method=method.lower()) return {"message": f"{method.upper()} normalization completed successfully"} @@ -74,8 +77,13 @@ async def _load_gene_lengths(gene_lengths_file: UploadFile) -> pd.Series: Load gene lengths from the uploaded file. """ content = await gene_lengths_file.read() - gene_lengths_df = pd.read_csv(StringIO(content.decode("utf-8")), index_col="GeneID") - gene_lengths_series = gene_lengths_df["GeneLength"] + gene_lengths_df = pd.read_csv(StringIO(content.decode("utf-8")), index_col=0) + if gene_lengths_df.shape[1] != 1: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Gene lengths file should contain exactly one column of gene lengths.", + ) + gene_lengths_series = gene_lengths_df.iloc[:, 0] gene_lengths_series = gene_lengths_series.apply(pd.to_numeric, errors="raise") return gene_lengths_series @@ -85,7 +93,9 @@ async def _fetch_raw_counts(db, experiment_result_id: str) -> pd.DataFrame: Fetch raw counts from the database for the given experiment_result_id. Returns a DataFrame with genes as rows and samples as columns. """ - expressions = await db.fetch_gene_expressions_by_experiment_id(experiment_result_id) + expressions, _ = await db.fetch_gene_expressions( + experiments=[experiment_result_id], method="raw", paginate=False + ) if not expressions: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Experiment result not found.") @@ -116,10 +126,12 @@ def _align_gene_lengths(raw_counts_df: pd.DataFrame, gene_lengths: pd.Series): async def _update_normalized_values(db, normalized_df: pd.DataFrame, experiment_result_id: str, method: str): """ - Update the normalized values in the database + Update the normalized values in the database. """ # Fetch existing expressions to get raw_count values - existing_expressions = await db.fetch_gene_expressions_by_experiment_id(experiment_result_id) + existing_expressions, _ = await db.fetch_gene_expressions( + experiments=[experiment_result_id], method="raw", paginate=False + ) raw_count_dict = {(expr.gene_code, expr.sample_id): expr.raw_count for expr in existing_expressions} normalized_df = normalized_df.reset_index().melt( @@ -138,6 +150,7 @@ async def _update_normalized_values(db, normalized_df: pd.DataFrame, experiment_ detail=f"Raw count not found for gene {gene_code}, sample {sample_id}", ) + # Create a GeneExpression object with the normalized value gene_expression = GeneExpression( gene_code=gene_code, sample_id=sample_id, From 577f5ac12ed0d3ca4631301cb40697f320988845 Mon Sep 17 00:00:00 2001 From: Julian Date: Fri, 13 Dec 2024 14:23:07 -0500 Subject: [PATCH 6/7] lint --- transcriptomics_data_service/db.py | 4 +++- .../routers/normalization.py | 4 +--- transcriptomics_data_service/scripts/normalize.py | 15 +++++++++++++-- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/transcriptomics_data_service/db.py b/transcriptomics_data_service/db.py index 636677a..92f793f 100644 --- a/transcriptomics_data_service/db.py +++ b/transcriptomics_data_service/db.py @@ -135,7 +135,9 @@ async def fetch_gene_expressions_by_experiment_id(self, experiment_result_id: st res = await conn.fetch(query, experiment_result_id) return tuple([self._deserialize_gene_expression(record) for record in res]) - async def fetch_gene_expressions(self, experiments: list[str], method: str = "raw", paginate: bool = False) -> Tuple[Tuple[GeneExpression, ...], int]: + async def fetch_gene_expressions( + self, experiments: list[str], method: str = "raw", paginate: bool = False + ) -> Tuple[Tuple[GeneExpression, ...], int]: if not experiments: return (), 0 # TODO: refactor this fetch_gene_expressions_by_experiment_id and implement pagination diff --git a/transcriptomics_data_service/routers/normalization.py b/transcriptomics_data_service/routers/normalization.py index d845ce7..439ed3a 100644 --- a/transcriptomics_data_service/routers/normalization.py +++ b/transcriptomics_data_service/routers/normalization.py @@ -93,9 +93,7 @@ async def _fetch_raw_counts(db, experiment_result_id: str) -> pd.DataFrame: Fetch raw counts from the database for the given experiment_result_id. Returns a DataFrame with genes as rows and samples as columns. """ - expressions, _ = await db.fetch_gene_expressions( - experiments=[experiment_result_id], method="raw", paginate=False - ) + expressions, _ = await db.fetch_gene_expressions(experiments=[experiment_result_id], method="raw", paginate=False) if not expressions: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Experiment result not found.") diff --git a/transcriptomics_data_service/scripts/normalize.py b/transcriptomics_data_service/scripts/normalize.py index d211dbf..55f9799 100644 --- a/transcriptomics_data_service/scripts/normalize.py +++ b/transcriptomics_data_service/scripts/normalize.py @@ -2,12 +2,14 @@ import numpy as np from joblib import Parallel, delayed + def filter_counts(counts_df): """Filter out genes (rows) and samples (columns) with zero total counts.""" row_filter = counts_df.sum(axis=1) > 0 col_filter = counts_df.sum(axis=0) > 0 return counts_df.loc[row_filter, col_filter] + def prepare_counts_and_lengths(counts_df, gene_lengths, scale_length=None): """Align counts and gene_lengths, drop zeros, and optionally scale gene lengths.""" counts_df = counts_df.loc[gene_lengths.index] @@ -18,11 +20,13 @@ def prepare_counts_and_lengths(counts_df, gene_lengths, scale_length=None): gene_lengths = gene_lengths / scale_length return filter_counts(counts_df), gene_lengths + def parallel_apply(columns, func, n_jobs=-1): """Apply a function to each column in parallel and combine results.""" results = Parallel(n_jobs=n_jobs)(delayed(func)(col) for col in columns) return pd.concat(results, axis=1) + def trim_values(log_ratio, log_mean, w, logratio_trim, sum_trim): """Perform log ratio and sum trimming.""" n = len(log_ratio) @@ -43,6 +47,7 @@ def trim_values(log_ratio, log_mean, w, logratio_trim, sum_trim): return lr_t[final_idx], w_t[final_idx] + def compute_TMM_normalization_factors(counts_df, logratio_trim=0.3, sum_trim=0.05, weighting=True, n_jobs=-1): """Compute TMM normalization factors for counts data.""" lib_sizes = counts_df.sum(axis=0) @@ -53,7 +58,7 @@ def compute_TMM_normalization_factors(counts_df, logratio_trim=0.3, sum_trim=0.0 sample_names = counts_df.columns data_values = counts_df.values - norm_factors = pd.Series(index=sample_names, dtype='float64') + norm_factors = pd.Series(index=sample_names, dtype="float64") norm_factors[ref_sample] = 1.0 def compute_norm_factor(sample): @@ -81,7 +86,7 @@ def compute_norm_factor(sample): lr_final, w_final = trim_values(log_ratio, log_mean, w, logratio_trim, sum_trim) mean_M = np.sum(w_final * lr_final) / np.sum(w_final) - norm_factor = 2 ** mean_M + norm_factor = 2**mean_M return sample, norm_factor samples = [s for s in sample_names if s != ref_sample] @@ -93,6 +98,7 @@ def compute_norm_factor(sample): norm_factors = norm_factors / np.exp(np.mean(np.log(norm_factors))) return norm_factors + def tmm_normalization(counts_df, logratio_trim=0.3, sum_trim=0.05, weighting=True, n_jobs=-1): """Perform TMM normalization on counts data.""" counts_df = filter_counts(counts_df) @@ -101,21 +107,26 @@ def tmm_normalization(counts_df, logratio_trim=0.3, sum_trim=0.05, weighting=Tru normalized_data = counts_df.div(lib_sizes, axis=1).div(norm_factors, axis=1) * lib_sizes.mean() return normalized_data + def getmm_normalization(counts_df, gene_lengths, logratio_trim=0.3, sum_trim=0.05, weighting=True, n_jobs=-1): """Perform GeTMM normalization on counts data.""" counts_df, gene_lengths = prepare_counts_and_lengths(counts_df, gene_lengths) rpk = counts_df.mul(1e3).div(gene_lengths, axis=0) return tmm_normalization(rpk, logratio_trim, sum_trim, weighting, n_jobs) + def compute_rpk(counts_df, gene_lengths_scaled, n_jobs=-1): """Compute RPK values in parallel.""" columns = counts_df.columns + def rpk_col(col): return counts_df[col] / gene_lengths_scaled + rpk = parallel_apply(columns, rpk_col, n_jobs) rpk.columns = columns return rpk + def tpm_normalization(counts_df, gene_lengths, scale_library=1e6, scale_length=1e3, n_jobs=-1): """Convert raw read counts to TPM in parallel.""" counts_df, gene_lengths_scaled = prepare_counts_and_lengths(counts_df, gene_lengths, scale_length=scale_length) From 293d1312bf64dde8ee1859c043eea5109e138354 Mon Sep 17 00:00:00 2001 From: Julian Date: Wed, 18 Dec 2024 14:41:04 -0500 Subject: [PATCH 7/7] add scaling_factor as default argument in getmm --- transcriptomics_data_service/scripts/normalize.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transcriptomics_data_service/scripts/normalize.py b/transcriptomics_data_service/scripts/normalize.py index 55f9799..9bcd402 100644 --- a/transcriptomics_data_service/scripts/normalize.py +++ b/transcriptomics_data_service/scripts/normalize.py @@ -108,10 +108,10 @@ def tmm_normalization(counts_df, logratio_trim=0.3, sum_trim=0.05, weighting=Tru return normalized_data -def getmm_normalization(counts_df, gene_lengths, logratio_trim=0.3, sum_trim=0.05, weighting=True, n_jobs=-1): +def getmm_normalization(counts_df, gene_lengths, logratio_trim=0.3, sum_trim=0.05, scaling_factor=1e3, weighting=True, n_jobs=-1): """Perform GeTMM normalization on counts data.""" counts_df, gene_lengths = prepare_counts_and_lengths(counts_df, gene_lengths) - rpk = counts_df.mul(1e3).div(gene_lengths, axis=0) + rpk = counts_df.mul(scaling_factor).div(gene_lengths, axis=0) return tmm_normalization(rpk, logratio_trim, sum_trim, weighting, n_jobs)