Skip to content

Commit

Permalink
Switch from None to raising exceptions in ASI2.
Browse files Browse the repository at this point in the history
  • Loading branch information
donkirkby committed Jan 25, 2018
1 parent b85bab7 commit a61ed59
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 103 deletions.
73 changes: 28 additions & 45 deletions pyvdrm/asi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,10 @@
from functools import reduce, total_ordering
from pyparsing import (Literal, nums, Word, Forward, Optional, Regex,
infixNotation, delimitedList, opAssoc, ParseException)
from pyvdrm.drm import AsiExpr, AsiBinaryExpr, DRMParser
from pyvdrm.drm import AsiExpr, AsiBinaryExpr, DRMParser, MissingPositionError
from pyvdrm.vcf import MutationSet


def maybe_foldl(func, noneable):
"""Safely fold a function over a potentially empty list of
potentially null values"""
if noneable is None:
return None
clean = [x for x in noneable if x is not None]
if not clean:
return None
return reduce(func, clean)


def maybe_map(func, noneable):
if noneable is None:
return None
r_list = []
for x in noneable:
if x is None:
continue
result = func(x)
if result is None:
continue
r_list.append(result)
if not r_list:
return None
return r_list


@total_ordering
class Score(object):
"""Encapsulate a score and the residues that support it"""
Expand Down Expand Up @@ -169,10 +142,17 @@ class ScoreList(AsiExpr):
def __call__(self, mutations):
operation, *rest = self.children
if operation == 'MAX':
return maybe_foldl(max, [f(mutations) for f in rest])

# the default operation is sum
return maybe_foldl(lambda x, y: x+y, [f(mutations) for f in self.children])
terms = rest
func = max
else:
# the default operation is sum
terms = self.children
func = sum
scores = [f(mutations) for f in terms]
matched_scores = [score.score for score in scores if score.score]
residues = reduce(lambda x, y: x | y,
(score.residues for score in scores))
return Score(bool(matched_scores) and func(matched_scores), residues)


class SelectFrom(AsiExpr):
Expand All @@ -186,15 +166,13 @@ def typecheck(self, tokens):
def __call__(self, mutations):
operation, *rest = self.children
# the head of the arg list must be an equality expression

scored = list(maybe_map(lambda f: f(mutations), rest))
passing = len(scored)

if operation(passing):
return Score(True, maybe_foldl(
lambda x, y: x.residues.union(y.residues), scored))
else:
return None
scored = [f(mutations) for f in rest]
passing = sum(bool(score.score) for score in scored)

return Score(operation(passing),
reduce(lambda x, y: x | y,
(item.residues for item in scored)))


class AsiScoreCond(AsiExpr):
Expand All @@ -204,7 +182,7 @@ class AsiScoreCond(AsiExpr):

def __call__(self, args):
"""Score conditions evaluate a list of expressions and sum scores"""
return maybe_foldl(lambda x, y: x+y, map(lambda x: x(args), self.children))
return sum((f(args) for f in self.children), Score(False, set()))


class AsiMutations(object):
Expand All @@ -213,19 +191,24 @@ class AsiMutations(object):
def __init__(self, _label=None, _pos=None, args=None):
"""Initialize set of mutations from a potentially ambiguous residue
"""
self.mutations = args and MutationSet(''.join(args))
self.mutations = MutationSet(''.join(args))

def __repr__(self):
if self.mutations is None:
return "AsiMutations()"
return "AsiMutations(args={!r})".format(str(self.mutations))

def __call__(self, env):
is_found = False
for mutation_set in env:
is_found |= mutation_set.pos == self.mutations.pos
intersection = self.mutations.mutations & mutation_set.mutations
if len(intersection) > 0:
return Score(True, intersection)
return None

if not is_found:
# Some required positions were not found in the environment.
raise MissingPositionError('Missing position {}.'.format(
self.mutations.pos))
return Score(False, set())


class ASI2(DRMParser):
Expand Down
2 changes: 0 additions & 2 deletions pyvdrm/hcvr.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,6 @@ def __init__(self, _label=None, _pos=None, args=None):
self.mutations = MutationSet(''.join(args))

def __repr__(self):
if self.mutations is None:
return "AsiMutations()"
return "AsiMutations(args={!r})".format(str(self.mutations))

def __call__(self, env):
Expand Down
71 changes: 32 additions & 39 deletions pyvdrm/tests/test_asi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,31 @@
from pyparsing import ParseException

from pyvdrm.asi2 import ASI2, AsiMutations, Score
from pyvdrm.drm import MissingPositionError
from pyvdrm.vcf import Mutation, MutationSet, VariantCalls

from pyvdrm.tests.test_vcf import add_mutations


# noinspection SqlNoDataSourceInspection,SqlDialectInspection
class TestRuleParser(unittest.TestCase):

def test_stanford_ex1(self):
ASI2("151M OR 69i")

