Skip to content

Commit

Permalink
Merge pull request #11 from IGDRion/dev
Browse files Browse the repository at this point in the history
Merging dev into main
  • Loading branch information
vlebars authored Apr 8, 2024
2 parents a6d83e1 + 5fcab0d commit 672c898
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 35 deletions.
48 changes: 32 additions & 16 deletions bin/extract_tss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,42 @@ def get_tss_interval(transcript, length):
return start, end


def get_interval_record(tx, length):
start, end = get_tss_interval(tx, length)
if start > 0:
return (
tx.seqname,
start - 1,
end,
f"{tx['gene_id']}::{tx['transcript_id']}::{tx.strand}",
tx.score,
tx.strand,
)
else:
print(
f"{tx['transcript_id']} skipped: Start Coordinate detected < 0.",
file=sys.stderr,
)
return None


def get_intervals(transcripts, length):
intervals = []
for transcript in transcripts:
start, end = get_tss_interval(transcript, length)
if start > 0:
intervals.append(
(
transcript.seqname,
start - 1,
end,
f"{transcript['gene_id']}::{transcript['transcript_id']}",
transcript.score,
transcript.strand,
)
)
# If defined strand, test only TSS
if transcript.strand in ("+", "-"):
interval_record = get_interval_record(transcript, length)
if interval_record is not None:
intervals.append(interval_record)

# If undefined strand, simulate and test both extremities
else:
print(
f"{transcript['transcript_id']} skipped: Start Coordinate detected < 0.",
file=sys.stderr,
)
for strand in ("+", "-"):
transcript.strand = strand
interval_record = get_interval_record(transcript, length)
if interval_record is not None:
intervals.append(interval_record)

return intervals


Expand Down
68 changes: 49 additions & 19 deletions bin/filter_gtf_ndr.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,65 @@
#! /usr/bin/env python3
from typing import Set
#! /usr/bin/env python3
from typing import Dict, Set, Tuple
from collections import namedtuple

from GTF import GTF

TranscriptProb = namedtuple("TranscriptProb", ["gene_id", "tx_id", "ndr"])


def parse_bambu(line):
return tuple(line)
def parse_bambu(line) -> TranscriptProb:
return TranscriptProb(line[1], line[0].lower(), float(line[2]))


def parse_tfkmers(line):
def parse_tfkmers(line) -> Tuple[TranscriptProb, str]:
ids = line[0].split("::")
return ids[1], ids[0], line[1]
return TranscriptProb(ids[0], ids[1].lower(), float(line[1])), ids[2]


StrandRecord = namedtuple("StrandRecord", ["ndr", "strand"])


def parse_ndr(csv, origin, th) -> Set[str]:
def parse_ndr(csv, origin, th) -> Tuple[Set[str], Dict[str, StrandRecord]]:
s = set()
strand_dict = dict()

# Skip header
next(csv)

strand = None
for line in csv:
line = line.split(",")

if origin == "bambu":
line = parse_bambu(line)
tx_prob = parse_bambu(line)
elif origin == "tfkmers":
line = parse_tfkmers(line)
tx_prob, strand = parse_tfkmers(line)
else:
exit("Unknown method")

tx_id, _, ndr = line
ndr = float(ndr)
if tx_prob.ndr < th:
s.add(tx_prob.tx_id)

if ndr < th:
s.add(tx_id.lower())
# Extract strand from sequence name to restrand GTF records
if origin == "tfkmers":
# If both extremities are tested, keep only lower extremity prob
if (
tx_prob.tx_id not in strand_dict
or tx_prob.ndr < strand_dict[tx_prob.tx_id].ndr
):
strand_dict[tx_prob.tx_id] = StrandRecord(tx_prob.ndr, strand)

return s
return s, strand_dict


def filter_count_matrix(file, transcripts, wr):
print(next(file), file=wr)
for line in file:
line_splitted = line.split("\t")
if line_splitted[0].startswith("BambuTx") and line_splitted[0].lower() not in transcripts:
if (
line_splitted[0].startswith("BambuTx")
and line_splitted[0].lower() not in transcripts
):
continue
print(line.rstrip(), file=wr)

Expand Down Expand Up @@ -99,8 +119,10 @@ def filter_count_matrix(file, transcripts, wr):
args = parser.parse_args()

###################################################
filter_bambu = parse_ndr(args.bambu, "bambu", args.bambu_threshold)
filter_tfkmers = parse_ndr(args.tfkmers, "tfkmers", args.tfkmers_threshold)
filter_bambu, _ = parse_ndr(args.bambu, "bambu", args.bambu_threshold)
filter_tfkmers, strand_dict = parse_ndr(
args.tfkmers, "tfkmers", args.tfkmers_threshold
)

if args.operation == "union":
filter = filter_bambu | filter_tfkmers
Expand All @@ -109,8 +131,16 @@ def filter_count_matrix(file, transcripts, wr):

with open("unformat.novel.filter.gtf", "w") as wr:
for record in GTF.parse_by_line(args.gtf):
if "transcript_id" in record and record["transcript_id"].lower() in filter:
print(record, file=wr)
if "transcript_id" in record:
tx_id = record["transcript_id"].lower()

if tx_id in filter:
# If operation == "union", tx_id can be OK in bambu
# but not in TFKmers. So strand not defined
if tx_id in strand_dict:
record.strand = strand_dict[tx_id].strand

print(record, file=wr)

with open("counts_transcript.filter.txt", "w") as wr:
filter_count_matrix(args.counts_tx, filter, wr)
1 change: 1 addition & 0 deletions modules/header.nf
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Tfkmers Tokenizer : ${params.tfkmers_tokenizer}
Tfkmers Threshold : ${params.tfkmers_threshold}
Bambu Threshold : ${params.bambu_threshold}
Filtering operation : ${params.operation}
Stranded : ${params.bambu_strand}
-${c_dim}-------------------------------------${c_reset}-
""".stripIndent()
}

0 comments on commit 672c898

Please sign in to comment.