Skip to content

Commit

Permalink
Further performance improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
maltekuehl committed Sep 21, 2024
1 parent e67ae65 commit 89c8e45
Show file tree
Hide file tree
Showing 12 changed files with 67 additions and 59 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ ipython_config.py
# Environments
.env
.venv
.python-version
env/
venv/
ENV/
Expand Down Expand Up @@ -69,6 +70,7 @@ requirements.dev.txt
manuscript.pdf
/data/rpgn_example/
/rpgn/
/benchmark/

# All .DS_Store files
**/.DS_Store
Expand Down
1 change: 0 additions & 1 deletion .python-version

This file was deleted.

2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,13 @@ dev = [
"pandas-stubs",
"pandoc",
"pre-commit",
"pybiomart",
"pytest",
"sphinx",
"sphinx-autoapi",
"sphinx-autodoc-typehints",
"sphinx-copybutton",
"sphinx-design",
"sphinx-rtd-theme",
"virtualenv",
]

[project.scripts]
Expand Down
13 changes: 9 additions & 4 deletions pytximport/core/_tximport.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,14 @@ def tximport(
raise FileNotFoundError(f"The transcript to gene mapping does not exist: {transcript_gene_map}")

try:
if transcript_gene_map.suffix == ".csv":
transcript_gene_map = pd.read_csv(transcript_gene_map, header=0, engine="pyarrow")
else:
transcript_gene_map = pd.read_table(transcript_gene_map, header=0, engine="pyarrow")
transcript_gene_map = pd.read_csv(
transcript_gene_map,
header=0,
engine="pyarrow",
sep=("," if transcript_gene_map.suffix == ".csv" else "\t"),
usecols=["transcript_id", "gene_id"],
dtype={"transcript_id": str, "gene_id": str},
)
except Exception as exception:
raise ValueError(f"Could not read the transcript to gene mapping: {exception}")

Expand Down Expand Up @@ -445,6 +449,7 @@ def tximport(
if transcript_gene_map is None:
raise ValueError("A transcript to gene mapping must be provided for `dtu_scaled_tpm`.")

log(25, "Calculating median gene length over isoforms.")
transcript_data = get_median_length_over_isoform(
transcript_data,
transcript_gene_map,
Expand Down
1 change: 0 additions & 1 deletion pytximport/importers/_read_kallisto.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Optional, Union

import numpy as np
import pandas as pd
from h5py import File

from ..definitions import InferentialReplicates, TranscriptData
Expand Down
36 changes: 17 additions & 19 deletions pytximport/importers/_read_tsv.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,25 +84,23 @@ def read_tsv(
if not file_path.exists():
raise ImportError(f"The file does not exist: {file_path}")

# Read the quantification file as a tsv, tab separated and the first line is the column names
if file_path.suffix == ".gz":
transcript_dataframe = pd.read_table(file_path, header=0, compression="gzip", sep="\t")
else:
usecols = [id_column, counts_column, length_column]
dtype = {id_column: str, counts_column: np.float64, length_column: np.float64}

if abundance_column is not None:
usecols.append(abundance_column)
dtype[abundance_column] = np.float64

transcript_dataframe = pd.read_table(
file_path,
header=0,
sep="\t",
engine="pyarrow",
usecols=usecols,
dtype=dtype,
)
# Read the quantification file as a tsv, tab separated with the first line being the column names
usecols = [id_column, counts_column, length_column]
dtype = {id_column: str, counts_column: np.float64, length_column: np.float64}

if abundance_column is not None:
usecols.append(abundance_column)
dtype[abundance_column] = np.float64

transcript_dataframe = pd.read_table(
file_path,
header=0,
sep="\t",
compression=("gzip" if file_path.suffix == ".gz" else None),
engine="pyarrow",
usecols=usecols,
dtype=dtype,
)

return parse_dataframe(
transcript_dataframe,
Expand Down
5 changes: 1 addition & 4 deletions pytximport/utils/_convert_abundance_to_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@ def convert_abundance_to_counts(
raise ValueError("The count transform must be 'scaled_tpm' or 'length_scaled_tpm'.")

# Scale the counts from abundance to the original sequencing depth of each sample
column_counts = counts.sum(axis=0)
new_counts = counts_transformed.sum(axis=0)
ratio = column_counts / new_counts
counts_transformed = (counts_transformed.T * ratio).T
counts_transformed = (counts_transformed.T * (counts.sum(axis=0) / counts_transformed.sum(axis=0))).T

return counts_transformed
3 changes: 1 addition & 2 deletions pytximport/utils/_convert_counts_to_tpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,4 @@ def convert_counts_to_tpm(
Returns:
np.ndarray: The transcript-level expression data with the TPM.
"""
normalization_factor = 1e6 / np.sum(counts / length)
return np.array(counts * normalization_factor / length)
return np.array(counts * (1e6 / np.sum(counts / length)) / length)
15 changes: 8 additions & 7 deletions pytximport/utils/_convert_transcripts_to_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def convert_transcripts_to_genes(
xr.Dataset: The gene-level expression data from multiple samples.
"""
transcript_ids: Union[np.ndarray, List[str]] = transcript_data.coords["transcript_id"].values
transcript_ids = transcript_data.coords["transcript_id"].values

if ignore_after_bar:
# Ignore the part of the transcript ID after the bar
Expand Down Expand Up @@ -72,19 +71,21 @@ def convert_transcripts_to_genes(
)
# Remove the missing transcripts by only keeping the data for the transcripts present in the mapping
transcript_ids_intersect = list(set(unique_transcripts).intersection(set(transcript_gene_map["transcript_id"])))
transcript_ids_intersect_boolean = np.isin(transcript_ids, transcript_ids_intersect)
transcript_data = transcript_data.isel(
transcript_id=transcript_ids_intersect_boolean,
transcript_id=np.isin(transcript_ids, transcript_ids_intersect),
drop=True,
)
transcript_ids = transcript_data.coords["transcript_id"].values
# transcript_ids = transcript_data.coords["transcript_id"].values
transcript_gene_map = transcript_gene_map[transcript_gene_map["transcript_id"].isin(transcript_ids_intersect)]

# Add the corresponding gene to the transcript-level expression
log(25, "Matching gene_ids.")
transcript_gene_dict = transcript_gene_map.set_index("transcript_id")["gene_id"].to_dict()
gene_ids_raw = transcript_data["transcript_id"].to_series().map(transcript_gene_dict).values
# gene_ids = np.repeat(gene_ids_raw, transcript_data["abundance"].shape[1]) # type: ignore
gene_ids_raw = (
transcript_data["transcript_id"]
.to_series()
.map(transcript_gene_map.set_index("transcript_id")["gene_id"])
.values
)

# Remove the transcript_id coordinate and rename the variable to gene_id
transcript_data = (
Expand Down
2 changes: 1 addition & 1 deletion pytximport/utils/_create_transcript_gene_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def create_transcript_gene_map_from_annotation(
# Each attribute line looks like this:
# gene_id ""; transcript_id ""; gene_name ""; gene_source ""; gene_biotype "";
# transcript_name ""; transcript_source "";
# We are only interested in the gene_id, gene_name, and transcript_id
# We are only interested in the gene_id, gene_name, transcript_id, transcript_name and gene_biotype
attribute_columns = [
"transcript_id",
"transcript_name",
Expand Down
38 changes: 21 additions & 17 deletions pytximport/utils/_get_median_length_over_isoform.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,31 +22,35 @@ def get_median_length_over_isoform(
assert "length" in transcript_data.data_vars, "The transcript data does not contain a `length` variable."

# Get the gene ids for each transcript
transcript_gene_dict = transcript_gene_map.set_index("transcript_id")["gene_id"].to_dict()
gene_ids = transcript_data["transcript_id"].to_series().map(transcript_gene_dict).values
gene_ids = (
transcript_data["transcript_id"]
.to_series()
.map(transcript_gene_map.set_index("transcript_id")["gene_id"].to_dict())
.values
)

# Check that no gene ids is nan
assert not any(pd.isna(gene_ids)), "Not all transcript ids could be mapped to gene ids. Please check the mapping."

transcript_data_copy = transcript_data.drop_vars("transcript_id")
transcript_data_copy = transcript_data_copy.assign_coords(gene_id=gene_ids)
transcript_data_copy = transcript_data_copy.rename({"transcript_id": "gene_id"})

# Get the row mean across samples for each transcript
average_transcript_length_across_samples = transcript_data_copy["length"].mean(axis=1)
median_gene_length = average_transcript_length_across_samples.groupby("gene_id").median().to_dataframe()

transcript_median_gene_length = [median_gene_length.loc[gene_id, "length"] for gene_id in gene_ids]
transcript_median_gene_length_repeated = np.reshape(
np.repeat(
transcript_median_gene_length,
transcript_data["abundance"].shape[1],
),
transcript_data["abundance"].shape,
median_gene_length = (
transcript_data.drop_vars("transcript_id")
.assign_coords(gene_id=gene_ids)
.rename({"transcript_id": "gene_id"})["length"]
.mean(dim="file")
.groupby("gene_id")
.median()
.to_dataframe()
)

transcript_data["median_isoform_length"] = xr.DataArray(
transcript_median_gene_length_repeated,
np.reshape(
np.repeat(
pd.Series(gene_ids).map(median_gene_length["length"]).to_numpy(),
transcript_data["abundance"].shape[1],
),
transcript_data["abundance"].shape,
),
dims=("transcript_id", "file"),
)

Expand Down
8 changes: 7 additions & 1 deletion pytximport/utils/_replace_transcript_ids_with_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@ def replace_transcript_ids_with_names(
"""
# Read the transcript to gene mapping
if isinstance(transcript_name_map, str) or isinstance(transcript_name_map, Path):
transcript_name_map = pd.read_table(transcript_name_map, header=0)
transcript_name_map = pd.read_table(
transcript_name_map,
header=0,
engine="c",
usecols=["transcript_id", "transcript_name"],
dtype=str,
)
transcript_name_map = transcript_name_map.drop_duplicates()

# Check that transcript_id and transcript_name are present in the mapping
Expand Down

0 comments on commit 89c8e45

Please sign in to comment.