def test_stanford_ex2(self):
def test_atleast_true(self):
rule = ASI2("SELECT ATLEAST 2 FROM (41L, 67N, 70R, 210W, 215F, 219Q)")
self.assertTrue(rule(VariantCalls('41L 67N 70d 210d 215d 219d')))

def test_atleast_false(self):
rule = ASI2("SELECT ATLEAST 2 FROM (41L, 67N, 70R, 210W, 215F, 219Q)")
m1 = MutationSet('41L')
m2 = MutationSet('67N')
m3 = MutationSet('70N')
self.assertTrue(rule([m1, m2]))
self.assertFalse(rule([m1, m3]))
self.assertFalse(rule(VariantCalls('41L 67d 70d 210d 215d 219d')))

def test_atleast_missing(self):
rule = ASI2("SELECT ATLEAST 2 FROM (41L, 67N, 70R, 210W, 215F, 219Q)")
with self.assertRaisesRegex(MissingPositionError,
r'Missing position 70.'):
rule(VariantCalls('41L 67N'))

def test_stanford_ex3(self):
ASI2("SELECT ATLEAST 2 AND NOTMORETHAN 2 FROM (41L, 67N, 70R, 210W, 215FY, 219QE)")
Expand Down Expand Up @@ -53,50 +62,47 @@ def test_asi2_compat(self):
class TestRuleSemantics(unittest.TestCase):
def test_score_from(self):
rule = ASI2("SCORE FROM ( 100G => 10, 101D => 20 )")
self.assertEqual(rule(VariantCalls("100G 102G")), 10)
self.assertEqual(rule(VariantCalls("100G 101d")), 10)

def test_score_negate(self):
rule = ASI2("SCORE FROM ( NOT 100G => 10, NOT 101SD => 20 )")
self.assertEqual(rule(VariantCalls("100G 102G")), 20)
self.assertEqual(rule(VariantCalls("100G 101d")), 20)
self.assertEqual(rule(VariantCalls("100S 101S")), 10)

def test_score_residues(self):
rule = ASI2("SCORE FROM ( 100G => 10, 101D => 20 )")
expected_residue = repr({Mutation('S100G')})

result = rule.dtree(VariantCalls("S100G R102G"))
result = rule.dtree(VariantCalls("S100G R101d"))

self.assertEqual(expected_residue, repr(result.residues))

def test_score_from_max(self):
rule = ASI2("SCORE FROM (MAX (100G => 10, 101D => 20, 102D => 30))")
self.assertEqual(rule(VariantCalls("100G 101D")), 20)
self.assertEqual(rule(VariantCalls("10G 11D")), False)
self.assertEqual(rule(VariantCalls("100G 101D 102d")), 20)
self.assertEqual(rule(VariantCalls("100d 101d 102d")), False)

def test_score_from_max_neg(self):
rule = ASI2("SCORE FROM (MAX (100G => -10, 101D => -20, 102D => 30))")
self.assertEqual(rule(VariantCalls("100G 101D")), -10)
self.assertEqual(rule(VariantCalls("10G 11D")), False)
self.assertEqual(rule(VariantCalls("100G 101D 102d")), -10)

def test_bool_and(self):
rule = ASI2("1G AND (2T AND 7Y)")
self.assertEqual(rule(VariantCalls("2T 7Y 1G")), True)
self.assertEqual(rule(VariantCalls("2T 3Y 1G")), False)
self.assertEqual(rule(VariantCalls("2T 7d 1G")), False)
self.assertEqual(rule(VariantCalls("7Y 1G 2T")), True)
self.assertEqual(rule([]), False)

def test_bool_or(self):
rule = ASI2("1G OR (2T OR 7Y)")
self.assertTrue(rule(VariantCalls("2T")))
self.assertFalse(rule(VariantCalls("3T")))
self.assertTrue(rule(VariantCalls("1G")))
self.assertFalse(rule([]))
self.assertTrue(rule(VariantCalls("1d 2T 7d")))
self.assertFalse(rule(VariantCalls("1d 2d 7d")))
self.assertTrue(rule(VariantCalls("1G 2d 7d")))

def test_select_from_atleast(self):
rule = ASI2("SELECT ATLEAST 2 FROM (2T, 7Y, 3G)")
self.assertTrue(rule(VariantCalls("2T 7Y 1G")))
self.assertFalse(rule(VariantCalls("2T 4Y 5G")))
self.assertTrue(rule(VariantCalls("3G 9Y 2T")))
self.assertTrue(rule(VariantCalls("2T 7Y 3d")))
self.assertFalse(rule(VariantCalls("2T 7d 3d")))
self.assertTrue(rule(VariantCalls("3G 7d 2T")))

