diff --git a/pbxplore/test/test_functions.py b/pbxplore/test/test_functions.py index 7b68276..05500f1 100644 --- a/pbxplore/test/test_functions.py +++ b/pbxplore/test/test_functions.py @@ -216,7 +216,7 @@ def test_make_profile_partial(self): 'ghijkl', 'hijklm', 'ijklmn', 'ghijkl', 'hijklm', 'ijklmn', 'jklmno', 'klmnop', # ignore in the test - 'ijklmn'] + 'ijklmn'] # Using 10 sequences makes things easier indices = [0, 1, 2, 6, 7, 8, 9, 10, 11, 14] ref_profile = numpy.array([[0.1, 0.0, 0.0, 0.0, 0.0, 0.0], # a @@ -238,5 +238,16 @@ def test_make_profile_partial(self): profile = kmeans.make_profile_partial(sequences, indices) assert(numpy.allclose(ref_profile, profile)) + def test_argmax(self): + reference = (([0, 1, 2, 3, 4], 4), # Ordered + ([4, 3, 2, 1, 0], 0), # Reverse ordered + ([1, 3, 2, 4, 0], 3), # Random order + ([0, 0, 3, 4, 4], 3), # Duplicates + ([4, 4, 4, 4, 4], 0), # All the same + ) + for test_case, expectation in reference: + self.assertEqual(kmeans._argmax(test_case), expectation) + + if __name__ == '__main__': unittest.main()