Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support amplicon schemes in artic bed format #113

Merged
merged 1 commit into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 55 additions & 21 deletions tests/scheme_id_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,9 @@
]


def test_scheme_init_from_tsv():
tmp_tsv = "test_scheme_init_from_tsv.tsv"
with open(tmp_tsv, "w") as f:
print(*SCHEME_TSV_HEADER_FIELDS, sep="\t", file=f)
print("amp1", "amp1_left", "left", "ACGT", "20", sep="\t", file=f)
print("amp1", "amp1_left_alt", "left", "CAG", "15", sep="\t", file=f)
print("amp1", "amp1_left_alt2", "left", "CA", "15", sep="\t", file=f)
print("amp1", "amp1_right", "right", "AAA", "50", sep="\t", file=f)
print("amp2", "amp2_left", "left", "A", "40", sep="\t", file=f)
print("amp2", "amp2_right", "right", "ATGTT", "90", sep="\t", file=f)
print("amp2", "amp2_right_alt", "right", "GTA", "101", sep="\t", file=f)
print("amp2", "amp2_right_alt2", "right", "GGTA", "100", sep="\t", file=f)

scheme = scheme_id.Scheme(tsv_file=tmp_tsv)
assert scheme.amplicons == [
@pytest.fixture()
def amp_scheme_data():
amps = [
{
"name": "amp1",
"start": 15,
Expand All @@ -53,14 +41,60 @@ def test_scheme_init_from_tsv():
},
},
]
assert scheme.left_starts == {15: 0, 20: 0, 40: 1}
assert scheme.right_ends == {52: 0, 94: 1, 103: 1}
assert scheme.amplicon_name_indexes == {"amp1": 0, "amp2": 1}
mean_amp_length = statistics.mean([52 - 15 + 1, 40 - 15, 103 - 52, 103 - 40 + 1])
assert scheme.mean_amp_length == mean_amp_length
return {
"amplicons": amps,
"left_starts": {15: 0, 20: 0, 40: 1},
"right_ends": {52: 0, 94: 1, 103: 1},
"amplicon_name_indexes": {"amp1": 0, "amp2": 1},
"mean_amp_length": statistics.mean(
[52 - 15 + 1, 40 - 15, 103 - 52, 103 - 40 + 1]
),
}


def test_scheme_init_from_tsv(amp_scheme_data):
tmp_tsv = "test_scheme_init_from_tsv.tsv"
with open(tmp_tsv, "w") as f:
print(*SCHEME_TSV_HEADER_FIELDS, sep="\t", file=f)
print("amp1", "amp1_left", "left", "ACGT", "20", sep="\t", file=f)
print("amp1", "amp1_left_alt", "left", "CAG", "15", sep="\t", file=f)
print("amp1", "amp1_left_alt2", "left", "CA", "15", sep="\t", file=f)
print("amp1", "amp1_right", "right", "AAA", "50", sep="\t", file=f)
print("amp2", "amp2_left", "left", "A", "40", sep="\t", file=f)
print("amp2", "amp2_right", "right", "ATGTT", "90", sep="\t", file=f)
print("amp2", "amp2_right_alt", "right", "GTA", "101", sep="\t", file=f)
print("amp2", "amp2_right_alt2", "right", "GGTA", "100", sep="\t", file=f)

scheme = scheme_id.Scheme(amp_scheme_file=tmp_tsv)
assert scheme.amplicons == amp_scheme_data["amplicons"]
assert scheme.left_starts == amp_scheme_data["left_starts"]
assert scheme.right_ends == amp_scheme_data["right_ends"]
assert scheme.amplicon_name_indexes == amp_scheme_data["amplicon_name_indexes"]
assert scheme.mean_amp_length == amp_scheme_data["mean_amp_length"]
os.unlink(tmp_tsv)


