Skip to content

Commit

Permalink
Bug fixes, error messages and biotype support from GTF files
Browse files Browse the repository at this point in the history
  • Loading branch information
maltekuehl committed Jul 15, 2024
1 parent 5ecd2b0 commit e0eb2dd
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 103 deletions.
137 changes: 73 additions & 64 deletions docs/source/example.ipynb

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions pytximport/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@
)
@click.option(
"-ow",
"--save_path_override",
"--save-path-override",
"--save_path_overwrite",
"--save-path-overwrite",
is_flag=True,
help="Whether to override the save path.",
help="Whether to overwrite the save path.",
)
@click.option(
"--ignore_after_bar",
Expand Down
15 changes: 10 additions & 5 deletions pytximport/core/_tximport.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def tximport(
output_type: Literal["xarray", "anndata"] = "anndata",
output_format: Literal["csv", "h5ad"] = "csv",
save_path: Optional[Union[str, Path]] = None,
save_path_override: bool = False,
save_path_overwrite: bool = False,
return_data: bool = True,
biotype_filter: Optional[List[str]] = None,
) -> Union[xr.Dataset, ad.AnnData, None]:
Expand Down Expand Up @@ -95,13 +95,15 @@ def tximport(
output_type (Literal["xarray", "anndata"], optional): The type of output. Defaults to "anndata".
output_format (Literal["csv", "h5ad"], optional): The type of output file. Defaults to "csv".
save_path (Optional[Union[str, Path]], optional): The path to save the gene-level expression. Defaults to None.
save_path_override (bool, optional): Whether to override the save path if it already exists. Defaults to False.
save_path_overwrite (bool, optional): Whether to overwrite the save path if it already exists.
Defaults to False.
return_data (bool, optional): Whether to return the gene-level expression. Defaults to True.
biotype_filter (List[str], optional): Filter the transcripts by biotype, including only those provided.
Defaults to None.
Returns:
Union[xr.Dataset, ad.AnnData, None]: The estimated gene-level expression data if `return_data` is True.
Union[xr.Dataset, ad.AnnData, None]: The estimated gene-level or transcript-level expression data if
`return_data` is True, else None.
"""
# start a timer
log(25, "Starting the import.")
Expand Down Expand Up @@ -342,6 +344,9 @@ def tximport(
if counts_from_abundance is not None:
length_key = "length"

if counts_from_abundance == "length_scaled_tpm":
raise ValueError("The `length_scaled_tpm` option is not supported for transcript-level expression.")

if counts_from_abundance == "dtu_scaled_tpm":
if transcript_gene_map is None:
raise ValueError("A transcript to gene mapping must be provided for `dtu_scaled_tpm`.")
Expand Down Expand Up @@ -415,9 +420,9 @@ def tximport(
)

if save_path is not None:
if save_path.exists() and not save_path_override:
if save_path.exists() and not save_path_overwrite:
raise FileExistsError(
f"The file already exists: {save_path}. Set `save_path_override` to True to override."
f"The file already exists: {save_path}. Set `save_path_overwrite` to True to overwrite."
)

if not save_path.parent.exists():
Expand Down
6 changes: 3 additions & 3 deletions pytximport/importers/_read_kallisto.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from logging import warning
from pathlib import Path
from typing import Literal, Optional, Union
from warnings import warn
from typing import Optional, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -116,7 +116,7 @@ def read_kallisto(

# calculate the transcript-level TPM if the abundance was not included
if abundance_column is None:
warn("Abundance column not provided, calculating TPM.", UserWarning)
warning("Abundance column not provided, calculating TPM.")
abundance = convert_counts_to_tpm(counts, length)
else:
assert len(transcript_ids) == len(abundance), "The transcript ids and abundance have different length."
Expand Down
4 changes: 2 additions & 2 deletions pytximport/utils/_convert_abundance_to_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ def convert_abundance_to_counts(
log(25, "Setting the counts to scaled TPM.")
counts_transformed = abundance
elif counts_from_abundance == "length_scaled_tpm":
# convert the TPM to counts and scale by the length
# convert the TPM to counts and scale by the gene length across samples
log(25, "Setting counts to length scaled TPM.")
counts_transformed = abundance * length.mean(axis=1)
else:
raise ValueError("The count transform must be 'scaled_tpm' or 'length_scaled_tpm'.")

# scale the counts
# scale the counts from abundance to the original sequencing depth of each sample
column_counts = counts.sum(axis=0)
new_counts = counts_transformed.sum(axis=0)
ratio = column_counts / new_counts
Expand Down
2 changes: 1 addition & 1 deletion pytximport/utils/_convert_transcripts_to_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def convert_transcripts_to_genes(
warning(
"Not all transcripts are present in the mapping."
+ f" {len(set(unique_transcripts) - set(transcript_gene_map['transcript_id']))}"
+ f" out of {len(unique_transcripts)} missing."
+ f" out of {len(unique_transcripts)} missing. Removing the missing transcripts."
)
# remove the missing transcripts by only keeping the data for the transcripts present in the mapping
transcript_ids_intersect = list(set(unique_transcripts).intersection(set(transcript_gene_map["transcript_id"])))
Expand Down
21 changes: 15 additions & 6 deletions pytximport/utils/_create_transcript_to_gene_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def create_transcript_to_gene_map_from_gtf_annotation(
field: Literal["gene_id", "gene_name"] = "gene_id",
chunk_size: int = 100000,
keep_gene_name: bool = True,
keep_biotype: bool = False,
) -> pd.DataFrame:
"""Create a mapping from transcript ids to gene ids using a GTF annotation file.
Expand All @@ -57,11 +58,12 @@ def create_transcript_to_gene_map_from_gtf_annotation(
chunk_size (int, optional): The number of lines to read at a time. Defaults to 100000.
keep_gene_name (bool, optional): Whether to keep the gene_name column when field is "gene_id".
Defaults to True.
keep_biotype (bool, optional): Whether to keep the gene_biotype column. Defaults to False.
Returns:
pd.DataFrame: The mapping from transcript ids to gene ids.
"""
transcript_gene_map = pd.DataFrame(columns=["transcript_id", "gene_id", "gene_name"])
transcript_gene_map = pd.DataFrame(columns=["transcript_id", "gene_id", "gene_name", "gene_biotype"])

for chunk in pd.read_csv(file_path, sep="\t", chunksize=chunk_size, header=None, comment="#"):
# see: https://www.ensembl.org/info/website/upload/gff.html
Expand All @@ -72,9 +74,10 @@ def create_transcript_to_gene_map_from_gtf_annotation(
# transcript_name ""; transcript_source "";
# however, we are only interested in the gene_id, gene_name, and transcript_id
attribute_columns = [
"gene_id",
"transcript_id",
"gene_id",
"gene_name",
"gene_biotype",
]
for column in attribute_columns:
chunk[column] = chunk["attribute"].apply(
Expand All @@ -87,7 +90,7 @@ def create_transcript_to_gene_map_from_gtf_annotation(
transcript_gene_map = pd.concat(
[
transcript_gene_map,
chunk[["transcript_id", "gene_id", "gene_name"]],
chunk[["transcript_id", "gene_id", "gene_name", "gene_biotype"]],
]
)

Expand All @@ -101,12 +104,18 @@ def create_transcript_to_gene_map_from_gtf_annotation(
if field == "gene_name":
transcript_gene_map.drop("gene_id", axis=1, inplace=True)
transcript_gene_map.rename(columns={"gene_name": "gene_id"}, inplace=True)
elif field == "gene_id" and not keep_gene_name:

if not keep_gene_name and "gene_name" in transcript_gene_map.columns:
transcript_gene_map.drop("gene_name", axis=1, inplace=True)

transcript_gene_map.replace("", np.nan, inplace=True)
if not keep_biotype and "gene_biotype" in transcript_gene_map.columns:
transcript_gene_map.drop("gene_biotype", axis=1, inplace=True)

transcript_gene_map[["gene_id", "transcript_id"]] = transcript_gene_map[["gene_id", "transcript_id"]].replace(
"", np.nan
)
transcript_gene_map.dropna(inplace=True)
transcript_gene_map.drop_duplicates(inplace=True)
transcript_gene_map.drop_duplicates(subset=["gene_id", "transcript_id"], inplace=True)
transcript_gene_map.reset_index(drop=True, inplace=True)

return transcript_gene_map
49 changes: 30 additions & 19 deletions test/test_transcriptome_to_gene_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,33 @@ def test_transcript_to_gene_map_from_gtf_annotation(
) -> None:
"""Test creating a transcript to gene map from a GTF annotation file."""
for keep_gene_name in [True, False]:
df_transcript_to_gene = create_transcript_to_gene_map_from_gtf_annotation(
gtf_annotation_file,
field="gene_id",
keep_gene_name=keep_gene_name,
)

assert isinstance(df_transcript_to_gene, pd.DataFrame), "The output is not a DataFrame."

if keep_gene_name:
assert df_transcript_to_gene.shape[1] == 3, "The output has the wrong number of columns."
else:
assert df_transcript_to_gene.shape[1] == 2, "The output has the wrong number of columns."

df_transcript_to_gene_reference = pd.read_csv(
gtf_annotation_file.parent / f"transcript_to_gene_map{'_gene_name' if keep_gene_name else ''}.csv",
header=0,
).reset_index(drop=True)

pd.testing.assert_frame_equal(df_transcript_to_gene, df_transcript_to_gene_reference)
for keep_biotype in [True, False]:
df_transcript_to_gene = create_transcript_to_gene_map_from_gtf_annotation(
gtf_annotation_file,
field="gene_id",
keep_gene_name=keep_gene_name,
keep_biotype=keep_biotype,
)

assert isinstance(df_transcript_to_gene, pd.DataFrame), "The output is not a DataFrame."

if keep_gene_name and keep_biotype:
assert df_transcript_to_gene.shape[1] == 4, "The output has the wrong number of columns."
elif keep_gene_name or keep_biotype:
assert df_transcript_to_gene.shape[1] == 3, "The output has the wrong number of columns."
else:
assert df_transcript_to_gene.shape[1] == 2, "The output has the wrong number of columns."

df_transcript_to_gene_reference = pd.read_csv(
gtf_annotation_file.parent / f"transcript_to_gene_map{'_gene_name' if keep_gene_name else ''}.csv",
header=0,
).reset_index(drop=True)

pd.testing.assert_frame_equal(
(
df_transcript_to_gene.drop(columns=["gene_biotype"])
if "gene_biotype" in df_transcript_to_gene.columns
else df_transcript_to_gene
),
df_transcript_to_gene_reference,
)

0 comments on commit e0eb2dd

Please sign in to comment.