diff --git a/bin/extract_tss.py b/bin/extract_tss.py index 529e7ba..8d2b974 100755 --- a/bin/extract_tss.py +++ b/bin/extract_tss.py @@ -15,26 +15,42 @@ def get_tss_interval(transcript, length): return start, end +def get_interval_record(tx, length): + start, end = get_tss_interval(tx, length) + if start > 0: + return ( + tx.seqname, + start - 1, + end, + f"{tx['gene_id']}::{tx['transcript_id']}::{tx.strand}", + tx.score, + tx.strand, + ) + else: + print( + f"{tx['transcript_id']} skipped: Start Coordinate detected < 0.", + file=sys.stderr, + ) + return None + + def get_intervals(transcripts, length): intervals = [] for transcript in transcripts: - start, end = get_tss_interval(transcript, length) - if start > 0: - intervals.append( - ( - transcript.seqname, - start - 1, - end, - f"{transcript['gene_id']}::{transcript['transcript_id']}", - transcript.score, - transcript.strand, - ) - ) + # If defined strand, test only TSS + if transcript.strand in ("+", "-"): + interval_record = get_interval_record(transcript, length) + if interval_record is not None: + intervals.append(interval_record) + + # If undefined strand, simulate and test both extremities else: - print( - f"{transcript['transcript_id']} skipped: Start Coordinate detected < 0.", - file=sys.stderr, - ) + for strand in ("+", "-"): + transcript.strand = strand + interval_record = get_interval_record(transcript, length) + if interval_record is not None: + intervals.append(interval_record) + return intervals diff --git a/bin/filter_gtf_ndr.py b/bin/filter_gtf_ndr.py index e6bba39..038ec7a 100755 --- a/bin/filter_gtf_ndr.py +++ b/bin/filter_gtf_ndr.py @@ -1,45 +1,65 @@ -#! /usr/bin/env python3 -from typing import Set +#! /usr/bin/env python3 +from typing import Dict, Set, Tuple +from collections import namedtuple + from GTF import GTF +TranscriptProb = namedtuple("TranscriptProb", ["gene_id", "tx_id", "ndr"]) + -def parse_bambu(line): - return tuple(line) +def parse_bambu(line) -> TranscriptProb: + return TranscriptProb(line[1], line[0].lower(), float(line[2])) -def parse_tfkmers(line): +def parse_tfkmers(line) -> Tuple[TranscriptProb, str]: ids = line[0].split("::") - return ids[1], ids[0], line[1] + return TranscriptProb(ids[0], ids[1].lower(), float(line[1])), ids[2] + + +StrandRecord = namedtuple("StrandRecord", ["ndr", "strand"]) -def parse_ndr(csv, origin, th) -> Set[str]: +def parse_ndr(csv, origin, th) -> Tuple[Set[str], Dict[str, StrandRecord]]: s = set() + strand_dict = dict() # Skip header next(csv) + strand = None for line in csv: line = line.split(",") if origin == "bambu": - line = parse_bambu(line) + tx_prob = parse_bambu(line) elif origin == "tfkmers": - line = parse_tfkmers(line) + tx_prob, strand = parse_tfkmers(line) + else: + exit("Unknown method") - tx_id, _, ndr = line - ndr = float(ndr) + if tx_prob.ndr < th: + s.add(tx_prob.tx_id) - if ndr < th: - s.add(tx_id.lower()) + # Extract strand from sequence name to restrand GTF records + if origin == "tfkmers": + # If both extremities are tested, keep only lower extremity prob + if ( + tx_prob.tx_id not in strand_dict + or tx_prob.ndr < strand_dict[tx_prob.tx_id].ndr + ): + strand_dict[tx_prob.tx_id] = StrandRecord(tx_prob.ndr, strand) - return s + return s, strand_dict def filter_count_matrix(file, transcripts, wr): print(next(file), file=wr) for line in file: line_splitted = line.split("\t") - if line_splitted[0].startswith("BambuTx") and line_splitted[0].lower() not in transcripts: + if ( + line_splitted[0].startswith("BambuTx") + and line_splitted[0].lower() not in transcripts + ): continue print(line.rstrip(), file=wr) @@ -99,8 +119,10 @@ def filter_count_matrix(file, transcripts, wr): args = parser.parse_args() ################################################### - filter_bambu = parse_ndr(args.bambu, "bambu", args.bambu_threshold) - filter_tfkmers = parse_ndr(args.tfkmers, "tfkmers", args.tfkmers_threshold) + filter_bambu, _ = parse_ndr(args.bambu, "bambu", args.bambu_threshold) + filter_tfkmers, strand_dict = parse_ndr( + args.tfkmers, "tfkmers", args.tfkmers_threshold + ) if args.operation == "union": filter = filter_bambu | filter_tfkmers @@ -109,8 +131,16 @@ def filter_count_matrix(file, transcripts, wr): with open("unformat.novel.filter.gtf", "w") as wr: for record in GTF.parse_by_line(args.gtf): - if "transcript_id" in record and record["transcript_id"].lower() in filter: - print(record, file=wr) + if "transcript_id" in record: + tx_id = record["transcript_id"].lower() + + if tx_id in filter: + # If operation == "union", tx_id can be OK in bambu + # but not in TFKmers. So strand not defined + if tx_id in strand_dict: + record.strand = strand_dict[tx_id].strand + + print(record, file=wr) with open("counts_transcript.filter.txt", "w") as wr: filter_count_matrix(args.counts_tx, filter, wr) diff --git a/modules/header.nf b/modules/header.nf index 15b5272..1077b00 100755 --- a/modules/header.nf +++ b/modules/header.nf @@ -24,6 +24,7 @@ Tfkmers Tokenizer : ${params.tfkmers_tokenizer} Tfkmers Threshold : ${params.tfkmers_threshold} Bambu Threshold : ${params.bambu_threshold} Filtering operation : ${params.operation} +Stranded : ${params.bambu_strand} -${c_dim}-------------------------------------${c_reset}- """.stripIndent() }