From cd1b27591da26ea17637cc01fd6d4c427c06e7d9 Mon Sep 17 00:00:00 2001 From: Sebastian Mattheis Date: Fri, 22 Dec 2017 08:19:51 +0100 Subject: [PATCH] Fixed bug in k-state for hmm break 'no transition'. Change-Id: Id6a3b7a7642f1552c7ea7ef9fd3c833a9f9a79fd --- pom.xml | 2 +- .../com/bmwcarit/barefoot/markov/KState.java | 49 +++++----- .../bmwcarit/barefoot/markov/KStateTest.java | 97 ++++++++++++------- util/submit/batch.py | 2 +- 4 files changed, 88 insertions(+), 62 deletions(-) diff --git a/pom.xml b/pom.xml index b7aa61b0..70282a76 100755 --- a/pom.xml +++ b/pom.xml @@ -2,7 +2,7 @@ 4.0.0 com.bmw-carit barefoot - 0.1.1 + 0.1.2 diff --git a/src/main/java/com/bmwcarit/barefoot/markov/KState.java b/src/main/java/com/bmwcarit/barefoot/markov/KState.java index 594e38a7..95e323e9 100755 --- a/src/main/java/com/bmwcarit/barefoot/markov/KState.java +++ b/src/main/java/com/bmwcarit/barefoot/markov/KState.java @@ -27,7 +27,7 @@ import org.json.JSONException; import org.json.JSONObject; -import com.bmwcarit.barefoot.util.Tuple; +import com.bmwcarit.barefoot.util.Triple; /** * k-State data structure for organizing state memory in HMM inference. @@ -40,7 +40,7 @@ public class KState, T extends StateTransition extends StateMemory { private final int k; private final long t; - private final LinkedList, S>> sequence; + private final LinkedList, S, C>> sequence; private final Map counters; /** @@ -100,13 +100,15 @@ public KState(JSONObject json, Factory factory) throws JSONException { } S sample = factory.sample(jsonseqelement.getJSONObject("sample")); + String kestid = jsonseqelement.getString("kestid"); + C kestimate = candidates.get(kestid); - sequence.add(new Tuple<>(vector, sample)); + sequence.add(new Triple<>(vector, sample, kestimate)); } - Collections.sort(sequence, new Comparator, S>>() { + Collections.sort(sequence, new Comparator, S, C>>() { @Override - public int compare(Tuple, S> left, Tuple, S> right) { + public int compare(Triple, S, C> left, Triple, S, C> right) { if (left.two().time() < right.two().time()) { return -1; } else if (left.two().time() > right.two().time()) { @@ -167,7 +169,7 @@ public S sample() { */ public List samples() { LinkedList samples = new LinkedList<>(); - for (Tuple, S> element : sequence) { + for (Triple, S, C> element : sequence) { samples.add(element.two()); } return samples; @@ -183,6 +185,7 @@ public void update(Set vector, S sample) { throw new RuntimeException("out-of-order state update is prohibited"); } + C kestimate = null; for (C candidate : vector) { counters.put(candidate, 0); if (candidate.predecessor() != null) { @@ -192,16 +195,16 @@ public void update(Set vector, S sample) { } counters.put(candidate.predecessor(), counters.get(candidate.predecessor()) + 1); } + if (kestimate == null || candidate.seqprob() > kestimate.seqprob()) { + kestimate = candidate; + } } if (!sequence.isEmpty()) { + Triple, S, C> last = sequence.peekLast(); Set deletes = new HashSet<>(); - C estimate = null; - for (C candidate : sequence.peekLast().one()) { - if (estimate == null || candidate.seqprob() > estimate.seqprob()) { - estimate = candidate; - } + for (C candidate : last.one()) { if (counters.get(candidate) == 0) { deletes.add(candidate); } @@ -210,13 +213,13 @@ public void update(Set vector, S sample) { int size = sequence.peekLast().one().size(); for (C candidate : deletes) { - if (deletes.size() != size || candidate != estimate) { + if (deletes.size() != size || candidate != last.three()) { remove(candidate, sequence.size() - 1); } } } - sequence.add(new Tuple<>(vector, sample)); + sequence.add(new Triple<>(vector, sample, kestimate)); while ((t > 0 && sample.time() - sequence.peekFirst().two().time() > t) || (k >= 0 && sequence.size() > k + 1)) { @@ -234,6 +237,10 @@ public void update(Set vector, S sample) { } protected void remove(C candidate, int index) { + if (sequence.get(index).three() == candidate) { + return; + } + Set vector = sequence.get(index).one(); counters.remove(candidate); vector.remove(candidate); @@ -282,14 +289,7 @@ public List sequence() { return null; } - C kestimate = null; - - for (C candidate : sequence.peekLast().one()) { - if (kestimate == null || candidate.seqprob() > kestimate.seqprob()) { - kestimate = candidate; - } - } - + C kestimate = sequence.peekLast().three(); LinkedList ksequence = new LinkedList<>(); for (int i = sequence.size() - 1; i >= 0; --i) { @@ -297,8 +297,8 @@ public List sequence() { ksequence.push(kestimate); kestimate = kestimate.predecessor(); } else { - ksequence.push(sequence.get(i).one().iterator().next()); - assert (sequence.get(i).one().size() == 1); + ksequence.push(sequence.get(i).three()); + kestimate = sequence.get(i).three().predecessor(); } } @@ -309,7 +309,7 @@ public List sequence() { public JSONObject toJSON() throws JSONException { JSONObject json = new JSONObject(); JSONArray jsonsequence = new JSONArray(); - for (Tuple, S> element : sequence) { + for (Triple, S, C> element : sequence) { JSONObject jsonseqelement = new JSONObject(); JSONArray jsonvector = new JSONArray(); for (C candidate : element.one()) { @@ -321,6 +321,7 @@ public JSONObject toJSON() throws JSONException { } jsonseqelement.put("vector", jsonvector); jsonseqelement.put("sample", element.two().toJSON()); + jsonseqelement.put("kestid", element.three().id()); jsonsequence.put(jsonseqelement); } diff --git a/src/test/java/com/bmwcarit/barefoot/markov/KStateTest.java b/src/test/java/com/bmwcarit/barefoot/markov/KStateTest.java index 8bbf1381..35b3656d 100644 --- a/src/test/java/com/bmwcarit/barefoot/markov/KStateTest.java +++ b/src/test/java/com/bmwcarit/barefoot/markov/KStateTest.java @@ -72,11 +72,10 @@ public void TestKStateUnbound() { elements.put(1, new MockElem(1, Math.log10(0.2), 0.2, null)); elements.put(2, new MockElem(2, Math.log10(0.5), 0.5, null)); - KState state = - new KState<>(); + KState state = new KState<>(); { - Set vector = new HashSet<>( - Arrays.asList(elements.get(0), elements.get(1), elements.get(2))); + Set vector = + new HashSet<>(Arrays.asList(elements.get(0), elements.get(1), elements.get(2))); state.update(vector, new Sample(0)); @@ -90,8 +89,8 @@ public void TestKStateUnbound() { elements.put(6, new MockElem(6, Math.log10(0.1), 0.1, elements.get(2))); { - Set vector = new HashSet<>(Arrays.asList(elements.get(3), - elements.get(4), elements.get(5), elements.get(6))); + Set vector = new HashSet<>(Arrays.asList(elements.get(3), elements.get(4), + elements.get(5), elements.get(6))); state.update(vector, new Sample(1)); @@ -110,8 +109,8 @@ public void TestKStateUnbound() { elements.put(10, new MockElem(10, Math.log10(0.1), 0.1, elements.get(6))); { - Set vector = new HashSet<>(Arrays.asList(elements.get(7), - elements.get(8), elements.get(9), elements.get(10))); + Set vector = new HashSet<>(Arrays.asList(elements.get(7), elements.get(8), + elements.get(9), elements.get(10))); state.update(vector, new Sample(2)); @@ -130,12 +129,12 @@ public void TestKStateUnbound() { elements.put(14, new MockElem(14, Math.log10(0.1), 0.1, null)); { - Set vector = new HashSet<>(Arrays.asList(elements.get(11), - elements.get(12), elements.get(13), elements.get(14))); + Set vector = new HashSet<>(Arrays.asList(elements.get(11), elements.get(12), + elements.get(13), elements.get(14))); state.update(vector, new Sample(3)); - assertEquals(7, state.size()); + assertEquals(8, state.size()); assertEquals(13, state.estimate().numid()); List sequence = new LinkedList<>(Arrays.asList(2, 6, 9, 13)); @@ -148,7 +147,7 @@ public void TestKStateUnbound() { state.update(vector, new Sample(4)); - assertEquals(7, state.size()); + assertEquals(8, state.size()); assertEquals(13, state.estimate().numid()); List sequence = new LinkedList<>(Arrays.asList(2, 6, 9, 13)); @@ -158,6 +157,36 @@ public void TestKStateUnbound() { } } + @Test + public void TestBreak() { + // Test k-state in case of HMM break 'no transition' as reported in barefoot issue #83. + // Tests only 'no transitions', no emissions is empty vector and, hence, input to update + // operation. + + KState state = new KState<>(); + Map elements = new HashMap<>(); + elements.put(0, new MockElem(0, Math.log10(0.4), 0.4, null)); + { + Set vector = new HashSet<>(Arrays.asList(elements.get(0))); + state.update(vector, new Sample(0)); + } + elements.put(1, new MockElem(1, Math.log(0.7), 0.6, null)); + elements.put(2, new MockElem(2, Math.log(0.3), 0.4, elements.get(0))); + { + Set vector = new HashSet<>(Arrays.asList(elements.get(1), elements.get(2))); + state.update(vector, new Sample(1)); + } + elements.put(3, new MockElem(3, Math.log(0.5), 0.6, null)); + { + Set vector = new HashSet<>(Arrays.asList(elements.get(3))); + state.update(vector, new Sample(2)); + } + List seq = state.sequence(); + assertEquals(seq.get(0).numid(), 0); + assertEquals(seq.get(1).numid(), 1); + assertEquals(seq.get(2).numid(), 3); + } + @Test public void TestKState() { Map elements = new HashMap<>(); @@ -165,11 +194,10 @@ public void TestKState() { elements.put(1, new MockElem(1, Math.log10(0.2), 0.2, null)); elements.put(2, new MockElem(2, Math.log10(0.5), 0.5, null)); - KState state = - new KState<>(1, -1); + KState state = new KState<>(1, -1); { - Set vector = new HashSet<>( - Arrays.asList(elements.get(0), elements.get(1), elements.get(2))); + Set vector = + new HashSet<>(Arrays.asList(elements.get(0), elements.get(1), elements.get(2))); state.update(vector, new Sample(0)); @@ -183,8 +211,8 @@ public void TestKState() { elements.put(6, new MockElem(6, Math.log10(0.1), 0.1, elements.get(2))); { - Set vector = new HashSet<>(Arrays.asList(elements.get(3), - elements.get(4), elements.get(5), elements.get(6))); + Set vector = new HashSet<>(Arrays.asList(elements.get(3), elements.get(4), + elements.get(5), elements.get(6))); state.update(vector, new Sample(1)); @@ -203,8 +231,8 @@ public void TestKState() { elements.put(10, new MockElem(10, Math.log10(0.1), 0.1, elements.get(6))); { - Set vector = new HashSet<>(Arrays.asList(elements.get(7), - elements.get(8), elements.get(9), elements.get(10))); + Set vector = new HashSet<>(Arrays.asList(elements.get(7), elements.get(8), + elements.get(9), elements.get(10))); state.update(vector, new Sample(2)); @@ -223,8 +251,8 @@ public void TestKState() { elements.put(14, new MockElem(14, Math.log10(0.1), 0.1, null)); { - Set vector = new HashSet<>(Arrays.asList(elements.get(11), - elements.get(12), elements.get(13), elements.get(14))); + Set vector = new HashSet<>(Arrays.asList(elements.get(11), elements.get(12), + elements.get(13), elements.get(14))); state.update(vector, new Sample(3)); @@ -258,11 +286,10 @@ public void TestTState() { elements.put(1, new MockElem(1, Math.log10(0.2), 0.2, null)); elements.put(2, new MockElem(2, Math.log10(0.5), 0.5, null)); - KState state = - new KState<>(-1, 1); + KState state = new KState<>(-1, 1); { - Set vector = new HashSet<>( - Arrays.asList(elements.get(0), elements.get(1), elements.get(2))); + Set vector = + new HashSet<>(Arrays.asList(elements.get(0), elements.get(1), elements.get(2))); state.update(vector, new Sample(0)); @@ -276,8 +303,8 @@ public void TestTState() { elements.put(6, new MockElem(6, Math.log10(0.1), 0.1, elements.get(2))); { - Set vector = new HashSet<>(Arrays.asList(elements.get(3), - elements.get(4), elements.get(5), elements.get(6))); + Set vector = new HashSet<>(Arrays.asList(elements.get(3), elements.get(4), + elements.get(5), elements.get(6))); state.update(vector, new Sample(1)); @@ -296,8 +323,8 @@ public void TestTState() { elements.put(10, new MockElem(10, Math.log10(0.1), 0.1, elements.get(6))); { - Set vector = new HashSet<>(Arrays.asList(elements.get(7), - elements.get(8), elements.get(9), elements.get(10))); + Set vector = new HashSet<>(Arrays.asList(elements.get(7), elements.get(8), + elements.get(9), elements.get(10))); state.update(vector, new Sample(2)); @@ -316,8 +343,8 @@ public void TestTState() { elements.put(14, new MockElem(14, Math.log10(0.1), 0.1, null)); { - Set vector = new HashSet<>(Arrays.asList(elements.get(11), - elements.get(12), elements.get(13), elements.get(14))); + Set vector = new HashSet<>(Arrays.asList(elements.get(11), elements.get(12), + elements.get(13), elements.get(14))); state.update(vector, new Sample(3)); @@ -348,8 +375,7 @@ public void TestTState() { public void TestKStateJSON() throws JSONException { Map elements = new HashMap<>(); - KState state = - new KState<>(1, -1); + KState state = new KState<>(1, -1); { JSONObject json = state.toJSON(); @@ -361,8 +387,7 @@ public void TestKStateJSON() throws JSONException { elements.put(2, new MockElem(2, Math.log10(0.5), 0.5, null)); state.update( - new HashSet<>( - Arrays.asList(elements.get(0), elements.get(1), elements.get(2))), + new HashSet<>(Arrays.asList(elements.get(0), elements.get(1), elements.get(2))), new Sample(0)); { diff --git a/util/submit/batch.py b/util/submit/batch.py index 504ae61c..ad06c551 100644 --- a/util/submit/batch.py +++ b/util/submit/batch.py @@ -54,7 +54,7 @@ tmp = "batch-%s" % random.randint(0, sys.maxint) file = open(tmp, "w") -file.write("{\"format\": \"%s\", \"request\": %s}" % (options.format, json.dumps(samples))) +file.write("{\"format\": \"%s\", \"request\": %s}\n" % (options.format, json.dumps(samples))) file.close() subprocess.call("cat %s | netcat %s %s" % (tmp, options.host, options.port), shell=True)