def test_scheme_init_from_bed(amp_scheme_data):
tmp_bed = "test_scheme_init_from_tsv.bed"
with open(tmp_bed, "w") as f:
print("REF", 20, 24, "amp1_LEFT_1", "1", "+", "ACGT", sep="\t", file=f)
print("REF", 15, 18, "amp1_LEFT_alt", "1", "+", "CAG", sep="\t", file=f)
print("REF", 15, 17, "amp1_LEFT_alt2", "1", "+", "CA", sep="\t", file=f)
print("REF", 50, 53, "amp1_RIGHT_1", "1", "+", "AAA", sep="\t", file=f)
print("REF", 40, 41, "amp2_LEFT_1", "1", "+", "A", sep="\t", file=f)
print("REF", 90, 95, "amp2_RIGHT_1", "1", "+", "ATGTT", sep="\t", file=f)
print("REF", 101, 104, "amp2_RIGHT_1", "1", "+", "GTA", sep="\t", file=f)
print("REF", 100, 104, "amp2_RIGHT_2", "1", "+", "GGTA", sep="\t", file=f)

scheme = scheme_id.Scheme(amp_scheme_file=tmp_bed)
assert scheme.amplicons == amp_scheme_data["amplicons"]
assert scheme.left_starts == amp_scheme_data["left_starts"]
assert scheme.right_ends == amp_scheme_data["right_ends"]
assert scheme.amplicon_name_indexes == amp_scheme_data["amplicon_name_indexes"]
assert scheme.mean_amp_length == amp_scheme_data["mean_amp_length"]
os.unlink(tmp_bed)


def test_scheme_init_distance_lists():
ref_length = 14
scheme = scheme_id.Scheme(end_tolerance=0)
Expand Down Expand Up @@ -123,7 +157,7 @@ def test_count_primer_hits():
print("amp1", "amp1_left_alt2", "left", "AAAAAA", "30", sep="\t", file=f)
print("amp1", "amp1_right", "right", "ACGTACG", "50", sep="\t", file=f)
print("amp1", "amp1_right_alt1", "right", "ACGT", "54", sep="\t", file=f)
scheme = scheme_id.Scheme(tsv_file=tmp_tsv)
scheme = scheme_id.Scheme(amp_scheme_file=tmp_tsv)
scheme.init_distance_lists(65)
expect_amplicons = copy.deepcopy(scheme.amplicons)
left_hits = [0] * 65
Expand Down
129 changes: 85 additions & 44 deletions viridian/scheme_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


class Scheme:
def __init__(self, tsv_file=None, end_tolerance=3):
def __init__(self, amp_scheme_file=None, end_tolerance=3):
self.left_starts = {}
self.right_ends = {}
self.left_dists = []
Expand All @@ -35,64 +35,105 @@ def __init__(self, tsv_file=None, end_tolerance=3):
self.end_tolerance = end_tolerance
self.last_amplicon_end = -1

if tsv_file is not None:
if amp_scheme_file is not None:
try:
self.load_from_tsv_file(tsv_file)
self.load_from_file(amp_scheme_file)
except:
raise Exception(f"Error loading primer scheme from TSV file {tsv_file}")
raise Exception(
f"Error loading primer scheme from file {amp_scheme_file}"
)

self.amp_coords = [(a["start"], a["end"]) for a in self.amplicons]
self.amp_coords.sort()
self._calculate_mean_amp_length()

def load_from_tsv_file(self, tsv_file):
def read_tsv_lines(self, tsv_file):
with open(tsv_file) as f:
for d in csv.DictReader(f, delimiter="\t"):
if d["Left_or_right"] not in ["left", "right"]:
raise Exception(
f"Left_or_right column not left or right. Got: {d['Left_or_right']}"
)
yield d

