From 24d14e05df7fe63e2c1e3ca7e29587247c32d755 Mon Sep 17 00:00:00 2001 From: Robin Ince Date: Mon, 25 Feb 2019 17:23:29 +0000 Subject: [PATCH] update for changes to dit interface --- mme2.py | 2 +- mme3pred.py | 3 ++- pidbroja.py | 12 ++++++------ pidbrojadist.py | 6 +++--- 4 files changed, 12 insertions(+), 11 deletions(-) diff --git a/mme2.py b/mme2.py index 0871cfd..3971cc7 100644 --- a/mme2.py +++ b/mme2.py @@ -5,7 +5,7 @@ import dit import scipy.io import itertools -from dit.algorithms.scipy_optimizers import maxent_dist +from dit.algorithms.distribution_optimizers import maxent_dist fname = sys.argv[1] dat = sp.io.loadmat(fname) diff --git a/mme3pred.py b/mme3pred.py index c9058a3..739daa7 100644 --- a/mme3pred.py +++ b/mme3pred.py @@ -4,7 +4,8 @@ import scipy as sp import dit import scipy.io -from dit.algorithms.scipy_optimizers import maxent_dist +from dit.algorithms.distribution_optimizers import maxent_dist + fname = sys.argv[1] dat = sp.io.loadmat(fname) diff --git a/pidbroja.py b/pidbroja.py index e150a9f..be7bf20 100644 --- a/pidbroja.py +++ b/pidbroja.py @@ -4,7 +4,7 @@ import scipy as sp import dit import scipy.io -from dit.algorithms.scipy_optimizers import pid_broja +from dit.pid import PID_BROJA fname = sys.argv[1] @@ -17,13 +17,13 @@ d = dit.Distribution(*zip(*np.ndenumerate(P))) -x = pid_broja(d, [[0],[1]], [2]) +x = PID_BROJA(d, [[0],[1]], [2]) pid = np.zeros(4) -pid[0] = x.R -pid[1] = x.U0 -pid[2] = x.U1 -pid[3] = x.S +pid[0] = x.get_partial(((0,), (1,))) +pid[1] = x.get_partial(((0,),)) +pid[2] = x.get_partial(((1,),)) +pid[3] = x.get_partial(((0,1),)) dat['pid'] = pid sp.io.savemat(fname,dat) diff --git a/pidbrojadist.py b/pidbrojadist.py index 43bda36..f983772 100644 --- a/pidbrojadist.py +++ b/pidbrojadist.py @@ -4,11 +4,11 @@ import scipy as sp import dit import scipy.io -from dit.algorithms.scipy_optimizers import BROJAOptimizer +from dit.algorithms.distribution_optimizers import BROJABivariateOptimizer def pid_broja_dist(dist, sources, target, rv_mode=None): - broja = BROJAOptimizer(dist, sources, target, rv_mode) - broja.optimize() + broja = BROJABivariateOptimizer(dist, sources, target, rv_mode) + broja.optimize(niter=10) opt_dist = broja.construct_dist() return opt_dist