Skip to content

Commit

Permalink
update for changes to dit interface
Browse files Browse the repository at this point in the history
  • Loading branch information
robince committed Feb 25, 2019
1 parent eea10f8 commit 24d14e0
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 11 deletions.
2 changes: 1 addition & 1 deletion mme2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import dit
import scipy.io
import itertools
from dit.algorithms.scipy_optimizers import maxent_dist
from dit.algorithms.distribution_optimizers import maxent_dist

fname = sys.argv[1]
dat = sp.io.loadmat(fname)
Expand Down
3 changes: 2 additions & 1 deletion mme3pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import scipy as sp
import dit
import scipy.io
from dit.algorithms.scipy_optimizers import maxent_dist
from dit.algorithms.distribution_optimizers import maxent_dist


fname = sys.argv[1]
dat = sp.io.loadmat(fname)
Expand Down
12 changes: 6 additions & 6 deletions pidbroja.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import scipy as sp
import dit
import scipy.io
from dit.algorithms.scipy_optimizers import pid_broja
from dit.pid import PID_BROJA


fname = sys.argv[1]
Expand All @@ -17,13 +17,13 @@

d = dit.Distribution(*zip(*np.ndenumerate(P)))

x = pid_broja(d, [[0],[1]], [2])
x = PID_BROJA(d, [[0],[1]], [2])

pid = np.zeros(4)
pid[0] = x.R
pid[1] = x.U0
pid[2] = x.U1
pid[3] = x.S
pid[0] = x.get_partial(((0,), (1,)))
pid[1] = x.get_partial(((0,),))
pid[2] = x.get_partial(((1,),))
pid[3] = x.get_partial(((0,1),))

dat['pid'] = pid
sp.io.savemat(fname,dat)
6 changes: 3 additions & 3 deletions pidbrojadist.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import scipy as sp
import dit
import scipy.io
from dit.algorithms.scipy_optimizers import BROJAOptimizer
from dit.algorithms.distribution_optimizers import BROJABivariateOptimizer

def pid_broja_dist(dist, sources, target, rv_mode=None):
broja = BROJAOptimizer(dist, sources, target, rv_mode)
broja.optimize()
broja = BROJABivariateOptimizer(dist, sources, target, rv_mode)
broja.optimize(niter=10)
opt_dist = broja.construct_dist()
return opt_dist

Expand Down

0 comments on commit 24d14e0

Please sign in to comment.