def test_score_from_exactly(self):
rule = ASI2("SELECT EXACTLY 1 FROM (2T, 7Y)")
Expand Down Expand Up @@ -155,10 +161,10 @@ def test_chained_and(self):
215FY) => 10), MAX ((41L AND 215ACDEILNSV) => 5, (41L AND 215FY) =>
15))
""")
self.assertEqual(rule(VariantCalls("40F 41L 210W 215Y")), 65)
self.assertEqual(rule(VariantCalls("41L 210W 215F")), 60)
self.assertEqual(rule(VariantCalls("40F 210W 215Y")), 25)
self.assertEqual(rule(VariantCalls("40F 67G 215Y")), 15)
self.assertEqual(rule(add_mutations("40F 41L 210W 215Y")), 65)
self.assertEqual(rule(add_mutations("41L 210W 215F")), 60)
self.assertEqual(rule(add_mutations("40F 210W 215Y")), 25)
self.assertEqual(rule(add_mutations("40F 67G 215Y")), 15)


class TestAsiMutations(unittest.TestCase):
Expand All @@ -169,11 +175,6 @@ def test_init_args(self):
self.assertEqual(expected_mutation_set, m.mutations)
self.assertEqual(expected_mutation_set.wildtype, m.mutations.wildtype)

def test_init_none(self):
m = AsiMutations()

self.assertIsNone(m.mutations)

def test_repr(self):
expected_repr = "AsiMutations(args='Q80KR')"
m = AsiMutations(args='Q80KR')
Expand All @@ -182,14 +183,6 @@ def test_repr(self):

self.assertEqual(expected_repr, r)

def test_repr_none(self):
expected_repr = "AsiMutations()"
m = AsiMutations()

r = repr(m)

self.assertEqual(expected_repr, r)


class TestScore(unittest.TestCase):
def test_init(self):
Expand Down
19 changes: 3 additions & 16 deletions pyvdrm/tests/test_hcvr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
from pyvdrm.hcvr import HCVR, AsiMutations, Score
from pyvdrm.vcf import Mutation, MutationSet, VariantCalls

from pyvdrm.tests.test_vcf import add_mutations


# noinspection SqlNoDataSourceInspection,SqlDialectInspection
class TestRuleParser(unittest.TestCase):

def test_stanford_ex1(self):
HCVR("151M OR 69i")

def test(self):
def test_atleast_true(self):
rule = HCVR("SELECT ATLEAST 2 FROM (41L, 67N, 70R, 210W, 215F, 219Q)")
self.assertTrue(rule(VariantCalls('41L 67N 70d 210d 215d 219d')))

Expand Down Expand Up @@ -158,21 +160,6 @@ def test_parse_exception_multiline(self):
self.assertEqual(expected_error_message, str(context.exception))


def add_mutations(text):
""" Add a small set of mutations to an RT wild type. """

# Start of RT reference.
ref = ("PISPIETVPVKLKPGMDGPKVKQWPLTEEKIKALVEICTEMEKEGKISKIGPENPYNTPVFA"
"IKKKDSTKWRKLVDFRELNKRTQDFWEVQLGIPHPAGLKKKKSVTVLDVGDAYFSVPLDEDF"
"RKYTAFTIPSINNETPGIRYQYNVLPQGWKGSPAIFQSSMTKILEPFRKQNPDIVIYQYMDD"
"LYVGSDLEIGQHRTKIEELRQHLLRWGLTTPDKKHQK")
seq = list(ref)
changes = VariantCalls(text)
for mutation_set in changes:
seq[mutation_set.pos - 1] = [m.variant for m in mutation_set]
return VariantCalls(reference=ref, sample=seq)


class TestActualRules(unittest.TestCase):
def test_hivdb_rules_parse(self):
folder = os.path.dirname(__file__)
Expand Down
16 changes: 16 additions & 0 deletions pyvdrm/tests/test_vcf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest
from pyvdrm.vcf import Mutation, MutationSet, VariantCalls
from vcf import VariantCalls


class TestMutation(unittest.TestCase):
Expand Down Expand Up @@ -429,3 +430,18 @@ def test_immutable(self):

if __name__ == '__main__':
unittest.main()


def add_mutations(text):
""" Add a small set of mutations to an RT wild type. """

# Start of RT reference.
ref = ("PISPIETVPVKLKPGMDGPKVKQWPLTEEKIKALVEICTEMEKEGKISKIGPENPYNTPVFA"
"IKKKDSTKWRKLVDFRELNKRTQDFWEVQLGIPHPAGLKKKKSVTVLDVGDAYFSVPLDEDF"
"RKYTAFTIPSINNETPGIRYQYNVLPQGWKGSPAIFQSSMTKILEPFRKQNPDIVIYQYMDD"
"LYVGSDLEIGQHRTKIEELRQHLLRWGLTTPDKKHQK")
seq = list(ref)
changes = VariantCalls(text)
for mutation_set in changes:
seq[mutation_set.pos - 1] = [m.variant for m in mutation_set]
return VariantCalls(reference=ref, sample=seq)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='pyvdrm',
version='0.2.0',
version='0.3.0',
description='',

url='',
Expand Down

0 comments on commit a61ed59

Please sign in to comment.