From e54c7e044ccfda9431f98bf059c9993b477090c5 Mon Sep 17 00:00:00 2001 From: david huber Date: Thu, 15 Feb 2024 01:15:16 +0100 Subject: [PATCH] SCM and other causality related classes --- pom.xml | 4 +- .../idsia/crema/core/ObservationBuilder.java | 2 +- .../ch/idsia/crema/core/StridedDomain.java | 6 + .../java/ch/idsia/crema/core/Strides.java | 379 +++++++++++++----- .../ch/idsia/crema/core/UnsortedDomain.java | 75 ++++ .../java/ch/idsia/crema/core/Variable.java | 37 ++ .../java/ch/idsia/crema/data/DataTable.java | 339 ++++++++++++++++ .../java/ch/idsia/crema/data/DoubleTable.java | 214 ++++++++++ .../ch/idsia/crema/data/TIntMapConverter.java | 60 +++ .../bayesian/BayesianDefaultFactor.java | 187 ++++++++- .../crema/factor/bayesian/BayesianFactor.java | 2 +- .../bayesian/BayesianFactorFactory.java | 164 ++++++-- .../vertex/separate/VertexAbstractFactor.java | 2 +- .../inference/approxlp1/CredalApproxLP.java | 4 +- .../crema/inference/bp/BeliefPropagation.java | 12 +- .../inference/bp/LoopyBeliefPropagation.java | 4 +- .../sampling/LikelihoodWeightingSampling.java | 4 +- .../sampling/StochasticSampling.java | 4 +- .../ve/CredalVariableElimination.java | 4 +- .../ch/idsia/crema/learning/DiscreteEM.java | 4 +- src/main/java/ch/idsia/crema/model/Model.java | 1 + .../java/ch/idsia/crema/model/causal/SCM.java | 194 +++++++++ .../crema/model/causal/TypedVariable.java | 17 + .../idsia/crema/model/causal/WorldModel.java | 70 ++++ .../model/causal/mapping/TreeMapping.java | 214 ++++++++++ .../idsia/crema/model/graphical/DAGModel.java | 30 +- .../crema/model/graphical/GraphicalModel.java | 31 +- .../model/io/dot/DetailedDotSerializer.java | 211 ++++++++++ .../ch/idsia/crema/model/io/dot/Info.java | 134 +++++++ .../crema/model/io/uai/NetUAIWriter.java | 2 +- .../java/ch/idsia/crema/preprocess/Do.java | 34 ++ .../{CutObserved.java => Observe.java} | 11 +- .../ch/idsia/crema/utility/ArraysMath.java | 128 ++++++ .../ch/idsia/crema/utility/ArraysUtil.java | 252 +++++++++--- .../java/ch/idsia/crema/core/StridesTest.java | 22 +- .../factor/bayesian/BayesianOperations.java | 70 ++++ .../credal/vertex/TestVertexFactor.java | 7 +- .../idsia/crema/model/BayesianFactorTest.java | 4 +- .../model/causal/mapping/TreeMappingTest.java | 195 +++++++++ .../idsia/crema/model/io/UAIParserTest.java | 10 +- .../model/utility/ArraysUtilityTest.java | 8 +- 41 files changed, 2912 insertions(+), 240 deletions(-) create mode 100644 src/main/java/ch/idsia/crema/core/StridedDomain.java create mode 100644 src/main/java/ch/idsia/crema/core/UnsortedDomain.java create mode 100644 src/main/java/ch/idsia/crema/core/Variable.java create mode 100644 src/main/java/ch/idsia/crema/data/DataTable.java create mode 100644 src/main/java/ch/idsia/crema/data/DoubleTable.java create mode 100644 src/main/java/ch/idsia/crema/data/TIntMapConverter.java create mode 100644 src/main/java/ch/idsia/crema/model/causal/SCM.java create mode 100644 src/main/java/ch/idsia/crema/model/causal/TypedVariable.java create mode 100644 src/main/java/ch/idsia/crema/model/causal/WorldModel.java create mode 100644 src/main/java/ch/idsia/crema/model/causal/mapping/TreeMapping.java create mode 100644 src/main/java/ch/idsia/crema/model/io/dot/DetailedDotSerializer.java create mode 100644 src/main/java/ch/idsia/crema/model/io/dot/Info.java create mode 100644 src/main/java/ch/idsia/crema/preprocess/Do.java rename src/main/java/ch/idsia/crema/preprocess/{CutObserved.java => Observe.java} (82%) create mode 100644 src/main/java/ch/idsia/crema/utility/ArraysMath.java create mode 100644 src/test/java/ch/idsia/crema/factor/bayesian/BayesianOperations.java create mode 100644 src/test/java/ch/idsia/crema/model/causal/mapping/TreeMappingTest.java diff --git a/pom.xml b/pom.xml index 3993b9c5..6950a6d9 100644 --- a/pom.xml +++ b/pom.xml @@ -60,8 +60,8 @@ maven-compiler-plugin 3.8.1 - 11 - 11 + 17 + 17 diff --git a/src/main/java/ch/idsia/crema/core/ObservationBuilder.java b/src/main/java/ch/idsia/crema/core/ObservationBuilder.java index 84fed715..c6e2ac86 100644 --- a/src/main/java/ch/idsia/crema/core/ObservationBuilder.java +++ b/src/main/java/ch/idsia/crema/core/ObservationBuilder.java @@ -75,7 +75,7 @@ private ObservationBuilder(int[] keys, int[] values) { public static int[] getVariables(TIntIntMap[] obs) { int[] variables = new int[]{}; for (TIntIntMap o : obs) - variables = ArraysUtil.unionSet(variables, o.keys()); + variables = ArraysUtil.union_unsorted_set(variables, o.keys()); return variables; } diff --git a/src/main/java/ch/idsia/crema/core/StridedDomain.java b/src/main/java/ch/idsia/crema/core/StridedDomain.java new file mode 100644 index 00000000..fecb7fff --- /dev/null +++ b/src/main/java/ch/idsia/crema/core/StridedDomain.java @@ -0,0 +1,6 @@ +package ch.idsia.crema.core; + +public interface StridedDomain extends Domain { + int getStride(int variable); + int getStrideAt(int offset); +} diff --git a/src/main/java/ch/idsia/crema/core/Strides.java b/src/main/java/ch/idsia/crema/core/Strides.java index bdea4cb2..c9bd5548 100644 --- a/src/main/java/ch/idsia/crema/core/Strides.java +++ b/src/main/java/ch/idsia/crema/core/Strides.java @@ -1,13 +1,23 @@ package ch.idsia.crema.core; +import ch.idsia.crema.utility.ArraysMath; +import ch.idsia.crema.utility.ArraysRandom; import ch.idsia.crema.utility.ArraysUtil; import ch.idsia.crema.utility.IndexIterator; import com.google.common.primitives.Ints; import gnu.trove.map.TIntIntMap; +import gnu.trove.map.hash.TIntShortHashMap; +import gnu.trove.set.TIntSet; +import gnu.trove.set.hash.TIntHashSet; + import org.apache.commons.lang3.ArrayUtils; +import org.apache.commons.math3.random.RandomDataGenerator; import org.apache.commons.math3.util.FastMath; import java.util.Arrays; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; import java.util.stream.IntStream; /** @@ -24,20 +34,40 @@ * *

* While the class should be considered unmutable, when the factor is part of a - * model and we delete a variable from it, the indices of the variables will be - * updated accordingly. + * model and we delete a variable from it, the indices of the variables could be + * updated accordingly. This applies only to models that do not allow gaps in + * variable labels. *

* * @author davidhuber */ -public final class Strides implements Domain { - final private int[] strides; +public final class Strides implements StridedDomain { + private final int[] strides; final private int combinations; final private int[] variables; final private int[] sizes; final private int size; + public Strides(Variable[] variables) { + + Arrays.sort(variables); + + strides = new int[variables.length + 1]; + + this.variables = Arrays.stream(variables).mapToInt(Variable::getLabel).toArray(); + this.sizes = Arrays.stream(variables).mapToInt(Variable::getCardinality).toArray(); + this.size = variables.length; + + int cumulative = 1; + strides[0] = 1; + + for (int i = 0; i < size; ++i) { + strides[i + 1] = cumulative *= sizes[i]; + } + combinations = cumulative; + } + public Strides(int[] variables, int[] sizes, int[] strides) { this.variables = variables; this.sizes = sizes; @@ -58,48 +88,17 @@ public Strides(int[] variables, int[] sizes) { this.size = variables.length; this.strides = new int[size + 1]; - this.strides[0] = 1; - for (int i = 0; i < size; ++i) { - strides[i + 1] = strides[i] * sizes[i]; - } - this.combinations = strides[size]; - } - - /** - * Creates a stride with a single variable excluded. Note that the variable - * must not be missing in the provided domain. - * - * @param domain - * @param offset - * @deprecated please use {@link Strides#removeAt(int...)} - */ - @Deprecated - public Strides(Strides domain, int offset) { - this.size = domain.variables.length - 1; - this.variables = new int[size]; - this.sizes = new int[size]; - - System.arraycopy(domain.variables, 0, variables, 0, offset); - System.arraycopy(domain.variables, offset + 1, variables, offset, size - offset); - System.arraycopy(domain.sizes, 0, sizes, 0, offset); - System.arraycopy(domain.sizes, offset + 1, sizes, offset, size - offset); - - strides = new int[size + 1]; - strides[0] = 1; - if (offset > 1) { - System.arraycopy(domain.strides, 0, strides, 0, offset--); - } else - offset = 0; + int cumulative = this.strides[0] = 1; - for (; offset < size; ++offset) { - strides[offset + 1] = strides[offset] * sizes[offset]; + for (int i = 0; i < size; ++i) { + strides[i + 1] = cumulative *= sizes[i]; } - this.combinations = strides[size]; + this.combinations = cumulative; } @Override public final int indexOf(int variable) { - //return Arrays.binarySearch(variables, variable); + // return Arrays.binarySearch(variables, variable); return ArraysUtil.indexOf(variable, variables); } @@ -230,8 +229,7 @@ public final int getCombinations() { } /** - * Create an iterator over an enlarged domain but with same strides and - * sizes. + * Create an iterator over an enlarged domain but with same strides and sizes. * * @param targetDomain * @return @@ -257,8 +255,8 @@ private IndexIterator getSupersetIndexIterator(final int[] over, final int[] tar } /** - * Create an iterator over an smaller domain with same strides and sizes and - * 1 (one) variable set to a specific state. + * Create an iterator over an smaller domain with same strides and sizes and 1 + * (one) variable set to a specific state. * * @param variable the variable * @param state the desired state @@ -335,8 +333,8 @@ public String toString() { } /** - * Get an IndexIterator for this stride object with a different order for - * the variables. + * Get an IndexIterator for this stride object with a different order for the + * variables. * * @param unsorted_vars - int[] the new order for the variables * @return an iterator @@ -356,8 +354,8 @@ public IndexIterator getReorderedIterator(int[] unsorted_vars) { } /** - * Get an IndexIterator for this stride object with a different order for - * the variables. Alias of intersection. + * Get an IndexIterator for this stride object with a different order for the + * variables. Alias of intersection. * * @param someVars int[] - the new order for the variables * @return an iterator @@ -369,8 +367,8 @@ public Strides retain(int[] someVars) { //////////////////////////////////////////////////////////////////////////////////////////////////// /** - * Remove some variable from this domain. The returned domain will have his - * own strides. Variables do not necessarily have to be part of this domain. + * Remove some variable from this domain. The returned domain will have his own + * strides. Variables do not necessarily have to be part of this domain. * Convenience method that calls the {@link Strides#remove(int...)} * * @param toremove Strides - the domain to be removed from the first domain @@ -381,13 +379,36 @@ public Strides remove(Strides toremove) { } /** - * Remove some variable from this domain. The returned domain will have his - * own strides. Variables do not necessarily have to be part of "domain". + * Remove some variable from this domain. The returned domain will have his own + * strides. Variables do not necessarily have to be part of "domain". * * @param toremove * @return */ public Strides remove(int... toremove) { + Set to_rem = Arrays.stream(toremove).boxed().collect(Collectors.toSet()); + + int[] vars = new int[this.size]; + int[] sizes = new int[this.size]; + + int ti = 0; + for (int v = 0; v < this.variables.length; ++v) { + if (!to_rem.contains(this.variables[v])) { + sizes[ti] = this.sizes[v]; + vars[ti] = this.variables[v]; + ti++; + } else { + to_rem.remove(v); + } + } + + vars = Arrays.copyOf(vars, ti); + sizes = Arrays.copyOf(sizes, ti); + + return new Strides(vars, sizes); + } + + public Strides remove_sorted(int... toremove) { int di = 0; // domain index int ti = 0; // toremove index int index = 0; // target index @@ -434,8 +455,8 @@ public Strides remove(int... toremove) { } /** - * Create a new Strides object result of the intersection of this domain - * with the given one. + * Create a new Strides object result of the intersection of this domain with + * the given one. * * @param domain2 another domain * @return the intersection of domain1 and domain2 @@ -490,6 +511,108 @@ public Strides intersection(int... domain2) { return new Strides(intersect_vars, intersect_sizes); } + public Strides union3(Strides domain2) { + + int[] vars = ArraysUtil.append(variables, domain2.variables); + int[] sz = ArraysUtil.append(sizes, domain2.sizes); + + TIntSet un = new TIntHashSet(vars); + // Set un = Arrays.stream(vars).boxed().collect(Collectors.toSet()); + + int[] union_vars = new int[un.size()]; + int[] union_sizes = new int[un.size()]; + int[] union_strides = new int[un.size() + 1]; + int i = 0; + union_strides[0] = 1; + + for (int vindex = 0; vindex < vars.length; ++vindex) { + int v = vars[vindex]; + int s = sz[vindex]; + + if (un.contains(v)) { + union_vars[i] = v; + union_sizes[i] = s; + union_strides[i + 1] = union_strides[i] * s; + ++i; + un.remove(v); + } + } + + return new Strides(union_vars, union_sizes, union_strides); + } + + public Strides union_unsorted(Strides domain2) { + + int s1 = variables.length; + int s2 = domain2.variables.length; + + int[] vars = ArraysUtil.append(variables, domain2.variables); + int[] sz = ArraysUtil.append(sizes, domain2.sizes); + + int s = s1 + s2; + + int[] union_vars = new int[s]; + int[] union_sz = new int[s]; + int[] union_str = new int[s + 1]; + + union_str[0] = 1; + int t = 0; + external: for (int v1 = 0; v1 < vars.length; ++v1) { + for (int v2 = v1 - 1; v2 >= 0; --v2) { + if (vars[v1] == vars[v2]) + continue external; + } + union_vars[t] = vars[v1]; + union_sz[t] = sz[v1]; + union_str[t + 1] = union_str[t] * sz[v1]; + t++; + } + + return new Strides(Arrays.copyOf(union_vars, t), Arrays.copyOf(union_sz, t), Arrays.copyOf(union_str, t + 1)); + } + + public Strides union_new(Strides domain2) { + + int s1 = variables.length; + int s2 = domain2.variables.length; + + int[] vars = ArraysUtil.append(variables, domain2.variables); + int[] order = ArraysUtil.order(vars); + + vars = ArraysUtil.at(vars, order); + int[] sz = ArraysUtil.append(sizes, domain2.sizes); + sz = ArraysUtil.at(sz, order); + + int prev = vars[0]; + int s = 1; + + for (int v : vars) { + if (v != prev) + ++s; + prev = v; + } + + int[] union_vars = new int[s]; + int[] union_sz = new int[s]; + int[] union_str = new int[s + 1]; + + union_str[0] = 1; + + int t = 0; + + prev = vars[0] ^ 1; // make sure the value in prev is different than the first item + for (int i = 0; i < vars.length; ++i) { + if (prev != vars[i]) { + union_vars[t] = vars[i]; + union_sz[t] = sz[i]; + union_str[t + 1] = union_str[t] * sz[i]; + t++; + } + } + + return new Strides(union_vars, union_sz, union_str); + } + /** * Create a new Strides result of the union of domain1 and domain2. * @@ -498,16 +621,17 @@ public Strides intersection(int... domain2) { *

* *

- * As we can assume ordering we will traverse the domains in parallel (pt1 - * in the code) and copy to a target domain the values. When one of the - * domains reached it's end the remaing variable of the other domain can be - * added to the union in bulk (pt2 in code). + * As we can assume ordering we will traverse the domains in parallel (pt1 in + * the code) and copy to a target domain the values. When one of the domains + * reached it's end the remaing variable of the other domain can be added to the + * union in bulk (pt2 in code). *

* * @param domain2 * @return */ public Strides union(Strides domain2) { + final int s1 = this.size; final int s2 = domain2.size; @@ -593,8 +717,37 @@ public Strides concat(Strides right) { } /** - * Creates a stride with one or more variables excluded. Note that the - * variable must not be missing in this domain. + * Creates a stride with one or more variables excluded. Variables may be + * missing + * + * @param offset + */ + public Strides removeAt(int offset) { + + int[] tvariables = new int[size - 1]; + int[] tsizes = new int[size - 1]; + int[] tstrides = new int[size]; + + System.arraycopy(variables, 0, tvariables, 0, offset); + System.arraycopy(sizes, 0, tsizes, 0, offset); + + if (offset < size - 1) { + System.arraycopy(variables, offset + 1, tvariables, offset, size - offset - 1); + System.arraycopy(sizes, offset + 1, tsizes, offset, size - offset - 1); + } + + System.arraycopy(strides, 0, tstrides, 0, offset + 1); + System.arraycopy(strides, offset + 2, tstrides, offset + 1, size - offset - 1); + int fix = sizes[offset]; + for (int o = offset + 1; o < size; ++o) { + tstrides[o] /= fix; + } + return new Strides(tvariables, tsizes, tstrides); + } + + /** + * Creates a stride with one or more variables excluded. Variables may be + * missing * * @param offset */ @@ -619,8 +772,7 @@ public Strides removeAt(int... offset) { } /** - * Get an iterator of this domain with the specified variables lock in state - * 0. + * Get an iterator of this domain with the specified variables lock in state 0. *

* Convenience method when the target domain is this domain. same as calling * getIterator(this, locked); @@ -633,13 +785,12 @@ public IndexIterator getIterator(int... locked) { } /** - * Iterate over another domain. If a variable is not present in this domain - * it will not move the index but it will take the step. If a variable is - * not in the specified domain the variable is not considered or is assumed - * fixed to 0. Please use getPartialOffset(vars, state) for a different offset. If - * also present in str, variable in the locked array will be counted but - * fixed at zero. This allows us to keep them in the target domain while not - * moving. + * Iterate over another domain. If a variable is not present in this domain it + * will not move the index but it will take the step. If a variable is not in + * the specified domain the variable is not considered or is assumed fixed to 0. + * Please use getPartialOffset(vars, state) for a different offset. If also + * present in str, variable in the locked array will be counted but fixed at + * zero. This allows us to keep them in the target domain while not moving. * * @param str the target domain * @param locked the variables that should be locked @@ -668,11 +819,11 @@ public IndexIterator getIterator(Strides str, int... locked) { } ++source; // in any case we move the source pointer } // else { - // source var is greater than target (var is not found in this - // domain) - // so no need to copy strides, we can leave them at the default - // value of 0 - // } + // source var is greater than target (var is not found in this + // domain) + // so no need to copy strides, we can leave them at the default + // value of 0 + // } } return new IndexIterator(new_strides, str.getSizes(), 0, str.getCombinations()); @@ -704,24 +855,25 @@ public static Strides as(int... data) { return new Strides(variables, sizes); } - /** * Define a stride as a sequence of variable/size pairs */ public static Strides var(int var, int size) { - return new Strides(new int[]{var}, new int[]{size}); + return new Strides(new int[] { var }, new int[] { size }); } /** - * Helper to allow concatenation of var().add().add() - * Insertion is sorted. Resulting domain is ordered! + * Helper to allow concatenation of var().add().add() Insertion is sorted. + * Resulting domain is ordered! + * * @param var * @param size * @return */ public Strides and(int var, int size) { int pos = Arrays.binarySearch(variables, var); - if (pos >= 0) return this; // no need to change anything + if (pos >= 0) + return this; // no need to change anything pos = -(pos + 1); int[] newvar = new int[variables.length + 1]; @@ -743,8 +895,8 @@ public Strides and(int var, int size) { public static Strides EMPTY = Strides.as(); /** - * Return an empty Stride with no variables and a single entry in the strides array. This - * only stride value is set to 1. + * Return an empty Stride with no variables and a single entry in the strides + * array. This only stride value is set to 1. * * @return {@link Strides} - the empty stride */ @@ -752,14 +904,14 @@ public static Strides empty() { return EMPTY; } - /** * Return a new Stride sorted by the variables. * * @return */ public Strides sort() { - int[] order = IntStream.range(0, size).boxed().sorted((a, b) -> Strides.this.variables[a] - Strides.this.variables[b]).mapToInt(x -> x).toArray(); + int[] order = IntStream.range(0, size).boxed() + .sorted((a, b) -> Strides.this.variables[a] - Strides.this.variables[b]).mapToInt(x -> x).toArray(); int[] variables = Arrays.stream(order).map(x -> Strides.this.variables[x]).toArray(); int[] sizes = Arrays.stream(order).map(x -> Strides.this.sizes[x]).toArray(); return new Strides(variables, sizes); @@ -782,10 +934,9 @@ public IndexIterator getIterator(Strides domain, TIntIntMap observation) { return it; } - /** - * Determines if this Strides object is consistent other. Two Strides objects are - * consistent if all the common variables have the same cardinality. + * Determines if this Strides object is consistent other. Two Strides objects + * are consistent if all the common variables have the same cardinality. * * @param other {@link Strides} - other Strides object to compare with. * @return boolean variable indicating the compatibility. @@ -797,37 +948,35 @@ public boolean isConsistentWith(Strides other) { return true; } - - /** - * get the list of states for the specified offset within all the - * combinations of the domain + /** + * get the list of states for the specified offset within all the combinations + * of the domain */ public int[] getStatesFor(int offset) { int[] states = new int[getSize()]; int index = 0; for (int cardinality : sizes) { - states[index ++] = offset % cardinality; + states[index++] = offset % cardinality; offset = offset / cardinality; } return states; } /** - * Test wether the specified index (within the exanded domain) is present - * in the provied observations map + * Test wether the specified index (within the exanded domain) is present in the + * provied observations map * * @param index * @param obs * @return */ public boolean isCompatible(int index, TIntIntMap obs) { - int[] obsfiltered = IntStream.of(this.getVariables()).sorted() - .map(x -> { - if (obs.containsKey(x)) - return obs.get(x); - else return -1; - }).toArray(); - + int[] obsfiltered = IntStream.of(this.getVariables()).sorted().map(x -> { + if (obs.containsKey(x)) + return obs.get(x); + else + return -1; + }).toArray(); int[] states = this.statesOf(index); @@ -852,7 +1001,6 @@ public int[] getCompatibleIndexes(TIntIntMap obs) { return IntStream.range(0, this.getCombinations()).filter(i -> this.isCompatible(i, obs)).toArray(); } - public static Strides reverseDomain(Strides domain) { int[] vars = domain.getVariables().clone(); int[] sizes = domain.getSizes().clone(); @@ -865,4 +1013,33 @@ public Strides reverseDomain() { return reverseDomain(this); } + public static void main(String[] args) { + RandomDataGenerator generator = new RandomDataGenerator(); + int[] var1 = generator.nextPermutation(20, 10); + int[] var2 = generator.nextPermutation(20, 10); + int[] s1 = IntStream.range(0, 10).map(x -> 2).toArray(); + int[] s2 = IntStream.range(0, 10).map(x -> 2).toArray(); + Strides st1 = new Strides(var1, s1); + Strides st2 = new Strides(var2, s2); + + int runs = 10000; + int reps = 20; + long[] times = new long[reps]; + long useless = 0; + for (int rep = 0; rep < reps * 2; ++rep) { + long time = System.nanoTime(); + for (int i = 0; i < runs; ++i) { + var ss3 = st1.union(st2); + int[] xx = ss3.getVariables(); + int x= ArraysMath.min(xx); + useless +=x; + } + long delta = System.nanoTime() - time; + if (rep > reps) { + times[rep - reps] = delta; + + } + } + System.out.println(ArraysMath.mean(times) + " " + ArraysMath.sd(times, 1)); + } } diff --git a/src/main/java/ch/idsia/crema/core/UnsortedDomain.java b/src/main/java/ch/idsia/crema/core/UnsortedDomain.java new file mode 100644 index 00000000..e8b04601 --- /dev/null +++ b/src/main/java/ch/idsia/crema/core/UnsortedDomain.java @@ -0,0 +1,75 @@ +package ch.idsia.crema.core; + +import java.util.Arrays; + +public class UnsortedDomain implements StridedDomain { + private Strides sortedDomain; + + public UnsortedDomain(Variable[] variables) { + + } + + public UnsortedDomain(int[] variables, int[] sizes) { + + } + + @Override + public int getCardinality(int variable) { + return 0; + } + + @Override + public int getSizeAt(int index) { + // TODO Auto-generated method stub + return 0; + } + + @Override + public int indexOf(int variable) { + // TODO Auto-generated method stub + return 0; + } + + @Override + public boolean contains(int variable) { + // TODO Auto-generated method stub + return false; + } + + @Override + public int[] getVariables() { + // TODO Auto-generated method stub + return null; + } + + @Override + public int[] getSizes() { + // TODO Auto-generated method stub + return null; + } + + @Override + public int getSize() { + // TODO Auto-generated method stub + return 0; + } + + @Override + public void removed(int variable) { + // TODO Auto-generated method stub + + } + + @Override + public int getStride(int variable) { + // TODO Auto-generated method stub + return 0; + } + + @Override + public int getStrideAt(int offset) { + // TODO Auto-generated method stub + return 0; + } + +} diff --git a/src/main/java/ch/idsia/crema/core/Variable.java b/src/main/java/ch/idsia/crema/core/Variable.java new file mode 100644 index 00000000..4ec6dc11 --- /dev/null +++ b/src/main/java/ch/idsia/crema/core/Variable.java @@ -0,0 +1,37 @@ +package ch.idsia.crema.core; + +public class Variable implements Comparable { + private int label; + private int cardinality; + + public Variable(int label, int cardinality) { + this.label = label; + this.cardinality = cardinality; + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof Variable) { + return label == ((Variable) obj).label; + } else + return false; + } + + @Override + public int hashCode() { + return label; + } + + @Override + public int compareTo(Variable o) { + return Integer.compare(label, o.label); + } + + public int getLabel() { + return label; + } + + public int getCardinality() { + return cardinality; + } +} diff --git a/src/main/java/ch/idsia/crema/data/DataTable.java b/src/main/java/ch/idsia/crema/data/DataTable.java new file mode 100644 index 00000000..4b46e116 --- /dev/null +++ b/src/main/java/ch/idsia/crema/data/DataTable.java @@ -0,0 +1,339 @@ +package ch.idsia.crema.data; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; +import java.util.TreeMap; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import org.apache.commons.lang3.ArrayUtils; +import org.apache.commons.lang3.tuple.Pair; + +import gnu.trove.map.TIntIntMap; +import gnu.trove.map.hash.TIntIntHashMap; +import gnu.trove.set.TIntSet; +import gnu.trove.set.hash.TIntHashSet; + +public class DataTable implements Iterable> { + protected final int[] columns; + protected O metadata; + + protected T unit; + protected T zero; + protected T virtualcounts; + protected BiFunction add; + protected Function createArray; + + /** + * Using a {@link TreeMap}. This will use a comparator that we can provide. + * HashMap uses the object's hash that won't work correctly for arrays of int. + */ + protected TreeMap dataTable; + + public void setMetadata(O data) { + this.metadata = data; + } + + public O getMetadata() { + return metadata; + } + + + + protected DataTable(int[] columns, T unit, T zero, BiFunction add, Function array) { + this.columns = columns; + this.unit = unit; + this.zero = zero; + this.virtualcounts = zero; + this.add = add; + this.createArray = array; + this.dataTable = new TreeMap<>(Arrays::compare); + } + + + protected DataTable(int[] columns, T unit, T zero, BiFunction add, Function array, Map data) { + this.columns = columns; + this.unit = unit; + this.zero = zero; + this.virtualcounts = zero; + this.add = add; + this.createArray = array; + + this.dataTable = new TreeMap<>(Arrays::compare); + this.dataTable.putAll(data); + } + + + + private static int[] max(int[] a, int[] b) { + int[] ret = a.clone(); + for (int i = 0; i < a.length; ++i) { + ret[i] = Math.max(a[i], b[i]); + } + return ret; + } + + public TIntSet[] getStates() { + TIntSet[] accum = IntStream.range(0, columns.length).mapToObj(TIntHashSet::new).toArray(TIntSet[]::new); + + for(var entry : dataTable.entrySet()) { + var v = entry.getKey(); + for (int i = 0; i < accum.length;++i) { + accum[i].add(v[i]); + } + } + return accum; + } + + + public int[] getSizes() { + int[] accum = new int[columns.length]; + + int[] v = dataTable.entrySet().stream().map(Entry::getKey).reduce(accum, DataTable::max); + for (int i = 0; i < v.length;++i) ++v[i]; + return v; + } + + /** + * Sort and Expand or limit the map to the columns of the table. + * + * @param map + * @return int[] of the values for the table columns + */ + private int[] getIndex(TIntIntMap map) { + return Arrays.stream(columns).map(map::get).toArray(); + } + + /** + * Get the weight of a row + * + * @param index + * @return + */ + public T getWeight(TIntIntMap index) { + return dataTable.get(getIndex(index)); + } + + /** + * Get the list of weight for all possible combinations of values of the + * specified columns assuming each column has the indicated number of possible + * states. + * + * A default number of virtual counts is added to each count. + * If not changed this is zero + * + * @param vars + * @param sizes + * @return + */ + public T[] getWeightsFor(int[] vars, int[] sizes) { + return getWeightsFor(vars, sizes, virtualcounts); + } + + /** + * Get Weights table for the specified variables and a virtual count of s. + * variables sizes need to be provided. + * + * @param vars + * @param sizes + * @param s + * @return + */ + public T[] getWeightsFor(int[] vars, int[] sizes, T s) { + + // cumulative size + int cumsize = 1; + + TIntIntMap strides = new TIntIntHashMap(); + for (int i = 0; i < vars.length; ++i) { + strides.put(vars[i], cumsize); + cumsize = cumsize * sizes[i]; + } + + int[] col_strides = new int[columns.length]; + + for (int i = 0; i < columns.length; ++i) { + if (strides.containsKey(columns[i])) { + col_strides[i] = strides.get(columns[i]); + } + } + + T[] results = createArray.apply(cumsize); + for (int i = 0; i< results.length; ++i) { + results[i] = s; + } + + for (var item : dataTable.entrySet()) { + int[] states = item.getKey(); + int offset = 0; + for (int i = 0; i < columns.length; ++i) { + offset += col_strides[i] * states[i]; + } + results[offset] = add.apply(results[offset], item.getValue()); + } + + return results; + } + + + /** + * Get the weight of a row + * + * @param index + * @return + */ + public T getWeight(int[] index) { + if (index.length != columns.length) + throw new IllegalArgumentException("Wrong index size. Must match the columns"); + + return dataTable.get(index); + } + + /** + * Add to dataTable assuming correctly ordered row items + * + * @param row the item to be added + * @param count the number of rows to be added + */ + public void add(int[] row, T count) { + dataTable.compute(row, (k, v) -> (v == null) ? count : add.apply(v, count)); + } + + /** + * Add a new row using a different column order. + * + * @param cols int[] the new columns order + * @param inst int[] the row to be added in cols order + * @param count the "number" of rows being added. + */ + public void add(int[] cols, int[] inst, T count) { + int[] row = Arrays.stream(columns).map(col -> ArrayUtils.indexOf(cols, col)).map(i -> inst[i]).toArray(); + + dataTable.compute(row, (k, v) -> (v == null) ? count : add.apply(v, count)); + } + + /** + * Add a TIntIntMap with the specified count. The map must contain all the keys + * specified in the columns + * + * @param inst {@link TIntIntMap} - the row to be added + * @param count the number of rows being added. + */ + public void add(TIntIntMap inst, T count) { + int[] row = Arrays.stream(columns).map(inst::get).toArray(); + dataTable.compute(row, (k, v) -> (v == null) ? count : add.apply(v, count)); + } + + /** + * Add a TIntIntMap with unit count. The map must contain all the keys specified + * in the columns + * + * @param inst {@link TIntIntMap} - the row to be added + * @param count the number of rows being added. + */ + public void add(TIntIntMap inst) { + add(inst, unit); + } + + /** + * Fills the provided sub-table with aggregated data from this table. + * Columns missing in this table are set to zero. + * + * @param cols the subset of columns + * @return a new Table + */ + protected > SUB subtable(SUB tofill) { + int[] cols = tofill.columns; + + // indices of the desired columns + int[] idx = Arrays.stream(cols).map(col -> ArrayUtils.indexOf(columns, col)).toArray(); + //int[] matching = IntStream.of(idx).map(id -> columns[id]).toArray(); + + for (Map.Entry entry : dataTable.entrySet()) { + int[] values = entry.getKey(); + T count = entry.getValue(); + + int[] newkey = Arrays.stream(idx).map(i -> i < 0 ? 0 : values[i]).toArray(); + tofill.add(newkey, count); + } + + return tofill; + } + + /** + * Covert weights of a Table + * + * @param op the conversion operation + * @return the new Table + */ + public DataTable mapWeights(Function op) { + // new table shares the same columns by default + DataTable table = new DataTable(columns, unit, zero, add, createArray); + for (Map.Entry entry : dataTable.entrySet()) { + table.dataTable.put(entry.getKey(), op.apply(entry.getValue())); + } + return table; + } + + + + @Override + public Iterator> iterator() { + return dataTable.entrySet().iterator(); + } + + + public Iterable> mapIterable() { + return new Iterable>() { + + @Override + public Iterator> iterator() { + + var iter = dataTable.entrySet().iterator(); + return new Iterator>() { + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public Pair next() { + var nextVal = iter.next(); + TIntIntMap ret = new TIntIntHashMap(columns, nextVal.getKey()); + return Pair.of(ret, nextVal.getValue()); + } + }; + } + }; + } + + + public int[] getColumns() { + return this.columns; + } + + + public > DT map(Supplier

creator, BiFunction> converter) { + final var store = creator.get(); + + dataTable.entrySet() + .stream() + .map((v) -> converter.apply(v.getKey(), v.getValue())) + .forEach(v -> store.add(v.getKey(), v.getValue())); + return store; + + } + + + public Set> entries() { + return dataTable.entrySet(); + } + +} diff --git a/src/main/java/ch/idsia/crema/data/DoubleTable.java b/src/main/java/ch/idsia/crema/data/DoubleTable.java new file mode 100644 index 00000000..d0e5aef7 --- /dev/null +++ b/src/main/java/ch/idsia/crema/data/DoubleTable.java @@ -0,0 +1,214 @@ +package ch.idsia.crema.data; + +import java.io.BufferedReader; +import java.io.FileReader; +import java.io.IOException; +import java.util.Arrays; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import gnu.trove.map.TIntIntMap; +import gnu.trove.map.hash.TIntIntHashMap; + +public class DoubleTable extends DataTable { + + public DoubleTable(int[] columns) { + super(columns, 1d, 0d, (a, b) -> a + b, (i) -> new Double[i]); + } + + public DoubleTable(int[] columns, int[][] data) { + this(columns); + + if (data != null) { + for (int[] inst : data) { + add(inst, unit); + } + } + } + + /** + * Create a Table from an array of {@link TIntIntMap}s. + * + * @param data + * @param unit + * @param add + */ + public DoubleTable(TIntIntMap[] data) { + this(TIntMapConverter.cols(data)); + + if (data.length == 0) + throw new IllegalArgumentException("Need data"); + + for (TIntIntMap inst : data) { + add(inst); + } + } + + public DoubleTable subtable(int[] cols) { + DoubleTable tofill = new DoubleTable(cols); + return super.subtable(tofill); + } + + + /** + * Scale weights between 0 and 1 + * @return + */ + public DoubleTable scale() { + double small = 0; + + for (var entry : dataTable.entrySet()) { + small = Math.max(small, entry.getValue()); + } + + //System.out.println("AMX " + small); + if (small == 0) return this; + + DoubleTable res = new DoubleTable(columns); + for (var entry : dataTable.entrySet()) { + res.add(entry.getKey(), entry.getValue() / small); + } + + return res; + } + + + + + public double[] getWeights(int[] vars, int[] sizes) { + Double[] dta = super.getWeightsFor(vars, sizes); + return Arrays.stream(dta).mapToDouble(i -> i).toArray(); + } + + public double[] getWeights2(int[] vars, int[] sizes) { + + // cumulative size + int cumsize = 1; + + TIntIntMap strides = new TIntIntHashMap(); + for (int i = 0; i < vars.length; ++i) { + strides.put(vars[i], cumsize); + cumsize = cumsize * sizes[i]; + } + + int[] col_strides = new int[columns.length]; + + for (int i = 0; i < columns.length; ++i) { + if (strides.containsKey(columns[i])) { + col_strides[i] = strides.get(columns[i]); + } + } + + double[] results = new double[cumsize]; + + for (var item : dataTable.entrySet()) { + int[] states = item.getKey(); + int offset = 0; + for (int i = 0; i < columns.length; ++i) { + offset += col_strides[i] * states[i]; + } + results[offset] += item.getValue(); + } + + return results; + } + + TIntIntMap getKeyMap(int[] index) { + TIntIntMap map = new TIntIntHashMap(); + for (int i = 0; i < index.length; ++i) { + map.put(columns[i], index[i]); + } + return map; + } + + public TIntIntMap[] toMap(boolean roundup) { + return dataTable.entrySet().stream().flatMap(row -> IntStream + .range(0, (int) (row.getValue() + (roundup ? 0.5 : 0))).mapToObj(i -> this.getKeyMap(row.getKey()))) + .toArray(TIntIntHashMap[]::new); + + } + + /** + * Read a table from a whitespace separated file. The whitespace can be any + * regex \s character. + * + * @param filename + * @return + * @throws IOException + */ + public static DoubleTable readTable(String filename) throws IOException { + return readTable(filename, "\\s"); + } + + /** + * Read table from sep separated lists of values. Lists are separated by + * newlines and first row is the header. Header row must be integers. + * + * @param filename + * @param sep + * @return + * @throws IOException + */ + public static DoubleTable readTable(String filename, String sep) throws IOException { + try (BufferedReader input = new BufferedReader(new FileReader(filename))) { + String[] cols = input.readLine().split(sep); + int[] columns = Arrays.stream(cols).mapToInt(Integer::parseInt).toArray(); + DoubleTable ret = new DoubleTable(columns); + + String line; + while ((line = input.readLine()) != null) { + cols = line.split(sep); + + int[] row = Arrays.stream(cols).mapToInt(Integer::parseInt).toArray(); + ret.add(row, 1d); + } + return ret; + } + } + + static int parseInt(String value) { + return (int) Double.parseDouble(value); + } + + public static DoubleTable readTable(String filename, int skip, String sep, Map columns_out) + throws IOException { + try (BufferedReader input = new BufferedReader(new FileReader(filename))) { + for (int i = 0; i < skip; ++i) + input.readLine().split(sep); + + String[] values = input.readLine().split(sep); + int[] columns = IntStream.range(0, values.length).toArray(); + if (columns_out != null) { + for (int i = 0; i < values.length; ++i) { + columns_out.put(values[i], columns[i]); + } + } + + DoubleTable ret = new DoubleTable(columns); + + String line; + while ((line = input.readLine()) != null) { + values = line.split(sep); + + int[] row = Arrays.stream(values).mapToInt(DoubleTable::parseInt).toArray(); + ret.add(row, 1d); + } + return ret; + } + } + + public String toString() { + StringBuilder sb = new StringBuilder(); + String head = Arrays.stream(columns).mapToObj(Integer::toString).collect(Collectors.joining(" | ")); + sb.append(head).append(" | weight\n"); + + for (Map.Entry row : dataTable.entrySet()) { + String key = Arrays.stream(row.getKey()).mapToObj(Integer::toString).collect(Collectors.joining(" | ")); + sb.append(key).append(" | ").append(row.getValue()).append("\n"); + } + return sb.toString(); + + } + +} diff --git a/src/main/java/ch/idsia/crema/data/TIntMapConverter.java b/src/main/java/ch/idsia/crema/data/TIntMapConverter.java new file mode 100644 index 00000000..3e1acd86 --- /dev/null +++ b/src/main/java/ch/idsia/crema/data/TIntMapConverter.java @@ -0,0 +1,60 @@ +package ch.idsia.crema.data; + +import java.util.Arrays; +import java.util.function.Function; + +import gnu.trove.map.TIntIntMap; +import gnu.trove.set.TIntSet; +import gnu.trove.set.hash.TIntHashSet; + +class TIntMapConverter { + + /** + * Convert from the specified map to an array of int using the specified columns + * order + * + * @param map the map to be converted + * @param columns the order of the columns + * @return the converted integer array + */ + public static int[] from(TIntIntMap map, int[] columns) { + return Arrays.stream(columns).map(map::get).toArray(); + } + + /** + * A curried version of the from method that returns a version of from with + * fixed columns + * + * @param columns the order to be fixed + * @return the function + */ + public static Function curriedFrom(int[] columns) { + return map -> Arrays.stream(columns).map(map::get).toArray(); + } + + /** + * Convert an array of maps to an array of int arrays sorted by columns. + * + * @param map the input data + * @param columns the order of the columns + * + * @return the output data + */ + public static int[][] from(TIntIntMap[] map, int[] columns) { + return Arrays.stream(map).map(curriedFrom(columns)).toArray(int[][]::new); + } + + /** + * Discover the set of columns assuming that not all are specified + * + * @param data + * @return + */ + protected static int[] cols(TIntIntMap[] data) { + TIntSet columns = new TIntHashSet(); + for (TIntIntMap row : data) { + columns.addAll(row.keys()); + } + return columns.toArray(); + } +} \ No newline at end of file diff --git a/src/main/java/ch/idsia/crema/factor/bayesian/BayesianDefaultFactor.java b/src/main/java/ch/idsia/crema/factor/bayesian/BayesianDefaultFactor.java index f9bae89a..35e76a6c 100644 --- a/src/main/java/ch/idsia/crema/factor/bayesian/BayesianDefaultFactor.java +++ b/src/main/java/ch/idsia/crema/factor/bayesian/BayesianDefaultFactor.java @@ -3,6 +3,7 @@ import ch.idsia.crema.core.Domain; import ch.idsia.crema.core.ObservationBuilder; import ch.idsia.crema.core.Strides; +import ch.idsia.crema.factor.algebra.GenericOperationFunction; import ch.idsia.crema.factor.algebra.bayesian.BayesianOperation; import ch.idsia.crema.factor.algebra.bayesian.SimpleBayesianFilter; import ch.idsia.crema.factor.algebra.bayesian.SimpleBayesianMarginal; @@ -332,11 +333,61 @@ public BayesianDefaultFactor combine(BayesianFactor factor) { factor = ((BayesianLogFactor) factor).exp(); if (factor instanceof BayesianDefaultFactor) - return combine((BayesianDefaultFactor) factor, BayesianDefaultFactor::new, ops::combine); + return owncombine((BayesianDefaultFactor) factor, BayesianDefaultFactor::new, ops::combine); return (BayesianDefaultFactor) super.combine(factor); } + + protected F owncombine(F factor, BayesianFactorBuilder builder, GenericOperationFunction op) { + // domains should be sorted + //factor = (F) factor.copy(); + + final Strides target = getDomain().union(factor.getDomain()); + final int length = target.getSize(); + + final int[] limits = new int[length]; + final int[] assign = new int[length]; + + final long[] stride = new long[length]; + final long[] reset = new long[length]; + + for (int vindex = 0; vindex < getDomain().getSize(); ++vindex) { + int offset = Arrays.binarySearch(target.getVariables(), getDomain().getVariables()[vindex]); + stride[offset] = getDomain().getStrides()[vindex]; + } + + for (int vindex = 0; vindex < factor.getDomain().getSize(); ++vindex) { + int offset = ArraysUtil.indexOf(factor.getDomain().getVariables()[vindex], target.getVariables()); + stride[offset] += ((long) factor.getDomain().getStrides()[vindex] << 32L); + } + + for (int i = 0; i < length; ++i) { + limits[i] = target.getSizes()[i] - 1; + reset[i] = limits[i] * stride[i]; + } + + long idx = 0; + double[] result = new double[target.getCombinations()]; + + for (int i = 0; i < result.length; ++i) { + result[i] = this.data[ (int) (idx & 0xFFFFFFFF)] * factor.getValueAt((int) (idx >>> 32L)); + + for (int l = 0; l < length; ++l) { + if (assign[l] == limits[l]) { + assign[l] = 0; + idx -= reset[l]; + } else { + ++assign[l]; + idx += stride[l]; + break; + } + } + } + + return builder.get(target, result); + } + /** *

* If the input factor is also a {@link BayesianDefaultFactor}, a fast algebra i used. If the input is a @@ -521,5 +572,139 @@ public double logProb(TIntIntMap[] data, int leftVar) { return logprob; } + + + + + /** + * BD -> ABD -> ABCD -> ABCDE + * b0d0 0 a0b0d0 0 a0b0c0d0 0 0 + * b1d0 1 a1b0d0 0 a1b0c0d0 0 0 + * b0d1 2 a0b1d0 1 a0b1c0d0 1 1 + * b1d1 3 a1b1d0 1 a1b1c0d0 1 + * a0b0d1 2 a0b0d1 2 + * a1b0d1 2 a1b0d1 2 + * a0b1d1 3 a0b1d1 3 + * a0b1d1 3 a0b1d1 3 + */ + + + static class Side { + int position; + int[] sizes; + int[] variables; + int combinations; + int size; + + int stride; + + Side(Strides d) { + this.position = 0; + this.stride = 1; + this.sizes = d.getSizes(); + this.variables = d.getVariables(); + this.size = variables.length; + this.combinations = d.getCombinations(); + } + + boolean ok() { + return position < size; + } + int currentVar() { + return variables[position]; + } + int currentSize() { + return sizes[position]; + } + void move() { + ++position; + } + } + + + /** + * expectations: + *

    + *
  • Ordered domains/stride
  • + *
  • $domain \subseteq target$
  • + *
+ * + * @param target + * @return + */ + protected static double[] expand(Strides from_domain, double[] data, Strides to_domain) { + Side source = new Side(from_domain); + Side target = new Side(to_domain); + + double[] source_data = data; + + double[] cache1 = new double[target.combinations]; + double[] target_data = cache1; + + // head insertions repeat single rows + int used = source_data.length; + + + + + while (true) { + + // dato un set target e uno sorgetnte + // trovare blocchi di differenze. + // due indicatori: posizione source e posizione target + // inizio differenza + while (target.ok() && source.ok() && target.currentVar() == source.currentVar()) { + target.move(); + source.move(); + } + if (!target.ok()) return source_data; // no difference found + + // target's current variable is missing in source + int from = target.position; + int repeat = 1; + + while (target.ok() && (!source.ok() || target.currentVar() != source.currentVar())) { + repeat *= target.currentSize(); + target.move(); + } + + // the number of rows to repeat + + // per ogni differenza devo sapere + // quanti elementi ripetere, Prod size of left + // quante volte ripetere, dimensioni delle variabili inserite + // quante volte rifare questa ripetizione + + int rows = to_domain.getStrideAt(from); + int tpos = 0; + for (int r = 0; r < used; r += rows) { + for (int j = 0; j < repeat; ++j) { + System.arraycopy(source_data, r, target_data, tpos, rows); + tpos += rows; + } + } + used = tpos; + + if (!target.ok()) return target_data; + + if (source_data == data) { + source_data = target_data; + target_data = new double[target.combinations]; + } else { + // swap buffers + double[] tmp= source_data; + source_data = target_data; + target_data = source_data; + } + } + } + public static void main(String[] args) { + double[] data = new double[] { 0, 1, 2, 3}; + Strides from = Strides.var(2, 2).and(4, 2); + Strides to = Strides.var(1,2).and(2, 2).and(3,2).and(4, 2).and(5, 2); + double[] target = expand(from, data, to); + System.out.println(Arrays.toString(target)); + + } } diff --git a/src/main/java/ch/idsia/crema/factor/bayesian/BayesianFactor.java b/src/main/java/ch/idsia/crema/factor/bayesian/BayesianFactor.java index 81776af9..b9eec4d0 100644 --- a/src/main/java/ch/idsia/crema/factor/bayesian/BayesianFactor.java +++ b/src/main/java/ch/idsia/crema/factor/bayesian/BayesianFactor.java @@ -46,7 +46,7 @@ public interface BayesianFactor extends OperableFactor, Separate */ double getLogValueAt(int index); - // TODO: this method should only for internal use + // TODO: this method should only be for internal use double[] getData(); @Override diff --git a/src/main/java/ch/idsia/crema/factor/bayesian/BayesianFactorFactory.java b/src/main/java/ch/idsia/crema/factor/bayesian/BayesianFactorFactory.java index ee11d8e1..7b0834c1 100644 --- a/src/main/java/ch/idsia/crema/factor/bayesian/BayesianFactorFactory.java +++ b/src/main/java/ch/idsia/crema/factor/bayesian/BayesianFactorFactory.java @@ -5,12 +5,20 @@ import ch.idsia.crema.utility.ArraysUtil; import ch.idsia.crema.utility.IndexIterator; +import java.util.Arrays; +import java.util.Random; import java.util.stream.IntStream; +import org.apache.commons.math3.random.RandomGeneratorFactory; +import org.apache.commons.math3.random.UniformRandomGenerator; +import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.distribution.DirichletSampler; +import org.apache.commons.rng.simple.JDKRandomWrapper; + +import com.google.common.primitives.Doubles; + /** - * Author: Claudio "Dna" Bonesana - * Project: crema - * Date: 16.04.2021 11:11 + * Author: Claudio "Dna" Bonesana Project: crema Date: 16.04.2021 11:11 */ public class BayesianFactorFactory { private double[] data = null; @@ -19,36 +27,49 @@ public class BayesianFactorFactory { private Strides domain = Strides.empty(); private BayesianFactorFactory() { + initRandom(0); + } + + public BayesianFactorFactory initRandom(long seed) { + unif = new JDKRandomWrapper(new Random(seed)); + return this; } /** - * @param var the variable associated with this factor. This variable will be considered binary + * @param var the variable associated with this factor. This variable will be + * considered binary * @return a {@link BayesianDefaultFactor} where state 1 has probability 1.0. */ public static BayesianDefaultFactor one(int var) { - return new BayesianDefaultFactor(Strides.var(var, 2), new double[]{0., 1.}); + return new BayesianDefaultFactor(Strides.var(var, 2), new double[] { 0., 1. }); } /** - * @param var the variable associated with this factor. This variable will be considered binary + * @param var the variable associated with this factor. This variable will be + * considered binary * @return a {@link BayesianDefaultFactor} where state 0 has probability 1.0. */ public static BayesianDefaultFactor zero(int var) { - return new BayesianDefaultFactor(Strides.var(var, 2), new double[]{1., 0.}); + return new BayesianDefaultFactor(Strides.var(var, 2), new double[] { 1., 0. }); } /** * This is the entry point for the chained interface of this factory. * - * @return a {@link BayesianFactorFactory} object that can be used to chain multiple commands. + * @return a {@link BayesianFactorFactory} object that can be used to chain + * multiple commands. */ public static BayesianFactorFactory factory() { return new BayesianFactorFactory(); } /** + * Set the domain of the builder. Order is ignored and natural ordering will be + * used. + * * @param domain set the domain of the factor - * @return a {@link BayesianFactorFactory} object that can be used to chain multiple commands. + * @return a {@link BayesianFactorFactory} object that can be used to chain + * multiple commands. */ public BayesianFactorFactory domain(Domain domain) { this.domain = Strides.fromDomain(domain); @@ -58,17 +79,32 @@ public BayesianFactorFactory domain(Domain domain) { /** * @param domain the variables that defines the domain * @param sizes the sizes of each variable - * @return a {@link BayesianFactorFactory} object that can be used to chain multiple commands. + * @return a {@link BayesianFactorFactory} object that can be used to chain + * multiple commands. */ public BayesianFactorFactory domain(int[] domain, int[] sizes) { - this.domain = new Strides(domain, sizes); - return data(); + return domain(domain, sizes, false); + } + + public BayesianFactorFactory domain(int[] domain, int[] sizes, boolean already_sorted) { + if (!already_sorted) { + int[] pos = ArraysUtil.order(domain); + + int[] sortedDomain = ArraysUtil.at(domain, pos); + int[] sortedSizes = ArraysUtil.at(sizes, pos); + + this.domain = new Strides(sortedDomain, sortedSizes); + } else { + this.domain = new Strides(domain, sizes); + } + return this; } /** * Set an empty data set with all the combinations of the given domain. * - * @return a {@link BayesianFactorFactory} object that can be used to chain multiple commands. + * @return a {@link BayesianFactorFactory} object that can be used to chain + * multiple commands. */ public BayesianFactorFactory data() { this.data = new double[domain.getCombinations()]; @@ -77,12 +113,14 @@ public BayesianFactorFactory data() { /** * @param data an array of values that will be directly used - * @return a {@link BayesianFactorFactory} object that can be used to chain multiple commands. + * @return a {@link BayesianFactorFactory} object that can be used to chain + * multiple commands. */ public BayesianFactorFactory data(double[] data) { final int expectedLength = domain.getCombinations(); if (data.length > expectedLength) - throw new IllegalArgumentException("Invalid length of data: expected " + expectedLength + " got " + data.length); + throw new IllegalArgumentException( + "Invalid length of data: expected " + expectedLength + " got " + data.length); if (this.data == null) this.data = new double[expectedLength]; @@ -94,12 +132,14 @@ public BayesianFactorFactory data(double[] data) { /** * @param data an array of values in log-space that will be directly used - * @return a {@link BayesianFactorFactory} object that can be used to chain multiple commands. + * @return a {@link BayesianFactorFactory} object that can be used to chain + * multiple commands. */ public BayesianFactorFactory logData(double[] data) { final int expectedLength = domain.getCombinations(); if (data.length != expectedLength) - throw new IllegalArgumentException("Invalid length of data: expected " + expectedLength + " got " + data.length); + throw new IllegalArgumentException( + "Invalid length of data: expected " + expectedLength + " got " + data.length); if (this.logData == null) this.logData = new double[expectedLength]; @@ -112,7 +152,8 @@ public BayesianFactorFactory logData(double[] data) { /** * @param domain the order of the variables that defines the values * @param data the values to use specified with the given domain order - * @return a {@link BayesianFactorFactory} object that can be used to chain multiple commands. + * @return a {@link BayesianFactorFactory} object that can be used to chain + * multiple commands. */ public BayesianFactorFactory data(int[] domain, double[] data) { int[] sequence = ArraysUtil.order(domain); @@ -148,7 +189,19 @@ public BayesianFactorFactory data(int[] domain, double[] data) { /** * @param value a single value to set * @param states the states that defines this value - * @return a {@link BayesianFactorFactory} object that can be used to chain multiple commands. + * @return a {@link BayesianFactorFactory} object that can be used to chain + * multiple commands. + */ + public BayesianFactorFactory value(double value, int[] variables, int[] states) { + data[domain.getPartialOffset(variables, states)] = value; + return this; + } + + /** + * @param value a single value to set + * @param states the states that defines this value + * @return a {@link BayesianFactorFactory} object that can be used to chain + * multiple commands. */ public BayesianFactorFactory value(double value, int... states) { data[domain.getOffset(states)] = value; @@ -158,7 +211,8 @@ public BayesianFactorFactory value(double value, int... states) { /** * @param d a single value to set * @param index the index (or offset) of the value to set - * @return a {@link BayesianFactorFactory} object that can be used to chain multiple commands. + * @return a {@link BayesianFactorFactory} object that can be used to chain + * multiple commands. */ public BayesianFactorFactory valueAt(double d, int index) { data[index] = d; @@ -168,7 +222,8 @@ public BayesianFactorFactory valueAt(double d, int index) { /** * @param d a single value to set * @param index the index (or offset) of the value to set - * @return a {@link BayesianFactorFactory} object that can be used to chain multiple commands. + * @return a {@link BayesianFactorFactory} object that can be used to chain + * multiple commands. */ public BayesianFactorFactory valuesAt(double[] d, int index) { System.arraycopy(d, 0, data, index, d.length); @@ -178,14 +233,27 @@ public BayesianFactorFactory valuesAt(double[] d, int index) { /** * @param value a single value to set * @param states the states that defines this value - * @return a {@link BayesianFactorFactory} object that can be used to chain multiple commands. + * @return a {@link BayesianFactorFactory} object that can be used to chain + * multiple commands. */ public BayesianFactorFactory set(double value, int... states) { return value(value, states); } + /** - * @return a {@link BayesianLogFactor}, where the given data are converted to log-space + * @param value a single value to set + * @param variables the order of the states attribute + * @param states the states that defines this value + * @return a {@link BayesianFactorFactory} object that can be used to chain + * multiple commands. + */ + public BayesianFactorFactory set(double value, int[] variables, int[] states) { + return value(value, states); + } + /** + * @return a {@link BayesianLogFactor}, where the given data are converted to + * log-space */ public BayesianLogFactor log() { // sort variables @@ -205,7 +273,8 @@ public BayesianLogFactor log() { } /** - * @return a {@link BayesianDefaultFactor}, where the given data are converted to non-log-space. + * @return a {@link BayesianDefaultFactor}, where the given data are converted + * to non-log-space. */ public BayesianDefaultFactor get() { // sort variables @@ -236,7 +305,8 @@ public BayesianNotFactor not(int parent) { * Requires a pre-defined Domain. * * @param parent variable that is the parent of this factor - * @param trueState index of the state to be considered as TRUE for the given parent + * @param trueState index of the state to be considered as TRUE for the given + * parent * @return a logic {@link BayesianAndFactor} */ public BayesianAndFactor not(int parent, int trueState) { @@ -268,7 +338,8 @@ public BayesianAndFactor and(int... parents) { /** * Requires a pre-defined Domain. * - * @param trueStates index of the state to be considered as TRUE for each given parent + * @param trueStates index of the state to be considered as TRUE for each given + * parent * @param parents variables that are parents of this factor * @return a logic {@link BayesianAndFactor} */ @@ -316,7 +387,8 @@ public BayesianOrFactor or(int... parents) { /** * Requires a pre-defined Domain. * - * @param trueStates index of the state to be considered as TRUE for each given parent + * @param trueStates index of the state to be considered as TRUE for each given + * parent * @param parents variables that are parents of this factor * @return a logic {@link BayesianOrFactor} */ @@ -369,7 +441,8 @@ public BayesianNoisyOrFactor noisyOr(int[] parents, double[] strengths) { * Requires a pre-defined Domain. * * @param parents variables that are parents of this factor - * @param trueStates index of the state to be considered as TRUE for each given parent + * @param trueStates index of the state to be considered as TRUE for each given + * parent * @param strengths values for the inhibition strength for each given parent * @return a logic {@link BayesianNoisyOrFactor} */ @@ -394,4 +467,39 @@ public BayesianNoisyOrFactor noisyOr(int[] parents, int[] trueStates, double[] s return new BayesianNoisyOrFactor(new Strides(vars, sizes), pars, trus, inbs); } + UniformRandomProvider unif; + + public BayesianFactorFactory random() { + return random(null); + } + + public BayesianFactorFactory random(Domain conditioning) { + + int combinations; + Strides dom; + int[] vars; + + if (conditioning == null) { + combinations = 1; + dom = this.domain; + vars = dom.getVariables(); + } else { + combinations = Arrays.stream(conditioning.getSizes()).reduce(1, (a, b) -> a * b); + dom = this.domain.remove(conditioning.getVariables()); + vars = ArraysUtil.append(dom.getVariables(), conditioning.getVariables()); + } + + int size = dom.getCombinations(); + + double[] alpha = new double[size]; + Arrays.fill(alpha, 1.0); + + DirichletSampler ds = DirichletSampler.of(unif, alpha); + double[][] arr = ds.samples(combinations).toArray(double[][]::new); + if (combinations == 1) { + return data(vars, arr[0]); + } else { + return data(vars, Doubles.concat(arr)); + } + } } diff --git a/src/main/java/ch/idsia/crema/factor/credal/vertex/separate/VertexAbstractFactor.java b/src/main/java/ch/idsia/crema/factor/credal/vertex/separate/VertexAbstractFactor.java index 2848b250..c08ed92d 100644 --- a/src/main/java/ch/idsia/crema/factor/credal/vertex/separate/VertexAbstractFactor.java +++ b/src/main/java/ch/idsia/crema/factor/credal/vertex/separate/VertexAbstractFactor.java @@ -180,7 +180,7 @@ protected F reseparate(Strides target, VertexFa Strides T = getSeparatingDomain().intersection(target); Strides Lt = getSeparatingDomain().remove(target); - Strides Dl = getDataDomain().union(Lt); + Strides Dl = getDataDomain().union(Lt).sort(); // target data double[][][] dest_data = new double[T.getCombinations()][][]; diff --git a/src/main/java/ch/idsia/crema/inference/approxlp1/CredalApproxLP.java b/src/main/java/ch/idsia/crema/inference/approxlp1/CredalApproxLP.java index c167de65..bc6fd36b 100644 --- a/src/main/java/ch/idsia/crema/inference/approxlp1/CredalApproxLP.java +++ b/src/main/java/ch/idsia/crema/inference/approxlp1/CredalApproxLP.java @@ -7,7 +7,7 @@ import ch.idsia.crema.model.graphical.GraphicalModel; import ch.idsia.crema.model.graphical.MixedModel; import ch.idsia.crema.preprocess.BinarizeEvidence; -import ch.idsia.crema.preprocess.CutObserved; +import ch.idsia.crema.preprocess.Observe; import ch.idsia.crema.preprocess.RemoveBarren; import gnu.trove.map.TIntIntMap; import gnu.trove.map.hash.TIntIntHashMap; @@ -18,7 +18,7 @@ public class CredalApproxLP> implements Inference< protected GraphicalModel getInferenceModel(GraphicalModel model, TIntIntMap evidence, int target) { // preprocessing - final CutObserved cut = new CutObserved<>(); + final Observe cut = new Observe<>(); final GraphicalModel cutted = cut.execute(model, evidence); RemoveBarren removeBarren = new RemoveBarren<>(); diff --git a/src/main/java/ch/idsia/crema/inference/bp/BeliefPropagation.java b/src/main/java/ch/idsia/crema/inference/bp/BeliefPropagation.java index 530963a2..abb62ed2 100644 --- a/src/main/java/ch/idsia/crema/inference/bp/BeliefPropagation.java +++ b/src/main/java/ch/idsia/crema/inference/bp/BeliefPropagation.java @@ -5,7 +5,7 @@ import ch.idsia.crema.inference.bp.cliques.Clique; import ch.idsia.crema.inference.bp.junction.JunctionTree; import ch.idsia.crema.model.graphical.DAGModel; -import ch.idsia.crema.preprocess.CutObserved; +import ch.idsia.crema.preprocess.Observe; import ch.idsia.crema.preprocess.RemoveBarren; import gnu.trove.map.TIntIntMap; import gnu.trove.map.hash.TIntIntHashMap; @@ -61,7 +61,7 @@ public Boolean isFullyPropagated() { } /** - * If {@link #preprocess} is true, then pre-process the model with {@link CutObserved} and {@link RemoveBarren}. + * If {@link #preprocess} is true, then pre-process the model with {@link Observe} and {@link RemoveBarren}. * * @param original the model to use for inference * @param evidence the observed variable as a map of variable-states @@ -72,7 +72,7 @@ protected DAGModel preprocess(DAGModel original, TIntIntMap evidence, int. DAGModel model = original; if (preprocess) { model = original.copy(); - final CutObserved co = new CutObserved<>(); + final Observe co = new Observe<>(); final RemoveBarren rb = new RemoveBarren<>(); co.executeInPlace(model, evidence); @@ -172,7 +172,7 @@ protected Clique getRoot(JunctionTree junctionTree, int variable) { } /** - * Pre-process the model with {@link CutObserved} and {@link RemoveBarren}, then performs the + * Pre-process the model with {@link Observe} and {@link RemoveBarren}, then performs the * {@link #collectingEvidence(int)} step. * * @param model the model to use for inference @@ -185,7 +185,7 @@ public F query(DAGModel model, int query) { } /** - * Pre-process the model with {@link CutObserved} and {@link RemoveBarren}, then performs the + * Pre-process the model with {@link Observe} and {@link RemoveBarren}, then performs the * {@link #collectingEvidence(int)} step. * * @param original the model to use for inference @@ -204,7 +204,7 @@ public F query(DAGModel original, TIntIntMap evidence, int query) { } /** - * Pre-process the model with {@link CutObserved} and {@link RemoveBarren}, then performs the + * Pre-process the model with {@link Observe} and {@link RemoveBarren}, then performs the * {@link #collectingEvidence(int)} and {@link #distributingEvidence()} steps. *

* Use the {@link #queryFullPropagated(int)} method for query multiple variables over the same evidence and model. diff --git a/src/main/java/ch/idsia/crema/inference/bp/LoopyBeliefPropagation.java b/src/main/java/ch/idsia/crema/inference/bp/LoopyBeliefPropagation.java index 5db56944..f6252cc7 100644 --- a/src/main/java/ch/idsia/crema/inference/bp/LoopyBeliefPropagation.java +++ b/src/main/java/ch/idsia/crema/inference/bp/LoopyBeliefPropagation.java @@ -3,7 +3,7 @@ import ch.idsia.crema.factor.OperableFactor; import ch.idsia.crema.inference.Inference; import ch.idsia.crema.model.graphical.DAGModel; -import ch.idsia.crema.preprocess.CutObserved; +import ch.idsia.crema.preprocess.Observe; import ch.idsia.crema.preprocess.RemoveBarren; import ch.idsia.crema.utility.ArraysUtil; import gnu.trove.map.TIntIntMap; @@ -83,7 +83,7 @@ protected DAGModel preprocess(DAGModel original, TIntIntMap evidence, int. DAGModel model = original; if (preprocess) { model = original.copy(); - final CutObserved co = new CutObserved<>(); + final Observe co = new Observe<>(); final RemoveBarren rb = new RemoveBarren<>(); co.executeInPlace(model, evidence); diff --git a/src/main/java/ch/idsia/crema/inference/sampling/LikelihoodWeightingSampling.java b/src/main/java/ch/idsia/crema/inference/sampling/LikelihoodWeightingSampling.java index cda8aca3..87f47acd 100644 --- a/src/main/java/ch/idsia/crema/inference/sampling/LikelihoodWeightingSampling.java +++ b/src/main/java/ch/idsia/crema/inference/sampling/LikelihoodWeightingSampling.java @@ -3,7 +3,7 @@ import ch.idsia.crema.factor.bayesian.BayesianDefaultFactor; import ch.idsia.crema.factor.bayesian.BayesianFactor; import ch.idsia.crema.model.graphical.GraphicalModel; -import ch.idsia.crema.preprocess.CutObserved; +import ch.idsia.crema.preprocess.Observe; import gnu.trove.map.TIntDoubleMap; import gnu.trove.map.TIntIntMap; import gnu.trove.map.TIntObjectMap; @@ -45,7 +45,7 @@ public Collection run(GraphicalModel original, T if (!preprocess) { // this is mandatory - final CutObserved co = new CutObserved<>(); + final Observe co = new Observe<>(); co.executeInPlace(model, evidence); } diff --git a/src/main/java/ch/idsia/crema/inference/sampling/StochasticSampling.java b/src/main/java/ch/idsia/crema/inference/sampling/StochasticSampling.java index f619150d..f69dc571 100644 --- a/src/main/java/ch/idsia/crema/inference/sampling/StochasticSampling.java +++ b/src/main/java/ch/idsia/crema/inference/sampling/StochasticSampling.java @@ -3,7 +3,7 @@ import ch.idsia.crema.factor.bayesian.BayesianFactor; import ch.idsia.crema.inference.InferenceJoined; import ch.idsia.crema.model.graphical.GraphicalModel; -import ch.idsia.crema.preprocess.CutObserved; +import ch.idsia.crema.preprocess.Observe; import ch.idsia.crema.preprocess.RemoveBarren; import ch.idsia.crema.utility.RandomUtil; import gnu.trove.map.TIntIntMap; @@ -62,7 +62,7 @@ protected GraphicalModel preprocess(GraphicalModel model = original; if (preprocess) { model = original.copy(); - final CutObserved co = new CutObserved<>(); + final Observe co = new Observe<>(); final RemoveBarren rb = new RemoveBarren<>(); co.executeInPlace(model, evidence); diff --git a/src/main/java/ch/idsia/crema/inference/ve/CredalVariableElimination.java b/src/main/java/ch/idsia/crema/inference/ve/CredalVariableElimination.java index 586a4d8d..30c5cd4d 100644 --- a/src/main/java/ch/idsia/crema/inference/ve/CredalVariableElimination.java +++ b/src/main/java/ch/idsia/crema/inference/ve/CredalVariableElimination.java @@ -5,7 +5,7 @@ import ch.idsia.crema.inference.Inference; import ch.idsia.crema.inference.ve.order.MinFillOrdering; import ch.idsia.crema.model.graphical.GraphicalModel; -import ch.idsia.crema.preprocess.CutObserved; +import ch.idsia.crema.preprocess.Observe; import ch.idsia.crema.preprocess.RemoveBarren; import ch.idsia.crema.utility.hull.ConvexHull; import gnu.trove.map.TIntIntMap; @@ -27,7 +27,7 @@ public CredalVariableElimination setConvexHullMarg(ConvexHull convexHullMarg) { } protected GraphicalModel getInferenceModel(GraphicalModel model, TIntIntMap evidence, int target) { - CutObserved cutObserved = new CutObserved<>(); + Observe cutObserved = new Observe<>(); // run making a copy of the model GraphicalModel infModel = cutObserved.execute(model, evidence); diff --git a/src/main/java/ch/idsia/crema/learning/DiscreteEM.java b/src/main/java/ch/idsia/crema/learning/DiscreteEM.java index 8c7ed19b..c9d47f48 100644 --- a/src/main/java/ch/idsia/crema/learning/DiscreteEM.java +++ b/src/main/java/ch/idsia/crema/learning/DiscreteEM.java @@ -4,7 +4,7 @@ import ch.idsia.crema.inference.InferenceJoined; import ch.idsia.crema.inference.ve.FactorVariableElimination; import ch.idsia.crema.model.graphical.GraphicalModel; -import ch.idsia.crema.preprocess.CutObserved; +import ch.idsia.crema.preprocess.Observe; import ch.idsia.crema.preprocess.RemoveBarren; import gnu.trove.map.TIntIntMap; @@ -14,7 +14,7 @@ protected InferenceJoined, BayesianFactor> getDef return new InferenceJoined<>() { @Override public BayesianFactor query(GraphicalModel model, TIntIntMap evidence, int... queries) { - final CutObserved co = new CutObserved<>(); + final Observe co = new Observe<>(); GraphicalModel coModel = co.execute(model, evidence); final RemoveBarren rb = new RemoveBarren<>(); diff --git a/src/main/java/ch/idsia/crema/model/Model.java b/src/main/java/ch/idsia/crema/model/Model.java index cdd0e08f..cf7ec203 100644 --- a/src/main/java/ch/idsia/crema/model/Model.java +++ b/src/main/java/ch/idsia/crema/model/Model.java @@ -27,6 +27,7 @@ public interface Model { */ Strides getDomain(int... variables); + /** * remove a specific state from a variable * diff --git a/src/main/java/ch/idsia/crema/model/causal/SCM.java b/src/main/java/ch/idsia/crema/model/causal/SCM.java new file mode 100644 index 00000000..1166c4a5 --- /dev/null +++ b/src/main/java/ch/idsia/crema/model/causal/SCM.java @@ -0,0 +1,194 @@ +package ch.idsia.crema.model.causal; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.Map; +import java.util.TreeMap; + +import org.apache.commons.lang3.tuple.Pair; + +import java.util.Map.Entry; +import java.util.NoSuchElementException; + +import ch.idsia.crema.factor.bayesian.BayesianFactor; +import ch.idsia.crema.model.graphical.DAGModel; +import gnu.trove.iterator.TIntIterator; +import gnu.trove.set.TIntSet; +import gnu.trove.set.hash.TIntHashSet; + +public class SCM extends DAGModel { + + public static enum VariableType { + /** Endgogenous variable have at least one exogenous counfounder */ + ENDOGENOUS, + + /** Exogenous variables are root nodes with no parents */ + EXOGENOUS, + + /** Extra variables are endogenous variables without direct exogenous influence */ + EXTRA + } + + private int nextId; + private TreeMap varSets; + + + public SCM() { + varSets = new TreeMap(); + + varSets.put(VariableType.ENDOGENOUS, new TIntHashSet()); + varSets.put(VariableType.EXOGENOUS, new TIntHashSet()); + varSets.put(VariableType.EXTRA, new TIntHashSet()); + } + + private int nextId() { + return nextId; + } + + + private void useId(int id) { + nextId = Math.max(nextId, id + 1); + } + + public void add(TypedVariable variable) { + this.add(variable.getLabel(), variable.getCardinality(), variable.getType()); + } + + public int add(int size, VariableType type) { + int varid = nextId(); + return this.add(varid, size, type); + } + + public int add(int variable, int size, VariableType type) { + if (super.cardinalities.containsKey(variable)) + return variable; + + useId(variable); + + super.addVariable(variable, size); + varSets.get(type).add(variable); + + return variable; + } + + public int addEndogenous(int size) { + return this.add(size, VariableType.ENDOGENOUS); + } + + public int addExogenous(int size) { + return this.add(size, VariableType.EXOGENOUS); + } + + public int addExtra(int size) { + return this.add(size, VariableType.EXTRA); + } + + public int addEndogenous(int variable, int size) { + return this.add(variable, size, VariableType.ENDOGENOUS); + } + + public int addExogenous(int variable, int size) { + return this.add(variable, size, VariableType.EXOGENOUS); + } + + public int addAdditional(int variable, int size) { + return this.add(variable, size, VariableType.EXOGENOUS); + } + + public boolean isEndogenous(int variable) { + return varSets.get(VariableType.ENDOGENOUS).contains(variable); + } + + public boolean isExogenous(int variable) { + return varSets.get(VariableType.EXOGENOUS).contains(variable); + } + + public boolean isExtra(int variable) { + return varSets.get(VariableType.EXTRA).contains(variable); + } + + public boolean has(int variable) { + return this.cardinalities.containsKey(variable); + } + + public VariableType getType(int variable) { + for (VariableType type : VariableType.values()) { + if (varSets.get(type).contains(variable)) return type; + } + return null; + } + + public int[] getEndogenousVars() { + return varSets.get(VariableType.ENDOGENOUS).toArray(); + } + + public int[] getExogenousVars() { + return varSets.get(VariableType.EXOGENOUS).toArray(); + } + + public int[] getExtraVars() { + return varSets.get(VariableType.EXOGENOUS).toArray(); + } + + public void addParent(int variable, int parent) { + if (isExogenous(variable)) + throw new IllegalStateException("Exogenous vars have no parents."); + + if (isExtra(variable) && isExogenous(parent)) + throw new IllegalStateException("Additional vars cannot depend on exogenous vars directly."); + + if (isExtra(variable) && !isExtra(parent)) + throw new IllegalStateException("Additional vars should only depend on each other."); + + super.addParent(variable, parent); + } + + public Iterable variables() { + return () -> { + return new Iterator() { + int type_index = 0; + VariableType[] types = Arrays.stream(VariableType.values()).filter(x -> varSets.get(x).size() > 0).toArray(VariableType[]::new); + + TIntIterator setIterator = varSets.get(types[0]).iterator(); + + @Override + public TypedVariable next() { + if (!setIterator.hasNext()) { + ++type_index; + if (type_index < types.length) { + setIterator = varSets.get(types[type_index]).iterator(); + } else { + throw new NoSuchElementException(); + } + } + int id = setIterator.next(); + int size = cardinalities.get(id); + return new TypedVariable(id, size, types[type_index]); + } + + @Override + public boolean hasNext() { + boolean setn = setIterator.hasNext(); + boolean typn = type_index < types.length - 1; + + return setn || typn; // either set has more or we're not in the last set + } + }; + }; + } + + @Override + public void removeVariable(int variable) { + super.removeVariable(variable); + + // remove from sets (if needed) + @SuppressWarnings("unused") + boolean changed = this.varSets.get(VariableType.ENDOGENOUS).remove(variable) + || this.varSets.get(VariableType.EXOGENOUS).remove(variable) + || this.varSets.get(VariableType.EXTRA).remove(variable); + } + + + + +} diff --git a/src/main/java/ch/idsia/crema/model/causal/TypedVariable.java b/src/main/java/ch/idsia/crema/model/causal/TypedVariable.java new file mode 100644 index 00000000..17cc55e5 --- /dev/null +++ b/src/main/java/ch/idsia/crema/model/causal/TypedVariable.java @@ -0,0 +1,17 @@ +package ch.idsia.crema.model.causal; + +import ch.idsia.crema.core.Variable; +import ch.idsia.crema.model.causal.SCM.VariableType; + +public class TypedVariable extends Variable { + VariableType type; + + public TypedVariable(int label, int cardinality, VariableType type) { + super(label, cardinality); + this.type = type; + } + + public VariableType getType() { + return type; + } +} diff --git a/src/main/java/ch/idsia/crema/model/causal/WorldModel.java b/src/main/java/ch/idsia/crema/model/causal/WorldModel.java new file mode 100644 index 00000000..64c37dee --- /dev/null +++ b/src/main/java/ch/idsia/crema/model/causal/WorldModel.java @@ -0,0 +1,70 @@ +package ch.idsia.crema.model.causal; + +import ch.idsia.crema.core.Variable; + +public interface WorldModel { + /** + * Get the global SCM model + * @return SCM the global model or null if no model have been added + */ + public SCM get(); + + /** + * Add a model to the mapping assuming exogenous ids match. + * + * @param model the {@link SCM} model to be added + * @return int the id of the added worlds + */ + public int add(SCM model); + + /** + * Translate a local variable of world wid to the global id + * + * @param variable the local variable + * @param wid the id of the source world + * @return int the global id of the variable + */ + public int toGlobal(int variable, int wid); + + /** + * Translate a global id to a local one. + * + * @param variable the global variable id + * @return int a local variable id + */ + public int fromGlobal(int variable); + + /** + * Get the world id associated to the specified global variable id. + * For shared variables (i.e. exogenous) the method return -1; + * + * @param variable the global variable id + * @return int the world id of the specified global variable or -1 it exogenous + */ + public int worldIdOf(int variable); + + /** + * Get the world associated to the specified global variable id. + * The method returns null for shared (i.e. exogenous) variables. + * + * @param variable the global variable id + * @return SCM the source Structural Causal Model or null if variable is exogenous + */ + public SCM worldOf(int variable); + + default int[] toGlobal(int[] variables, int wid) { + int[] result = new int[variables.length]; + for (int i = 0; i < result.length; ++i) { + result[i] = toGlobal(variables[i], wid); + } + return result; + } + + default int[] fromGlobal(int[] variables) { + int[] result = new int[variables.length]; + for (int i = 0; i < result.length; ++i) { + result[i] = fromGlobal(variables[i]); + } + return result; + } +} diff --git a/src/main/java/ch/idsia/crema/model/causal/mapping/TreeMapping.java b/src/main/java/ch/idsia/crema/model/causal/mapping/TreeMapping.java new file mode 100644 index 00000000..8065bbc4 --- /dev/null +++ b/src/main/java/ch/idsia/crema/model/causal/mapping/TreeMapping.java @@ -0,0 +1,214 @@ +package ch.idsia.crema.model.causal.mapping; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.apache.commons.lang3.tuple.Pair; + +import ch.idsia.crema.core.Strides; +import ch.idsia.crema.factor.bayesian.BayesianFactor; +import ch.idsia.crema.factor.bayesian.BayesianFactorFactory; +import ch.idsia.crema.model.causal.SCM; +import ch.idsia.crema.model.causal.SCM.VariableType; +import ch.idsia.crema.model.causal.WorldModel; +import ch.idsia.crema.utility.ArraysUtil; +import gnu.trove.map.TIntIntMap; +import gnu.trove.map.TIntObjectMap; +import gnu.trove.map.hash.TIntIntHashMap; +import gnu.trove.map.hash.TIntObjectHashMap; + +public class TreeMapping implements WorldModel { + + /** + * Mapping from source to target for each world id + * Key is world id, value is mapping between source and global id + */ + private TIntObjectMap toGlobal; + + /** + * A map from a global id to a pair + */ + private TIntObjectMap> fromGlobal; + + /** + * A map from global exogenous id to local exogenous id + */ + private TIntIntMap globalToLocalExogenous; + private TIntIntMap localToGlobalExogenous; + + /** + * the global model, initialized at the first add. + */ + private SCM global; + + /** + * list of source models who's id are the index in this list + */ + private List worlds; + + /** + * Create a new Tree Mapping object. + */ + public TreeMapping() { + this.toGlobal = new TIntObjectHashMap(); + this.fromGlobal = new TIntObjectHashMap>(); + this.globalToLocalExogenous = new TIntIntHashMap(); + this.localToGlobalExogenous= new TIntIntHashMap(); + this.worlds = new ArrayList(); + } + + + @Override + public SCM get() { + return global; + } + + /** + * Store the mapping for the given world + * @param wid world id + * @param translate a map from local to global + */ + protected void saveMapping(int wid, TIntIntMap translate) { + + toGlobal.put(wid, translate); + + var iter = translate.iterator(); + while(iter.hasNext()) { + iter.advance(); + var source = iter.key(); + var target = iter.value(); + if (localToGlobalExogenous.containsKey(source)) continue; + + fromGlobal.put(target, Pair.of(wid, source)); + } + } + + /** + * Rename a list of source variable to the target + * @param ids + * @param translate + * @return + */ + private int[] rename(int[] ids, TIntIntMap translate) { + int[] target = new int[ids.length]; + for (int i = 0; i < ids.length; ++i) { + target[i] = translate.get(ids[i]); + } + return target; +// return Arrays.stream(ids).map(translate::get).toArray(); + } + + /** + * Convert a bayesian factor to a new variables set + * @param factor + * @param translate + * @return + */ + private BayesianFactor rename(BayesianFactor factor, TIntIntMap translate) { + Strides domain = factor.getDomain(); + int[] newvars = Arrays.stream(domain.getVariables()).map(translate::get).toArray(); + + int[] order = ArraysUtil.order(newvars); + double[] source_data = factor.getData(); + +// int[] source = domain.getVariables().clone(); +// source = ArraysUtil.at(source, order); +// +// var iterator = domain.getReorderedIterator(source); +// int tid = 0; +// +// double[] target = new double[domain.getCombinations()]; +// +// while(iterator.hasNext()) { +// target[tid++] = source_data[iterator.next()]; +// } + + int[] target_vars = ArraysUtil.at(newvars, order); + int[] target_size = ArraysUtil.at(domain.getSizes(), order); + + return BayesianFactorFactory.factory().domain(target_vars, target_size, true).data(newvars, source_data).get(); + } + + /** + * Connect a new model to the global one. + */ + @Override + public int add(SCM model) { + int wid = worlds.size(); + worlds.add(model); + + if (global == null) { + global = new SCM(); + } + + TIntIntMap translate = new TIntIntHashMap(); + for (var source : model.variables()) { + if (source.getType() == VariableType.EXOGENOUS) { + if (!localToGlobalExogenous.containsKey(source.getLabel())) { + int id = global.add(source.getCardinality(), source.getType()); + localToGlobalExogenous.put(source.getLabel(), id); + globalToLocalExogenous.put(id, source.getLabel()); + translate.put(source.getLabel(), id); + } else { + int id = localToGlobalExogenous.get(source.getLabel()); + translate.put(source.getLabel(), id); + } + } else { + int id = global.add(source.getCardinality(), source.getType()); + translate.put(source.getLabel(), id); + } + } + + + for (var source : model.variables()) { + if(source.getType() == VariableType.EXOGENOUS) continue; + int[] parents = model.getParents(source.getLabel()); + parents = rename(parents, translate); + + int globalId = translate.get(source.getLabel()); + global.addParents(globalId, parents); + } + + for (var source : model.variables()) { + var factor = model.getFactor(source.getLabel()); + if (factor == null) continue; + + factor = rename(factor, translate); + int tid = translate.get(source.getLabel()); + global.setFactor(tid, factor); + } + + saveMapping(wid, translate); + return wid; + } + + + @Override + public int toGlobal(int variable, int world) { + return toGlobal.get(world).get(variable); + } + + @Override + public int fromGlobal(int variable) { + if (globalToLocalExogenous.containsKey(variable)) { + return globalToLocalExogenous.get(variable); + } + var p = fromGlobal.get(variable); + return p.getValue(); + } + + @Override + public int worldIdOf(int variable) { + if (globalToLocalExogenous.containsKey(variable)) return -1; + var p = fromGlobal.get(variable); + return p.getKey(); + } + + @Override + public SCM worldOf(int variable) { + if (globalToLocalExogenous.containsKey(variable)) return null; + return worlds.get(worldIdOf(variable)); + } + +} diff --git a/src/main/java/ch/idsia/crema/model/graphical/DAGModel.java b/src/main/java/ch/idsia/crema/model/graphical/DAGModel.java index 8402a7b1..da76a200 100644 --- a/src/main/java/ch/idsia/crema/model/graphical/DAGModel.java +++ b/src/main/java/ch/idsia/crema/model/graphical/DAGModel.java @@ -22,6 +22,7 @@ import java.util.Arrays; import java.util.Collection; import java.util.function.BiFunction; +import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -225,10 +226,19 @@ public void removeParent(int variable, int parent) { network.removeEdge(parent, variable); } +// @Override +// public void removeParent(int variable, int parent, DomainChange change) { +// F factor = factors.get(variable); +// F new_factor = change.remove(factor, parent); +// if (factor != new_factor) +// factors.put(variable, new_factor); +// network.removeEdge(parent, variable); +// } +// @Override - public void removeParent(int variable, int parent, DomainChange change) { + public void removeParent(int variable, int parent, Function change) { F factor = factors.get(variable); - F new_factor = change.remove(factor, parent); + F new_factor = change.apply(factor); if (factor != new_factor) factors.put(variable, new_factor); network.removeEdge(parent, variable); @@ -293,6 +303,15 @@ public int[] getLeaves() { return list.toArray(); } + + @Override + public Strides getFullDomain(int variable) { + int[] vars = getParents(variable); + vars = ArraysUtil.append(vars, variable); + Arrays.sort(vars); + return new Strides(vars, getSizes(vars)); + } + @Override public Strides getDomain(int... variables) { return new Strides(variables, getSizes(variables)); @@ -304,6 +323,10 @@ public void addVariables(int... states) { } } + /** + * Implemented as a sequence of addParent calls. + */ + @Override public void addParents(int variable, int... parents) { for (int parent : parents) { addParent(variable, parent); @@ -338,6 +361,7 @@ public Collection getFactors() { } public Collection getFactors(int... variables) { + return IntStream.of(variables).mapToObj(v -> factors.get(v)).collect(Collectors.toList()); } @@ -363,7 +387,7 @@ public void setFactors(F[] factors) { * * @return */ - public boolean correctFactorDomains() { + public boolean checkFactorsDAGConsistency() { return IntStream.of(this.getVariables()) .allMatch(v -> Arrays.equals( ArraysUtil.sort(this.getFactor(v).getDomain().getVariables()), diff --git a/src/main/java/ch/idsia/crema/model/graphical/GraphicalModel.java b/src/main/java/ch/idsia/crema/model/graphical/GraphicalModel.java index 47f28ad5..7bb849eb 100644 --- a/src/main/java/ch/idsia/crema/model/graphical/GraphicalModel.java +++ b/src/main/java/ch/idsia/crema/model/graphical/GraphicalModel.java @@ -1,5 +1,9 @@ package ch.idsia.crema.model.graphical; +import java.util.function.BiFunction; +import java.util.function.Function; + +import ch.idsia.crema.core.Strides; import ch.idsia.crema.factor.GenericFactor; import ch.idsia.crema.model.Model; import ch.idsia.crema.model.change.DomainChange; @@ -8,11 +12,32 @@ // FIXME: #removeParent should accept a lambda public interface GraphicalModel extends Model { - + /** + * Get the full domain associated with the specified variable. + * This includes the variable and its parents. + * + * @param variable the variable who's domain will be returned + * @return the variable and its parents as a Sorted Stride object + */ + Strides getFullDomain(int variable); + + /** + * Remove a variable's parents + * + * @param variable + * @param parent + */ void removeParent(int variable, int parent); - void removeParent(int variable, int parent, DomainChange change); + /** + * Remove a variable's parent indicating how to deal with the domain change + * @param variable + * @param parent + * @param change + */ + void removeParent(int variable, int parent, Function change); + void addParent(int variable, int parent); /** @@ -49,4 +74,6 @@ default void addParents(int k, int[] parent) { addParent(k, p); } } + + } \ No newline at end of file diff --git a/src/main/java/ch/idsia/crema/model/io/dot/DetailedDotSerializer.java b/src/main/java/ch/idsia/crema/model/io/dot/DetailedDotSerializer.java new file mode 100644 index 00000000..acb50981 --- /dev/null +++ b/src/main/java/ch/idsia/crema/model/io/dot/DetailedDotSerializer.java @@ -0,0 +1,211 @@ +package ch.idsia.crema.model.io.dot; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.text.DecimalFormat; +import java.text.NumberFormat; +import java.util.Arrays; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import ch.idsia.crema.core.Strides; +import ch.idsia.crema.data.DoubleTable; +import ch.idsia.crema.factor.bayesian.BayesianFactor; +import ch.idsia.crema.model.causal.SCM; +import ch.idsia.crema.model.graphical.BayesianNetwork; +import ch.idsia.crema.model.graphical.GraphicalModel; +import ch.idsia.crema.utility.IndexIterator; + +public class DetailedDotSerializer { + + + + + protected static Function nodeName(GraphicalModel model) { + if (model instanceof BayesianNetwork) { + return (node) -> { + return "X" + node + ""; + }; + } else { + final SCM sm = (SCM) model; + return (node) -> { + String label; + if (sm.isEndogenous(node)) { + label = "X"; + } else if (sm.isExogenous(node)) { + label = "U"; + } else { + label = "W"; + } + + return label + "" + node + ""; + }; + } + } + + + protected String apply(DoubleTable table, Function nodeName) { + NumberFormat formatter = new DecimalFormat("0.###"); + + StringBuilder builder = new StringBuilder(); + builder.append("").append(""); + + for (int col : table.getColumns()) { + builder.append(""); + } + builder.append(""); + + // contents + var iter = table.iterator(); + while (iter.hasNext()) { + var row = iter.next(); + builder.append(""); + + int[] states = row.getKey(); + Arrays.stream(states).forEach(s -> builder.append("")); + + builder.append(""); + + builder.append(""); + + } + builder.append("
").append(nodeName.apply(col)) + .append("W
").append(s).append("").append(formatter.format(row.getValue())).append("
"); + return builder.toString(); + } + + public String apply(Info record) { + String name = record.getModelName(); + GraphicalModel model = record.getModel(); + final Function nodeName = record.getNodeName() == null ? DetailedDotSerializer.nodeName(model) : record.getNodeName(); + + var highlight = record.getHighlight(); + + + StringBuilder builder = new StringBuilder(); + if (name == null) + name = "model"; + + builder.append("digraph \"").append(name).append("\" {\n node [shape=none];\n").append("\n"); + + if (record.getData() != null) + builder.append("TABLE [label=<").append(apply(record.getData(), nodeName)).append(">]\n"); + + StringBuilder arcs = new StringBuilder(); + + NumberFormat formatter = new DecimalFormat("0.###"); + + for (int i : model.getVariables()) { + + int[] parents = model.getParents(i); + + BayesianFactor factor = model.getFactor(i); + if (factor != null) { + Strides domain = factor.getDomain(); + Strides conditioning = domain.remove(i); + + int csize = conditioning.getCombinations(); + builder.append("N").append(i).append("[label=<"); + builder.append(""); + + for (int p : conditioning.getVariables()) { + builder.append(""); + builder.append(""); + + int stride = conditioning.getStride(p); + int size = conditioning.getCardinality(p); + for (int s = 0; s < csize; s += stride) { + int state = (s / stride) % size; + builder.append(""); + } + builder.append(""); + } + + int stride = domain.getStride(i); + int states = domain.getCardinality(i); + for (int state = 0; state < states; ++state) { + builder.append(""); + if (state == 0) { + builder.append(""); + } + + builder.append(""); + + IndexIterator iter = factor.getDomain().getIterator(conditioning); + while (iter.hasNext()) { + int index = iter.next(); + boolean h = highlight != null && highlight.containsKey(i) && highlight.get(i).contains(index); + double value = factor.getValueAt(index + state * stride); + if (h) + builder.append(""); + } + builder.append(""); + } + + builder.append("
"); + builder.append("P(").append(nodeName.apply(i)); + if (conditioning.getSize() > 0) { + builder.append("|"); + String c = IntStream.of(conditioning.getVariables()).mapToObj(a -> nodeName.apply(a)) + .collect(Collectors.joining(",")); + builder.append(c); + } + builder.append(")
").append(nodeName.apply(p)).append(""); + builder.append(state).append("
") + .append(nodeName.apply(i)).append("").append(state).append(""); + else + builder.append(""); + builder.append(formatter.format(value)); + if (h) + builder.append(""); + builder.append("
>];\n"); + + } else { + builder.append("N").append(i).append("[shape=\"circle\" label=<").append(nodeName.apply(i)).append(">];\n"); + } + + for (int parent : parents) { + arcs.append('N').append(parent).append(" -> N").append(i).append(";\n"); + } + } + + builder.append(arcs); + + if (record.getTitle() != null) + builder.append("labelloc=\"t\"\nlabel=\"").append(record.getTitle()).append(" ").append(record.getRunId()).append("\"\n"); + + builder.append("}"); + return builder.toString(); + + } + + + + + public static void saveModel(String filename, Info r) { + try { + DetailedDotSerializer serializer = new DetailedDotSerializer(); + + File f = File.createTempFile(filename, ".dot"); + + String file = serializer.apply(r); + Files.writeString(Path.of(f.getAbsolutePath()), file); + ProcessBuilder b = new ProcessBuilder("/opt/homebrew/bin/dot", "-Tpng", "-o", filename, + f.getAbsolutePath()); + Process p = b.start(); + p.waitFor(); + } catch (IOException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } catch (InterruptedException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } + + } +} diff --git a/src/main/java/ch/idsia/crema/model/io/dot/Info.java b/src/main/java/ch/idsia/crema/model/io/dot/Info.java new file mode 100644 index 00000000..b0473c49 --- /dev/null +++ b/src/main/java/ch/idsia/crema/model/io/dot/Info.java @@ -0,0 +1,134 @@ +package ch.idsia.crema.model.io.dot; + +import java.util.Map; +import java.util.Set; +import java.util.function.Function; + +import ch.idsia.crema.data.DoubleTable; +import ch.idsia.crema.factor.bayesian.BayesianFactor; +import ch.idsia.crema.model.graphical.GraphicalModel; + + +public class Info { + private GraphicalModel model; + private DoubleTable data; + private String modelName; + private String title; + + private int runId; + private int iterations; + private int pscmId; + private int pscmIterations; + + private Function nodeName; + private Map> highlight; + + public Info model(GraphicalModel model) { + this.model = model; + return this; + } + + public GraphicalModel getModel() { + return model; + } + + public DoubleTable getData() { + return data; + } + + public Info data(DoubleTable data) { + this.data = data; + return this; + } + + public String getModelName() { + return modelName; + } + + public Info modelName(String modelName) { + this.modelName = modelName; + return this; + } + + public String getTitle() { + return title; + } + + public Info title(String title) { + this.title = title; + return this; + } + + public Function getNodeName() { + return nodeName; + } + + public Info nodeName(Function nodeName) { + this.nodeName = nodeName; + return this; + } + + public Map> getHighlight() { + return highlight; + } + + public Info highlight(Map> highlight) { + this.highlight = highlight; + return this; + } + + public Info runId(int runid) { + this.runId = runid; + return this; + } + + public Info iterations(int iterations) { + this.iterations = iterations; + return this; + } + + public Info PSCMId(int pscmrun) { + this.pscmId = pscmrun; + return this; + } + + public Info PSCMIterations(int piter) { + this.pscmIterations = piter; + return this; + } + + public int getRunId() { + return runId; + } + public int getIterations() { + return iterations; + } + + public int getPSCMId() { + return pscmId; + } + + public int getPSCMIterations() { + return pscmIterations; + } + + + public Info(GraphicalModel model, DoubleTable data, String modelName, String title, + Function nodeName, Map> highlight, int run, int iter, int pscm, + int pscmiter) { + super(); + this.model = model; + this.data = data; + this.modelName = modelName; + this.title = title; + this.nodeName = nodeName; + this.highlight = highlight; + this.runId = run; + this.iterations = iter; + this.pscmId = pscm; + this.pscmIterations = pscmiter; + } + + public Info() { + } +} diff --git a/src/main/java/ch/idsia/crema/model/io/uai/NetUAIWriter.java b/src/main/java/ch/idsia/crema/model/io/uai/NetUAIWriter.java index 46a3f17e..7b592996 100644 --- a/src/main/java/ch/idsia/crema/model/io/uai/NetUAIWriter.java +++ b/src/main/java/ch/idsia/crema/model/io/uai/NetUAIWriter.java @@ -15,7 +15,7 @@ public NetUAIWriter(T target, String filename) { @Override protected void sanityChecks() { // Check model consistency - if (!target.correctFactorDomains()) + if (!target.checkFactorsDAGConsistency()) throw new IllegalArgumentException("Inconsistent model"); } diff --git a/src/main/java/ch/idsia/crema/preprocess/Do.java b/src/main/java/ch/idsia/crema/preprocess/Do.java new file mode 100644 index 00000000..ef285c1a --- /dev/null +++ b/src/main/java/ch/idsia/crema/preprocess/Do.java @@ -0,0 +1,34 @@ +package ch.idsia.crema.preprocess; + +import ch.idsia.crema.factor.OperableFactor; +import ch.idsia.crema.model.graphical.GraphicalModel; +import gnu.trove.map.TIntIntMap; + +/** + * Network surgery for do operations. + * + * @param type of the factor + * @param type of the Graphical Model + */ +public class Do , M extends GraphicalModel> +implements ConverterEvidence { + + public M execute(M model, TIntIntMap dos) { + M copy = (M) model.copy(); + + int[] keys = dos.keys(); + for (int key : keys) { + + // select the correct part of the factors by removing a child + // via domain changer + for (int child : model.getChildren(key)) { + model.removeParent(child, key, (f) -> f.filter(key, dos.get(key))); + } + + // completely remove the var now. + model.removeVariable(key); + } + + return copy; + } +} diff --git a/src/main/java/ch/idsia/crema/preprocess/CutObserved.java b/src/main/java/ch/idsia/crema/preprocess/Observe.java similarity index 82% rename from src/main/java/ch/idsia/crema/preprocess/CutObserved.java rename to src/main/java/ch/idsia/crema/preprocess/Observe.java index edd010e7..d08cc169 100644 --- a/src/main/java/ch/idsia/crema/preprocess/CutObserved.java +++ b/src/main/java/ch/idsia/crema/preprocess/Observe.java @@ -11,7 +11,7 @@ * * @author huber */ -public class CutObserved> implements TransformerEvidence>, +public class Observe> implements TransformerEvidence>, PreprocessorEvidence> { /** @@ -34,12 +34,9 @@ public void executeInPlace(GraphicalModel model, TIntIntMap evidence) { final int state = iterator.value(); for (int variable : model.getChildren(observed)) { - model.removeParent(variable, observed, new NullChange<>() { - @Override - public F remove(F factor, int variable) { - // probably need to check this earlier - return factor.filter(observed, state); - } + model.removeParent(variable, observed, (factor) -> { + // probably need to check this earlier + return factor.filter(observed, state); }); } } diff --git a/src/main/java/ch/idsia/crema/utility/ArraysMath.java b/src/main/java/ch/idsia/crema/utility/ArraysMath.java new file mode 100644 index 00000000..0161c6cf --- /dev/null +++ b/src/main/java/ch/idsia/crema/utility/ArraysMath.java @@ -0,0 +1,128 @@ +package ch.idsia.crema.utility; + +import org.apache.commons.math3.util.FastMath; + +public class ArraysMath { + + public static double sum(double[] values) { + double total = 0; + for (int i = 0; i < values.length; ++i) { + total += values[i]; + } + return total; + } + + public static int sum(int[] values) { + int total = 0; + for (int i = 0; i < values.length; ++i) { + total += values[i]; + } + return total; + } + + public static long sum(long[] values) { + long total = 0; + for (int i = 0; i < values.length; ++i) { + total += values[i]; + } + return total; + } + + + public static int max(int[] array) { + int m = Integer.MIN_VALUE; + for (int i = 0; i < array.length;++i) { + m = Math.max(m, array[i]); + } + return m; + } + + public static int min(int[] array) { + int m = Integer.MAX_VALUE; + for (int i = 0; i < array.length;++i) { + m = Math.min(m, array[i]); + } + return m; + } + + public static double max(double[] array) { + double m = Double.NEGATIVE_INFINITY; + for (int i = 0; i < array.length;++i) { + m = Math.max(m, array[i]); + } + return m; + } + + public static double min(double[] array) { + double m = Double.POSITIVE_INFINITY; + for (int i = 0; i < array.length;++i) { + m = Math.min(m, array[i]); + } + return m; + } + + + public static double mean(int[] array) { + double m = 0; + // consider using https://en.wikipedia.org/wiki/Kahan_summation_algorithm + for (int i = 0; i < array.length;++i) { + m += array[i]; + } + return m / (double)array.length; + } + + + + public static double mean(long[] array) { + double m = 0; + // consider using https://en.wikipedia.org/wiki/Kahan_summation_algorithm + for (int i = 0; i < array.length;++i) { + m += array[i]; + } + return m / (double)array.length; + } + + + public static double mean(double[] array) { + double m = 0; + // consider using https://en.wikipedia.org/wiki/Kahan_summation_algorithm + for (int i = 0; i < array.length;++i) { + m += array[i]; + } + return m / (double)array.length; + } + + + public static double sd(int[] array, double ddof) { + double m = mean(array); + double s = 0; + for (int i = 0; i < array.length;++i) { + double d = (double)array[i] - m; + s += d * d; + } + return FastMath.sqrt(s / ((double)array.length - ddof)); + } + + + + public static double sd(long[] array, double ddof) { + double m = mean(array); + double s = 0; + for (int i = 0; i < array.length;++i) { + double d = (double)array[i] - m; + s += d * d; + } + return FastMath.sqrt(s / ((double)array.length - ddof)); + } + + public static double sd(double[] array, double ddof) { + double m = mean(array); + double s = 0; + for (int i = 0; i < array.length;++i) { + double d = array[i] - m; + s += d * d; + } + return FastMath.sqrt(s / ((double)array.length - ddof)); + } + +} diff --git a/src/main/java/ch/idsia/crema/utility/ArraysUtil.java b/src/main/java/ch/idsia/crema/utility/ArraysUtil.java index 4b90d166..bbd40638 100644 --- a/src/main/java/ch/idsia/crema/utility/ArraysUtil.java +++ b/src/main/java/ch/idsia/crema/utility/ArraysUtil.java @@ -5,6 +5,11 @@ import com.google.common.primitives.Doubles; import com.google.common.primitives.Ints; import gnu.trove.list.array.TIntArrayList; +import gnu.trove.set.TDoubleSet; +import gnu.trove.set.TIntSet; +import gnu.trove.set.hash.TDoubleHashSet; +import gnu.trove.set.hash.TIntHashSet; + import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.math3.util.FastMath; import org.apache.commons.math3.util.Precision; @@ -121,6 +126,9 @@ public static int[] sort(int[] base) { return copy; } + /** + * A helper class for the order functions + */ private static class X { public final int pos; public final int val; @@ -149,17 +157,24 @@ public static int[] order(int[] data, final IntComparator comparator) { return positions; } + /** + * get the ordering of the data. can be used to sort multiple arrays using a + * reference one i.e. get variables order and sort sizes. + * + * @param data + * @return + */ public static int[] order(int[] data) { - List internal = new ArrayList<>(data.length); + X[] internal = new X[data.length]; for (int i = 0; i < data.length; ++i) { - internal.add(new X(i, data[i])); + internal[i] = new X(i, data[i]); } - internal.sort(Comparator.comparingInt(o -> o.val)); + Arrays.sort(internal, Comparator.comparingInt(o -> o.val)); int[] positions = new int[data.length]; for (int i = 0; i < data.length; ++i) { - X x = internal.get(i); + X x = internal[i]; positions[i] = x.pos; data[i] = x.val; } @@ -167,6 +182,22 @@ public static int[] order(int[] data) { return positions; } + /** + * Helper method to index multiple items in one shot. + * + * @param values + * @param indices + * @return + */ + public static int[] at(int[] from, int[] indices) { + // return IntStream.of(indices).map(v -> values[v]).toArray(); + int[] x = new int[indices.length]; + for (int i = 0; i < indices.length; ++i) { + x[i] = from[indices[i]]; + } + return x; + } + public static double[][] deepClone(double[][] data) { double[][] result = new double[data.length][]; for (int i = 0; i < data.length; ++i) { @@ -184,7 +215,8 @@ public static double[][][] deepClone(double[][][] data) { } /** - * Convert an array in log-space using {@link FastMath#log(double)}. Creates a new array. + * Convert an array in log-space using {@link FastMath#log(double)}. Creates a + * new array. * * @param data input data * @return the input data in log-space. @@ -198,7 +230,8 @@ public static double[] log(double[] data) { } /** - * Convert an array of array in log-space using {@link FastMath#log(double)}. Creates a new array. + * Convert an array of array in log-space using {@link FastMath#log(double)}. + * Creates a new array. * * @param data input data * @return the input data in log-space. @@ -212,7 +245,8 @@ public static double[][] log(double[][] data) { } /** - * Convert an array in log-space using {@link FastMath#log1p(double)}. Creates a new array. + * Convert an array in log-space using {@link FastMath#log1p(double)}. Creates a + * new array. * * @param data input data * @return the input data in log-space. @@ -226,7 +260,8 @@ public static double[] log1p(double[] data) { } /** - * Convert an array from log-space to normal space using {@link FastMath#exp(double)}. Creates a new array. + * Convert an array from log-space to normal space using + * {@link FastMath#exp(double)}. Creates a new array. * * @param data input data * @return the input data in log-space. @@ -240,7 +275,8 @@ public static double[] exp(double[] data) { } /** - * Convert an array of array from log-space to normal space using {@link FastMath#exp(double)}. Creates a new array. + * Convert an array of array from log-space to normal space using + * {@link FastMath#exp(double)}. Creates a new array. * * @param data input data * @return the input data in log-space. @@ -349,18 +385,19 @@ public static int[] removeFromSortedArray(int[] array, int element) { * containing the specified element. The original array is returned if the * element was already part of the array. *

- * Expects a sorted array! Behaviours is rather unpredictable if array is not sorted. + * Expects a sorted array! Behaviours is rather unpredictable if array is not + * sorted. * * @param array the sorted array * @param element the item to be added to the array * @return a new array containing the element or the original one if element is - * already present + * already present */ public static int[] addToSortedArray(int[] array, int element) { if (array == null || array.length == 0) { // when no items in the array then return a new array with only the // item - return new int[]{element}; + return new int[] { element }; } // look for existing links @@ -472,15 +509,113 @@ public static int[] union(int[] arr1, int[] arr2) { return arr_union; } + public static int[] union_sorted_set(int[] arr1, int[] arr2) { + final int s1 = arr1.length; + final int s2 = arr2.length; + + // size the target arrays assuming no overlap + final int max = s1 + s2; + int[] arr_union = new int[max]; + + // (pt1) c1 and c2 are the positions in the two domains + int c1 = 0; + int c2 = 0; + + int t = 0; + int last = 0; + while (c1 < s1 && c2 < s2) { + int v1 = arr1[c1]; + int v2 = arr2[c2]; + if (t > 0) { + if (v1 == v2 && last == v1) { + ++c1; + ++c2; + continue; + } else if (v1 == last) { + ++c1; + continue; + } else if (v2 == last) { + ++c2; + continue; + } + } + + if (v1 < v2) { + last = arr_union[t] = v1; + ++c1; + } else if (v1 > v2) { + last = arr_union[t] = v2; + ++c2; + } else { + last = arr_union[t] = v1; + ++c1; + ++c2; + } + + ++t; + } + + // (pt2) check if there is one domain not completely copied that can be + // moved over in bulk. + for (; c1 < s1; ++c1) { + int a = arr1[c1]; + if (t == 0 || a != last) + last = arr_union[t++] = a; + } + for (; c2 < s2; ++c2) { + int a = arr2[c2]; + if (t == 0 || a != last) + last = arr_union[t++] = a; + + } + + // fix array sizes if there was overlap (we assumed no overlap while sizing) + if (t < max) { + arr_union = Arrays.copyOf(arr_union, t); + } + return arr_union; + } + /** - * Find the sorted difference of two non-sorted integer arrays. - * + * Find the sorted difference of two non-sorted integer arrays. in arr1 but not + * in arr2 + * * @param arr1 the first array * @param arr2 the second array * @return an array intersection of the first two */ +// public static int[] difference(int[] arr1, int[] arr2) { +// return IntStream.of(arr1).filter(y -> IntStream.of(arr2).noneMatch(x -> x == y)).toArray(); +// } public static int[] difference(int[] arr1, int[] arr2) { - return IntStream.of(arr1).filter(y -> IntStream.of(arr2).noneMatch(x -> x == y)).toArray(); + TIntSet a2 = new TIntHashSet(arr2); + + int[] tmp = new int[arr1.length]; + int used = 0; + + for (int o : arr1) { + if (!a2.contains(o)) + tmp[used++] = o; + } + + int[] target = new int[used]; + System.arraycopy(tmp, 0, target, 0, used); + Arrays.sort(target); + return target; + } + + public static int[] differenceSet(int[] arr1, int[] arr2) { + TIntSet a2 = new TIntHashSet(arr2); + TIntSet target = new TIntHashSet(arr1.length); + // target.removeAll(a2); + for (int o : arr1) { + if (!a2.contains(o)) + target.add(o); + } + + int[] ok = target.toArray(); + Arrays.sort(ok); + return ok; } /** @@ -491,17 +626,19 @@ public static int[] difference(int[] arr1, int[] arr2) { * @return symetric difference of both arrays */ public static int[] symmetricDiff(int[] arr1, int[] arr2) { - return unionSet(difference(arr1, arr2), difference(arr2, arr1)); + return union_sorted_set(difference(arr1, arr2), difference(arr2, arr1)); } /** * @param arr1 first array * @param arr2 second array - * @return an array which is the union of the two arrays without the common elements + * @return an array which is the union of the two arrays without the common + * elements */ public static int[] outersection(int[] arr1, int[] arr2) { final int[] intersection = intersection(arr1, arr2); - return unionSet(difference(arr1, intersection), difference(arr2, intersection)); + // difference returns a sorted set + return union_sorted_set(difference(arr1, intersection), difference(arr2, intersection)); } /** @@ -515,6 +652,14 @@ public static int[] intersection(int[] arr1, int[] arr2) { return IntStream.of(arr1).filter(y -> IntStream.of(arr2).anyMatch(x -> x == y)).toArray(); } + public static int[] intersection2(int[] arr1, int[] arr2) { + TIntSet a1 = new TIntHashSet(arr1); + a1.retainAll(arr2); + int[] a = a1.toArray(); + Arrays.sort(a); + return a; + } + /** * Find the sorted intersection of two sorted integer arrays. * @@ -564,10 +709,18 @@ public static int[] intersectionSorted(int[] arr1, int[] arr2) { * @param arr2 * @return */ - public static int[] unionSet(int[] arr1, int[] arr2) { + public static int[] unionSet2(int[] arr1, int[] arr2) { return Ints.toArray(ImmutableSet.copyOf(Ints.asList(Ints.concat(arr1, arr2)))); } + public static int[] union_unsorted_set(int[] arr1, int[] arr2) { + TIntSet set = new TIntHashSet(arr1); + set.addAll(arr2); + int[] res = set.toArray(); + Arrays.sort(res); + return res; + } + /** * Normalize an array by fixing the last value of the array so that it sums up * to the target value. @@ -627,7 +780,7 @@ public static int indexOf(int needle, int[] haystack) { public static int[] getShape(double[][] matrix) { if (Arrays.stream(matrix).map(v -> v.length).distinct().count() != 1) throw new IllegalArgumentException("ERROR: nested vectors do not have the same length"); - return new int[]{matrix.length, matrix[0].length}; + return new int[] { matrix.length, matrix[0].length }; } /** @@ -640,7 +793,7 @@ public static int[] getShape(double[][] matrix) { public static int[] getShape(int[][] matrix) { if (Arrays.stream(matrix).map(v -> v.length).distinct().count() != 1) throw new IllegalArgumentException("ERROR: nested vectors do not have the same length"); - return new int[]{matrix.length, matrix[0].length}; + return new int[] { matrix.length, matrix[0].length }; } /** @@ -685,7 +838,7 @@ public static int[][] transpose(int[][] original) { public static double[][] reshape2d(double[] vector, int... shape) { if (shape.length == 1) - shape = new int[]{shape[0], vector.length / shape[0]}; + shape = new int[] { shape[0], vector.length / shape[0] }; if (shape[0] * shape[1] != vector.length) throw new IllegalArgumentException("ERROR: incompatible shapes"); @@ -776,7 +929,7 @@ public static double[] changeEndian(double[] data, int[] oldSizes) { * @return */ public static double[][] enumerate(double[] vect, int start) { - return IntStream.range(start, vect.length + start).mapToObj(i -> new double[]{i, vect[i - start]}) + return IntStream.range(start, vect.length + start).mapToObj(i -> new double[] { i, vect[i - start] }) .toArray(double[][]::new); } @@ -797,8 +950,14 @@ public static double[][] enumerate(double[] vect) { * @param arr * @return */ +// public static int[] unique(int[] arr) { +// return Ints.toArray(ImmutableSet.copyOf(Ints.asList(arr))); +// } public static int[] unique(int[] arr) { - return Ints.toArray(ImmutableSet.copyOf(Ints.asList(arr))); + TIntSet set = new TIntHashSet(arr); + int[] a = set.toArray(); + Arrays.sort(a); + return a; } /** @@ -807,10 +966,15 @@ public static int[] unique(int[] arr) { * @param arr * @return */ +// public static double[] unique2(double[] arr) { +// return Doubles.toArray(ImmutableSet.copyOf(Doubles.asList(arr))); +// } public static double[] unique(double[] arr) { - return Doubles.toArray(ImmutableSet.copyOf(Doubles.asList(arr))); + TDoubleSet set = new TDoubleHashSet(arr); + double[] a = set.toArray(); + Arrays.sort(a); + return a; } - /** * Round all the values in a vector with a number of decimals. * @@ -1038,7 +1202,7 @@ private static int ndimWithClass(Class array) { * @param a * @return */ - @SuppressWarnings({"rawtypes", "unchecked"}) + @SuppressWarnings({ "rawtypes", "unchecked" }) public static double[] flattenDoubles(List a) { int ndims = ArraysUtil.ndim(a.get(0)); if (ndims > 1) { @@ -1071,7 +1235,7 @@ public static double[] flattenDoubles(List a) { * @param a * @return */ - @SuppressWarnings({"rawtypes", "unchecked"}) + @SuppressWarnings({ "rawtypes", "unchecked" }) public static int[] flattenInts(List a) { int ndims = ArraysUtil.ndim(a.get(0)); if (ndims > 1) { @@ -1097,17 +1261,6 @@ public static int[] flattenInts(List a) { return out; } - /** - * Helper method to index multiple items in one shot. - * - * @param values - * @param indices - * @return - */ - public static int[] at(int[] values, int[] indices) { - return IntStream.of(indices).map(v -> values[v]).toArray(); - } - /** * Non inline version of the Apache commons reverse methods * @@ -1121,8 +1274,7 @@ public static int[] reverse(int[] is) { } public static boolean isOneHot(double[] arr) { - return where(arr, x -> x == 1).length == 1 && - where(arr, x -> x == 0).length == arr.length - 1; + return where(arr, x -> x == 1).length == 1 && where(arr, x -> x == 0).length == arr.length - 1; } public static double[] replace(double[] arr, Predicate pred, double replacement) { @@ -1154,7 +1306,7 @@ public static boolean equals(int[] arr1, int[] arr2, boolean sort, boolean uniqu arr1 = unique(arr1); arr2 = unique(arr2); } - if (sort) { + if (sort && !unique) { // unique also sorts! arr1 = sort(arr1); arr2 = sort(arr2); } @@ -1185,22 +1337,26 @@ public static void shuffle(int[] array) { } /** - * Function that transforms an array such that all the non-zero elements sum exactly 1.0 + * Function that transforms an array such that all the non-zero elements sum + * exactly 1.0 + * * @param vector * @return */ - public static double[] normalizeNonZeros(double[] vector){ - if(Arrays.stream(vector).sum()!=1.0) { + public static double[] normalizeNonZeros(double[] vector) { + if (Arrays.stream(vector).sum() != 1.0) { // Find the last element different to zero int idx = vector.length - 1; - while (idx > 0 && vector[idx] == 0.0) idx--; + while (idx > 0 && vector[idx] == 0.0) + idx--; double sum = 0.0; - for (int i = 0; i < idx; i++) sum += vector[i]; + for (int i = 0; i < idx; i++) + sum += vector[i]; vector[idx] = 1.0 - sum; } return vector; } -} +} \ No newline at end of file diff --git a/src/test/java/ch/idsia/crema/core/StridesTest.java b/src/test/java/ch/idsia/crema/core/StridesTest.java index 018ba9c4..6eda7987 100644 --- a/src/test/java/ch/idsia/crema/core/StridesTest.java +++ b/src/test/java/ch/idsia/crema/core/StridesTest.java @@ -8,14 +8,6 @@ public class StridesTest { - @SuppressWarnings("deprecation") - public void testFilteredIterator() { - Strides domain = new Strides(new int[]{0, 1, 2, 3}, new int[]{2, 3, 4, 2}); - - IndexIterator iterator1 = domain.getFiteredIndexIterator(1, 2); - //int offset = domain.getPartialOffset(vars, states) - //IndexIterator iterator2 = domain.getIterator(filtered.getVariable()).offset(offset); - } @Test public void testEmpty() { @@ -221,19 +213,18 @@ public void testStridesStridesInt() { } - @SuppressWarnings("deprecation") @Test public void testRemoveAt() { - Strides s1 = new Strides(new int[]{0, 2, 3, 9}, new int[]{2, 2, 3, 7}); + Strides s1 = new Strides(new int[]{ 0, 2, 3, 9 }, new int[]{2, 2, 3, 7}); - Strides n1 = new Strides(s1, 1); + Strides n1 = s1.removeAt(1); Strides n2 = s1.removeAt(1); assertArrayEquals(n1.getVariables(), n2.getVariables()); assertArrayEquals(n1.getSizes(), n2.getSizes()); - n1 = new Strides(s1, 1); - n1 = new Strides(n1, 2); + n1 = s1.removeAt(1); + n1 = n1.removeAt(2); n2 = s1.removeAt(1, 3); assertArrayEquals(n1.getVariables(), n2.getVariables()); @@ -243,6 +234,11 @@ public void testRemoveAt() { assertArrayEquals(new int[0], n2.getVariables()); assertArrayEquals(new int[0], n2.getSizes()); assertEquals(1, n2.getCombinations()); + + n2 = s1.removeAt(3).removeAt(2).removeAt(0).removeAt(0); + assertArrayEquals(new int[0], n2.getVariables()); + assertArrayEquals(new int[0], n2.getSizes()); + assertEquals(1, n2.getCombinations()); } @Test diff --git a/src/test/java/ch/idsia/crema/factor/bayesian/BayesianOperations.java b/src/test/java/ch/idsia/crema/factor/bayesian/BayesianOperations.java new file mode 100644 index 00000000..3e595dee --- /dev/null +++ b/src/test/java/ch/idsia/crema/factor/bayesian/BayesianOperations.java @@ -0,0 +1,70 @@ +package ch.idsia.crema.factor.bayesian; + +import java.util.Arrays; +import org.junit.jupiter.api.Test; +import ch.idsia.crema.core.Domain; +import ch.idsia.crema.core.DomainBuilder; +import ch.idsia.crema.utility.ArraysMath; + +public class BayesianOperations { +//0.01656,0.0168,0.04284,0.018,0.02496,0.04728,0.03024,0.05856,0.06984 0.32508 +//0.01656,0.0168,0.04284,0.018,0.02496,0.04728,0.03024,0.05856,0.06984 + @Test + public void test() { + Domain abc = DomainBuilder.var(0,1,2).size(3,2,2).strides(); + Domain bc = DomainBuilder.var(1,2).size(2,2).strides(); + Domain bde = DomainBuilder.var(1,3,4).size(2,3,3).strides(); + Domain de = DomainBuilder.var(3,4).size(3,3).strides(); + + Domain d = DomainBuilder.var(3).size(3).strides(); + Domain a = DomainBuilder.var(0).size(3).strides(); + + Domain ce = DomainBuilder.var(2,4).size(2,3).strides(); + Domain e = DomainBuilder.var(4).size(3).strides(); + + Domain ed = DomainBuilder.var(4,3).size(3,3).strides(); + + var pa = BayesianFactorFactory.factory().domain(a).data(new double[]{0.3, 0.4, 0.3}).get(); + var pabc = BayesianFactorFactory.factory().domain(abc).data(new double[]{ + 0.1, 0.2, 0.7, + 0.3, 0.1, 0.6, + 0.3, 0.5, 0.2, + 0.5, 0.4, 0.1 + }).get(); + var pbde = BayesianFactorFactory.factory().domain(bde).data(new double[] { + 0.1, 0.9, + 0.3, 0.7, + 0.6, 0.4, + 0.5, 0.5, + 0.8, 0.2, + 0.2, 0.8, + 0.9, 0.1, + 0.4, 0.6, + 0.7, 0.3 + }).get(); + var pd = BayesianFactorFactory.factory().domain(d).data(new double[]{ + 0.6, 0.2, 0.2 + }).get(); + var pce = BayesianFactorFactory.factory().domain(ce).data(new double[]{ + 0.9, 0.1, + 0.8, 0.2, + 0.6, 0.4 + }).get(); + var ped = BayesianFactorFactory.factory().domain(de).data(new double[]{ + 0.1, 0.1, 0.8, + 0.2, 0.7, 0.1, + 0.7, 0.2, 0.1 + }).get(); + + var xx = pa.combine(pabc).combine(pbde) + .marginalize(1) + .combine(pce) + .marginalize(2) + .combine(pd, ped) + .marginalize(3); + // var xx = pd.combine(ped); + double o = ArraysMath.sum(xx.getData()); + System.out.println(Arrays.toString(xx.getData())); + } + +} diff --git a/src/test/java/ch/idsia/crema/factor/credal/vertex/TestVertexFactor.java b/src/test/java/ch/idsia/crema/factor/credal/vertex/TestVertexFactor.java index 9ff77ad9..28d57d5c 100644 --- a/src/test/java/ch/idsia/crema/factor/credal/vertex/TestVertexFactor.java +++ b/src/test/java/ch/idsia/crema/factor/credal/vertex/TestVertexFactor.java @@ -84,7 +84,7 @@ public void testAdd() { @Test public void testExpansion() { // p(X1|X0,X5) - VertexFactor vf = VertexFactorFactory.factory().domain(Strides.as(1, 4), Strides.as(0, 3).and(5, 2)) + VertexFactor vf = VertexFactorFactory.factory().domain(Strides.as(1, 4).sort(), Strides.as(0, 3).and(5, 2).sort()) .addVertex(new double[]{0.1, 0.3, 0.2, 0.4}, 0, 1) .addVertex(new double[]{0.3, 0.2, 0.1, 0.4}, 0, 1) .addVertex(new double[]{0.1, 0.3, 0.4, 0.2}, 0, 1) @@ -101,6 +101,8 @@ public void testExpansion() { .addVertex(new double[]{0.1, 0.2, 0.2, 0.5}, 2, 0) .get(); + + // P(X1|X0,X5) -> diventa estensivo rispetto a X5 VertexFactor v2 = vf.reseparate(Strides.as(0, 3)); double[][][] dta = v2.getData(); @@ -120,7 +122,7 @@ public void testExpansion() { assertArrayEquals(new double[]{0.6, 0.1, 0.2, 0.1, 0.4, 0.2, 0.2, 0.2}, dta[2][2], 0.00001); v2 = vf.reseparate(Strides.as(5, 2)); - dta = v2.getData(); + dta = v2.getData(); assertEquals(2, dta.length); // we removed var 0 from the separation domain @@ -129,6 +131,7 @@ public void testExpansion() { assertEquals(12, dta[1][0].length); // test some vertices + assertArrayEquals(new double[]{0.7, 0.1, 0.6, 0.1, 0.1, 0.1, 0.1, 0.7, 0.2, 0.1, 0.1, 0.1}, dta[0][0], 0.00001); assertArrayEquals(new double[]{0.7, 0.1, 0.6, 0.1, 0.1, 0.1, 0.1, 0.7, 0.2, 0.1, 0.1, 0.1}, dta[0][0], diff --git a/src/test/java/ch/idsia/crema/model/BayesianFactorTest.java b/src/test/java/ch/idsia/crema/model/BayesianFactorTest.java index 32e5aba7..7557e49e 100644 --- a/src/test/java/ch/idsia/crema/model/BayesianFactorTest.java +++ b/src/test/java/ch/idsia/crema/model/BayesianFactorTest.java @@ -257,7 +257,7 @@ public void testLogCombineDivision() { assertEquals(ibf, r); } - @Test +// @Test public void testtime() { int a = 0; int b = 1; @@ -347,7 +347,7 @@ public void testtime() { - @Test +// @Test public void testLogSpeed() { int[] vars = new int[]{0, 1, 2}; diff --git a/src/test/java/ch/idsia/crema/model/causal/mapping/TreeMappingTest.java b/src/test/java/ch/idsia/crema/model/causal/mapping/TreeMappingTest.java new file mode 100644 index 00000000..0d3e7f28 --- /dev/null +++ b/src/test/java/ch/idsia/crema/model/causal/mapping/TreeMappingTest.java @@ -0,0 +1,195 @@ +/** + * + */ +package ch.idsia.crema.model.causal.mapping; + +import static org.junit.jupiter.api.Assertions.*; + +import java.util.Arrays; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import ch.idsia.crema.core.Domain; +import ch.idsia.crema.core.Strides; +import ch.idsia.crema.factor.bayesian.BayesianFactor; +import ch.idsia.crema.factor.bayesian.BayesianFactorFactory; +import ch.idsia.crema.model.causal.SCM; +import ch.idsia.crema.model.causal.SCM.VariableType; +import ch.idsia.crema.model.io.dot.DetailedDotSerializer; +import ch.idsia.crema.model.io.dot.Info; + +/** + * + */ +class TreeMappingTest { + SCM one; + SCM two; + + int v1; + int v2; + int e1; + int e2; + + @BeforeEach + void setup() { + SCM scm1 = new SCM(); + v1 = scm1.addEndogenous(2); + v2 = scm1.addEndogenous(2); + e1 = scm1.addExogenous(2); + e2 = scm1.addExogenous(2); + scm1.addParents(v2, v1, e2); + scm1.addParent(v1, e1); + + var f1 = BayesianFactorFactory.factory().domain(scm1.getFullDomain(v1)) + .data(new int[] { v1, e1 }, new double[] { 0.1, 0.9, 0.3, 0.7 }).get(); + scm1.setFactor(v1, f1); + + var f2 = BayesianFactorFactory.factory().domain(scm1.getFullDomain(v2)) + .data(new int[] { v2, v1, e2 }, new double[] { 0.4, 0.6, 0.2, 0.8, 0.1, 0.9, 0.7, 0.3 }).get(); + scm1.setFactor(v2, f2); + + var f3 = BayesianFactorFactory.factory().domain(scm1.getFullDomain(e1)) + .data(new int[] { e1 }, new double[] { 0.25, 0.75 }).get(); + scm1.setFactor(e1, f3); + + var f4 = BayesianFactorFactory.factory().domain(scm1.getFullDomain(e2)) + .data(new int[] { e2 }, new double[] { 0.2, 0.8 }).get(); + scm1.setFactor(e2, f4); + + SCM scm2 = new SCM(); + scm2.addEndogenous(v1, 2); + scm2.addEndogenous(v2, 2); + scm2.addExogenous(e1, 2); + scm2.addExogenous(e2, 2); + + scm2.addParent(v2, e2); + scm2.addParent(v1, e1); + + f1 = BayesianFactorFactory.factory().domain(scm2.getFullDomain(v1)) + .data(new int[] { v1, e1 }, new double[] { 0.45, 0.55, 0.35, 0.65 }).get(); + scm2.setFactor(v1, f1); + + f2 = BayesianFactorFactory.factory().domain(scm2.getFullDomain(v2)) + .data(new int[] { v2, e2 }, new double[] { 0.6, 0.4, 0.15, 0.85 }).get(); + scm2.setFactor(v2, f2); + + f3 = BayesianFactorFactory.factory().domain(scm2.getFullDomain(e1)) + .data(new int[] { e1 }, new double[] { 0.55, 0.45 }).get(); + scm2.setFactor(e1, f3); + + f4 = BayesianFactorFactory.factory().domain(scm2.getFullDomain(e2)) + .data(new int[] { e2 }, new double[] { 0.65, 0.35 }).get(); + scm2.setFactor(e2, f4); + + one = scm1; + two = scm2; + } + + /** + * Test method for + * {@link ch.idsia.crema.model.causal.mapping.TreeMapping#get()}. + */ + @Test + void testGet() { + TreeMapping tm = new TreeMapping(); + assertNull(tm.get()); + tm.add(one); + assertNotNull(tm.get()); + } + + static String name(TreeMapping tm, int id) { + SCM global = tm.get(); + int lid = tm.fromGlobal(id); + if (global.isExogenous(id)) { + return "U" + lid; + } + + int world = tm.worldIdOf(id); + return "X" + lid + "" + world + ""; + } + + /** + * Test method for + * {@link ch.idsia.crema.model.causal.mapping.TreeMapping#add(ch.idsia.crema.model.causal.SCM)}. + */ + @Test + void testAdd() { + + final TreeMapping tm = new TreeMapping(); + assertNull(tm.get()); + tm.add(one); + tm.add(two); + SCM global = tm.get(); + // DetailedDotSerializer.saveModel("test.png", new + // Info().model(global).nodeName((id)->name(tm, id))); + assertNotNull(tm.get()); + } + + /** + * Test method for + * {@link ch.idsia.crema.model.causal.mapping.TreeMapping#toGlobal(int, int)}. + */ + @Test + void testToGlobal() { + final TreeMapping tm = new TreeMapping(); + int w1 = tm.add(one); + int w2 = tm.add(two); + SCM global = tm.get(); + + for (int v : one.getVariables()) { + int g = tm.toGlobal(v, w1); + assertEquals(tm.fromGlobal(g), v); + } + for (int v : two.getVariables()) { + int g = tm.toGlobal(v, w2); + assertEquals(tm.fromGlobal(g), v); + } + + } + + /** + * Test method for + * {@link ch.idsia.crema.model.causal.mapping.TreeMapping#fromGlobal(int)}. + */ + @Test + void testExogenous() { + final TreeMapping tm = new TreeMapping(); + int w1 = tm.add(one); + int w2 = tm.add(two); + + int ge1 = tm.toGlobal(e1, w1); + int ge2 = tm.toGlobal(e1, w2); + assertEquals(ge1, ge2); + } + + /** + * Test parenting preservalfor + */ + @Test + void testParenting() { + final TreeMapping tm = new TreeMapping(); + int w1 = tm.add(one); + int w2 = tm.add(two); + SCM global = tm.get(); + for (int v : global.getVariables()) { + int[] p = global.getParents(v); + + SCM w = tm.worldOf(v); + if (w == null) { + w = one; + assert(global.isExogenous(v)); + } + + int vs = tm.fromGlobal(v); + int[] ps = tm.fromGlobal(p); + int[] pt = w.getParents(vs); + + Arrays.sort(ps); + Arrays.sort(pt); + assertArrayEquals(ps, pt); + } + } + + +} diff --git a/src/test/java/ch/idsia/crema/model/io/UAIParserTest.java b/src/test/java/ch/idsia/crema/model/io/UAIParserTest.java index 61b8ae35..ae25a475 100644 --- a/src/test/java/ch/idsia/crema/model/io/UAIParserTest.java +++ b/src/test/java/ch/idsia/crema/model/io/UAIParserTest.java @@ -42,11 +42,11 @@ void numvars(String name, String num) { Assertions.assertEquals(((DAGModel) models.get(name)).getVariables().length, Integer.parseInt(num)); } - @ParameterizedTest - @ValueSource(strings = {"simple-hcredal.uai", "simple-vcredal.uai"}) - void checkDomains(String name) { - Assertions.assertTrue(((DAGModel) models.get(name)).correctFactorDomains()); - } +// @ParameterizedTest +// @ValueSource(strings = {"simple-hcredal.uai", "simple-vcredal.uai"}) +// void checkDomains(String name) { +// Assertions.assertTrue(((DAGModel) models.get(name)).correctFactorDomains()); +// } @ParameterizedTest @ValueSource(strings = {"simple-hcredal.uai"}) diff --git a/src/test/java/ch/idsia/crema/model/utility/ArraysUtilityTest.java b/src/test/java/ch/idsia/crema/model/utility/ArraysUtilityTest.java index 7c8d0a8c..24ad320b 100644 --- a/src/test/java/ch/idsia/crema/model/utility/ArraysUtilityTest.java +++ b/src/test/java/ch/idsia/crema/model/utility/ArraysUtilityTest.java @@ -277,7 +277,7 @@ public void testroundArrayToTarget() { @ParameterizedTest @CsvSource({"1&2,3&4,1&2&3&4", "3&4, 3&4, 3&4", - "4&4, 3&4, 4&3", + "4&4, 3&4, 3&4", "3&4, 3&3, 3&4", "1&2, 3&4, 1&2&3&4", }) @@ -286,7 +286,7 @@ void unionSet(String arr1, String arr2, String res) { int[] v2 = ArraysUtil.latexToIntVector(arr2); int[] expected = ArraysUtil.latexToIntVector(res); - assertArrayEquals(expected, ArraysUtil.unionSet(v1, v2)); + assertArrayEquals(expected, ArraysUtil.union_unsorted_set(v1, v2)); } @@ -375,7 +375,7 @@ public void enumerate(String vect, int start, String expected) { @ParameterizedTest @CsvSource({"1&2&3&4,1&2&3&4", "1&2&1&3&4,1&2&3&4", - "0&10&4&1&4,0&10&4&1" + "0&10&4&1&4,0&1&4&10" }) void unique(String arr, String expected) { int[] arr_ = ArraysUtil.latexToIntVector(arr); @@ -389,7 +389,7 @@ void unique(String arr, String expected) { @ParameterizedTest @CsvSource({"1&2&3&4,1&2&3&4", "1&2&1&3&4,1&2&3&4", - "0&10&4&1&4,0&10&4&1" + "0&10&4&1&4,0&1&4&10" }) void uniqueDoubles(String arr, String expected) { double[] arr_ = ArraysUtil.latexToDoubleVector(arr);