-
Notifications
You must be signed in to change notification settings - Fork 15
/
Immi_mvn.m
103 lines (91 loc) · 2.5 KB
/
Immi_mvn.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
function Immi = Immi(A, Cfull, varsizes)
% calculate redundancy between a set of Gaussian sources
% from minimum mutual information
%
% A is cell array of sources
% Cfull is full covariance of system
% varsizes specifies the number of variables in each X_i and S (S last)
if sum(varsizes) ~= size(Cfull,1)
error('wrong number of variables specified')
end
if length(varsizes)~=3
error('only 2 variables supported')
end
NA = length(A);
NVs = varsizes(end);
Nx = length(varsizes-1);
NVx = varsizes(1:end-1);
varstart = cumsum(varsizes)+1;
varstart = [1 varstart(1:end-1)];
uniquevars = unique([A{:}]);
sidx = varstart(end):(varstart(end)+NVs-1);
Cs = Cfull(sidx,sidx);
% build Cax for each source
AC = [];
for ai=1:NA
thsA = A{ai};
aidxfull = {};
aidx = {};
thsvstart = 1;
for vi=1:length(thsA)
aidxfull{vi} = varstart(thsA(vi)):(varstart(thsA(vi))+NVx(thsA(vi))-1);
thsL = length(aidxfull{vi});
aidx{vi} = thsvstart:(thsvstart+thsL-1);
thsvstart = thsvstart+thsL;
end
thsNv = length(cell2mat(aidx));
Cas = zeros(thsNv+NVs);
Ca = zeros(thsNv);
% fill in blocks
% diagonal
for vi=1:length(thsA)
Ca(aidx{vi},aidx{vi}) = Cfull(aidxfull{vi},aidxfull{vi});
end
% off diagonal
for vi=1:length(thsA)
for vj=1:length(thsA)
if vi==vj
continue
end
Ca(aidx{vi},aidx{vj}) = Cfull(aidxfull{vi},aidxfull{vj});
end
end
Cas(1:thsNv,1:thsNv) = Ca;
% joint with S
% diagonal
thssidx = thsNv+1:thsNv+NVs;
Cas(thssidx,thssidx) = Cs;
% off diagonal
for vi=1:length(thsA)
Cas(aidx{vi},thssidx) = Cfull(aidxfull{vi},sidx);
Cas(thssidx,aidx{vi}) = Cfull(sidx,aidxfull{vi});
end
Casoff = Cas(1:thsNv,thssidx);
CXYY1 = Casoff * pinv(Cs);
Cacs = Ca - CXYY1*Cas(thssidx,1:thsNv);
MacsF = CXYY1;
AC(ai).Ca = Ca;
AC(ai).Cas = Cas;
AC(ai).Cacs = Cacs;
AC(ai).Casoff = Casoff;
AC(ai).MacsF = CXYY1;
AC(ai).Nv = thsNv;
end
chS = chol(Cs);
HS = sum(log(diag(chS))); % + 0.5*Nvary*log(2*pi*exp(1));
I = zeros(1,NA);
for ai=1:NA
% use closed form expression
chA = chol(AC(ai).Ca);
chAS = chol(AC(ai).Cas);
% normalisations cancel for information
HA = sum(log(diag(chA))); % + 0.5*Nvarx*log(2*pi*exp(1));
HAS = sum(log(diag(chAS))); % + 0.5*(Nvarx+Nvary)*log(2*pi*exp(1));
I(ai) = (HA + HS - HAS) / log(2);
end
if NA==1
Immi = I(1);
end
if NA==2
Immi = min(I);
end