Skip to content

Commit

Permalink
Add benchmark scripts and ignore types where pandas DataFrames cause …
Browse files Browse the repository at this point in the history
…mypy problems
  • Loading branch information
maltekuehl committed Sep 16, 2024
1 parent 75d7d25 commit 4df75a1
Show file tree
Hide file tree
Showing 14 changed files with 166 additions and 29 deletions.
2 changes: 1 addition & 1 deletion .python-version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.9
3.11
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ dev = [
"mypy",
"myst-parser",
"nbsphinx",
"pandas-stubs",
"pandoc",
"pre-commit",
"pybiomart",
Expand Down
2 changes: 1 addition & 1 deletion pytximport/core/_tximport.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def tximport(

result = convert_transcripts_to_genes(
transcript_data,
transcript_gene_map,
transcript_gene_map, # type: ignore
ignore_after_bar=ignore_after_bar,
ignore_transcript_version=ignore_transcript_version,
counts_from_abundance=counts_from_abundance,
Expand Down
12 changes: 6 additions & 6 deletions pytximport/importers/_read_tsv.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,20 @@ def parse_dataframe(
if abundance_column is None:
warning("Abundance column not provided, calculating TPM.", UserWarning)
abundance = convert_counts_to_tpm(
counts=transcript_dataframe[counts_column].astype("float64").values,
length=transcript_dataframe[length_column].astype("float64").values,
counts=transcript_dataframe[counts_column].astype("float64").values, # type: ignore
length=transcript_dataframe[length_column].astype("float64").values, # type: ignore
)
else:
assert (
abundance_column in transcript_dataframe.columns
), f"Could not find the abundance column `{abundance_column}`."
abundance = transcript_dataframe[abundance_column].astype("float64").values
abundance = transcript_dataframe[abundance_column].astype("float64").values # type: ignore

# Create a DataFrame with the transcript-level expression
transcripts = TranscriptData(
transcript_id=transcript_dataframe[id_column].values,
counts=transcript_dataframe[counts_column].astype("float64").values,
length=transcript_dataframe[length_column].astype("float64").values,
transcript_id=transcript_dataframe[id_column].values, # type: ignore
counts=transcript_dataframe[counts_column].astype("float64").values, # type: ignore
length=transcript_dataframe[length_column].astype("float64").values, # type: ignore
abundance=abundance,
inferential_replicates=None,
)
Expand Down
4 changes: 2 additions & 2 deletions pytximport/utils/_convert_transcripts_to_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def convert_transcripts_to_genes(

if ignore_transcript_version:
# Ignore the transcript version in both the data and the transcript gene map
transcript_data, transcript_gene_map, transcript_ids = remove_transcript_version(
transcript_data, transcript_gene_map, transcript_ids = remove_transcript_version( # type: ignore
transcript_data,
transcript_gene_map,
transcript_ids, # type: ignore
Expand Down Expand Up @@ -84,7 +84,7 @@ def convert_transcripts_to_genes(
log(25, "Matching gene_ids.")
transcript_gene_dict = transcript_gene_map.set_index("transcript_id")["gene_id"].to_dict()
gene_ids_raw = transcript_data["transcript_id"].to_series().map(transcript_gene_dict).values
gene_ids = np.repeat(gene_ids_raw, transcript_data["abundance"].shape[1])
gene_ids = np.repeat(gene_ids_raw, transcript_data["abundance"].shape[1]) # type: ignore

# Remove the transcript_id coordinate
transcript_data = transcript_data.drop_vars("transcript_id")
Expand Down
12 changes: 7 additions & 5 deletions pytximport/utils/_create_transcript_gene_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,13 @@ def create_transcript_gene_map(
elif species == "mouse":
dataset = Dataset(name="mmusculus_gene_ensembl", host=host)

transcript_gene_map = dataset.query(attributes=[source_field, target_field])
transcript_gene_map.columns = [
"transcript_id",
("gene_id" if target_field != "external_transcript_name" else "transcript_name"),
]
transcript_gene_map: pd.DataFrame = dataset.query(attributes=[source_field, target_field])
transcript_gene_map.columns = pd.Index(
[
"transcript_id",
("gene_id" if target_field != "external_transcript_name" else "transcript_name"),
]
)

transcript_gene_map.dropna(inplace=True)
transcript_gene_map.drop_duplicates(inplace=True)
Expand Down
4 changes: 2 additions & 2 deletions pytximport/utils/_remove_transcript_version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union

import pandas as pd
import xarray as xr
Expand All @@ -9,7 +9,7 @@ def remove_transcript_version(
transcript_target_map: Optional[pd.DataFrame] = None,
transcript_ids: Optional[List[str]] = None,
id_column: str = "transcript_id",
) -> Tuple[xr.Dataset, pd.DataFrame, List[str]]:
) -> Tuple[xr.Dataset, Union[pd.DataFrame, None], List[str]]:
"""Remove the transcript version from the transcript data and the transcript target map.
Args:
Expand Down
22 changes: 11 additions & 11 deletions pytximport/utils/_replace_transcript_ids_with_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@

def replace_transcript_ids_with_names(
transcript_data: Union[ad.AnnData, xr.Dataset],
transcript_name_map: Optional[Union[pd.DataFrame, Union[str, Path]]] = None,
transcript_name_map: Union[pd.DataFrame, Union[str, Path]],
) -> Union[ad.AnnData, xr.Dataset]:
"""Replace transcript IDs with transcript names.
Args:
transcript_data (Union[ad.AnnData, xr.Dataset]): The transcript-level expression data.
transcript_name_map (Optional[Union[pd.DataFrame, Union[str, Path]]], optional): The mapping from transcripts to
names. Contains two columns: `transcript_id` and `transcript_name`. Defaults to None.
transcript_name_map (Union[pd.DataFrame, Union[str, Path]]): The mapping from transcripts to
names. Contains two columns: `transcript_id` and `transcript_name`.
Returns:
Union[ad.AnnData, xr.Dataset]: The transcript-level expression data with the transcript names.
Expand All @@ -27,14 +27,11 @@ def replace_transcript_ids_with_names(
transcript_name_map = pd.read_table(transcript_name_map, header=0)
transcript_name_map = transcript_name_map.drop_duplicates()

# Assert that transcript_id and transcript_name are present in the mapping
if transcript_name_map is not None:
assert "transcript_id" in transcript_name_map.columns, "The mapping does not contain a `transcript_id` column."
assert (
"transcript_name" in transcript_name_map.columns
), "The mapping does not contain a `transcript_name` column."
# Check that transcript_id and transcript_name are present in the mapping
assert "transcript_id" in transcript_name_map.columns, "The mapping does not contain a `transcript_id` column."
assert "transcript_name" in transcript_name_map.columns, "The mapping does not contain a `transcript_name` column."

# Check whether the transcript_data is an AnnData object and convert it to a DataFrame
# Check whether the transcript_data is an AnnData object and convert it to an xr.Dataset
return_as_anndata = False
if isinstance(transcript_data, ad.AnnData):
return_as_anndata = True
Expand All @@ -51,7 +48,10 @@ def replace_transcript_ids_with_names(
)

# Remove the transcript version
transcript_data, transcript_name_map, _ = remove_transcript_version(transcript_data, transcript_name_map)
transcript_data, transcript_name_map, _ = remove_transcript_version( # type: ignore
transcript_data,
transcript_name_map,
)

transcript_name_dict = transcript_name_map.set_index("transcript_id")["transcript_name"].to_dict()
transcript_names = transcript_data["transcript_id"].to_series().map(transcript_name_dict).values
Expand Down
2 changes: 2 additions & 0 deletions test/benchmark/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*.json
*.csv
75 changes: 75 additions & 0 deletions test/benchmark/bench.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
options(repos = c(CRAN = "https://cloud.r-project.org"))

if (!requireNamespace("tximport", quietly = TRUE)) {
install.packages("BiocManager")
BiocManager::install("tximport")
}

if (!requireNamespace("bench", quietly = TRUE)) {
install.packages("bench")
}

if (!requireNamespace("readr", quietly = TRUE)) {
install.packages("readr")
}

if (!requireNamespace("matrixStats", quietly = TRUE)) {
install.packages("matrixStats")
}

# Load libraries
library(tximport)
library(bench)
library(readr)
library(matrixStats)

# Define the file paths and transcript-gene mapping
transcript_gene_mapping <- read_tsv("../data/fabry_disease/transcript_gene_mapping_human.tsv")

files <- c(
"../data/fabry_disease/SRR16504309_wt/quant.sf",
"../data/fabry_disease/SRR16504310_wt/quant.sf",
"../data/fabry_disease/SRR16504311_ko/quant.sf",
"../data/fabry_disease/SRR16504312_ko/quant.sf"
)

# Create a function to benchmark
tximport_benchmark <- function() {
txi <- tximport(
files,
type = "salmon",
tx2gene = transcript_gene_mapping,
ignoreTxVersion = TRUE,
ignoreAfterBar = TRUE,
dropInfReps = FALSE,
countsFromAbundance = "lengthScaledTPM",
infRepStat = rowMedians
)
return(txi)
}

# Run the benchmark
benchmark_results <- bench::mark(
tximport_benchmark(),
filter_gc = FALSE,
iterations = 11
)

benchmark_results_df <- as.data.frame(
lapply(benchmark_results, as.character),
stringsAsFactors = FALSE
)

benchmark_results_df <- benchmark_results_df[, !grepl(
"result",
names(benchmark_results_df)
)]

benchmark_results_df <- benchmark_results_df[, !grepl(
"memory",
names(benchmark_results_df)
)]

print(benchmark_results)

write.csv(as.data.frame(benchmark_results_df), "tximport_time_memory.csv")
50 changes: 50 additions & 0 deletions test/benchmark/bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Benchmark pytximport."""

from logging import INFO, getLogger

import numpy as np
import pandas as pd
import pyperf

from pytximport import tximport

# Load transcript-gene mapping (this can be done once outside the benchmark)
transcript_gene_mapping_human = pd.read_table(
"../data/fabry_disease/transcript_gene_mapping_human.tsv",
header=0,
sep="\t",
)

# Define the files list
files = [
"../data/fabry_disease/SRR16504309_wt/quant.sf",
"../data/fabry_disease/SRR16504310_wt/quant.sf",
"../data/fabry_disease/SRR16504311_ko/quant.sf",
"../data/fabry_disease/SRR16504312_ko/quant.sf",
]


# Function to benchmark
def tximport_benchmark():
"""Benchmark pytximport.
Returns:
ad.AnnData: The AnnData object.
"""
txi = tximport(
files,
"salmon",
transcript_gene_mapping_human,
inferential_replicates=True,
inferential_replicate_transformer=lambda x: np.median(x, axis=1),
)
return txi


# Run the pyperf benchmark
if __name__ == "__main__":
# Set log level to 25
getLogger().setLevel(25)

runner = pyperf.Runner()
runner.bench_func("pytximport", tximport_benchmark)
5 changes: 5 additions & 0 deletions test/benchmark/pytximport.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
rm *.json
python3 bench.py -o ./pytximport_time.json -p 1 -n 11 -w 0
python3 -m pyperf stats ./pytximport_time.json
python3 bench.py -o ./pytximport_memory.json -p 1 -n 11 -w 0 --track-memory
python3 -m pyperf stats ./pytximport_memory.json
2 changes: 2 additions & 0 deletions test/benchmark/tximport.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
rm *.csv
Rscript bench.R
2 changes: 1 addition & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def transcript_name_mapping_human() -> pd.DataFrame:

@pytest.fixture(scope="session")
def transcript_name_mapping_human_path() -> Path:
"""Provides the path to the transcript id to transcript name mapping for human samples."""
"""Provide the path to the transcript id to transcript name mapping for human samples."""
return Path(FILE_DIR) / "transcript_name_mapping_human.tsv"


Expand Down

0 comments on commit 4df75a1

Please sign in to comment.