forked from mikedewar/EDHMM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
initial.py
54 lines (43 loc) · 1.31 KB
/
initial.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import numpy as np
import logging
import pymc
log = logging.getLogger('initial')
from utils import *
class Initial:
"""
Defines an Initial distribution
"""
def __init__(self, K, beta=0.001):
self.K = K
self.beta = beta
state_dist = pymc.Categorical('state_init', [1./K for i in range(K)])
dur_dist = pymc.Exponential('dur_init', beta)
self.dist = pymc.Model({
"s_init":state_dist,
"d_init":dur_dist
})
def __call__(self, z=None):
if z is None:
return self.sample()
else:
return self.likelihood(z)
def __len__(self):
return self.K
def sample(self):
self.dist.draw_from_prior()
x = int(self.dist.s_init.value)
d = int(round(self.dist.d_init.value))
return x, d
def likelihood(self, z):
assert z[0] in range(self.K)
assert z[1] > 0, z
self.dist.s_init.set_value(z[0])
self.dist.d_init.set_value(z[1])
l_x = self.dist.s_init.logp
l_d = self.dist.d_init.logp
return l_x + l_d
def update(self, E):
raise NotImplementedError
def report(self):
report = "initial distribution: K=%s, beta=%s"%(self.K,self.beta)
log.info(report)