if d["Amplicon_name"] not in self.amplicon_name_indexes:
self.amplicons.append(
{
"name": d["Amplicon_name"],
"start": float("inf"),
"end": -1,
"primers": {"left": [], "right": []},
}
def read_bed_lines(self, bed_file):
with open(bed_file) as f:
for line in f:
fields = line.rstrip().split("\t")
if len(fields) != 7:
raise Exception(
f"Error reading amplicon scheme BED file {bed_file}. Expected 7 columns, but got {len(fields)}:\n{line}"
)
self.amplicon_name_indexes[d["Amplicon_name"]] = (
len(self.amplicons) - 1

d = {}
try:
d["Position"] = int(fields[1])
except:
raise Exception(
f"Error reading amplicon scheme BED file {bed_file}. Could not get amplicon start position from second column:\n{line}"
)

amp_index = self.amplicon_name_indexes[d["Amplicon_name"]]
amp = self.amplicons[amp_index]
primer_start = int(d["Position"])
primer_end = primer_start + len(d["Sequence"]) - 1

if d["Left_or_right"] == "left":
primers = amp["primers"]["left"]
same = [x for x in primers if x[0] == primer_start]
if len(same):
same[0][1] = max(primer_end, same[0][1])
else:
primers.append([primer_start, primer_end])
primers.sort()
self.left_starts[primer_start] = amp_index
amp["start"] = min(amp["start"], primer_start)
elif d["Left_or_right"] == "right":
primers = amp["primers"]["right"]
same = [x for x in primers if x[1] == primer_end]
if len(same):
same[0][0] = min(primer_start, same[0][0])
else:
primers.append([primer_start, primer_end])
primers.sort()

self.right_ends[primer_end] = amp_index
amp["end"] = max(amp["end"], primer_end)
self.last_amplicon_end = max(amp["end"], self.last_amplicon_end)
try:
d["Amplicon_name"], d["Left_or_right"], primer_name = fields[
3
].rsplit("_", maxsplit=2)
except:
raise Exception(
f"Error reading amplicon scheme BED file {bed_file}. Could not get amplicon name, left/right, primer name from column 4:\n{line}"
)

d["Left_or_right"] = d["Left_or_right"].lower()
if d["Left_or_right"] not in ["left", "right"]:
raise Exception(
f"Error reading amplicon scheme BED file {bed_file}. Could not get left/right from column 4:\n{line}"
)

d["Sequence"] = fields[6]
yield d

def load_from_file(self, filename):
read_func = (
self.read_bed_lines if filename.endswith(".bed") else self.read_tsv_lines
)
for d in read_func(filename):
if d["Amplicon_name"] not in self.amplicon_name_indexes:
self.amplicons.append(
{
"name": d["Amplicon_name"],
"start": float("inf"),
"end": -1,
"primers": {"left": [], "right": []},
}
)
self.amplicon_name_indexes[d["Amplicon_name"]] = len(self.amplicons) - 1

amp_index = self.amplicon_name_indexes[d["Amplicon_name"]]
amp = self.amplicons[amp_index]
primer_start = int(d["Position"])
primer_end = primer_start + len(d["Sequence"]) - 1

if d["Left_or_right"] == "left":
primers = amp["primers"]["left"]
same = [x for x in primers if x[0] == primer_start]
if len(same):
same[0][1] = max(primer_end, same[0][1])
else:
primers.append([primer_start, primer_end])
primers.sort()
self.left_starts[primer_start] = amp_index
amp["start"] = min(amp["start"], primer_start)
elif d["Left_or_right"] == "right":
primers = amp["primers"]["right"]
same = [x for x in primers if x[1] == primer_end]
if len(same):
same[0][0] = min(primer_start, same[0][0])
else:
primers.append([primer_start, primer_end])
primers.sort()

self.right_ends[primer_end] = amp_index
amp["end"] = max(amp["end"], primer_end)
self.last_amplicon_end = max(amp["end"], self.last_amplicon_end)

def _calculate_mean_amp_length(self):
left_lengths = [
Expand Down Expand Up @@ -261,7 +302,7 @@ def simulate_reads(
random.seed(42)

with open(outfile, "w") as f:
for (start, end) in self.amp_coords:
for start, end in self.amp_coords:
if read_length is None:
print(f">{start}_{end}", file=f)
print(ref_seq[start : end + 1], file=f)
Expand Down Expand Up @@ -506,7 +547,7 @@ def analyse_bam(
for scheme_name, scheme_tsv in scheme_tsvs.items():
logging.info(f"{LOG_PREFIX} Analysing amplicon scheme {scheme_name}")
logging.debug(f"{LOG_PREFIX} {scheme_name} Load TSV file {scheme_tsv}")
scheme = Scheme(tsv_file=scheme_tsv, end_tolerance=end_tolerance)
scheme = Scheme(amp_scheme_file=scheme_tsv, end_tolerance=end_tolerance)
if scheme.last_amplicon_end > ref_length:
return (
json_dict,
Expand Down
2 changes: 1 addition & 1 deletion viridian/scheme_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def simulate_all_schemes(

for scheme_name, scheme_tsv in amplicon_scheme_name_to_tsv.items():
logging.info(f"Processing scheme {scheme_name}")
scheme = scheme_id.Scheme(tsv_file=scheme_tsv)
scheme = scheme_id.Scheme(amp_scheme_file=scheme_tsv)

for fragment in False, True:
if fragment:
Expand Down
Loading