diff --git a/tests/scheme_id_test.py b/tests/scheme_id_test.py index 2abcfe4..c54252b 100644 --- a/tests/scheme_id_test.py +++ b/tests/scheme_id_test.py @@ -19,21 +19,9 @@ ] -def test_scheme_init_from_tsv(): - tmp_tsv = "test_scheme_init_from_tsv.tsv" - with open(tmp_tsv, "w") as f: - print(*SCHEME_TSV_HEADER_FIELDS, sep="\t", file=f) - print("amp1", "amp1_left", "left", "ACGT", "20", sep="\t", file=f) - print("amp1", "amp1_left_alt", "left", "CAG", "15", sep="\t", file=f) - print("amp1", "amp1_left_alt2", "left", "CA", "15", sep="\t", file=f) - print("amp1", "amp1_right", "right", "AAA", "50", sep="\t", file=f) - print("amp2", "amp2_left", "left", "A", "40", sep="\t", file=f) - print("amp2", "amp2_right", "right", "ATGTT", "90", sep="\t", file=f) - print("amp2", "amp2_right_alt", "right", "GTA", "101", sep="\t", file=f) - print("amp2", "amp2_right_alt2", "right", "GGTA", "100", sep="\t", file=f) - - scheme = scheme_id.Scheme(tsv_file=tmp_tsv) - assert scheme.amplicons == [ +@pytest.fixture() +def amp_scheme_data(): + amps = [ { "name": "amp1", "start": 15, @@ -53,14 +41,60 @@ def test_scheme_init_from_tsv(): }, }, ] - assert scheme.left_starts == {15: 0, 20: 0, 40: 1} - assert scheme.right_ends == {52: 0, 94: 1, 103: 1} - assert scheme.amplicon_name_indexes == {"amp1": 0, "amp2": 1} - mean_amp_length = statistics.mean([52 - 15 + 1, 40 - 15, 103 - 52, 103 - 40 + 1]) - assert scheme.mean_amp_length == mean_amp_length + return { + "amplicons": amps, + "left_starts": {15: 0, 20: 0, 40: 1}, + "right_ends": {52: 0, 94: 1, 103: 1}, + "amplicon_name_indexes": {"amp1": 0, "amp2": 1}, + "mean_amp_length": statistics.mean( + [52 - 15 + 1, 40 - 15, 103 - 52, 103 - 40 + 1] + ), + } + + +def test_scheme_init_from_tsv(amp_scheme_data): + tmp_tsv = "test_scheme_init_from_tsv.tsv" + with open(tmp_tsv, "w") as f: + print(*SCHEME_TSV_HEADER_FIELDS, sep="\t", file=f) + print("amp1", "amp1_left", "left", "ACGT", "20", sep="\t", file=f) + print("amp1", "amp1_left_alt", "left", "CAG", "15", sep="\t", file=f) + print("amp1", "amp1_left_alt2", "left", "CA", "15", sep="\t", file=f) + print("amp1", "amp1_right", "right", "AAA", "50", sep="\t", file=f) + print("amp2", "amp2_left", "left", "A", "40", sep="\t", file=f) + print("amp2", "amp2_right", "right", "ATGTT", "90", sep="\t", file=f) + print("amp2", "amp2_right_alt", "right", "GTA", "101", sep="\t", file=f) + print("amp2", "amp2_right_alt2", "right", "GGTA", "100", sep="\t", file=f) + + scheme = scheme_id.Scheme(amp_scheme_file=tmp_tsv) + assert scheme.amplicons == amp_scheme_data["amplicons"] + assert scheme.left_starts == amp_scheme_data["left_starts"] + assert scheme.right_ends == amp_scheme_data["right_ends"] + assert scheme.amplicon_name_indexes == amp_scheme_data["amplicon_name_indexes"] + assert scheme.mean_amp_length == amp_scheme_data["mean_amp_length"] os.unlink(tmp_tsv) +def test_scheme_init_from_bed(amp_scheme_data): + tmp_bed = "test_scheme_init_from_tsv.bed" + with open(tmp_bed, "w") as f: + print("REF", 20, 24, "amp1_LEFT_1", "1", "+", "ACGT", sep="\t", file=f) + print("REF", 15, 18, "amp1_LEFT_alt", "1", "+", "CAG", sep="\t", file=f) + print("REF", 15, 17, "amp1_LEFT_alt2", "1", "+", "CA", sep="\t", file=f) + print("REF", 50, 53, "amp1_RIGHT_1", "1", "+", "AAA", sep="\t", file=f) + print("REF", 40, 41, "amp2_LEFT_1", "1", "+", "A", sep="\t", file=f) + print("REF", 90, 95, "amp2_RIGHT_1", "1", "+", "ATGTT", sep="\t", file=f) + print("REF", 101, 104, "amp2_RIGHT_1", "1", "+", "GTA", sep="\t", file=f) + print("REF", 100, 104, "amp2_RIGHT_2", "1", "+", "GGTA", sep="\t", file=f) + + scheme = scheme_id.Scheme(amp_scheme_file=tmp_bed) + assert scheme.amplicons == amp_scheme_data["amplicons"] + assert scheme.left_starts == amp_scheme_data["left_starts"] + assert scheme.right_ends == amp_scheme_data["right_ends"] + assert scheme.amplicon_name_indexes == amp_scheme_data["amplicon_name_indexes"] + assert scheme.mean_amp_length == amp_scheme_data["mean_amp_length"] + os.unlink(tmp_bed) + + def test_scheme_init_distance_lists(): ref_length = 14 scheme = scheme_id.Scheme(end_tolerance=0) @@ -123,7 +157,7 @@ def test_count_primer_hits(): print("amp1", "amp1_left_alt2", "left", "AAAAAA", "30", sep="\t", file=f) print("amp1", "amp1_right", "right", "ACGTACG", "50", sep="\t", file=f) print("amp1", "amp1_right_alt1", "right", "ACGT", "54", sep="\t", file=f) - scheme = scheme_id.Scheme(tsv_file=tmp_tsv) + scheme = scheme_id.Scheme(amp_scheme_file=tmp_tsv) scheme.init_distance_lists(65) expect_amplicons = copy.deepcopy(scheme.amplicons) left_hits = [0] * 65 diff --git a/viridian/scheme_id.py b/viridian/scheme_id.py index 5ca6e92..0d3e065 100644 --- a/viridian/scheme_id.py +++ b/viridian/scheme_id.py @@ -18,7 +18,7 @@ class Scheme: - def __init__(self, tsv_file=None, end_tolerance=3): + def __init__(self, amp_scheme_file=None, end_tolerance=3): self.left_starts = {} self.right_ends = {} self.left_dists = [] @@ -35,64 +35,105 @@ def __init__(self, tsv_file=None, end_tolerance=3): self.end_tolerance = end_tolerance self.last_amplicon_end = -1 - if tsv_file is not None: + if amp_scheme_file is not None: try: - self.load_from_tsv_file(tsv_file) + self.load_from_file(amp_scheme_file) except: - raise Exception(f"Error loading primer scheme from TSV file {tsv_file}") + raise Exception( + f"Error loading primer scheme from file {amp_scheme_file}" + ) self.amp_coords = [(a["start"], a["end"]) for a in self.amplicons] self.amp_coords.sort() self._calculate_mean_amp_length() - def load_from_tsv_file(self, tsv_file): + def read_tsv_lines(self, tsv_file): with open(tsv_file) as f: for d in csv.DictReader(f, delimiter="\t"): if d["Left_or_right"] not in ["left", "right"]: raise Exception( f"Left_or_right column not left or right. Got: {d['Left_or_right']}" ) + yield d - if d["Amplicon_name"] not in self.amplicon_name_indexes: - self.amplicons.append( - { - "name": d["Amplicon_name"], - "start": float("inf"), - "end": -1, - "primers": {"left": [], "right": []}, - } + def read_bed_lines(self, bed_file): + with open(bed_file) as f: + for line in f: + fields = line.rstrip().split("\t") + if len(fields) != 7: + raise Exception( + f"Error reading amplicon scheme BED file {bed_file}. Expected 7 columns, but got {len(fields)}:\n{line}" ) - self.amplicon_name_indexes[d["Amplicon_name"]] = ( - len(self.amplicons) - 1 + + d = {} + try: + d["Position"] = int(fields[1]) + except: + raise Exception( + f"Error reading amplicon scheme BED file {bed_file}. Could not get amplicon start position from second column:\n{line}" ) - amp_index = self.amplicon_name_indexes[d["Amplicon_name"]] - amp = self.amplicons[amp_index] - primer_start = int(d["Position"]) - primer_end = primer_start + len(d["Sequence"]) - 1 - - if d["Left_or_right"] == "left": - primers = amp["primers"]["left"] - same = [x for x in primers if x[0] == primer_start] - if len(same): - same[0][1] = max(primer_end, same[0][1]) - else: - primers.append([primer_start, primer_end]) - primers.sort() - self.left_starts[primer_start] = amp_index - amp["start"] = min(amp["start"], primer_start) - elif d["Left_or_right"] == "right": - primers = amp["primers"]["right"] - same = [x for x in primers if x[1] == primer_end] - if len(same): - same[0][0] = min(primer_start, same[0][0]) - else: - primers.append([primer_start, primer_end]) - primers.sort() - - self.right_ends[primer_end] = amp_index - amp["end"] = max(amp["end"], primer_end) - self.last_amplicon_end = max(amp["end"], self.last_amplicon_end) + try: + d["Amplicon_name"], d["Left_or_right"], primer_name = fields[ + 3 + ].rsplit("_", maxsplit=2) + except: + raise Exception( + f"Error reading amplicon scheme BED file {bed_file}. Could not get amplicon name, left/right, primer name from column 4:\n{line}" + ) + + d["Left_or_right"] = d["Left_or_right"].lower() + if d["Left_or_right"] not in ["left", "right"]: + raise Exception( + f"Error reading amplicon scheme BED file {bed_file}. Could not get left/right from column 4:\n{line}" + ) + + d["Sequence"] = fields[6] + yield d + + def load_from_file(self, filename): + read_func = ( + self.read_bed_lines if filename.endswith(".bed") else self.read_tsv_lines + ) + for d in read_func(filename): + if d["Amplicon_name"] not in self.amplicon_name_indexes: + self.amplicons.append( + { + "name": d["Amplicon_name"], + "start": float("inf"), + "end": -1, + "primers": {"left": [], "right": []}, + } + ) + self.amplicon_name_indexes[d["Amplicon_name"]] = len(self.amplicons) - 1 + + amp_index = self.amplicon_name_indexes[d["Amplicon_name"]] + amp = self.amplicons[amp_index] + primer_start = int(d["Position"]) + primer_end = primer_start + len(d["Sequence"]) - 1 + + if d["Left_or_right"] == "left": + primers = amp["primers"]["left"] + same = [x for x in primers if x[0] == primer_start] + if len(same): + same[0][1] = max(primer_end, same[0][1]) + else: + primers.append([primer_start, primer_end]) + primers.sort() + self.left_starts[primer_start] = amp_index + amp["start"] = min(amp["start"], primer_start) + elif d["Left_or_right"] == "right": + primers = amp["primers"]["right"] + same = [x for x in primers if x[1] == primer_end] + if len(same): + same[0][0] = min(primer_start, same[0][0]) + else: + primers.append([primer_start, primer_end]) + primers.sort() + + self.right_ends[primer_end] = amp_index + amp["end"] = max(amp["end"], primer_end) + self.last_amplicon_end = max(amp["end"], self.last_amplicon_end) def _calculate_mean_amp_length(self): left_lengths = [ @@ -261,7 +302,7 @@ def simulate_reads( random.seed(42) with open(outfile, "w") as f: - for (start, end) in self.amp_coords: + for start, end in self.amp_coords: if read_length is None: print(f">{start}_{end}", file=f) print(ref_seq[start : end + 1], file=f) @@ -506,7 +547,7 @@ def analyse_bam( for scheme_name, scheme_tsv in scheme_tsvs.items(): logging.info(f"{LOG_PREFIX} Analysing amplicon scheme {scheme_name}") logging.debug(f"{LOG_PREFIX} {scheme_name} Load TSV file {scheme_tsv}") - scheme = Scheme(tsv_file=scheme_tsv, end_tolerance=end_tolerance) + scheme = Scheme(amp_scheme_file=scheme_tsv, end_tolerance=end_tolerance) if scheme.last_amplicon_end > ref_length: return ( json_dict, diff --git a/viridian/scheme_simulate.py b/viridian/scheme_simulate.py index b4651e2..175c0c7 100644 --- a/viridian/scheme_simulate.py +++ b/viridian/scheme_simulate.py @@ -86,7 +86,7 @@ def simulate_all_schemes( for scheme_name, scheme_tsv in amplicon_scheme_name_to_tsv.items(): logging.info(f"Processing scheme {scheme_name}") - scheme = scheme_id.Scheme(tsv_file=scheme_tsv) + scheme = scheme_id.Scheme(amp_scheme_file=scheme_tsv) for fragment in False, True: if fragment: