Skip to content

Commit

Permalink
More matching tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Aug 30, 2024
1 parent e8435e5 commit bad0019
Showing 1 changed file with 96 additions and 41 deletions.
137 changes: 96 additions & 41 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,44 +576,99 @@ def test_exact_matches(
assert len(s.path) == 1
assert s.path[0].parent == parent

# def test_stuff(
# self, tmp_path, fx_ts_2020_02_10, fx_alignment_store, fx_metadata_db
# ):
# # SRR11597207 0 42 0
# # SRR11597218 1 10 1

# # date = "2020-02-11" # 2 samples
# date = "2020-02-13" # 4 samples
# samples = sc2ts.preprocess(
# date,
# metadata_db=fx_metadata_db,
# alignment_store=fx_alignment_store,
# base_ts=fx_ts_2020_02_10,
# )
# # print(samples)

# num_mismatches = 3
# sc2ts.match_tsinfer(
# samples=samples,
# ts=fx_ts_2020_02_10,
# num_mismatches=3,
# precision=12,
# num_threads=0,
# )
# for sample in samples:
# print(
# sample.strain,
# sample.get_hmm_cost(num_mismatches),
# sample.path[0].parent,
# len(sample.mutations),
# )

# # match_db = sc2ts.MatchDb.initialise(tmp_path / "match.db")
# # ts = sc2ts.extend(
# # alignment_store=fx_alignment_store,
# # metadata_db=fx_metadata_db,
# # base_ts=fx_ts_2020_02_10,
# # date="2020-02-11",
# # match_db=match_db,
# # min_group_size=2,
# # )
@pytest.mark.parametrize(
("strain", "parent", "position", "derived_state"),
[("SRR11597218", 10, 289, "T"), ("ERR4206593", 62, 26994, "T")],
)
@pytest.mark.parametrize("num_mismatches", [1, 2, 3, 4])
@pytest.mark.parametrize("precision", [0, 1, 2, 12])
def test_one_mismatch(
self,
fx_ts_2020_02_10,
fx_alignment_store,
fx_metadata_db,
strain,
parent,
position,
derived_state,
num_mismatches,
precision,
):
samples = sc2ts.preprocess(
[fx_metadata_db[strain]], fx_ts_2020_02_10, "2020-02-20", fx_alignment_store
)
sc2ts.match_tsinfer(
samples=samples,
ts=fx_ts_2020_02_10,
num_mismatches=num_mismatches,
precision=precision,
num_threads=0,
)
s = samples[0]
assert len(s.mutations) == 1
assert s.mutations[0].site_position == position
assert s.mutations[0].derived_state == derived_state
assert len(s.path) == 1
assert s.path[0].parent == parent

@pytest.mark.parametrize("num_mismatches", [2, 3, 4])
@pytest.mark.parametrize("precision", [0, 1, 2, 12])
def test_two_mismatches(
self,
fx_ts_2020_02_10,
fx_alignment_store,
fx_metadata_db,
num_mismatches,
precision,
):
strain = "ERR4204459"
samples = sc2ts.preprocess(
[fx_metadata_db[strain]], fx_ts_2020_02_10, "2020-02-20", fx_alignment_store
)
sc2ts.match_tsinfer(
samples=samples,
ts=fx_ts_2020_02_10,
num_mismatches=num_mismatches,
precision=precision,
num_threads=0,
)
s = samples[0]
assert len(s.path) == 1
assert s.path[0].parent == 5
assert len(s.mutations) == 2
# assert s.mutations[0].site_position == position
# assert s.mutations[0].derived_state == derived_state


# def test_stuff(
# self,
# fx_ts_2020_02_10,
# fx_alignment_store,
# fx_metadata_db):

# # date = "2020-02-11" # 2 samples
# date = "2020-02-13" # 4 samples

# # metadata_matches =
# samples = sc2ts.preprocess(
# list(fx_metadata_db.get(date)),
# fx_ts_2020_02_10,
# date, fx_alignment_store
# )
# # print(samples)

# num_mismatches = 3
# sc2ts.match_tsinfer(
# samples=samples,
# ts=fx_ts_2020_02_10,
# num_mismatches=3,
# precision=12,
# num_threads=0,
# )
# for sample in samples:
# print(
# sample.strain,
# sample.get_hmm_cost(num_mismatches),
# sample.path[0].parent,
# len(sample.mutations),
# )

0 comments on commit bad0019

Please sign in to comment.