Skip to content

Commit

Permalink
Performance improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
maltekuehl committed Sep 20, 2024
1 parent 365dcf9 commit e67ae65
Show file tree
Hide file tree
Showing 12 changed files with 102 additions and 227 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dependencies = [
"h5py>=3.0.0,<4",
"numpy>=1.19.0,<3",
"pandas>=2.2.0,<3",
"pyarrow>=15.0.0",
"pybiomart==0.2.0",
"tqdm>=4.0.0,<5",
"xarray>=2024.0.0",
Expand Down
16 changes: 10 additions & 6 deletions pytximport/core/_tximport.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,9 @@ def tximport(

try:
if transcript_gene_map.suffix == ".csv":
transcript_gene_map = pd.read_csv(transcript_gene_map, header=0)
transcript_gene_map = pd.read_csv(transcript_gene_map, header=0, engine="pyarrow")
else:
transcript_gene_map = pd.read_table(transcript_gene_map, header=0)
transcript_gene_map = pd.read_table(transcript_gene_map, header=0, engine="pyarrow")
except Exception as exception:
raise ValueError(f"Could not read the transcript to gene mapping: {exception}")

Expand Down Expand Up @@ -271,6 +271,9 @@ def tximport(
# TODO: this may break in newer versions of sailfish (>0.10.1), when no explicit auxDir is provided
importer_kwargs["aux_dir_name"] = "aux"

# Pass whether to load inferential replicates to the importer
importer_kwargs["inferential_replicates"] = inferential_replicates

elif inferential_replicates:
warning("Inferential replicates are not supported for this data type.")
inferential_replicates = False
Expand Down Expand Up @@ -394,9 +397,11 @@ def tximport(
# Remove appended gene names after underscore for RSEM data for both transcript and gene ids
if (
data_type == "rsem"
and (gene_level and transcript_data.coords["gene_id"].values[0].count("_") > 0)
or (not gene_level and transcript_data.coords["transcript_id"].values[0].count("_") > 0)
and ignore_after_bar
and (
(gene_level and transcript_data.coords["gene_id"].values[0].count("_") > 0)
or (not gene_level and transcript_data.coords["transcript_id"].values[0].count("_") > 0)
)
):
warning(
(
Expand Down Expand Up @@ -570,8 +575,7 @@ def tximport(
df_gene_data.to_csv(output_path, index=True, header=True, quoting=2)

# End the timer
end_time = time()
log(25, f"Finished the import in {end_time - start_time:.2f} seconds.")
log(25, f"Finished the import in {time() - start_time:.2f} seconds.")

if return_data:
return result
Expand Down
80 changes: 36 additions & 44 deletions pytximport/importers/_read_kallisto.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from ..definitions import InferentialReplicates, TranscriptData
from ..utils._convert_counts_to_tpm import convert_counts_to_tpm
from ._read_tsv import read_tsv


def read_inferential_replicates_kallisto(
Expand Down Expand Up @@ -57,6 +58,7 @@ def read_kallisto(
counts_column: str = "est_counts",
length_column: str = "aux/eff_lengths",
abundance_column: Optional[str] = None,
inferential_replicates: bool = False,
) -> TranscriptData:
"""Read a kallisto quantification file.
Expand Down Expand Up @@ -90,49 +92,39 @@ def read_kallisto(
if abundance_column is not None:
abundance = f.file[abundance_column][:]

# Check that the length of the counts, length, and abundances are the same
assert (
len(transcript_ids) == len(counts) == len(length)
), "The transcript ids, counts and length have different length."

# Calculate the transcript-level TPM if the abundance was not included
if abundance_column is None:
warning("Abundance column not provided, calculating TPM.")
abundance = convert_counts_to_tpm(counts, length)
else:
assert len(transcript_ids) == len(abundance), "The transcript ids and abundance have different length."

# Create a DataFrame with the transcript-level expression
transcripts = TranscriptData(
transcript_id=transcript_ids,
counts=counts,
length=length,
abundance=abundance,
inferential_replicates=None,
)

if inferential_replicates:
transcripts["inferential_replicates"] = read_inferential_replicates_kallisto(
file_path,
)

elif file_path.suffix == ".tsv":
# Read the quantification file as a tsv, tab separated and the first line is the column names
transcript_data = pd.read_table(file_path, header=0)

# Check that the columns are in the table
assert id_column in transcript_data.columns, f"Could not find the transcript id column `{id_column}`."
assert counts_column in transcript_data.columns, f"Could not find the counts column `{counts_column}`."
assert length_column in transcript_data.columns, f"Could not find the length column `{length_column}`."

transcript_ids = transcript_data[id_column].values
counts = transcript_data[counts_column].astype("float64").values
length = transcript_data[length_column].astype("float64").values

if abundance_column is not None:
assert (
abundance_column in transcript_data.columns
), f"Could not find the abundance column `{abundance_column}`."
abundance = transcript_data[abundance_column].astype("float64").values

# Check that the length of the counts, length, and abundances are the same
assert (
len(transcript_ids) == len(counts) == len(length)
), "The transcript ids, counts and length have different length."

# Calculate the transcript-level TPM if the abundance was not included
if abundance_column is None:
warning("Abundance column not provided, calculating TPM.")
abundance = convert_counts_to_tpm(counts, length)
else:
assert len(transcript_ids) == len(abundance), "The transcript ids and abundance have different length."

# Create a DataFrame with the transcript-level expression
transcripts = TranscriptData(
transcript_id=transcript_ids,
counts=counts,
length=length,
abundance=abundance,
inferential_replicates=None,
)

transcripts["inferential_replicates"] = read_inferential_replicates_kallisto(
file_path,
)

# Return the transcript-level expression
transcripts = read_tsv(
file_path,
id_column=id_column,
counts_column=counts_column,
length_column=length_column,
abundance_column=abundance_column,
)

return transcripts
10 changes: 6 additions & 4 deletions pytximport/importers/_read_salmon.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def read_salmon(
length_column: str = "EffectiveLength",
abundance_column: str = "TPM",
aux_dir_name: Literal["aux_info", "aux"] = "aux_info",
inferential_replicates: bool = False,
) -> TranscriptData:
"""Read a salmon quantification file.
Expand Down Expand Up @@ -129,9 +130,10 @@ def read_salmon(
abundance_column=abundance_column,
)

transcript_data["inferential_replicates"] = read_inferential_replicates_salmon(
file_path,
aux_dir_name=aux_dir_name,
)
if inferential_replicates:
transcript_data["inferential_replicates"] = read_inferential_replicates_salmon(
file_path,
aux_dir_name=aux_dir_name,
)

return transcript_data
27 changes: 21 additions & 6 deletions pytximport/importers/_read_tsv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path
from typing import Optional, Union

import numpy as np
import pandas as pd

from ..definitions import TranscriptData
Expand Down Expand Up @@ -36,20 +37,20 @@ def parse_dataframe(
if abundance_column is None:
warning("Abundance column not provided, calculating TPM.", UserWarning)
abundance = convert_counts_to_tpm(
counts=transcript_dataframe[counts_column].astype("float64").values, # type: ignore
length=transcript_dataframe[length_column].astype("float64").values, # type: ignore
counts=transcript_dataframe[counts_column].values, # type: ignore
length=transcript_dataframe[length_column].values, # type: ignore
)
else:
assert (
abundance_column in transcript_dataframe.columns
), f"Could not find the abundance column `{abundance_column}`."
abundance = transcript_dataframe[abundance_column].astype("float64").values # type: ignore
abundance = transcript_dataframe[abundance_column].values # type: ignore

# Create a DataFrame with the transcript-level expression
transcripts = TranscriptData(
transcript_id=transcript_dataframe[id_column].values, # type: ignore
counts=transcript_dataframe[counts_column].astype("float64").values, # type: ignore
length=transcript_dataframe[length_column].astype("float64").values, # type: ignore
counts=transcript_dataframe[counts_column].values, # type: ignore
length=transcript_dataframe[length_column].values, # type: ignore
abundance=abundance,
inferential_replicates=None,
)
Expand Down Expand Up @@ -87,7 +88,21 @@ def read_tsv(
if file_path.suffix == ".gz":
transcript_dataframe = pd.read_table(file_path, header=0, compression="gzip", sep="\t")
else:
transcript_dataframe = pd.read_table(file_path, header=0, sep="\t")
usecols = [id_column, counts_column, length_column]
dtype = {id_column: str, counts_column: np.float64, length_column: np.float64}

if abundance_column is not None:
usecols.append(abundance_column)
dtype[abundance_column] = np.float64

transcript_dataframe = pd.read_table(
file_path,
header=0,
sep="\t",
engine="pyarrow",
usecols=usecols,
dtype=dtype,
)

return parse_dataframe(
transcript_dataframe,
Expand Down
44 changes: 27 additions & 17 deletions pytximport/utils/_convert_transcripts_to_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,19 +84,25 @@ def convert_transcripts_to_genes(
log(25, "Matching gene_ids.")
transcript_gene_dict = transcript_gene_map.set_index("transcript_id")["gene_id"].to_dict()
gene_ids_raw = transcript_data["transcript_id"].to_series().map(transcript_gene_dict).values
gene_ids = np.repeat(gene_ids_raw, transcript_data["abundance"].shape[1]) # type: ignore
# gene_ids = np.repeat(gene_ids_raw, transcript_data["abundance"].shape[1]) # type: ignore

# Remove the transcript_id coordinate
transcript_data = transcript_data.drop_vars("transcript_id")
transcript_data = transcript_data.assign_coords(gene_id=gene_ids_raw)

# Rename the first dimension to gene
transcript_data = transcript_data.rename({"transcript_id": "gene_id"})
# Remove the transcript_id coordinate and rename the variable to gene_id
transcript_data = (
transcript_data.drop_vars("transcript_id")
.assign_coords(gene_id=gene_ids_raw)
.rename({"transcript_id": "gene_id"})
)

# Get the unique genes but keep the order
unique_genes = list(pd.Series(gene_ids).unique())
unique_genes = pd.Series(gene_ids_raw).unique()

log(25, "Creating gene abundance.")
# We already calculate the abundance length product here so that we can reuse the sum
transcript_data["abundance_length_product"] = xr.apply_ufunc(
np.multiply,
transcript_data["abundance"],
transcript_data["length"],
)
transcript_data_summed_by_gene = transcript_data.groupby("gene_id").sum()
abundance_gene = xr.DataArray(
transcript_data_summed_by_gene["abundance"],
Expand All @@ -109,28 +115,32 @@ def convert_transcripts_to_genes(
dims=["gene_id", "file"],
)

inferential_replicates_gene = None
if "inferential_replicates" in transcript_data.data_vars:
log(25, "Creating inferential replicates.")
inferential_replicates_gene = xr.DataArray(
transcript_data_summed_by_gene["inferential_replicates"],
dims=["gene_id", "bootstraps", "file"],
)
else:
inferential_replicates_gene = None

variances_gene = None
if "variance" in transcript_data.data_vars and inferential_replicates_gene is not None:
log(25, "Creating variances.")
variances_gene = inferential_replicates_gene.var(dim="bootstraps", ddof=1)

log(25, "Creating lengths.")
transcript_data["abundance_length_product"] = transcript_data["abundance"] * transcript_data["length"]
abundance_weighted_length = transcript_data.groupby("gene_id").sum()["abundance_length_product"]
length = xr.DataArray(abundance_weighted_length / abundance_gene.data, dims=["gene_id", "file"], name="length")
length = xr.DataArray(
transcript_data_summed_by_gene["abundance_length_product"] / abundance_gene.data,
dims=["gene_id", "file"],
name="length",
)

log(25, "Replacing missing lengths.")
average_transcript_length_across_samples = transcript_data["length"].mean(axis=1)
average_gene_length = average_transcript_length_across_samples.groupby("gene_id").mean()
length = replace_missing_average_transcript_length(length, average_gene_length)
length = replace_missing_average_transcript_length(
length,
# Average gene length across samples
transcript_data["length"].mean(axis=1).groupby("gene_id").mean(),
)

# Convert the counts to the desired count type
if counts_from_abundance is not None:
Expand All @@ -153,7 +163,7 @@ def convert_transcripts_to_genes(
if inferential_replicates_gene is not None:
data_vars["inferential_replicates"] = inferential_replicates_gene

if "variance" in transcript_data.data_vars:
if variances_gene is not None:
data_vars["variance"] = variances_gene

gene_expression = xr.Dataset(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def replace_missing_average_transcript_length(
Args:
length (xr.DataArray): The average length of transcripts at the gene level with a sample dimension.
length_gene_mean (xr.DataArray): The mean length of the transcript of the genes across samples.
length_gene_mean (xr.DataArray): The mean length of the transcripts of the genes across samples.
Returns:
xr.DataArray: The average length of transcripts at the gene level with a sample dimension.
Expand Down
2 changes: 0 additions & 2 deletions test/benchmark/.gitignore

This file was deleted.

Loading

0 comments on commit e67ae65

Please sign in to comment.