* While the class should be considered unmutable, when the factor is part of a
- * model and we delete a variable from it, the indices of the variables will be
- * updated accordingly.
+ * model and we delete a variable from it, the indices of the variables could be
+ * updated accordingly. This applies only to models that do not allow gaps in
+ * variable labels.
*
*
* @author davidhuber
*/
-public final class Strides implements Domain {
- final private int[] strides;
+public final class Strides implements StridedDomain {
+ private final int[] strides;
final private int combinations;
final private int[] variables;
final private int[] sizes;
final private int size;
+ public Strides(Variable[] variables) {
+
+ Arrays.sort(variables);
+
+ strides = new int[variables.length + 1];
+
+ this.variables = Arrays.stream(variables).mapToInt(Variable::getLabel).toArray();
+ this.sizes = Arrays.stream(variables).mapToInt(Variable::getCardinality).toArray();
+ this.size = variables.length;
+
+ int cumulative = 1;
+ strides[0] = 1;
+
+ for (int i = 0; i < size; ++i) {
+ strides[i + 1] = cumulative *= sizes[i];
+ }
+ combinations = cumulative;
+ }
+
public Strides(int[] variables, int[] sizes, int[] strides) {
this.variables = variables;
this.sizes = sizes;
@@ -58,48 +88,17 @@ public Strides(int[] variables, int[] sizes) {
this.size = variables.length;
this.strides = new int[size + 1];
- this.strides[0] = 1;
- for (int i = 0; i < size; ++i) {
- strides[i + 1] = strides[i] * sizes[i];
- }
- this.combinations = strides[size];
- }
-
- /**
- * Creates a stride with a single variable excluded. Note that the variable
- * must not be missing in the provided domain.
- *
- * @param domain
- * @param offset
- * @deprecated please use {@link Strides#removeAt(int...)}
- */
- @Deprecated
- public Strides(Strides domain, int offset) {
- this.size = domain.variables.length - 1;
- this.variables = new int[size];
- this.sizes = new int[size];
-
- System.arraycopy(domain.variables, 0, variables, 0, offset);
- System.arraycopy(domain.variables, offset + 1, variables, offset, size - offset);
- System.arraycopy(domain.sizes, 0, sizes, 0, offset);
- System.arraycopy(domain.sizes, offset + 1, sizes, offset, size - offset);
-
- strides = new int[size + 1];
- strides[0] = 1;
- if (offset > 1) {
- System.arraycopy(domain.strides, 0, strides, 0, offset--);
- } else
- offset = 0;
+ int cumulative = this.strides[0] = 1;
- for (; offset < size; ++offset) {
- strides[offset + 1] = strides[offset] * sizes[offset];
+ for (int i = 0; i < size; ++i) {
+ strides[i + 1] = cumulative *= sizes[i];
}
- this.combinations = strides[size];
+ this.combinations = cumulative;
}
@Override
public final int indexOf(int variable) {
- //return Arrays.binarySearch(variables, variable);
+ // return Arrays.binarySearch(variables, variable);
return ArraysUtil.indexOf(variable, variables);
}
@@ -230,8 +229,7 @@ public final int getCombinations() {
}
/**
- * Create an iterator over an enlarged domain but with same strides and
- * sizes.
+ * Create an iterator over an enlarged domain but with same strides and sizes.
*
* @param targetDomain
* @return
@@ -257,8 +255,8 @@ private IndexIterator getSupersetIndexIterator(final int[] over, final int[] tar
}
/**
- * Create an iterator over an smaller domain with same strides and sizes and
- * 1 (one) variable set to a specific state.
+ * Create an iterator over an smaller domain with same strides and sizes and 1
+ * (one) variable set to a specific state.
*
* @param variable the variable
* @param state the desired state
@@ -335,8 +333,8 @@ public String toString() {
}
/**
- * Get an IndexIterator for this stride object with a different order for
- * the variables.
+ * Get an IndexIterator for this stride object with a different order for the
+ * variables.
*
* @param unsorted_vars - int[] the new order for the variables
* @return an iterator
@@ -356,8 +354,8 @@ public IndexIterator getReorderedIterator(int[] unsorted_vars) {
}
/**
- * Get an IndexIterator for this stride object with a different order for
- * the variables. Alias of intersection.
+ * Get an IndexIterator for this stride object with a different order for the
+ * variables. Alias of intersection.
*
* @param someVars int[] - the new order for the variables
* @return an iterator
@@ -369,8 +367,8 @@ public Strides retain(int[] someVars) {
////////////////////////////////////////////////////////////////////////////////////////////////////
/**
- * Remove some variable from this domain. The returned domain will have his
- * own strides. Variables do not necessarily have to be part of this domain.
+ * Remove some variable from this domain. The returned domain will have his own
+ * strides. Variables do not necessarily have to be part of this domain.
* Convenience method that calls the {@link Strides#remove(int...)}
*
* @param toremove Strides - the domain to be removed from the first domain
@@ -381,13 +379,36 @@ public Strides remove(Strides toremove) {
}
/**
- * Remove some variable from this domain. The returned domain will have his
- * own strides. Variables do not necessarily have to be part of "domain".
+ * Remove some variable from this domain. The returned domain will have his own
+ * strides. Variables do not necessarily have to be part of "domain".
*
* @param toremove
* @return
*/
public Strides remove(int... toremove) {
+ Set to_rem = Arrays.stream(toremove).boxed().collect(Collectors.toSet());
+
+ int[] vars = new int[this.size];
+ int[] sizes = new int[this.size];
+
+ int ti = 0;
+ for (int v = 0; v < this.variables.length; ++v) {
+ if (!to_rem.contains(this.variables[v])) {
+ sizes[ti] = this.sizes[v];
+ vars[ti] = this.variables[v];
+ ti++;
+ } else {
+ to_rem.remove(v);
+ }
+ }
+
+ vars = Arrays.copyOf(vars, ti);
+ sizes = Arrays.copyOf(sizes, ti);
+
+ return new Strides(vars, sizes);
+ }
+
+ public Strides remove_sorted(int... toremove) {
int di = 0; // domain index
int ti = 0; // toremove index
int index = 0; // target index
@@ -434,8 +455,8 @@ public Strides remove(int... toremove) {
}
/**
- * Create a new Strides object result of the intersection of this domain
- * with the given one.
+ * Create a new Strides object result of the intersection of this domain with
+ * the given one.
*
* @param domain2 another domain
* @return the intersection of domain1 and domain2
@@ -490,6 +511,108 @@ public Strides intersection(int... domain2) {
return new Strides(intersect_vars, intersect_sizes);
}
+ public Strides union3(Strides domain2) {
+
+ int[] vars = ArraysUtil.append(variables, domain2.variables);
+ int[] sz = ArraysUtil.append(sizes, domain2.sizes);
+
+ TIntSet un = new TIntHashSet(vars);
+ // Set un = Arrays.stream(vars).boxed().collect(Collectors.toSet());
+
+ int[] union_vars = new int[un.size()];
+ int[] union_sizes = new int[un.size()];
+ int[] union_strides = new int[un.size() + 1];
+ int i = 0;
+ union_strides[0] = 1;
+
+ for (int vindex = 0; vindex < vars.length; ++vindex) {
+ int v = vars[vindex];
+ int s = sz[vindex];
+
+ if (un.contains(v)) {
+ union_vars[i] = v;
+ union_sizes[i] = s;
+ union_strides[i + 1] = union_strides[i] * s;
+ ++i;
+ un.remove(v);
+ }
+ }
+
+ return new Strides(union_vars, union_sizes, union_strides);
+ }
+
+ public Strides union_unsorted(Strides domain2) {
+
+ int s1 = variables.length;
+ int s2 = domain2.variables.length;
+
+ int[] vars = ArraysUtil.append(variables, domain2.variables);
+ int[] sz = ArraysUtil.append(sizes, domain2.sizes);
+
+ int s = s1 + s2;
+
+ int[] union_vars = new int[s];
+ int[] union_sz = new int[s];
+ int[] union_str = new int[s + 1];
+
+ union_str[0] = 1;
+ int t = 0;
+ external: for (int v1 = 0; v1 < vars.length; ++v1) {
+ for (int v2 = v1 - 1; v2 >= 0; --v2) {
+ if (vars[v1] == vars[v2])
+ continue external;
+ }
+ union_vars[t] = vars[v1];
+ union_sz[t] = sz[v1];
+ union_str[t + 1] = union_str[t] * sz[v1];
+ t++;
+ }
+
+ return new Strides(Arrays.copyOf(union_vars, t), Arrays.copyOf(union_sz, t), Arrays.copyOf(union_str, t + 1));
+ }
+
+ public Strides union_new(Strides domain2) {
+
+ int s1 = variables.length;
+ int s2 = domain2.variables.length;
+
+ int[] vars = ArraysUtil.append(variables, domain2.variables);
+ int[] order = ArraysUtil.order(vars);
+
+ vars = ArraysUtil.at(vars, order);
+ int[] sz = ArraysUtil.append(sizes, domain2.sizes);
+ sz = ArraysUtil.at(sz, order);
+
+ int prev = vars[0];
+ int s = 1;
+
+ for (int v : vars) {
+ if (v != prev)
+ ++s;
+ prev = v;
+ }
+
+ int[] union_vars = new int[s];
+ int[] union_sz = new int[s];
+ int[] union_str = new int[s + 1];
+
+ union_str[0] = 1;
+
+ int t = 0;
+
+ prev = vars[0] ^ 1; // make sure the value in prev is different than the first item
+ for (int i = 0; i < vars.length; ++i) {
+ if (prev != vars[i]) {
+ union_vars[t] = vars[i];
+ union_sz[t] = sz[i];
+ union_str[t + 1] = union_str[t] * sz[i];
+ t++;
+ }
+ }
+
+ return new Strides(union_vars, union_sz, union_str);
+ }
+
/**
* Create a new Strides result of the union of domain1 and domain2.
*
@@ -498,16 +621,17 @@ public Strides intersection(int... domain2) {
*
*
*
- * As we can assume ordering we will traverse the domains in parallel (pt1
- * in the code) and copy to a target domain the values. When one of the
- * domains reached it's end the remaing variable of the other domain can be
- * added to the union in bulk (pt2 in code).
+ * As we can assume ordering we will traverse the domains in parallel (pt1 in
+ * the code) and copy to a target domain the values. When one of the domains
+ * reached it's end the remaing variable of the other domain can be added to the
+ * union in bulk (pt2 in code).
*
*
* @param domain2
* @return
*/
public Strides union(Strides domain2) {
+
final int s1 = this.size;
final int s2 = domain2.size;
@@ -593,8 +717,37 @@ public Strides concat(Strides right) {
}
/**
- * Creates a stride with one or more variables excluded. Note that the
- * variable must not be missing in this domain.
+ * Creates a stride with one or more variables excluded. Variables may be
+ * missing
+ *
+ * @param offset
+ */
+ public Strides removeAt(int offset) {
+
+ int[] tvariables = new int[size - 1];
+ int[] tsizes = new int[size - 1];
+ int[] tstrides = new int[size];
+
+ System.arraycopy(variables, 0, tvariables, 0, offset);
+ System.arraycopy(sizes, 0, tsizes, 0, offset);
+
+ if (offset < size - 1) {
+ System.arraycopy(variables, offset + 1, tvariables, offset, size - offset - 1);
+ System.arraycopy(sizes, offset + 1, tsizes, offset, size - offset - 1);
+ }
+
+ System.arraycopy(strides, 0, tstrides, 0, offset + 1);
+ System.arraycopy(strides, offset + 2, tstrides, offset + 1, size - offset - 1);
+ int fix = sizes[offset];
+ for (int o = offset + 1; o < size; ++o) {
+ tstrides[o] /= fix;
+ }
+ return new Strides(tvariables, tsizes, tstrides);
+ }
+
+ /**
+ * Creates a stride with one or more variables excluded. Variables may be
+ * missing
*
* @param offset
*/
@@ -619,8 +772,7 @@ public Strides removeAt(int... offset) {
}
/**
- * Get an iterator of this domain with the specified variables lock in state
- * 0.
+ * Get an iterator of this domain with the specified variables lock in state 0.
*
* Convenience method when the target domain is this domain. same as calling
* getIterator(this, locked);
@@ -633,13 +785,12 @@ public IndexIterator getIterator(int... locked) {
}
/**
- * Iterate over another domain. If a variable is not present in this domain
- * it will not move the index but it will take the step. If a variable is
- * not in the specified domain the variable is not considered or is assumed
- * fixed to 0. Please use getPartialOffset(vars, state) for a different offset. If
- * also present in str, variable in the locked array will be counted but
- * fixed at zero. This allows us to keep them in the target domain while not
- * moving.
+ * Iterate over another domain. If a variable is not present in this domain it
+ * will not move the index but it will take the step. If a variable is not in
+ * the specified domain the variable is not considered or is assumed fixed to 0.
+ * Please use getPartialOffset(vars, state) for a different offset. If also
+ * present in str, variable in the locked array will be counted but fixed at
+ * zero. This allows us to keep them in the target domain while not moving.
*
* @param str the target domain
* @param locked the variables that should be locked
@@ -668,11 +819,11 @@ public IndexIterator getIterator(Strides str, int... locked) {
}
++source; // in any case we move the source pointer
} // else {
- // source var is greater than target (var is not found in this
- // domain)
- // so no need to copy strides, we can leave them at the default
- // value of 0
- // }
+ // source var is greater than target (var is not found in this
+ // domain)
+ // so no need to copy strides, we can leave them at the default
+ // value of 0
+ // }
}
return new IndexIterator(new_strides, str.getSizes(), 0, str.getCombinations());
@@ -704,24 +855,25 @@ public static Strides as(int... data) {
return new Strides(variables, sizes);
}
-
/**
* Define a stride as a sequence of variable/size pairs
*/
public static Strides var(int var, int size) {
- return new Strides(new int[]{var}, new int[]{size});
+ return new Strides(new int[] { var }, new int[] { size });
}
/**
- * Helper to allow concatenation of var().add().add()
- * Insertion is sorted. Resulting domain is ordered!
+ * Helper to allow concatenation of var().add().add() Insertion is sorted.
+ * Resulting domain is ordered!
+ *
* @param var
* @param size
* @return
*/
public Strides and(int var, int size) {
int pos = Arrays.binarySearch(variables, var);
- if (pos >= 0) return this; // no need to change anything
+ if (pos >= 0)
+ return this; // no need to change anything
pos = -(pos + 1);
int[] newvar = new int[variables.length + 1];
@@ -743,8 +895,8 @@ public Strides and(int var, int size) {
public static Strides EMPTY = Strides.as();
/**
- * Return an empty Stride with no variables and a single entry in the strides array. This
- * only stride value is set to 1.
+ * Return an empty Stride with no variables and a single entry in the strides
+ * array. This only stride value is set to 1.
*
* @return {@link Strides} - the empty stride
*/
@@ -752,14 +904,14 @@ public static Strides empty() {
return EMPTY;
}
-
/**
* Return a new Stride sorted by the variables.
*
* @return
*/
public Strides sort() {
- int[] order = IntStream.range(0, size).boxed().sorted((a, b) -> Strides.this.variables[a] - Strides.this.variables[b]).mapToInt(x -> x).toArray();
+ int[] order = IntStream.range(0, size).boxed()
+ .sorted((a, b) -> Strides.this.variables[a] - Strides.this.variables[b]).mapToInt(x -> x).toArray();
int[] variables = Arrays.stream(order).map(x -> Strides.this.variables[x]).toArray();
int[] sizes = Arrays.stream(order).map(x -> Strides.this.sizes[x]).toArray();
return new Strides(variables, sizes);
@@ -782,10 +934,9 @@ public IndexIterator getIterator(Strides domain, TIntIntMap observation) {
return it;
}
-
/**
- * Determines if this Strides object is consistent other. Two Strides objects are
- * consistent if all the common variables have the same cardinality.
+ * Determines if this Strides object is consistent other. Two Strides objects
+ * are consistent if all the common variables have the same cardinality.
*
* @param other {@link Strides} - other Strides object to compare with.
* @return boolean variable indicating the compatibility.
@@ -797,37 +948,35 @@ public boolean isConsistentWith(Strides other) {
return true;
}
-
- /**
- * get the list of states for the specified offset within all the
- * combinations of the domain
+ /**
+ * get the list of states for the specified offset within all the combinations
+ * of the domain
*/
public int[] getStatesFor(int offset) {
int[] states = new int[getSize()];
int index = 0;
for (int cardinality : sizes) {
- states[index ++] = offset % cardinality;
+ states[index++] = offset % cardinality;
offset = offset / cardinality;
}
return states;
}
/**
- * Test wether the specified index (within the exanded domain) is present
- * in the provied observations map
+ * Test wether the specified index (within the exanded domain) is present in the
+ * provied observations map
*
* @param index
* @param obs
* @return
*/
public boolean isCompatible(int index, TIntIntMap obs) {
- int[] obsfiltered = IntStream.of(this.getVariables()).sorted()
- .map(x -> {
- if (obs.containsKey(x))
- return obs.get(x);
- else return -1;
- }).toArray();
-
+ int[] obsfiltered = IntStream.of(this.getVariables()).sorted().map(x -> {
+ if (obs.containsKey(x))
+ return obs.get(x);
+ else
+ return -1;
+ }).toArray();
int[] states = this.statesOf(index);
@@ -852,7 +1001,6 @@ public int[] getCompatibleIndexes(TIntIntMap obs) {
return IntStream.range(0, this.getCombinations()).filter(i -> this.isCompatible(i, obs)).toArray();
}
-
public static Strides reverseDomain(Strides domain) {
int[] vars = domain.getVariables().clone();
int[] sizes = domain.getSizes().clone();
@@ -865,4 +1013,33 @@ public Strides reverseDomain() {
return reverseDomain(this);
}
+ public static void main(String[] args) {
+ RandomDataGenerator generator = new RandomDataGenerator();
+ int[] var1 = generator.nextPermutation(20, 10);
+ int[] var2 = generator.nextPermutation(20, 10);
+ int[] s1 = IntStream.range(0, 10).map(x -> 2).toArray();
+ int[] s2 = IntStream.range(0, 10).map(x -> 2).toArray();
+ Strides st1 = new Strides(var1, s1);
+ Strides st2 = new Strides(var2, s2);
+
+ int runs = 10000;
+ int reps = 20;
+ long[] times = new long[reps];
+ long useless = 0;
+ for (int rep = 0; rep < reps * 2; ++rep) {
+ long time = System.nanoTime();
+ for (int i = 0; i < runs; ++i) {
+ var ss3 = st1.union(st2);
+ int[] xx = ss3.getVariables();
+ int x= ArraysMath.min(xx);
+ useless +=x;
+ }
+ long delta = System.nanoTime() - time;
+ if (rep > reps) {
+ times[rep - reps] = delta;
+
+ }
+ }
+ System.out.println(ArraysMath.mean(times) + " " + ArraysMath.sd(times, 1));
+ }
}
diff --git a/src/main/java/ch/idsia/crema/core/UnsortedDomain.java b/src/main/java/ch/idsia/crema/core/UnsortedDomain.java
new file mode 100644
index 00000000..e8b04601
--- /dev/null
+++ b/src/main/java/ch/idsia/crema/core/UnsortedDomain.java
@@ -0,0 +1,75 @@
+package ch.idsia.crema.core;
+
+import java.util.Arrays;
+
+public class UnsortedDomain implements StridedDomain {
+ private Strides sortedDomain;
+
+ public UnsortedDomain(Variable[] variables) {
+
+ }
+
+ public UnsortedDomain(int[] variables, int[] sizes) {
+
+ }
+
+ @Override
+ public int getCardinality(int variable) {
+ return 0;
+ }
+
+ @Override
+ public int getSizeAt(int index) {
+ // TODO Auto-generated method stub
+ return 0;
+ }
+
+ @Override
+ public int indexOf(int variable) {
+ // TODO Auto-generated method stub
+ return 0;
+ }
+
+ @Override
+ public boolean contains(int variable) {
+ // TODO Auto-generated method stub
+ return false;
+ }
+
+ @Override
+ public int[] getVariables() {
+ // TODO Auto-generated method stub
+ return null;
+ }
+
+ @Override
+ public int[] getSizes() {
+ // TODO Auto-generated method stub
+ return null;
+ }
+
+ @Override
+ public int getSize() {
+ // TODO Auto-generated method stub
+ return 0;
+ }
+
+ @Override
+ public void removed(int variable) {
+ // TODO Auto-generated method stub
+
+ }
+
+ @Override
+ public int getStride(int variable) {
+ // TODO Auto-generated method stub
+ return 0;
+ }
+
+ @Override
+ public int getStrideAt(int offset) {
+ // TODO Auto-generated method stub
+ return 0;
+ }
+
+}
diff --git a/src/main/java/ch/idsia/crema/core/Variable.java b/src/main/java/ch/idsia/crema/core/Variable.java
new file mode 100644
index 00000000..4ec6dc11
--- /dev/null
+++ b/src/main/java/ch/idsia/crema/core/Variable.java
@@ -0,0 +1,37 @@
+package ch.idsia.crema.core;
+
+public class Variable implements Comparable {
+ private int label;
+ private int cardinality;
+
+ public Variable(int label, int cardinality) {
+ this.label = label;
+ this.cardinality = cardinality;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (obj instanceof Variable) {
+ return label == ((Variable) obj).label;
+ } else
+ return false;
+ }
+
+ @Override
+ public int hashCode() {
+ return label;
+ }
+
+ @Override
+ public int compareTo(Variable o) {
+ return Integer.compare(label, o.label);
+ }
+
+ public int getLabel() {
+ return label;
+ }
+
+ public int getCardinality() {
+ return cardinality;
+ }
+}
diff --git a/src/main/java/ch/idsia/crema/data/DataTable.java b/src/main/java/ch/idsia/crema/data/DataTable.java
new file mode 100644
index 00000000..4b46e116
--- /dev/null
+++ b/src/main/java/ch/idsia/crema/data/DataTable.java
@@ -0,0 +1,339 @@
+package ch.idsia.crema.data;
+
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Set;
+import java.util.TreeMap;
+import java.util.function.BiFunction;
+import java.util.function.Function;
+import java.util.function.Supplier;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.commons.lang3.tuple.Pair;
+
+import gnu.trove.map.TIntIntMap;
+import gnu.trove.map.hash.TIntIntHashMap;
+import gnu.trove.set.TIntSet;
+import gnu.trove.set.hash.TIntHashSet;
+
+public class DataTable implements Iterable> {
+ protected final int[] columns;
+ protected O metadata;
+
+ protected T unit;
+ protected T zero;
+ protected T virtualcounts;
+ protected BiFunction add;
+ protected Function createArray;
+
+ /**
+ * Using a {@link TreeMap}. This will use a comparator that we can provide.
+ * HashMap uses the object's hash that won't work correctly for arrays of int.
+ */
+ protected TreeMap dataTable;
+
+ public void setMetadata(O data) {
+ this.metadata = data;
+ }
+
+ public O getMetadata() {
+ return metadata;
+ }
+
+
+
+ protected DataTable(int[] columns, T unit, T zero, BiFunction add, Function array) {
+ this.columns = columns;
+ this.unit = unit;
+ this.zero = zero;
+ this.virtualcounts = zero;
+ this.add = add;
+ this.createArray = array;
+ this.dataTable = new TreeMap<>(Arrays::compare);
+ }
+
+
+ protected DataTable(int[] columns, T unit, T zero, BiFunction add, Function array, Map data) {
+ this.columns = columns;
+ this.unit = unit;
+ this.zero = zero;
+ this.virtualcounts = zero;
+ this.add = add;
+ this.createArray = array;
+
+ this.dataTable = new TreeMap<>(Arrays::compare);
+ this.dataTable.putAll(data);
+ }
+
+
+
+ private static int[] max(int[] a, int[] b) {
+ int[] ret = a.clone();
+ for (int i = 0; i < a.length; ++i) {
+ ret[i] = Math.max(a[i], b[i]);
+ }
+ return ret;
+ }
+
+ public TIntSet[] getStates() {
+ TIntSet[] accum = IntStream.range(0, columns.length).mapToObj(TIntHashSet::new).toArray(TIntSet[]::new);
+
+ for(var entry : dataTable.entrySet()) {
+ var v = entry.getKey();
+ for (int i = 0; i < accum.length;++i) {
+ accum[i].add(v[i]);
+ }
+ }
+ return accum;
+ }
+
+
+ public int[] getSizes() {
+ int[] accum = new int[columns.length];
+
+ int[] v = dataTable.entrySet().stream().map(Entry::getKey).reduce(accum, DataTable::max);
+ for (int i = 0; i < v.length;++i) ++v[i];
+ return v;
+ }
+
+ /**
+ * Sort and Expand or limit the map to the columns of the table.
+ *
+ * @param map
+ * @return int[] of the values for the table columns
+ */
+ private int[] getIndex(TIntIntMap map) {
+ return Arrays.stream(columns).map(map::get).toArray();
+ }
+
+ /**
+ * Get the weight of a row
+ *
+ * @param index
+ * @return
+ */
+ public T getWeight(TIntIntMap index) {
+ return dataTable.get(getIndex(index));
+ }
+
+ /**
+ * Get the list of weight for all possible combinations of values of the
+ * specified columns assuming each column has the indicated number of possible
+ * states.
+ *
+ * A default number of virtual counts is added to each count.
+ * If not changed this is zero
+ *
+ * @param vars
+ * @param sizes
+ * @return
+ */
+ public T[] getWeightsFor(int[] vars, int[] sizes) {
+ return getWeightsFor(vars, sizes, virtualcounts);
+ }
+
+ /**
+ * Get Weights table for the specified variables and a virtual count of s.
+ * variables sizes need to be provided.
+ *
+ * @param vars
+ * @param sizes
+ * @param s
+ * @return
+ */
+ public T[] getWeightsFor(int[] vars, int[] sizes, T s) {
+
+ // cumulative size
+ int cumsize = 1;
+
+ TIntIntMap strides = new TIntIntHashMap();
+ for (int i = 0; i < vars.length; ++i) {
+ strides.put(vars[i], cumsize);
+ cumsize = cumsize * sizes[i];
+ }
+
+ int[] col_strides = new int[columns.length];
+
+ for (int i = 0; i < columns.length; ++i) {
+ if (strides.containsKey(columns[i])) {
+ col_strides[i] = strides.get(columns[i]);
+ }
+ }
+
+ T[] results = createArray.apply(cumsize);
+ for (int i = 0; i< results.length; ++i) {
+ results[i] = s;
+ }
+
+ for (var item : dataTable.entrySet()) {
+ int[] states = item.getKey();
+ int offset = 0;
+ for (int i = 0; i < columns.length; ++i) {
+ offset += col_strides[i] * states[i];
+ }
+ results[offset] = add.apply(results[offset], item.getValue());
+ }
+
+ return results;
+ }
+
+
+ /**
+ * Get the weight of a row
+ *
+ * @param index
+ * @return
+ */
+ public T getWeight(int[] index) {
+ if (index.length != columns.length)
+ throw new IllegalArgumentException("Wrong index size. Must match the columns");
+
+ return dataTable.get(index);
+ }
+
+ /**
+ * Add to dataTable assuming correctly ordered row items
+ *
+ * @param row the item to be added
+ * @param count the number of rows to be added
+ */
+ public void add(int[] row, T count) {
+ dataTable.compute(row, (k, v) -> (v == null) ? count : add.apply(v, count));
+ }
+
+ /**
+ * Add a new row using a different column order.
+ *
+ * @param cols int[] the new columns order
+ * @param inst int[] the row to be added in cols order
+ * @param count the "number" of rows being added.
+ */
+ public void add(int[] cols, int[] inst, T count) {
+ int[] row = Arrays.stream(columns).map(col -> ArrayUtils.indexOf(cols, col)).map(i -> inst[i]).toArray();
+
+ dataTable.compute(row, (k, v) -> (v == null) ? count : add.apply(v, count));
+ }
+
+ /**
+ * Add a TIntIntMap with the specified count. The map must contain all the keys
+ * specified in the columns
+ *
+ * @param inst {@link TIntIntMap} - the row to be added
+ * @param count the number of rows being added.
+ */
+ public void add(TIntIntMap inst, T count) {
+ int[] row = Arrays.stream(columns).map(inst::get).toArray();
+ dataTable.compute(row, (k, v) -> (v == null) ? count : add.apply(v, count));
+ }
+
+ /**
+ * Add a TIntIntMap with unit count. The map must contain all the keys specified
+ * in the columns
+ *
+ * @param inst {@link TIntIntMap} - the row to be added
+ * @param count the number of rows being added.
+ */
+ public void add(TIntIntMap inst) {
+ add(inst, unit);
+ }
+
+ /**
+ * Fills the provided sub-table with aggregated data from this table.
+ * Columns missing in this table are set to zero.
+ *
+ * @param cols the subset of columns
+ * @return a new Table
+ */
+ protected > SUB subtable(SUB tofill) {
+ int[] cols = tofill.columns;
+
+ // indices of the desired columns
+ int[] idx = Arrays.stream(cols).map(col -> ArrayUtils.indexOf(columns, col)).toArray();
+ //int[] matching = IntStream.of(idx).map(id -> columns[id]).toArray();
+
+ for (Map.Entry entry : dataTable.entrySet()) {
+ int[] values = entry.getKey();
+ T count = entry.getValue();
+
+ int[] newkey = Arrays.stream(idx).map(i -> i < 0 ? 0 : values[i]).toArray();
+ tofill.add(newkey, count);
+ }
+
+ return tofill;
+ }
+
+ /**
+ * Covert weights of a Table
+ *
+ * @param op the conversion operation
+ * @return the new Table
+ */
+ public DataTable mapWeights(Function op) {
+ // new table shares the same columns by default
+ DataTable table = new DataTable(columns, unit, zero, add, createArray);
+ for (Map.Entry entry : dataTable.entrySet()) {
+ table.dataTable.put(entry.getKey(), op.apply(entry.getValue()));
+ }
+ return table;
+ }
+
+
+
+ @Override
+ public Iterator> iterator() {
+ return dataTable.entrySet().iterator();
+ }
+
+
+ public Iterable> mapIterable() {
+ return new Iterable>() {
+
+ @Override
+ public Iterator> iterator() {
+
+ var iter = dataTable.entrySet().iterator();
+ return new Iterator>() {
+
+ @Override
+ public boolean hasNext() {
+ return iter.hasNext();
+ }
+
+ @Override
+ public Pair next() {
+ var nextVal = iter.next();
+ TIntIntMap ret = new TIntIntHashMap(columns, nextVal.getKey());
+ return Pair.of(ret, nextVal.getValue());
+ }
+ };
+ }
+ };
+ }
+
+
+ public int[] getColumns() {
+ return this.columns;
+ }
+
+
+ public > DT map(Supplier
creator, BiFunction> converter) {
+ final var store = creator.get();
+
+ dataTable.entrySet()
+ .stream()
+ .map((v) -> converter.apply(v.getKey(), v.getValue()))
+ .forEach(v -> store.add(v.getKey(), v.getValue()));
+ return store;
+
+ }
+
+
+ public Set> entries() {
+ return dataTable.entrySet();
+ }
+
+}
diff --git a/src/main/java/ch/idsia/crema/data/DoubleTable.java b/src/main/java/ch/idsia/crema/data/DoubleTable.java
new file mode 100644
index 00000000..d0e5aef7
--- /dev/null
+++ b/src/main/java/ch/idsia/crema/data/DoubleTable.java
@@ -0,0 +1,214 @@
+package ch.idsia.crema.data;
+
+import java.io.BufferedReader;
+import java.io.FileReader;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Map;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+import gnu.trove.map.TIntIntMap;
+import gnu.trove.map.hash.TIntIntHashMap;
+
+public class DoubleTable extends DataTable {
+
+ public DoubleTable(int[] columns) {
+ super(columns, 1d, 0d, (a, b) -> a + b, (i) -> new Double[i]);
+ }
+
+ public DoubleTable(int[] columns, int[][] data) {
+ this(columns);
+
+ if (data != null) {
+ for (int[] inst : data) {
+ add(inst, unit);
+ }
+ }
+ }
+
+ /**
+ * Create a Table from an array of {@link TIntIntMap}s.
+ *
+ * @param data
+ * @param unit
+ * @param add
+ */
+ public DoubleTable(TIntIntMap[] data) {
+ this(TIntMapConverter.cols(data));
+
+ if (data.length == 0)
+ throw new IllegalArgumentException("Need data");
+
+ for (TIntIntMap inst : data) {
+ add(inst);
+ }
+ }
+
+ public DoubleTable subtable(int[] cols) {
+ DoubleTable tofill = new DoubleTable(cols);
+ return super.subtable(tofill);
+ }
+
+
+ /**
+ * Scale weights between 0 and 1
+ * @return
+ */
+ public DoubleTable scale() {
+ double small = 0;
+
+ for (var entry : dataTable.entrySet()) {
+ small = Math.max(small, entry.getValue());
+ }
+
+ //System.out.println("AMX " + small);
+ if (small == 0) return this;
+
+ DoubleTable res = new DoubleTable(columns);
+ for (var entry : dataTable.entrySet()) {
+ res.add(entry.getKey(), entry.getValue() / small);
+ }
+
+ return res;
+ }
+
+
+
+
+ public double[] getWeights(int[] vars, int[] sizes) {
+ Double[] dta = super.getWeightsFor(vars, sizes);
+ return Arrays.stream(dta).mapToDouble(i -> i).toArray();
+ }
+
+ public double[] getWeights2(int[] vars, int[] sizes) {
+
+ // cumulative size
+ int cumsize = 1;
+
+ TIntIntMap strides = new TIntIntHashMap();
+ for (int i = 0; i < vars.length; ++i) {
+ strides.put(vars[i], cumsize);
+ cumsize = cumsize * sizes[i];
+ }
+
+ int[] col_strides = new int[columns.length];
+
+ for (int i = 0; i < columns.length; ++i) {
+ if (strides.containsKey(columns[i])) {
+ col_strides[i] = strides.get(columns[i]);
+ }
+ }
+
+ double[] results = new double[cumsize];
+
+ for (var item : dataTable.entrySet()) {
+ int[] states = item.getKey();
+ int offset = 0;
+ for (int i = 0; i < columns.length; ++i) {
+ offset += col_strides[i] * states[i];
+ }
+ results[offset] += item.getValue();
+ }
+
+ return results;
+ }
+
+ TIntIntMap getKeyMap(int[] index) {
+ TIntIntMap map = new TIntIntHashMap();
+ for (int i = 0; i < index.length; ++i) {
+ map.put(columns[i], index[i]);
+ }
+ return map;
+ }
+
+ public TIntIntMap[] toMap(boolean roundup) {
+ return dataTable.entrySet().stream().flatMap(row -> IntStream
+ .range(0, (int) (row.getValue() + (roundup ? 0.5 : 0))).mapToObj(i -> this.getKeyMap(row.getKey())))
+ .toArray(TIntIntHashMap[]::new);
+
+ }
+
+ /**
+ * Read a table from a whitespace separated file. The whitespace can be any
+ * regex \s character.
+ *
+ * @param filename
+ * @return
+ * @throws IOException
+ */
+ public static DoubleTable readTable(String filename) throws IOException {
+ return readTable(filename, "\\s");
+ }
+
+ /**
+ * Read table from sep separated lists of values. Lists are separated by
+ * newlines and first row is the header. Header row must be integers.
+ *
+ * @param filename
+ * @param sep
+ * @return
+ * @throws IOException
+ */
+ public static DoubleTable readTable(String filename, String sep) throws IOException {
+ try (BufferedReader input = new BufferedReader(new FileReader(filename))) {
+ String[] cols = input.readLine().split(sep);
+ int[] columns = Arrays.stream(cols).mapToInt(Integer::parseInt).toArray();
+ DoubleTable ret = new DoubleTable(columns);
+
+ String line;
+ while ((line = input.readLine()) != null) {
+ cols = line.split(sep);
+
+ int[] row = Arrays.stream(cols).mapToInt(Integer::parseInt).toArray();
+ ret.add(row, 1d);
+ }
+ return ret;
+ }
+ }
+
+ static int parseInt(String value) {
+ return (int) Double.parseDouble(value);
+ }
+
+ public static DoubleTable readTable(String filename, int skip, String sep, Map columns_out)
+ throws IOException {
+ try (BufferedReader input = new BufferedReader(new FileReader(filename))) {
+ for (int i = 0; i < skip; ++i)
+ input.readLine().split(sep);
+
+ String[] values = input.readLine().split(sep);
+ int[] columns = IntStream.range(0, values.length).toArray();
+ if (columns_out != null) {
+ for (int i = 0; i < values.length; ++i) {
+ columns_out.put(values[i], columns[i]);
+ }
+ }
+
+ DoubleTable ret = new DoubleTable(columns);
+
+ String line;
+ while ((line = input.readLine()) != null) {
+ values = line.split(sep);
+
+ int[] row = Arrays.stream(values).mapToInt(DoubleTable::parseInt).toArray();
+ ret.add(row, 1d);
+ }
+ return ret;
+ }
+ }
+
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ String head = Arrays.stream(columns).mapToObj(Integer::toString).collect(Collectors.joining(" | "));
+ sb.append(head).append(" | weight\n");
+
+ for (Map.Entry row : dataTable.entrySet()) {
+ String key = Arrays.stream(row.getKey()).mapToObj(Integer::toString).collect(Collectors.joining(" | "));
+ sb.append(key).append(" | ").append(row.getValue()).append("\n");
+ }
+ return sb.toString();
+
+ }
+
+}
diff --git a/src/main/java/ch/idsia/crema/data/TIntMapConverter.java b/src/main/java/ch/idsia/crema/data/TIntMapConverter.java
new file mode 100644
index 00000000..3e1acd86
--- /dev/null
+++ b/src/main/java/ch/idsia/crema/data/TIntMapConverter.java
@@ -0,0 +1,60 @@
+package ch.idsia.crema.data;
+
+import java.util.Arrays;
+import java.util.function.Function;
+
+import gnu.trove.map.TIntIntMap;
+import gnu.trove.set.TIntSet;
+import gnu.trove.set.hash.TIntHashSet;
+
+class TIntMapConverter {
+
+ /**
+ * Convert from the specified map to an array of int using the specified columns
+ * order
+ *
+ * @param map the map to be converted
+ * @param columns the order of the columns
+ * @return the converted integer array
+ */
+ public static int[] from(TIntIntMap map, int[] columns) {
+ return Arrays.stream(columns).map(map::get).toArray();
+ }
+
+ /**
+ * A curried version of the from method that returns a version of from with
+ * fixed columns
+ *
+ * @param columns the order to be fixed
+ * @return the function
+ */
+ public static Function curriedFrom(int[] columns) {
+ return map -> Arrays.stream(columns).map(map::get).toArray();
+ }
+
+ /**
+ * Convert an array of maps to an array of int arrays sorted by columns.
+ *
+ * @param map the input data
+ * @param columns the order of the columns
+ *
+ * @return the output data
+ */
+ public static int[][] from(TIntIntMap[] map, int[] columns) {
+ return Arrays.stream(map).map(curriedFrom(columns)).toArray(int[][]::new);
+ }
+
+ /**
+ * Discover the set of columns assuming that not all are specified
+ *
+ * @param data
+ * @return
+ */
+ protected static int[] cols(TIntIntMap[] data) {
+ TIntSet columns = new TIntHashSet();
+ for (TIntIntMap row : data) {
+ columns.addAll(row.keys());
+ }
+ return columns.toArray();
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/ch/idsia/crema/factor/bayesian/BayesianDefaultFactor.java b/src/main/java/ch/idsia/crema/factor/bayesian/BayesianDefaultFactor.java
index f9bae89a..35e76a6c 100644
--- a/src/main/java/ch/idsia/crema/factor/bayesian/BayesianDefaultFactor.java
+++ b/src/main/java/ch/idsia/crema/factor/bayesian/BayesianDefaultFactor.java
@@ -3,6 +3,7 @@
import ch.idsia.crema.core.Domain;
import ch.idsia.crema.core.ObservationBuilder;
import ch.idsia.crema.core.Strides;
+import ch.idsia.crema.factor.algebra.GenericOperationFunction;
import ch.idsia.crema.factor.algebra.bayesian.BayesianOperation;
import ch.idsia.crema.factor.algebra.bayesian.SimpleBayesianFilter;
import ch.idsia.crema.factor.algebra.bayesian.SimpleBayesianMarginal;
@@ -332,11 +333,61 @@ public BayesianDefaultFactor combine(BayesianFactor factor) {
factor = ((BayesianLogFactor) factor).exp();
if (factor instanceof BayesianDefaultFactor)
- return combine((BayesianDefaultFactor) factor, BayesianDefaultFactor::new, ops::combine);
+ return owncombine((BayesianDefaultFactor) factor, BayesianDefaultFactor::new, ops::combine);
return (BayesianDefaultFactor) super.combine(factor);
}
+
+ protected F owncombine(F factor, BayesianFactorBuilder builder, GenericOperationFunction op) {
+ // domains should be sorted
+ //factor = (F) factor.copy();
+
+ final Strides target = getDomain().union(factor.getDomain());
+ final int length = target.getSize();
+
+ final int[] limits = new int[length];
+ final int[] assign = new int[length];
+
+ final long[] stride = new long[length];
+ final long[] reset = new long[length];
+
+ for (int vindex = 0; vindex < getDomain().getSize(); ++vindex) {
+ int offset = Arrays.binarySearch(target.getVariables(), getDomain().getVariables()[vindex]);
+ stride[offset] = getDomain().getStrides()[vindex];
+ }
+
+ for (int vindex = 0; vindex < factor.getDomain().getSize(); ++vindex) {
+ int offset = ArraysUtil.indexOf(factor.getDomain().getVariables()[vindex], target.getVariables());
+ stride[offset] += ((long) factor.getDomain().getStrides()[vindex] << 32L);
+ }
+
+ for (int i = 0; i < length; ++i) {
+ limits[i] = target.getSizes()[i] - 1;
+ reset[i] = limits[i] * stride[i];
+ }
+
+ long idx = 0;
+ double[] result = new double[target.getCombinations()];
+
+ for (int i = 0; i < result.length; ++i) {
+ result[i] = this.data[ (int) (idx & 0xFFFFFFFF)] * factor.getValueAt((int) (idx >>> 32L));
+
+ for (int l = 0; l < length; ++l) {
+ if (assign[l] == limits[l]) {
+ assign[l] = 0;
+ idx -= reset[l];
+ } else {
+ ++assign[l];
+ idx += stride[l];
+ break;
+ }
+ }
+ }
+
+ return builder.get(target, result);
+ }
+
/**
*
* If the input factor is also a {@link BayesianDefaultFactor}, a fast algebra i used. If the input is a
@@ -521,5 +572,139 @@ public double logProb(TIntIntMap[] data, int leftVar) {
return logprob;
}
+
+
+
+
+ /**
+ * BD -> ABD -> ABCD -> ABCDE
+ * b0d0 0 a0b0d0 0 a0b0c0d0 0 0
+ * b1d0 1 a1b0d0 0 a1b0c0d0 0 0
+ * b0d1 2 a0b1d0 1 a0b1c0d0 1 1
+ * b1d1 3 a1b1d0 1 a1b1c0d0 1
+ * a0b0d1 2 a0b0d1 2
+ * a1b0d1 2 a1b0d1 2
+ * a0b1d1 3 a0b1d1 3
+ * a0b1d1 3 a0b1d1 3
+ */
+
+
+ static class Side {
+ int position;
+ int[] sizes;
+ int[] variables;
+ int combinations;
+ int size;
+
+ int stride;
+
+ Side(Strides d) {
+ this.position = 0;
+ this.stride = 1;
+ this.sizes = d.getSizes();
+ this.variables = d.getVariables();
+ this.size = variables.length;
+ this.combinations = d.getCombinations();
+ }
+
+ boolean ok() {
+ return position < size;
+ }
+ int currentVar() {
+ return variables[position];
+ }
+ int currentSize() {
+ return sizes[position];
+ }
+ void move() {
+ ++position;
+ }
+ }
+
+
+ /**
+ * expectations:
+ *
+ *
Ordered domains/stride
+ *
$domain \subseteq target$
+ *
+ *
+ * @param target
+ * @return
+ */
+ protected static double[] expand(Strides from_domain, double[] data, Strides to_domain) {
+ Side source = new Side(from_domain);
+ Side target = new Side(to_domain);
+
+ double[] source_data = data;
+
+ double[] cache1 = new double[target.combinations];
+ double[] target_data = cache1;
+
+ // head insertions repeat single rows
+ int used = source_data.length;
+
+
+
+
+ while (true) {
+
+ // dato un set target e uno sorgetnte
+ // trovare blocchi di differenze.
+ // due indicatori: posizione source e posizione target
+ // inizio differenza
+ while (target.ok() && source.ok() && target.currentVar() == source.currentVar()) {
+ target.move();
+ source.move();
+ }
+ if (!target.ok()) return source_data; // no difference found
+
+ // target's current variable is missing in source
+ int from = target.position;
+ int repeat = 1;
+
+ while (target.ok() && (!source.ok() || target.currentVar() != source.currentVar())) {
+ repeat *= target.currentSize();
+ target.move();
+ }
+
+ // the number of rows to repeat
+
+ // per ogni differenza devo sapere
+ // quanti elementi ripetere, Prod size of left
+ // quante volte ripetere, dimensioni delle variabili inserite
+ // quante volte rifare questa ripetizione
+
+ int rows = to_domain.getStrideAt(from);
+ int tpos = 0;
+ for (int r = 0; r < used; r += rows) {
+ for (int j = 0; j < repeat; ++j) {
+ System.arraycopy(source_data, r, target_data, tpos, rows);
+ tpos += rows;
+ }
+ }
+ used = tpos;
+
+ if (!target.ok()) return target_data;
+
+ if (source_data == data) {
+ source_data = target_data;
+ target_data = new double[target.combinations];
+ } else {
+ // swap buffers
+ double[] tmp= source_data;
+ source_data = target_data;
+ target_data = source_data;
+ }
+ }
+ }
+ public static void main(String[] args) {
+ double[] data = new double[] { 0, 1, 2, 3};
+ Strides from = Strides.var(2, 2).and(4, 2);
+ Strides to = Strides.var(1,2).and(2, 2).and(3,2).and(4, 2).and(5, 2);
+ double[] target = expand(from, data, to);
+ System.out.println(Arrays.toString(target));
+
+ }
}
diff --git a/src/main/java/ch/idsia/crema/factor/bayesian/BayesianFactor.java b/src/main/java/ch/idsia/crema/factor/bayesian/BayesianFactor.java
index 81776af9..b9eec4d0 100644
--- a/src/main/java/ch/idsia/crema/factor/bayesian/BayesianFactor.java
+++ b/src/main/java/ch/idsia/crema/factor/bayesian/BayesianFactor.java
@@ -46,7 +46,7 @@ public interface BayesianFactor extends OperableFactor, Separate
*/
double getLogValueAt(int index);
- // TODO: this method should only for internal use
+ // TODO: this method should only be for internal use
double[] getData();
@Override
diff --git a/src/main/java/ch/idsia/crema/factor/bayesian/BayesianFactorFactory.java b/src/main/java/ch/idsia/crema/factor/bayesian/BayesianFactorFactory.java
index ee11d8e1..7b0834c1 100644
--- a/src/main/java/ch/idsia/crema/factor/bayesian/BayesianFactorFactory.java
+++ b/src/main/java/ch/idsia/crema/factor/bayesian/BayesianFactorFactory.java
@@ -5,12 +5,20 @@
import ch.idsia.crema.utility.ArraysUtil;
import ch.idsia.crema.utility.IndexIterator;
+import java.util.Arrays;
+import java.util.Random;
import java.util.stream.IntStream;
+import org.apache.commons.math3.random.RandomGeneratorFactory;
+import org.apache.commons.math3.random.UniformRandomGenerator;
+import org.apache.commons.rng.UniformRandomProvider;
+import org.apache.commons.rng.sampling.distribution.DirichletSampler;
+import org.apache.commons.rng.simple.JDKRandomWrapper;
+
+import com.google.common.primitives.Doubles;
+
/**
- * Author: Claudio "Dna" Bonesana
- * Project: crema
- * Date: 16.04.2021 11:11
+ * Author: Claudio "Dna" Bonesana Project: crema Date: 16.04.2021 11:11
*/
public class BayesianFactorFactory {
private double[] data = null;
@@ -19,36 +27,49 @@ public class BayesianFactorFactory {
private Strides domain = Strides.empty();
private BayesianFactorFactory() {
+ initRandom(0);
+ }
+
+ public BayesianFactorFactory initRandom(long seed) {
+ unif = new JDKRandomWrapper(new Random(seed));
+ return this;
}
/**
- * @param var the variable associated with this factor. This variable will be considered binary
+ * @param var the variable associated with this factor. This variable will be
+ * considered binary
* @return a {@link BayesianDefaultFactor} where state 1 has probability 1.0.
*/
public static BayesianDefaultFactor one(int var) {
- return new BayesianDefaultFactor(Strides.var(var, 2), new double[]{0., 1.});
+ return new BayesianDefaultFactor(Strides.var(var, 2), new double[] { 0., 1. });
}
/**
- * @param var the variable associated with this factor. This variable will be considered binary
+ * @param var the variable associated with this factor. This variable will be
+ * considered binary
* @return a {@link BayesianDefaultFactor} where state 0 has probability 1.0.
*/
public static BayesianDefaultFactor zero(int var) {
- return new BayesianDefaultFactor(Strides.var(var, 2), new double[]{1., 0.});
+ return new BayesianDefaultFactor(Strides.var(var, 2), new double[] { 1., 0. });
}
/**
* This is the entry point for the chained interface of this factory.
*
- * @return a {@link BayesianFactorFactory} object that can be used to chain multiple commands.
+ * @return a {@link BayesianFactorFactory} object that can be used to chain
+ * multiple commands.
*/
public static BayesianFactorFactory factory() {
return new BayesianFactorFactory();
}
/**
+ * Set the domain of the builder. Order is ignored and natural ordering will be
+ * used.
+ *
* @param domain set the domain of the factor
- * @return a {@link BayesianFactorFactory} object that can be used to chain multiple commands.
+ * @return a {@link BayesianFactorFactory} object that can be used to chain
+ * multiple commands.
*/
public BayesianFactorFactory domain(Domain domain) {
this.domain = Strides.fromDomain(domain);
@@ -58,17 +79,32 @@ public BayesianFactorFactory domain(Domain domain) {
/**
* @param domain the variables that defines the domain
* @param sizes the sizes of each variable
- * @return a {@link BayesianFactorFactory} object that can be used to chain multiple commands.
+ * @return a {@link BayesianFactorFactory} object that can be used to chain
+ * multiple commands.
*/
public BayesianFactorFactory domain(int[] domain, int[] sizes) {
- this.domain = new Strides(domain, sizes);
- return data();
+ return domain(domain, sizes, false);
+ }
+
+ public BayesianFactorFactory domain(int[] domain, int[] sizes, boolean already_sorted) {
+ if (!already_sorted) {
+ int[] pos = ArraysUtil.order(domain);
+
+ int[] sortedDomain = ArraysUtil.at(domain, pos);
+ int[] sortedSizes = ArraysUtil.at(sizes, pos);
+
+ this.domain = new Strides(sortedDomain, sortedSizes);
+ } else {
+ this.domain = new Strides(domain, sizes);
+ }
+ return this;
}
/**
* Set an empty data set with all the combinations of the given domain.
*
- * @return a {@link BayesianFactorFactory} object that can be used to chain multiple commands.
+ * @return a {@link BayesianFactorFactory} object that can be used to chain
+ * multiple commands.
*/
public BayesianFactorFactory data() {
this.data = new double[domain.getCombinations()];
@@ -77,12 +113,14 @@ public BayesianFactorFactory data() {
/**
* @param data an array of values that will be directly used
- * @return a {@link BayesianFactorFactory} object that can be used to chain multiple commands.
+ * @return a {@link BayesianFactorFactory} object that can be used to chain
+ * multiple commands.
*/
public BayesianFactorFactory data(double[] data) {
final int expectedLength = domain.getCombinations();
if (data.length > expectedLength)
- throw new IllegalArgumentException("Invalid length of data: expected " + expectedLength + " got " + data.length);
+ throw new IllegalArgumentException(
+ "Invalid length of data: expected " + expectedLength + " got " + data.length);
if (this.data == null)
this.data = new double[expectedLength];
@@ -94,12 +132,14 @@ public BayesianFactorFactory data(double[] data) {
/**
* @param data an array of values in log-space that will be directly used
- * @return a {@link BayesianFactorFactory} object that can be used to chain multiple commands.
+ * @return a {@link BayesianFactorFactory} object that can be used to chain
+ * multiple commands.
*/
public BayesianFactorFactory logData(double[] data) {
final int expectedLength = domain.getCombinations();
if (data.length != expectedLength)
- throw new IllegalArgumentException("Invalid length of data: expected " + expectedLength + " got " + data.length);
+ throw new IllegalArgumentException(
+ "Invalid length of data: expected " + expectedLength + " got " + data.length);
if (this.logData == null)
this.logData = new double[expectedLength];
@@ -112,7 +152,8 @@ public BayesianFactorFactory logData(double[] data) {
/**
* @param domain the order of the variables that defines the values
* @param data the values to use specified with the given domain order
- * @return a {@link BayesianFactorFactory} object that can be used to chain multiple commands.
+ * @return a {@link BayesianFactorFactory} object that can be used to chain
+ * multiple commands.
*/
public BayesianFactorFactory data(int[] domain, double[] data) {
int[] sequence = ArraysUtil.order(domain);
@@ -148,7 +189,19 @@ public BayesianFactorFactory data(int[] domain, double[] data) {
/**
* @param value a single value to set
* @param states the states that defines this value
- * @return a {@link BayesianFactorFactory} object that can be used to chain multiple commands.
+ * @return a {@link BayesianFactorFactory} object that can be used to chain
+ * multiple commands.
+ */
+ public BayesianFactorFactory value(double value, int[] variables, int[] states) {
+ data[domain.getPartialOffset(variables, states)] = value;
+ return this;
+ }
+
+ /**
+ * @param value a single value to set
+ * @param states the states that defines this value
+ * @return a {@link BayesianFactorFactory} object that can be used to chain
+ * multiple commands.
*/
public BayesianFactorFactory value(double value, int... states) {
data[domain.getOffset(states)] = value;
@@ -158,7 +211,8 @@ public BayesianFactorFactory value(double value, int... states) {
/**
* @param d a single value to set
* @param index the index (or offset) of the value to set
- * @return a {@link BayesianFactorFactory} object that can be used to chain multiple commands.
+ * @return a {@link BayesianFactorFactory} object that can be used to chain
+ * multiple commands.
*/
public BayesianFactorFactory valueAt(double d, int index) {
data[index] = d;
@@ -168,7 +222,8 @@ public BayesianFactorFactory valueAt(double d, int index) {
/**
* @param d a single value to set
* @param index the index (or offset) of the value to set
- * @return a {@link BayesianFactorFactory} object that can be used to chain multiple commands.
+ * @return a {@link BayesianFactorFactory} object that can be used to chain
+ * multiple commands.
*/
public BayesianFactorFactory valuesAt(double[] d, int index) {
System.arraycopy(d, 0, data, index, d.length);
@@ -178,14 +233,27 @@ public BayesianFactorFactory valuesAt(double[] d, int index) {
/**
* @param value a single value to set
* @param states the states that defines this value
- * @return a {@link BayesianFactorFactory} object that can be used to chain multiple commands.
+ * @return a {@link BayesianFactorFactory} object that can be used to chain
+ * multiple commands.
*/
public BayesianFactorFactory set(double value, int... states) {
return value(value, states);
}
+
/**
- * @return a {@link BayesianLogFactor}, where the given data are converted to log-space
+ * @param value a single value to set
+ * @param variables the order of the states attribute
+ * @param states the states that defines this value
+ * @return a {@link BayesianFactorFactory} object that can be used to chain
+ * multiple commands.
+ */
+ public BayesianFactorFactory set(double value, int[] variables, int[] states) {
+ return value(value, states);
+ }
+ /**
+ * @return a {@link BayesianLogFactor}, where the given data are converted to
+ * log-space
*/
public BayesianLogFactor log() {
// sort variables
@@ -205,7 +273,8 @@ public BayesianLogFactor log() {
}
/**
- * @return a {@link BayesianDefaultFactor}, where the given data are converted to non-log-space.
+ * @return a {@link BayesianDefaultFactor}, where the given data are converted
+ * to non-log-space.
*/
public BayesianDefaultFactor get() {
// sort variables
@@ -236,7 +305,8 @@ public BayesianNotFactor not(int parent) {
* Requires a pre-defined Domain.
*
* @param parent variable that is the parent of this factor
- * @param trueState index of the state to be considered as TRUE for the given parent
+ * @param trueState index of the state to be considered as TRUE for the given
+ * parent
* @return a logic {@link BayesianAndFactor}
*/
public BayesianAndFactor not(int parent, int trueState) {
@@ -268,7 +338,8 @@ public BayesianAndFactor and(int... parents) {
/**
* Requires a pre-defined Domain.
*
- * @param trueStates index of the state to be considered as TRUE for each given parent
+ * @param trueStates index of the state to be considered as TRUE for each given
+ * parent
* @param parents variables that are parents of this factor
* @return a logic {@link BayesianAndFactor}
*/
@@ -316,7 +387,8 @@ public BayesianOrFactor or(int... parents) {
/**
* Requires a pre-defined Domain.
*
- * @param trueStates index of the state to be considered as TRUE for each given parent
+ * @param trueStates index of the state to be considered as TRUE for each given
+ * parent
* @param parents variables that are parents of this factor
* @return a logic {@link BayesianOrFactor}
*/
@@ -369,7 +441,8 @@ public BayesianNoisyOrFactor noisyOr(int[] parents, double[] strengths) {
* Requires a pre-defined Domain.
*
* @param parents variables that are parents of this factor
- * @param trueStates index of the state to be considered as TRUE for each given parent
+ * @param trueStates index of the state to be considered as TRUE for each given
+ * parent
* @param strengths values for the inhibition strength for each given parent
* @return a logic {@link BayesianNoisyOrFactor}
*/
@@ -394,4 +467,39 @@ public BayesianNoisyOrFactor noisyOr(int[] parents, int[] trueStates, double[] s
return new BayesianNoisyOrFactor(new Strides(vars, sizes), pars, trus, inbs);
}
+ UniformRandomProvider unif;
+
+ public BayesianFactorFactory random() {
+ return random(null);
+ }
+
+ public BayesianFactorFactory random(Domain conditioning) {
+
+ int combinations;
+ Strides dom;
+ int[] vars;
+
+ if (conditioning == null) {
+ combinations = 1;
+ dom = this.domain;
+ vars = dom.getVariables();
+ } else {
+ combinations = Arrays.stream(conditioning.getSizes()).reduce(1, (a, b) -> a * b);
+ dom = this.domain.remove(conditioning.getVariables());
+ vars = ArraysUtil.append(dom.getVariables(), conditioning.getVariables());
+ }
+
+ int size = dom.getCombinations();
+
+ double[] alpha = new double[size];
+ Arrays.fill(alpha, 1.0);
+
+ DirichletSampler ds = DirichletSampler.of(unif, alpha);
+ double[][] arr = ds.samples(combinations).toArray(double[][]::new);
+ if (combinations == 1) {
+ return data(vars, arr[0]);
+ } else {
+ return data(vars, Doubles.concat(arr));
+ }
+ }
}
diff --git a/src/main/java/ch/idsia/crema/factor/credal/vertex/separate/VertexAbstractFactor.java b/src/main/java/ch/idsia/crema/factor/credal/vertex/separate/VertexAbstractFactor.java
index 2848b250..c08ed92d 100644
--- a/src/main/java/ch/idsia/crema/factor/credal/vertex/separate/VertexAbstractFactor.java
+++ b/src/main/java/ch/idsia/crema/factor/credal/vertex/separate/VertexAbstractFactor.java
@@ -180,7 +180,7 @@ protected F reseparate(Strides target, VertexFa
Strides T = getSeparatingDomain().intersection(target);
Strides Lt = getSeparatingDomain().remove(target);
- Strides Dl = getDataDomain().union(Lt);
+ Strides Dl = getDataDomain().union(Lt).sort();
// target data
double[][][] dest_data = new double[T.getCombinations()][][];
diff --git a/src/main/java/ch/idsia/crema/inference/approxlp1/CredalApproxLP.java b/src/main/java/ch/idsia/crema/inference/approxlp1/CredalApproxLP.java
index c167de65..bc6fd36b 100644
--- a/src/main/java/ch/idsia/crema/inference/approxlp1/CredalApproxLP.java
+++ b/src/main/java/ch/idsia/crema/inference/approxlp1/CredalApproxLP.java
@@ -7,7 +7,7 @@
import ch.idsia.crema.model.graphical.GraphicalModel;
import ch.idsia.crema.model.graphical.MixedModel;
import ch.idsia.crema.preprocess.BinarizeEvidence;
-import ch.idsia.crema.preprocess.CutObserved;
+import ch.idsia.crema.preprocess.Observe;
import ch.idsia.crema.preprocess.RemoveBarren;
import gnu.trove.map.TIntIntMap;
import gnu.trove.map.hash.TIntIntHashMap;
@@ -18,7 +18,7 @@ public class CredalApproxLP> implements Inference<
protected GraphicalModel getInferenceModel(GraphicalModel model, TIntIntMap evidence, int target) {
// preprocessing
- final CutObserved cut = new CutObserved<>();
+ final Observe cut = new Observe<>();
final GraphicalModel cutted = cut.execute(model, evidence);
RemoveBarren removeBarren = new RemoveBarren<>();
diff --git a/src/main/java/ch/idsia/crema/inference/bp/BeliefPropagation.java b/src/main/java/ch/idsia/crema/inference/bp/BeliefPropagation.java
index 530963a2..abb62ed2 100644
--- a/src/main/java/ch/idsia/crema/inference/bp/BeliefPropagation.java
+++ b/src/main/java/ch/idsia/crema/inference/bp/BeliefPropagation.java
@@ -5,7 +5,7 @@
import ch.idsia.crema.inference.bp.cliques.Clique;
import ch.idsia.crema.inference.bp.junction.JunctionTree;
import ch.idsia.crema.model.graphical.DAGModel;
-import ch.idsia.crema.preprocess.CutObserved;
+import ch.idsia.crema.preprocess.Observe;
import ch.idsia.crema.preprocess.RemoveBarren;
import gnu.trove.map.TIntIntMap;
import gnu.trove.map.hash.TIntIntHashMap;
@@ -61,7 +61,7 @@ public Boolean isFullyPropagated() {
}
/**
- * If {@link #preprocess} is true, then pre-process the model with {@link CutObserved} and {@link RemoveBarren}.
+ * If {@link #preprocess} is true, then pre-process the model with {@link Observe} and {@link RemoveBarren}.
*
* @param original the model to use for inference
* @param evidence the observed variable as a map of variable-states
@@ -72,7 +72,7 @@ protected DAGModel preprocess(DAGModel original, TIntIntMap evidence, int.
DAGModel model = original;
if (preprocess) {
model = original.copy();
- final CutObserved co = new CutObserved<>();
+ final Observe co = new Observe<>();
final RemoveBarren rb = new RemoveBarren<>();
co.executeInPlace(model, evidence);
@@ -172,7 +172,7 @@ protected Clique getRoot(JunctionTree junctionTree, int variable) {
}
/**
- * Pre-process the model with {@link CutObserved} and {@link RemoveBarren}, then performs the
+ * Pre-process the model with {@link Observe} and {@link RemoveBarren}, then performs the
* {@link #collectingEvidence(int)} step.
*
* @param model the model to use for inference
@@ -185,7 +185,7 @@ public F query(DAGModel model, int query) {
}
/**
- * Pre-process the model with {@link CutObserved} and {@link RemoveBarren}, then performs the
+ * Pre-process the model with {@link Observe} and {@link RemoveBarren}, then performs the
* {@link #collectingEvidence(int)} step.
*
* @param original the model to use for inference
@@ -204,7 +204,7 @@ public F query(DAGModel original, TIntIntMap evidence, int query) {
}
/**
- * Pre-process the model with {@link CutObserved} and {@link RemoveBarren}, then performs the
+ * Pre-process the model with {@link Observe} and {@link RemoveBarren}, then performs the
* {@link #collectingEvidence(int)} and {@link #distributingEvidence()} steps.
*
* Use the {@link #queryFullPropagated(int)} method for query multiple variables over the same evidence and model.
diff --git a/src/main/java/ch/idsia/crema/inference/bp/LoopyBeliefPropagation.java b/src/main/java/ch/idsia/crema/inference/bp/LoopyBeliefPropagation.java
index 5db56944..f6252cc7 100644
--- a/src/main/java/ch/idsia/crema/inference/bp/LoopyBeliefPropagation.java
+++ b/src/main/java/ch/idsia/crema/inference/bp/LoopyBeliefPropagation.java
@@ -3,7 +3,7 @@
import ch.idsia.crema.factor.OperableFactor;
import ch.idsia.crema.inference.Inference;
import ch.idsia.crema.model.graphical.DAGModel;
-import ch.idsia.crema.preprocess.CutObserved;
+import ch.idsia.crema.preprocess.Observe;
import ch.idsia.crema.preprocess.RemoveBarren;
import ch.idsia.crema.utility.ArraysUtil;
import gnu.trove.map.TIntIntMap;
@@ -83,7 +83,7 @@ protected DAGModel preprocess(DAGModel original, TIntIntMap evidence, int.
DAGModel model = original;
if (preprocess) {
model = original.copy();
- final CutObserved co = new CutObserved<>();
+ final Observe co = new Observe<>();
final RemoveBarren rb = new RemoveBarren<>();
co.executeInPlace(model, evidence);
diff --git a/src/main/java/ch/idsia/crema/inference/sampling/LikelihoodWeightingSampling.java b/src/main/java/ch/idsia/crema/inference/sampling/LikelihoodWeightingSampling.java
index cda8aca3..87f47acd 100644
--- a/src/main/java/ch/idsia/crema/inference/sampling/LikelihoodWeightingSampling.java
+++ b/src/main/java/ch/idsia/crema/inference/sampling/LikelihoodWeightingSampling.java
@@ -3,7 +3,7 @@
import ch.idsia.crema.factor.bayesian.BayesianDefaultFactor;
import ch.idsia.crema.factor.bayesian.BayesianFactor;
import ch.idsia.crema.model.graphical.GraphicalModel;
-import ch.idsia.crema.preprocess.CutObserved;
+import ch.idsia.crema.preprocess.Observe;
import gnu.trove.map.TIntDoubleMap;
import gnu.trove.map.TIntIntMap;
import gnu.trove.map.TIntObjectMap;
@@ -45,7 +45,7 @@ public Collection run(GraphicalModel original, T
if (!preprocess) {
// this is mandatory
- final CutObserved co = new CutObserved<>();
+ final Observe co = new Observe<>();
co.executeInPlace(model, evidence);
}
diff --git a/src/main/java/ch/idsia/crema/inference/sampling/StochasticSampling.java b/src/main/java/ch/idsia/crema/inference/sampling/StochasticSampling.java
index f619150d..f69dc571 100644
--- a/src/main/java/ch/idsia/crema/inference/sampling/StochasticSampling.java
+++ b/src/main/java/ch/idsia/crema/inference/sampling/StochasticSampling.java
@@ -3,7 +3,7 @@
import ch.idsia.crema.factor.bayesian.BayesianFactor;
import ch.idsia.crema.inference.InferenceJoined;
import ch.idsia.crema.model.graphical.GraphicalModel;
-import ch.idsia.crema.preprocess.CutObserved;
+import ch.idsia.crema.preprocess.Observe;
import ch.idsia.crema.preprocess.RemoveBarren;
import ch.idsia.crema.utility.RandomUtil;
import gnu.trove.map.TIntIntMap;
@@ -62,7 +62,7 @@ protected GraphicalModel preprocess(GraphicalModel model = original;
if (preprocess) {
model = original.copy();
- final CutObserved co = new CutObserved<>();
+ final Observe co = new Observe<>();
final RemoveBarren rb = new RemoveBarren<>();
co.executeInPlace(model, evidence);
diff --git a/src/main/java/ch/idsia/crema/inference/ve/CredalVariableElimination.java b/src/main/java/ch/idsia/crema/inference/ve/CredalVariableElimination.java
index 586a4d8d..30c5cd4d 100644
--- a/src/main/java/ch/idsia/crema/inference/ve/CredalVariableElimination.java
+++ b/src/main/java/ch/idsia/crema/inference/ve/CredalVariableElimination.java
@@ -5,7 +5,7 @@
import ch.idsia.crema.inference.Inference;
import ch.idsia.crema.inference.ve.order.MinFillOrdering;
import ch.idsia.crema.model.graphical.GraphicalModel;
-import ch.idsia.crema.preprocess.CutObserved;
+import ch.idsia.crema.preprocess.Observe;
import ch.idsia.crema.preprocess.RemoveBarren;
import ch.idsia.crema.utility.hull.ConvexHull;
import gnu.trove.map.TIntIntMap;
@@ -27,7 +27,7 @@ public CredalVariableElimination setConvexHullMarg(ConvexHull convexHullMarg) {
}
protected GraphicalModel getInferenceModel(GraphicalModel model, TIntIntMap evidence, int target) {
- CutObserved cutObserved = new CutObserved<>();
+ Observe cutObserved = new Observe<>();
// run making a copy of the model
GraphicalModel infModel = cutObserved.execute(model, evidence);
diff --git a/src/main/java/ch/idsia/crema/learning/DiscreteEM.java b/src/main/java/ch/idsia/crema/learning/DiscreteEM.java
index 8c7ed19b..c9d47f48 100644
--- a/src/main/java/ch/idsia/crema/learning/DiscreteEM.java
+++ b/src/main/java/ch/idsia/crema/learning/DiscreteEM.java
@@ -4,7 +4,7 @@
import ch.idsia.crema.inference.InferenceJoined;
import ch.idsia.crema.inference.ve.FactorVariableElimination;
import ch.idsia.crema.model.graphical.GraphicalModel;
-import ch.idsia.crema.preprocess.CutObserved;
+import ch.idsia.crema.preprocess.Observe;
import ch.idsia.crema.preprocess.RemoveBarren;
import gnu.trove.map.TIntIntMap;
@@ -14,7 +14,7 @@ protected InferenceJoined, BayesianFactor> getDef
return new InferenceJoined<>() {
@Override
public BayesianFactor query(GraphicalModel model, TIntIntMap evidence, int... queries) {
- final CutObserved co = new CutObserved<>();
+ final Observe co = new Observe<>();
GraphicalModel coModel = co.execute(model, evidence);
final RemoveBarren rb = new RemoveBarren<>();
diff --git a/src/main/java/ch/idsia/crema/model/Model.java b/src/main/java/ch/idsia/crema/model/Model.java
index cdd0e08f..cf7ec203 100644
--- a/src/main/java/ch/idsia/crema/model/Model.java
+++ b/src/main/java/ch/idsia/crema/model/Model.java
@@ -27,6 +27,7 @@ public interface Model {
*/
Strides getDomain(int... variables);
+
/**
* remove a specific state from a variable
*
diff --git a/src/main/java/ch/idsia/crema/model/causal/SCM.java b/src/main/java/ch/idsia/crema/model/causal/SCM.java
new file mode 100644
index 00000000..1166c4a5
--- /dev/null
+++ b/src/main/java/ch/idsia/crema/model/causal/SCM.java
@@ -0,0 +1,194 @@
+package ch.idsia.crema.model.causal;
+
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.TreeMap;
+
+import org.apache.commons.lang3.tuple.Pair;
+
+import java.util.Map.Entry;
+import java.util.NoSuchElementException;
+
+import ch.idsia.crema.factor.bayesian.BayesianFactor;
+import ch.idsia.crema.model.graphical.DAGModel;
+import gnu.trove.iterator.TIntIterator;
+import gnu.trove.set.TIntSet;
+import gnu.trove.set.hash.TIntHashSet;
+
+public class SCM extends DAGModel {
+
+ public static enum VariableType {
+ /** Endgogenous variable have at least one exogenous counfounder */
+ ENDOGENOUS,
+
+ /** Exogenous variables are root nodes with no parents */
+ EXOGENOUS,
+
+ /** Extra variables are endogenous variables without direct exogenous influence */
+ EXTRA
+ }
+
+ private int nextId;
+ private TreeMap varSets;
+
+
+ public SCM() {
+ varSets = new TreeMap();
+
+ varSets.put(VariableType.ENDOGENOUS, new TIntHashSet());
+ varSets.put(VariableType.EXOGENOUS, new TIntHashSet());
+ varSets.put(VariableType.EXTRA, new TIntHashSet());
+ }
+
+ private int nextId() {
+ return nextId;
+ }
+
+
+ private void useId(int id) {
+ nextId = Math.max(nextId, id + 1);
+ }
+
+ public void add(TypedVariable variable) {
+ this.add(variable.getLabel(), variable.getCardinality(), variable.getType());
+ }
+
+ public int add(int size, VariableType type) {
+ int varid = nextId();
+ return this.add(varid, size, type);
+ }
+
+ public int add(int variable, int size, VariableType type) {
+ if (super.cardinalities.containsKey(variable))
+ return variable;
+
+ useId(variable);
+
+ super.addVariable(variable, size);
+ varSets.get(type).add(variable);
+
+ return variable;
+ }
+
+ public int addEndogenous(int size) {
+ return this.add(size, VariableType.ENDOGENOUS);
+ }
+
+ public int addExogenous(int size) {
+ return this.add(size, VariableType.EXOGENOUS);
+ }
+
+ public int addExtra(int size) {
+ return this.add(size, VariableType.EXTRA);
+ }
+
+ public int addEndogenous(int variable, int size) {
+ return this.add(variable, size, VariableType.ENDOGENOUS);
+ }
+
+ public int addExogenous(int variable, int size) {
+ return this.add(variable, size, VariableType.EXOGENOUS);
+ }
+
+ public int addAdditional(int variable, int size) {
+ return this.add(variable, size, VariableType.EXOGENOUS);
+ }
+
+ public boolean isEndogenous(int variable) {
+ return varSets.get(VariableType.ENDOGENOUS).contains(variable);
+ }
+
+ public boolean isExogenous(int variable) {
+ return varSets.get(VariableType.EXOGENOUS).contains(variable);
+ }
+
+ public boolean isExtra(int variable) {
+ return varSets.get(VariableType.EXTRA).contains(variable);
+ }
+
+ public boolean has(int variable) {
+ return this.cardinalities.containsKey(variable);
+ }
+
+ public VariableType getType(int variable) {
+ for (VariableType type : VariableType.values()) {
+ if (varSets.get(type).contains(variable)) return type;
+ }
+ return null;
+ }
+
+ public int[] getEndogenousVars() {
+ return varSets.get(VariableType.ENDOGENOUS).toArray();
+ }
+
+ public int[] getExogenousVars() {
+ return varSets.get(VariableType.EXOGENOUS).toArray();
+ }
+
+ public int[] getExtraVars() {
+ return varSets.get(VariableType.EXOGENOUS).toArray();
+ }
+
+ public void addParent(int variable, int parent) {
+ if (isExogenous(variable))
+ throw new IllegalStateException("Exogenous vars have no parents.");
+
+ if (isExtra(variable) && isExogenous(parent))
+ throw new IllegalStateException("Additional vars cannot depend on exogenous vars directly.");
+
+ if (isExtra(variable) && !isExtra(parent))
+ throw new IllegalStateException("Additional vars should only depend on each other.");
+
+ super.addParent(variable, parent);
+ }
+
+ public Iterable variables() {
+ return () -> {
+ return new Iterator() {
+ int type_index = 0;
+ VariableType[] types = Arrays.stream(VariableType.values()).filter(x -> varSets.get(x).size() > 0).toArray(VariableType[]::new);
+
+ TIntIterator setIterator = varSets.get(types[0]).iterator();
+
+ @Override
+ public TypedVariable next() {
+ if (!setIterator.hasNext()) {
+ ++type_index;
+ if (type_index < types.length) {
+ setIterator = varSets.get(types[type_index]).iterator();
+ } else {
+ throw new NoSuchElementException();
+ }
+ }
+ int id = setIterator.next();
+ int size = cardinalities.get(id);
+ return new TypedVariable(id, size, types[type_index]);
+ }
+
+ @Override
+ public boolean hasNext() {
+ boolean setn = setIterator.hasNext();
+ boolean typn = type_index < types.length - 1;
+
+ return setn || typn; // either set has more or we're not in the last set
+ }
+ };
+ };
+ }
+
+ @Override
+ public void removeVariable(int variable) {
+ super.removeVariable(variable);
+
+ // remove from sets (if needed)
+ @SuppressWarnings("unused")
+ boolean changed = this.varSets.get(VariableType.ENDOGENOUS).remove(variable)
+ || this.varSets.get(VariableType.EXOGENOUS).remove(variable)
+ || this.varSets.get(VariableType.EXTRA).remove(variable);
+ }
+
+
+
+
+}
diff --git a/src/main/java/ch/idsia/crema/model/causal/TypedVariable.java b/src/main/java/ch/idsia/crema/model/causal/TypedVariable.java
new file mode 100644
index 00000000..17cc55e5
--- /dev/null
+++ b/src/main/java/ch/idsia/crema/model/causal/TypedVariable.java
@@ -0,0 +1,17 @@
+package ch.idsia.crema.model.causal;
+
+import ch.idsia.crema.core.Variable;
+import ch.idsia.crema.model.causal.SCM.VariableType;
+
+public class TypedVariable extends Variable {
+ VariableType type;
+
+ public TypedVariable(int label, int cardinality, VariableType type) {
+ super(label, cardinality);
+ this.type = type;
+ }
+
+ public VariableType getType() {
+ return type;
+ }
+}
diff --git a/src/main/java/ch/idsia/crema/model/causal/WorldModel.java b/src/main/java/ch/idsia/crema/model/causal/WorldModel.java
new file mode 100644
index 00000000..64c37dee
--- /dev/null
+++ b/src/main/java/ch/idsia/crema/model/causal/WorldModel.java
@@ -0,0 +1,70 @@
+package ch.idsia.crema.model.causal;
+
+import ch.idsia.crema.core.Variable;
+
+public interface WorldModel {
+ /**
+ * Get the global SCM model
+ * @return SCM the global model or null if no model have been added
+ */
+ public SCM get();
+
+ /**
+ * Add a model to the mapping assuming exogenous ids match.
+ *
+ * @param model the {@link SCM} model to be added
+ * @return int the id of the added worlds
+ */
+ public int add(SCM model);
+
+ /**
+ * Translate a local variable of world wid to the global id
+ *
+ * @param variable the local variable
+ * @param wid the id of the source world
+ * @return int the global id of the variable
+ */
+ public int toGlobal(int variable, int wid);
+
+ /**
+ * Translate a global id to a local one.
+ *
+ * @param variable the global variable id
+ * @return int a local variable id
+ */
+ public int fromGlobal(int variable);
+
+ /**
+ * Get the world id associated to the specified global variable id.
+ * For shared variables (i.e. exogenous) the method return -1;
+ *
+ * @param variable the global variable id
+ * @return int the world id of the specified global variable or -1 it exogenous
+ */
+ public int worldIdOf(int variable);
+
+ /**
+ * Get the world associated to the specified global variable id.
+ * The method returns null for shared (i.e. exogenous) variables.
+ *
+ * @param variable the global variable id
+ * @return SCM the source Structural Causal Model or null if variable is exogenous
+ */
+ public SCM worldOf(int variable);
+
+ default int[] toGlobal(int[] variables, int wid) {
+ int[] result = new int[variables.length];
+ for (int i = 0; i < result.length; ++i) {
+ result[i] = toGlobal(variables[i], wid);
+ }
+ return result;
+ }
+
+ default int[] fromGlobal(int[] variables) {
+ int[] result = new int[variables.length];
+ for (int i = 0; i < result.length; ++i) {
+ result[i] = fromGlobal(variables[i]);
+ }
+ return result;
+ }
+}
diff --git a/src/main/java/ch/idsia/crema/model/causal/mapping/TreeMapping.java b/src/main/java/ch/idsia/crema/model/causal/mapping/TreeMapping.java
new file mode 100644
index 00000000..8065bbc4
--- /dev/null
+++ b/src/main/java/ch/idsia/crema/model/causal/mapping/TreeMapping.java
@@ -0,0 +1,214 @@
+package ch.idsia.crema.model.causal.mapping;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.commons.lang3.tuple.Pair;
+
+import ch.idsia.crema.core.Strides;
+import ch.idsia.crema.factor.bayesian.BayesianFactor;
+import ch.idsia.crema.factor.bayesian.BayesianFactorFactory;
+import ch.idsia.crema.model.causal.SCM;
+import ch.idsia.crema.model.causal.SCM.VariableType;
+import ch.idsia.crema.model.causal.WorldModel;
+import ch.idsia.crema.utility.ArraysUtil;
+import gnu.trove.map.TIntIntMap;
+import gnu.trove.map.TIntObjectMap;
+import gnu.trove.map.hash.TIntIntHashMap;
+import gnu.trove.map.hash.TIntObjectHashMap;
+
+public class TreeMapping implements WorldModel {
+
+ /**
+ * Mapping from source to target for each world id
+ * Key is world id, value is mapping between source and global id
+ */
+ private TIntObjectMap toGlobal;
+
+ /**
+ * A map from a global id to a pair
+ */
+ private TIntObjectMap> fromGlobal;
+
+ /**
+ * A map from global exogenous id to local exogenous id
+ */
+ private TIntIntMap globalToLocalExogenous;
+ private TIntIntMap localToGlobalExogenous;
+
+ /**
+ * the global model, initialized at the first add.
+ */
+ private SCM global;
+
+ /**
+ * list of source models who's id are the index in this list
+ */
+ private List worlds;
+
+ /**
+ * Create a new Tree Mapping object.
+ */
+ public TreeMapping() {
+ this.toGlobal = new TIntObjectHashMap();
+ this.fromGlobal = new TIntObjectHashMap>();
+ this.globalToLocalExogenous = new TIntIntHashMap();
+ this.localToGlobalExogenous= new TIntIntHashMap();
+ this.worlds = new ArrayList();
+ }
+
+
+ @Override
+ public SCM get() {
+ return global;
+ }
+
+ /**
+ * Store the mapping for the given world
+ * @param wid world id
+ * @param translate a map from local to global
+ */
+ protected void saveMapping(int wid, TIntIntMap translate) {
+
+ toGlobal.put(wid, translate);
+
+ var iter = translate.iterator();
+ while(iter.hasNext()) {
+ iter.advance();
+ var source = iter.key();
+ var target = iter.value();
+ if (localToGlobalExogenous.containsKey(source)) continue;
+
+ fromGlobal.put(target, Pair.of(wid, source));
+ }
+ }
+
+ /**
+ * Rename a list of source variable to the target
+ * @param ids
+ * @param translate
+ * @return
+ */
+ private int[] rename(int[] ids, TIntIntMap translate) {
+ int[] target = new int[ids.length];
+ for (int i = 0; i < ids.length; ++i) {
+ target[i] = translate.get(ids[i]);
+ }
+ return target;
+// return Arrays.stream(ids).map(translate::get).toArray();
+ }
+
+ /**
+ * Convert a bayesian factor to a new variables set
+ * @param factor
+ * @param translate
+ * @return
+ */
+ private BayesianFactor rename(BayesianFactor factor, TIntIntMap translate) {
+ Strides domain = factor.getDomain();
+ int[] newvars = Arrays.stream(domain.getVariables()).map(translate::get).toArray();
+
+ int[] order = ArraysUtil.order(newvars);
+ double[] source_data = factor.getData();
+
+// int[] source = domain.getVariables().clone();
+// source = ArraysUtil.at(source, order);
+//
+// var iterator = domain.getReorderedIterator(source);
+// int tid = 0;
+//
+// double[] target = new double[domain.getCombinations()];
+//
+// while(iterator.hasNext()) {
+// target[tid++] = source_data[iterator.next()];
+// }
+
+ int[] target_vars = ArraysUtil.at(newvars, order);
+ int[] target_size = ArraysUtil.at(domain.getSizes(), order);
+
+ return BayesianFactorFactory.factory().domain(target_vars, target_size, true).data(newvars, source_data).get();
+ }
+
+ /**
+ * Connect a new model to the global one.
+ */
+ @Override
+ public int add(SCM model) {
+ int wid = worlds.size();
+ worlds.add(model);
+
+ if (global == null) {
+ global = new SCM();
+ }
+
+ TIntIntMap translate = new TIntIntHashMap();
+ for (var source : model.variables()) {
+ if (source.getType() == VariableType.EXOGENOUS) {
+ if (!localToGlobalExogenous.containsKey(source.getLabel())) {
+ int id = global.add(source.getCardinality(), source.getType());
+ localToGlobalExogenous.put(source.getLabel(), id);
+ globalToLocalExogenous.put(id, source.getLabel());
+ translate.put(source.getLabel(), id);
+ } else {
+ int id = localToGlobalExogenous.get(source.getLabel());
+ translate.put(source.getLabel(), id);
+ }
+ } else {
+ int id = global.add(source.getCardinality(), source.getType());
+ translate.put(source.getLabel(), id);
+ }
+ }
+
+
+ for (var source : model.variables()) {
+ if(source.getType() == VariableType.EXOGENOUS) continue;
+ int[] parents = model.getParents(source.getLabel());
+ parents = rename(parents, translate);
+
+ int globalId = translate.get(source.getLabel());
+ global.addParents(globalId, parents);
+ }
+
+ for (var source : model.variables()) {
+ var factor = model.getFactor(source.getLabel());
+ if (factor == null) continue;
+
+ factor = rename(factor, translate);
+ int tid = translate.get(source.getLabel());
+ global.setFactor(tid, factor);
+ }
+
+ saveMapping(wid, translate);
+ return wid;
+ }
+
+
+ @Override
+ public int toGlobal(int variable, int world) {
+ return toGlobal.get(world).get(variable);
+ }
+
+ @Override
+ public int fromGlobal(int variable) {
+ if (globalToLocalExogenous.containsKey(variable)) {
+ return globalToLocalExogenous.get(variable);
+ }
+ var p = fromGlobal.get(variable);
+ return p.getValue();
+ }
+
+ @Override
+ public int worldIdOf(int variable) {
+ if (globalToLocalExogenous.containsKey(variable)) return -1;
+ var p = fromGlobal.get(variable);
+ return p.getKey();
+ }
+
+ @Override
+ public SCM worldOf(int variable) {
+ if (globalToLocalExogenous.containsKey(variable)) return null;
+ return worlds.get(worldIdOf(variable));
+ }
+
+}
diff --git a/src/main/java/ch/idsia/crema/model/graphical/DAGModel.java b/src/main/java/ch/idsia/crema/model/graphical/DAGModel.java
index 8402a7b1..da76a200 100644
--- a/src/main/java/ch/idsia/crema/model/graphical/DAGModel.java
+++ b/src/main/java/ch/idsia/crema/model/graphical/DAGModel.java
@@ -22,6 +22,7 @@
import java.util.Arrays;
import java.util.Collection;
import java.util.function.BiFunction;
+import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
@@ -225,10 +226,19 @@ public void removeParent(int variable, int parent) {
network.removeEdge(parent, variable);
}
+// @Override
+// public void removeParent(int variable, int parent, DomainChange change) {
+// F factor = factors.get(variable);
+// F new_factor = change.remove(factor, parent);
+// if (factor != new_factor)
+// factors.put(variable, new_factor);
+// network.removeEdge(parent, variable);
+// }
+//
@Override
- public void removeParent(int variable, int parent, DomainChange change) {
+ public void removeParent(int variable, int parent, Function change) {
F factor = factors.get(variable);
- F new_factor = change.remove(factor, parent);
+ F new_factor = change.apply(factor);
if (factor != new_factor)
factors.put(variable, new_factor);
network.removeEdge(parent, variable);
@@ -293,6 +303,15 @@ public int[] getLeaves() {
return list.toArray();
}
+
+ @Override
+ public Strides getFullDomain(int variable) {
+ int[] vars = getParents(variable);
+ vars = ArraysUtil.append(vars, variable);
+ Arrays.sort(vars);
+ return new Strides(vars, getSizes(vars));
+ }
+
@Override
public Strides getDomain(int... variables) {
return new Strides(variables, getSizes(variables));
@@ -304,6 +323,10 @@ public void addVariables(int... states) {
}
}
+ /**
+ * Implemented as a sequence of addParent calls.
+ */
+ @Override
public void addParents(int variable, int... parents) {
for (int parent : parents) {
addParent(variable, parent);
@@ -338,6 +361,7 @@ public Collection getFactors() {
}
public Collection getFactors(int... variables) {
+
return IntStream.of(variables).mapToObj(v -> factors.get(v)).collect(Collectors.toList());
}
@@ -363,7 +387,7 @@ public void setFactors(F[] factors) {
*
* @return
*/
- public boolean correctFactorDomains() {
+ public boolean checkFactorsDAGConsistency() {
return IntStream.of(this.getVariables())
.allMatch(v -> Arrays.equals(
ArraysUtil.sort(this.getFactor(v).getDomain().getVariables()),
diff --git a/src/main/java/ch/idsia/crema/model/graphical/GraphicalModel.java b/src/main/java/ch/idsia/crema/model/graphical/GraphicalModel.java
index 47f28ad5..7bb849eb 100644
--- a/src/main/java/ch/idsia/crema/model/graphical/GraphicalModel.java
+++ b/src/main/java/ch/idsia/crema/model/graphical/GraphicalModel.java
@@ -1,5 +1,9 @@
package ch.idsia.crema.model.graphical;
+import java.util.function.BiFunction;
+import java.util.function.Function;
+
+import ch.idsia.crema.core.Strides;
import ch.idsia.crema.factor.GenericFactor;
import ch.idsia.crema.model.Model;
import ch.idsia.crema.model.change.DomainChange;
@@ -8,11 +12,32 @@
// FIXME: #removeParent should accept a lambda
public interface GraphicalModel extends Model {
-
+ /**
+ * Get the full domain associated with the specified variable.
+ * This includes the variable and its parents.
+ *
+ * @param variable the variable who's domain will be returned
+ * @return the variable and its parents as a Sorted Stride object
+ */
+ Strides getFullDomain(int variable);
+
+ /**
+ * Remove a variable's parents
+ *
+ * @param variable
+ * @param parent
+ */
void removeParent(int variable, int parent);
- void removeParent(int variable, int parent, DomainChange change);
+ /**
+ * Remove a variable's parent indicating how to deal with the domain change
+ * @param variable
+ * @param parent
+ * @param change
+ */
+ void removeParent(int variable, int parent, Function change);
+
void addParent(int variable, int parent);
/**
@@ -49,4 +74,6 @@ default void addParents(int k, int[] parent) {
addParent(k, p);
}
}
+
+
}
\ No newline at end of file
diff --git a/src/main/java/ch/idsia/crema/model/io/dot/DetailedDotSerializer.java b/src/main/java/ch/idsia/crema/model/io/dot/DetailedDotSerializer.java
new file mode 100644
index 00000000..acb50981
--- /dev/null
+++ b/src/main/java/ch/idsia/crema/model/io/dot/DetailedDotSerializer.java
@@ -0,0 +1,211 @@
+package ch.idsia.crema.model.io.dot;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.text.DecimalFormat;
+import java.text.NumberFormat;
+import java.util.Arrays;
+import java.util.Map;
+import java.util.Set;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+import ch.idsia.crema.core.Strides;
+import ch.idsia.crema.data.DoubleTable;
+import ch.idsia.crema.factor.bayesian.BayesianFactor;
+import ch.idsia.crema.model.causal.SCM;
+import ch.idsia.crema.model.graphical.BayesianNetwork;
+import ch.idsia.crema.model.graphical.GraphicalModel;
+import ch.idsia.crema.utility.IndexIterator;
+
+public class DetailedDotSerializer {
+
+
+
+
+ protected static Function nodeName(GraphicalModel model) {
+ if (model instanceof BayesianNetwork) {
+ return (node) -> {
+ return "X" + node + "";
+ };
+ } else {
+ final SCM sm = (SCM) model;
+ return (node) -> {
+ String label;
+ if (sm.isEndogenous(node)) {
+ label = "X";
+ } else if (sm.isExogenous(node)) {
+ label = "U";
+ } else {
+ label = "W";
+ }
+
+ return label + "" + node + "";
+ };
+ }
+ }
+
+
+ protected String apply(DoubleTable table, Function nodeName) {
+ NumberFormat formatter = new DecimalFormat("0.###");
+
+ StringBuilder builder = new StringBuilder();
+ builder.append("
").append("
");
+
+ for (int col : table.getColumns()) {
+ builder.append("
").append(nodeName.apply(col))
+ .append("
");
+ }
+ builder.append("
W
");
+
+ // contents
+ var iter = table.iterator();
+ while (iter.hasNext()) {
+ var row = iter.next();
+ builder.append("
");
+
+ for (int p : conditioning.getVariables()) {
+ builder.append("
");
+ builder.append("
").append(nodeName.apply(p)).append("
");
+
+ int stride = conditioning.getStride(p);
+ int size = conditioning.getCardinality(p);
+ for (int s = 0; s < csize; s += stride) {
+ int state = (s / stride) % size;
+ builder.append("
");
+ builder.append(state).append("
");
+ }
+ builder.append("
");
+ }
+
+ int stride = domain.getStride(i);
+ int states = domain.getCardinality(i);
+ for (int state = 0; state < states; ++state) {
+ builder.append("
");
+ if (state == 0) {
+ builder.append("
")
+ .append(nodeName.apply(i)).append("
");
+ }
+
+ builder.append("
").append(state).append("
");
+
+ IndexIterator iter = factor.getDomain().getIterator(conditioning);
+ while (iter.hasNext()) {
+ int index = iter.next();
+ boolean h = highlight != null && highlight.containsKey(i) && highlight.get(i).contains(index);
+ double value = factor.getValueAt(index + state * stride);
+ if (h)
+ builder.append("
");
+ else
+ builder.append("
");
+ builder.append(formatter.format(value));
+ if (h)
+ builder.append("");
+ builder.append("
");
+ }
+ builder.append("
");
+ }
+
+ builder.append("
>];\n");
+
+ } else {
+ builder.append("N").append(i).append("[shape=\"circle\" label=<").append(nodeName.apply(i)).append(">];\n");
+ }
+
+ for (int parent : parents) {
+ arcs.append('N').append(parent).append(" -> N").append(i).append(";\n");
+ }
+ }
+
+ builder.append(arcs);
+
+ if (record.getTitle() != null)
+ builder.append("labelloc=\"t\"\nlabel=\"").append(record.getTitle()).append(" ").append(record.getRunId()).append("\"\n");
+
+ builder.append("}");
+ return builder.toString();
+
+ }
+
+
+
+
+ public static void saveModel(String filename, Info r) {
+ try {
+ DetailedDotSerializer serializer = new DetailedDotSerializer();
+
+ File f = File.createTempFile(filename, ".dot");
+
+ String file = serializer.apply(r);
+ Files.writeString(Path.of(f.getAbsolutePath()), file);
+ ProcessBuilder b = new ProcessBuilder("/opt/homebrew/bin/dot", "-Tpng", "-o", filename,
+ f.getAbsolutePath());
+ Process p = b.start();
+ p.waitFor();
+ } catch (IOException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ } catch (InterruptedException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
+
+ }
+}
diff --git a/src/main/java/ch/idsia/crema/model/io/dot/Info.java b/src/main/java/ch/idsia/crema/model/io/dot/Info.java
new file mode 100644
index 00000000..b0473c49
--- /dev/null
+++ b/src/main/java/ch/idsia/crema/model/io/dot/Info.java
@@ -0,0 +1,134 @@
+package ch.idsia.crema.model.io.dot;
+
+import java.util.Map;
+import java.util.Set;
+import java.util.function.Function;
+
+import ch.idsia.crema.data.DoubleTable;
+import ch.idsia.crema.factor.bayesian.BayesianFactor;
+import ch.idsia.crema.model.graphical.GraphicalModel;
+
+
+public class Info {
+ private GraphicalModel model;
+ private DoubleTable data;
+ private String modelName;
+ private String title;
+
+ private int runId;
+ private int iterations;
+ private int pscmId;
+ private int pscmIterations;
+
+ private Function nodeName;
+ private Map> highlight;
+
+ public Info model(GraphicalModel model) {
+ this.model = model;
+ return this;
+ }
+
+ public GraphicalModel getModel() {
+ return model;
+ }
+
+ public DoubleTable getData() {
+ return data;
+ }
+
+ public Info data(DoubleTable data) {
+ this.data = data;
+ return this;
+ }
+
+ public String getModelName() {
+ return modelName;
+ }
+
+ public Info modelName(String modelName) {
+ this.modelName = modelName;
+ return this;
+ }
+
+ public String getTitle() {
+ return title;
+ }
+
+ public Info title(String title) {
+ this.title = title;
+ return this;
+ }
+
+ public Function getNodeName() {
+ return nodeName;
+ }
+
+ public Info nodeName(Function nodeName) {
+ this.nodeName = nodeName;
+ return this;
+ }
+
+ public Map> getHighlight() {
+ return highlight;
+ }
+
+ public Info highlight(Map> highlight) {
+ this.highlight = highlight;
+ return this;
+ }
+
+ public Info runId(int runid) {
+ this.runId = runid;
+ return this;
+ }
+
+ public Info iterations(int iterations) {
+ this.iterations = iterations;
+ return this;
+ }
+
+ public Info PSCMId(int pscmrun) {
+ this.pscmId = pscmrun;
+ return this;
+ }
+
+ public Info PSCMIterations(int piter) {
+ this.pscmIterations = piter;
+ return this;
+ }
+
+ public int getRunId() {
+ return runId;
+ }
+ public int getIterations() {
+ return iterations;
+ }
+
+ public int getPSCMId() {
+ return pscmId;
+ }
+
+ public int getPSCMIterations() {
+ return pscmIterations;
+ }
+
+
+ public Info(GraphicalModel model, DoubleTable data, String modelName, String title,
+ Function nodeName, Map> highlight, int run, int iter, int pscm,
+ int pscmiter) {
+ super();
+ this.model = model;
+ this.data = data;
+ this.modelName = modelName;
+ this.title = title;
+ this.nodeName = nodeName;
+ this.highlight = highlight;
+ this.runId = run;
+ this.iterations = iter;
+ this.pscmId = pscm;
+ this.pscmIterations = pscmiter;
+ }
+
+ public Info() {
+ }
+}
diff --git a/src/main/java/ch/idsia/crema/model/io/uai/NetUAIWriter.java b/src/main/java/ch/idsia/crema/model/io/uai/NetUAIWriter.java
index 46a3f17e..7b592996 100644
--- a/src/main/java/ch/idsia/crema/model/io/uai/NetUAIWriter.java
+++ b/src/main/java/ch/idsia/crema/model/io/uai/NetUAIWriter.java
@@ -15,7 +15,7 @@ public NetUAIWriter(T target, String filename) {
@Override
protected void sanityChecks() {
// Check model consistency
- if (!target.correctFactorDomains())
+ if (!target.checkFactorsDAGConsistency())
throw new IllegalArgumentException("Inconsistent model");
}
diff --git a/src/main/java/ch/idsia/crema/preprocess/Do.java b/src/main/java/ch/idsia/crema/preprocess/Do.java
new file mode 100644
index 00000000..ef285c1a
--- /dev/null
+++ b/src/main/java/ch/idsia/crema/preprocess/Do.java
@@ -0,0 +1,34 @@
+package ch.idsia.crema.preprocess;
+
+import ch.idsia.crema.factor.OperableFactor;
+import ch.idsia.crema.model.graphical.GraphicalModel;
+import gnu.trove.map.TIntIntMap;
+
+/**
+ * Network surgery for do operations.
+ *
+ * @param type of the factor
+ * @param type of the Graphical Model
+ */
+public class Do , M extends GraphicalModel>
+implements ConverterEvidence {
+
+ public M execute(M model, TIntIntMap dos) {
+ M copy = (M) model.copy();
+
+ int[] keys = dos.keys();
+ for (int key : keys) {
+
+ // select the correct part of the factors by removing a child
+ // via domain changer
+ for (int child : model.getChildren(key)) {
+ model.removeParent(child, key, (f) -> f.filter(key, dos.get(key)));
+ }
+
+ // completely remove the var now.
+ model.removeVariable(key);
+ }
+
+ return copy;
+ }
+}
diff --git a/src/main/java/ch/idsia/crema/preprocess/CutObserved.java b/src/main/java/ch/idsia/crema/preprocess/Observe.java
similarity index 82%
rename from src/main/java/ch/idsia/crema/preprocess/CutObserved.java
rename to src/main/java/ch/idsia/crema/preprocess/Observe.java
index edd010e7..d08cc169 100644
--- a/src/main/java/ch/idsia/crema/preprocess/CutObserved.java
+++ b/src/main/java/ch/idsia/crema/preprocess/Observe.java
@@ -11,7 +11,7 @@
*
* @author huber
*/
-public class CutObserved> implements TransformerEvidence>,
+public class Observe> implements TransformerEvidence>,
PreprocessorEvidence> {
/**
@@ -34,12 +34,9 @@ public void executeInPlace(GraphicalModel model, TIntIntMap evidence) {
final int state = iterator.value();
for (int variable : model.getChildren(observed)) {
- model.removeParent(variable, observed, new NullChange<>() {
- @Override
- public F remove(F factor, int variable) {
- // probably need to check this earlier
- return factor.filter(observed, state);
- }
+ model.removeParent(variable, observed, (factor) -> {
+ // probably need to check this earlier
+ return factor.filter(observed, state);
});
}
}
diff --git a/src/main/java/ch/idsia/crema/utility/ArraysMath.java b/src/main/java/ch/idsia/crema/utility/ArraysMath.java
new file mode 100644
index 00000000..0161c6cf
--- /dev/null
+++ b/src/main/java/ch/idsia/crema/utility/ArraysMath.java
@@ -0,0 +1,128 @@
+package ch.idsia.crema.utility;
+
+import org.apache.commons.math3.util.FastMath;
+
+public class ArraysMath {
+
+ public static double sum(double[] values) {
+ double total = 0;
+ for (int i = 0; i < values.length; ++i) {
+ total += values[i];
+ }
+ return total;
+ }
+
+ public static int sum(int[] values) {
+ int total = 0;
+ for (int i = 0; i < values.length; ++i) {
+ total += values[i];
+ }
+ return total;
+ }
+
+ public static long sum(long[] values) {
+ long total = 0;
+ for (int i = 0; i < values.length; ++i) {
+ total += values[i];
+ }
+ return total;
+ }
+
+
+ public static int max(int[] array) {
+ int m = Integer.MIN_VALUE;
+ for (int i = 0; i < array.length;++i) {
+ m = Math.max(m, array[i]);
+ }
+ return m;
+ }
+
+ public static int min(int[] array) {
+ int m = Integer.MAX_VALUE;
+ for (int i = 0; i < array.length;++i) {
+ m = Math.min(m, array[i]);
+ }
+ return m;
+ }
+
+ public static double max(double[] array) {
+ double m = Double.NEGATIVE_INFINITY;
+ for (int i = 0; i < array.length;++i) {
+ m = Math.max(m, array[i]);
+ }
+ return m;
+ }
+
+ public static double min(double[] array) {
+ double m = Double.POSITIVE_INFINITY;
+ for (int i = 0; i < array.length;++i) {
+ m = Math.min(m, array[i]);
+ }
+ return m;
+ }
+
+
+ public static double mean(int[] array) {
+ double m = 0;
+ // consider using https://en.wikipedia.org/wiki/Kahan_summation_algorithm
+ for (int i = 0; i < array.length;++i) {
+ m += array[i];
+ }
+ return m / (double)array.length;
+ }
+
+
+
+ public static double mean(long[] array) {
+ double m = 0;
+ // consider using https://en.wikipedia.org/wiki/Kahan_summation_algorithm
+ for (int i = 0; i < array.length;++i) {
+ m += array[i];
+ }
+ return m / (double)array.length;
+ }
+
+
+ public static double mean(double[] array) {
+ double m = 0;
+ // consider using https://en.wikipedia.org/wiki/Kahan_summation_algorithm
+ for (int i = 0; i < array.length;++i) {
+ m += array[i];
+ }
+ return m / (double)array.length;
+ }
+
+
+ public static double sd(int[] array, double ddof) {
+ double m = mean(array);
+ double s = 0;
+ for (int i = 0; i < array.length;++i) {
+ double d = (double)array[i] - m;
+ s += d * d;
+ }
+ return FastMath.sqrt(s / ((double)array.length - ddof));
+ }
+
+
+
+ public static double sd(long[] array, double ddof) {
+ double m = mean(array);
+ double s = 0;
+ for (int i = 0; i < array.length;++i) {
+ double d = (double)array[i] - m;
+ s += d * d;
+ }
+ return FastMath.sqrt(s / ((double)array.length - ddof));
+ }
+
+ public static double sd(double[] array, double ddof) {
+ double m = mean(array);
+ double s = 0;
+ for (int i = 0; i < array.length;++i) {
+ double d = array[i] - m;
+ s += d * d;
+ }
+ return FastMath.sqrt(s / ((double)array.length - ddof));
+ }
+
+}
diff --git a/src/main/java/ch/idsia/crema/utility/ArraysUtil.java b/src/main/java/ch/idsia/crema/utility/ArraysUtil.java
index 4b90d166..bbd40638 100644
--- a/src/main/java/ch/idsia/crema/utility/ArraysUtil.java
+++ b/src/main/java/ch/idsia/crema/utility/ArraysUtil.java
@@ -5,6 +5,11 @@
import com.google.common.primitives.Doubles;
import com.google.common.primitives.Ints;
import gnu.trove.list.array.TIntArrayList;
+import gnu.trove.set.TDoubleSet;
+import gnu.trove.set.TIntSet;
+import gnu.trove.set.hash.TDoubleHashSet;
+import gnu.trove.set.hash.TIntHashSet;
+
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.util.Precision;
@@ -121,6 +126,9 @@ public static int[] sort(int[] base) {
return copy;
}
+ /**
+ * A helper class for the order functions
+ */
private static class X {
public final int pos;
public final int val;
@@ -149,17 +157,24 @@ public static int[] order(int[] data, final IntComparator comparator) {
return positions;
}
+ /**
+ * get the ordering of the data. can be used to sort multiple arrays using a
+ * reference one i.e. get variables order and sort sizes.
+ *
+ * @param data
+ * @return
+ */
public static int[] order(int[] data) {
- List internal = new ArrayList<>(data.length);
+ X[] internal = new X[data.length];
for (int i = 0; i < data.length; ++i) {
- internal.add(new X(i, data[i]));
+ internal[i] = new X(i, data[i]);
}
- internal.sort(Comparator.comparingInt(o -> o.val));
+ Arrays.sort(internal, Comparator.comparingInt(o -> o.val));
int[] positions = new int[data.length];
for (int i = 0; i < data.length; ++i) {
- X x = internal.get(i);
+ X x = internal[i];
positions[i] = x.pos;
data[i] = x.val;
}
@@ -167,6 +182,22 @@ public static int[] order(int[] data) {
return positions;
}
+ /**
+ * Helper method to index multiple items in one shot.
+ *
+ * @param values
+ * @param indices
+ * @return
+ */
+ public static int[] at(int[] from, int[] indices) {
+ // return IntStream.of(indices).map(v -> values[v]).toArray();
+ int[] x = new int[indices.length];
+ for (int i = 0; i < indices.length; ++i) {
+ x[i] = from[indices[i]];
+ }
+ return x;
+ }
+
public static double[][] deepClone(double[][] data) {
double[][] result = new double[data.length][];
for (int i = 0; i < data.length; ++i) {
@@ -184,7 +215,8 @@ public static double[][][] deepClone(double[][][] data) {
}
/**
- * Convert an array in log-space using {@link FastMath#log(double)}. Creates a new array.
+ * Convert an array in log-space using {@link FastMath#log(double)}. Creates a
+ * new array.
*
* @param data input data
* @return the input data in log-space.
@@ -198,7 +230,8 @@ public static double[] log(double[] data) {
}
/**
- * Convert an array of array in log-space using {@link FastMath#log(double)}. Creates a new array.
+ * Convert an array of array in log-space using {@link FastMath#log(double)}.
+ * Creates a new array.
*
* @param data input data
* @return the input data in log-space.
@@ -212,7 +245,8 @@ public static double[][] log(double[][] data) {
}
/**
- * Convert an array in log-space using {@link FastMath#log1p(double)}. Creates a new array.
+ * Convert an array in log-space using {@link FastMath#log1p(double)}. Creates a
+ * new array.
*
* @param data input data
* @return the input data in log-space.
@@ -226,7 +260,8 @@ public static double[] log1p(double[] data) {
}
/**
- * Convert an array from log-space to normal space using {@link FastMath#exp(double)}. Creates a new array.
+ * Convert an array from log-space to normal space using
+ * {@link FastMath#exp(double)}. Creates a new array.
*
* @param data input data
* @return the input data in log-space.
@@ -240,7 +275,8 @@ public static double[] exp(double[] data) {
}
/**
- * Convert an array of array from log-space to normal space using {@link FastMath#exp(double)}. Creates a new array.
+ * Convert an array of array from log-space to normal space using
+ * {@link FastMath#exp(double)}. Creates a new array.
*
* @param data input data
* @return the input data in log-space.
@@ -349,18 +385,19 @@ public static int[] removeFromSortedArray(int[] array, int element) {
* containing the specified element. The original array is returned if the
* element was already part of the array.
*
- * Expects a sorted array! Behaviours is rather unpredictable if array is not sorted.
+ * Expects a sorted array! Behaviours is rather unpredictable if array is not
+ * sorted.
*
* @param array the sorted array
* @param element the item to be added to the array
* @return a new array containing the element or the original one if element is
- * already present
+ * already present
*/
public static int[] addToSortedArray(int[] array, int element) {
if (array == null || array.length == 0) {
// when no items in the array then return a new array with only the
// item
- return new int[]{element};
+ return new int[] { element };
}
// look for existing links
@@ -472,15 +509,113 @@ public static int[] union(int[] arr1, int[] arr2) {
return arr_union;
}
+ public static int[] union_sorted_set(int[] arr1, int[] arr2) {
+ final int s1 = arr1.length;
+ final int s2 = arr2.length;
+
+ // size the target arrays assuming no overlap
+ final int max = s1 + s2;
+ int[] arr_union = new int[max];
+
+ // (pt1) c1 and c2 are the positions in the two domains
+ int c1 = 0;
+ int c2 = 0;
+
+ int t = 0;
+ int last = 0;
+ while (c1 < s1 && c2 < s2) {
+ int v1 = arr1[c1];
+ int v2 = arr2[c2];
+ if (t > 0) {
+ if (v1 == v2 && last == v1) {
+ ++c1;
+ ++c2;
+ continue;
+ } else if (v1 == last) {
+ ++c1;
+ continue;
+ } else if (v2 == last) {
+ ++c2;
+ continue;
+ }
+ }
+
+ if (v1 < v2) {
+ last = arr_union[t] = v1;
+ ++c1;
+ } else if (v1 > v2) {
+ last = arr_union[t] = v2;
+ ++c2;
+ } else {
+ last = arr_union[t] = v1;
+ ++c1;
+ ++c2;
+ }
+
+ ++t;
+ }
+
+ // (pt2) check if there is one domain not completely copied that can be
+ // moved over in bulk.
+ for (; c1 < s1; ++c1) {
+ int a = arr1[c1];
+ if (t == 0 || a != last)
+ last = arr_union[t++] = a;
+ }
+ for (; c2 < s2; ++c2) {
+ int a = arr2[c2];
+ if (t == 0 || a != last)
+ last = arr_union[t++] = a;
+
+ }
+
+ // fix array sizes if there was overlap (we assumed no overlap while sizing)
+ if (t < max) {
+ arr_union = Arrays.copyOf(arr_union, t);
+ }
+ return arr_union;
+ }
+
/**
- * Find the sorted difference of two non-sorted integer arrays.
- *
+ * Find the sorted difference of two non-sorted integer arrays. in arr1 but not
+ * in arr2
+ *
* @param arr1 the first array
* @param arr2 the second array
* @return an array intersection of the first two
*/
+// public static int[] difference(int[] arr1, int[] arr2) {
+// return IntStream.of(arr1).filter(y -> IntStream.of(arr2).noneMatch(x -> x == y)).toArray();
+// }
public static int[] difference(int[] arr1, int[] arr2) {
- return IntStream.of(arr1).filter(y -> IntStream.of(arr2).noneMatch(x -> x == y)).toArray();
+ TIntSet a2 = new TIntHashSet(arr2);
+
+ int[] tmp = new int[arr1.length];
+ int used = 0;
+
+ for (int o : arr1) {
+ if (!a2.contains(o))
+ tmp[used++] = o;
+ }
+
+ int[] target = new int[used];
+ System.arraycopy(tmp, 0, target, 0, used);
+ Arrays.sort(target);
+ return target;
+ }
+
+ public static int[] differenceSet(int[] arr1, int[] arr2) {
+ TIntSet a2 = new TIntHashSet(arr2);
+ TIntSet target = new TIntHashSet(arr1.length);
+ // target.removeAll(a2);
+ for (int o : arr1) {
+ if (!a2.contains(o))
+ target.add(o);
+ }
+
+ int[] ok = target.toArray();
+ Arrays.sort(ok);
+ return ok;
}
/**
@@ -491,17 +626,19 @@ public static int[] difference(int[] arr1, int[] arr2) {
* @return symetric difference of both arrays
*/
public static int[] symmetricDiff(int[] arr1, int[] arr2) {
- return unionSet(difference(arr1, arr2), difference(arr2, arr1));
+ return union_sorted_set(difference(arr1, arr2), difference(arr2, arr1));
}
/**
* @param arr1 first array
* @param arr2 second array
- * @return an array which is the union of the two arrays without the common elements
+ * @return an array which is the union of the two arrays without the common
+ * elements
*/
public static int[] outersection(int[] arr1, int[] arr2) {
final int[] intersection = intersection(arr1, arr2);
- return unionSet(difference(arr1, intersection), difference(arr2, intersection));
+ // difference returns a sorted set
+ return union_sorted_set(difference(arr1, intersection), difference(arr2, intersection));
}
/**
@@ -515,6 +652,14 @@ public static int[] intersection(int[] arr1, int[] arr2) {
return IntStream.of(arr1).filter(y -> IntStream.of(arr2).anyMatch(x -> x == y)).toArray();
}
+ public static int[] intersection2(int[] arr1, int[] arr2) {
+ TIntSet a1 = new TIntHashSet(arr1);
+ a1.retainAll(arr2);
+ int[] a = a1.toArray();
+ Arrays.sort(a);
+ return a;
+ }
+
/**
* Find the sorted intersection of two sorted integer arrays.
*
@@ -564,10 +709,18 @@ public static int[] intersectionSorted(int[] arr1, int[] arr2) {
* @param arr2
* @return
*/
- public static int[] unionSet(int[] arr1, int[] arr2) {
+ public static int[] unionSet2(int[] arr1, int[] arr2) {
return Ints.toArray(ImmutableSet.copyOf(Ints.asList(Ints.concat(arr1, arr2))));
}
+ public static int[] union_unsorted_set(int[] arr1, int[] arr2) {
+ TIntSet set = new TIntHashSet(arr1);
+ set.addAll(arr2);
+ int[] res = set.toArray();
+ Arrays.sort(res);
+ return res;
+ }
+
/**
* Normalize an array by fixing the last value of the array so that it sums up
* to the target value.
@@ -627,7 +780,7 @@ public static int indexOf(int needle, int[] haystack) {
public static int[] getShape(double[][] matrix) {
if (Arrays.stream(matrix).map(v -> v.length).distinct().count() != 1)
throw new IllegalArgumentException("ERROR: nested vectors do not have the same length");
- return new int[]{matrix.length, matrix[0].length};
+ return new int[] { matrix.length, matrix[0].length };
}
/**
@@ -640,7 +793,7 @@ public static int[] getShape(double[][] matrix) {
public static int[] getShape(int[][] matrix) {
if (Arrays.stream(matrix).map(v -> v.length).distinct().count() != 1)
throw new IllegalArgumentException("ERROR: nested vectors do not have the same length");
- return new int[]{matrix.length, matrix[0].length};
+ return new int[] { matrix.length, matrix[0].length };
}
/**
@@ -685,7 +838,7 @@ public static int[][] transpose(int[][] original) {
public static double[][] reshape2d(double[] vector, int... shape) {
if (shape.length == 1)
- shape = new int[]{shape[0], vector.length / shape[0]};
+ shape = new int[] { shape[0], vector.length / shape[0] };
if (shape[0] * shape[1] != vector.length)
throw new IllegalArgumentException("ERROR: incompatible shapes");
@@ -776,7 +929,7 @@ public static double[] changeEndian(double[] data, int[] oldSizes) {
* @return
*/
public static double[][] enumerate(double[] vect, int start) {
- return IntStream.range(start, vect.length + start).mapToObj(i -> new double[]{i, vect[i - start]})
+ return IntStream.range(start, vect.length + start).mapToObj(i -> new double[] { i, vect[i - start] })
.toArray(double[][]::new);
}
@@ -797,8 +950,14 @@ public static double[][] enumerate(double[] vect) {
* @param arr
* @return
*/
+// public static int[] unique(int[] arr) {
+// return Ints.toArray(ImmutableSet.copyOf(Ints.asList(arr)));
+// }
public static int[] unique(int[] arr) {
- return Ints.toArray(ImmutableSet.copyOf(Ints.asList(arr)));
+ TIntSet set = new TIntHashSet(arr);
+ int[] a = set.toArray();
+ Arrays.sort(a);
+ return a;
}
/**
@@ -807,10 +966,15 @@ public static int[] unique(int[] arr) {
* @param arr
* @return
*/
+// public static double[] unique2(double[] arr) {
+// return Doubles.toArray(ImmutableSet.copyOf(Doubles.asList(arr)));
+// }
public static double[] unique(double[] arr) {
- return Doubles.toArray(ImmutableSet.copyOf(Doubles.asList(arr)));
+ TDoubleSet set = new TDoubleHashSet(arr);
+ double[] a = set.toArray();
+ Arrays.sort(a);
+ return a;
}
-
/**
* Round all the values in a vector with a number of decimals.
*
@@ -1038,7 +1202,7 @@ private static int ndimWithClass(Class array) {
* @param a
* @return
*/
- @SuppressWarnings({"rawtypes", "unchecked"})
+ @SuppressWarnings({ "rawtypes", "unchecked" })
public static double[] flattenDoubles(List a) {
int ndims = ArraysUtil.ndim(a.get(0));
if (ndims > 1) {
@@ -1071,7 +1235,7 @@ public static double[] flattenDoubles(List a) {
* @param a
* @return
*/
- @SuppressWarnings({"rawtypes", "unchecked"})
+ @SuppressWarnings({ "rawtypes", "unchecked" })
public static int[] flattenInts(List a) {
int ndims = ArraysUtil.ndim(a.get(0));
if (ndims > 1) {
@@ -1097,17 +1261,6 @@ public static int[] flattenInts(List a) {
return out;
}
- /**
- * Helper method to index multiple items in one shot.
- *
- * @param values
- * @param indices
- * @return
- */
- public static int[] at(int[] values, int[] indices) {
- return IntStream.of(indices).map(v -> values[v]).toArray();
- }
-
/**
* Non inline version of the Apache commons reverse methods
*
@@ -1121,8 +1274,7 @@ public static int[] reverse(int[] is) {
}
public static boolean isOneHot(double[] arr) {
- return where(arr, x -> x == 1).length == 1 &&
- where(arr, x -> x == 0).length == arr.length - 1;
+ return where(arr, x -> x == 1).length == 1 && where(arr, x -> x == 0).length == arr.length - 1;
}
public static double[] replace(double[] arr, Predicate pred, double replacement) {
@@ -1154,7 +1306,7 @@ public static boolean equals(int[] arr1, int[] arr2, boolean sort, boolean uniqu
arr1 = unique(arr1);
arr2 = unique(arr2);
}
- if (sort) {
+ if (sort && !unique) { // unique also sorts!
arr1 = sort(arr1);
arr2 = sort(arr2);
}
@@ -1185,22 +1337,26 @@ public static void shuffle(int[] array) {
}
/**
- * Function that transforms an array such that all the non-zero elements sum exactly 1.0
+ * Function that transforms an array such that all the non-zero elements sum
+ * exactly 1.0
+ *
* @param vector
* @return
*/
- public static double[] normalizeNonZeros(double[] vector){
- if(Arrays.stream(vector).sum()!=1.0) {
+ public static double[] normalizeNonZeros(double[] vector) {
+ if (Arrays.stream(vector).sum() != 1.0) {
// Find the last element different to zero
int idx = vector.length - 1;
- while (idx > 0 && vector[idx] == 0.0) idx--;
+ while (idx > 0 && vector[idx] == 0.0)
+ idx--;
double sum = 0.0;
- for (int i = 0; i < idx; i++) sum += vector[i];
+ for (int i = 0; i < idx; i++)
+ sum += vector[i];
vector[idx] = 1.0 - sum;
}
return vector;
}
-}
+}
\ No newline at end of file
diff --git a/src/test/java/ch/idsia/crema/core/StridesTest.java b/src/test/java/ch/idsia/crema/core/StridesTest.java
index 018ba9c4..6eda7987 100644
--- a/src/test/java/ch/idsia/crema/core/StridesTest.java
+++ b/src/test/java/ch/idsia/crema/core/StridesTest.java
@@ -8,14 +8,6 @@
public class StridesTest {
- @SuppressWarnings("deprecation")
- public void testFilteredIterator() {
- Strides domain = new Strides(new int[]{0, 1, 2, 3}, new int[]{2, 3, 4, 2});
-
- IndexIterator iterator1 = domain.getFiteredIndexIterator(1, 2);
- //int offset = domain.getPartialOffset(vars, states)
- //IndexIterator iterator2 = domain.getIterator(filtered.getVariable()).offset(offset);
- }
@Test
public void testEmpty() {
@@ -221,19 +213,18 @@ public void testStridesStridesInt() {
}
- @SuppressWarnings("deprecation")
@Test
public void testRemoveAt() {
- Strides s1 = new Strides(new int[]{0, 2, 3, 9}, new int[]{2, 2, 3, 7});
+ Strides s1 = new Strides(new int[]{ 0, 2, 3, 9 }, new int[]{2, 2, 3, 7});
- Strides n1 = new Strides(s1, 1);
+ Strides n1 = s1.removeAt(1);
Strides n2 = s1.removeAt(1);
assertArrayEquals(n1.getVariables(), n2.getVariables());
assertArrayEquals(n1.getSizes(), n2.getSizes());
- n1 = new Strides(s1, 1);
- n1 = new Strides(n1, 2);
+ n1 = s1.removeAt(1);
+ n1 = n1.removeAt(2);
n2 = s1.removeAt(1, 3);
assertArrayEquals(n1.getVariables(), n2.getVariables());
@@ -243,6 +234,11 @@ public void testRemoveAt() {
assertArrayEquals(new int[0], n2.getVariables());
assertArrayEquals(new int[0], n2.getSizes());
assertEquals(1, n2.getCombinations());
+
+ n2 = s1.removeAt(3).removeAt(2).removeAt(0).removeAt(0);
+ assertArrayEquals(new int[0], n2.getVariables());
+ assertArrayEquals(new int[0], n2.getSizes());
+ assertEquals(1, n2.getCombinations());
}
@Test
diff --git a/src/test/java/ch/idsia/crema/factor/bayesian/BayesianOperations.java b/src/test/java/ch/idsia/crema/factor/bayesian/BayesianOperations.java
new file mode 100644
index 00000000..3e595dee
--- /dev/null
+++ b/src/test/java/ch/idsia/crema/factor/bayesian/BayesianOperations.java
@@ -0,0 +1,70 @@
+package ch.idsia.crema.factor.bayesian;
+
+import java.util.Arrays;
+import org.junit.jupiter.api.Test;
+import ch.idsia.crema.core.Domain;
+import ch.idsia.crema.core.DomainBuilder;
+import ch.idsia.crema.utility.ArraysMath;
+
+public class BayesianOperations {
+//0.01656,0.0168,0.04284,0.018,0.02496,0.04728,0.03024,0.05856,0.06984 0.32508
+//0.01656,0.0168,0.04284,0.018,0.02496,0.04728,0.03024,0.05856,0.06984
+ @Test
+ public void test() {
+ Domain abc = DomainBuilder.var(0,1,2).size(3,2,2).strides();
+ Domain bc = DomainBuilder.var(1,2).size(2,2).strides();
+ Domain bde = DomainBuilder.var(1,3,4).size(2,3,3).strides();
+ Domain de = DomainBuilder.var(3,4).size(3,3).strides();
+
+ Domain d = DomainBuilder.var(3).size(3).strides();
+ Domain a = DomainBuilder.var(0).size(3).strides();
+
+ Domain ce = DomainBuilder.var(2,4).size(2,3).strides();
+ Domain e = DomainBuilder.var(4).size(3).strides();
+
+ Domain ed = DomainBuilder.var(4,3).size(3,3).strides();
+
+ var pa = BayesianFactorFactory.factory().domain(a).data(new double[]{0.3, 0.4, 0.3}).get();
+ var pabc = BayesianFactorFactory.factory().domain(abc).data(new double[]{
+ 0.1, 0.2, 0.7,
+ 0.3, 0.1, 0.6,
+ 0.3, 0.5, 0.2,
+ 0.5, 0.4, 0.1
+ }).get();
+ var pbde = BayesianFactorFactory.factory().domain(bde).data(new double[] {
+ 0.1, 0.9,
+ 0.3, 0.7,
+ 0.6, 0.4,
+ 0.5, 0.5,
+ 0.8, 0.2,
+ 0.2, 0.8,
+ 0.9, 0.1,
+ 0.4, 0.6,
+ 0.7, 0.3
+ }).get();
+ var pd = BayesianFactorFactory.factory().domain(d).data(new double[]{
+ 0.6, 0.2, 0.2
+ }).get();
+ var pce = BayesianFactorFactory.factory().domain(ce).data(new double[]{
+ 0.9, 0.1,
+ 0.8, 0.2,
+ 0.6, 0.4
+ }).get();
+ var ped = BayesianFactorFactory.factory().domain(de).data(new double[]{
+ 0.1, 0.1, 0.8,
+ 0.2, 0.7, 0.1,
+ 0.7, 0.2, 0.1
+ }).get();
+
+ var xx = pa.combine(pabc).combine(pbde)
+ .marginalize(1)
+ .combine(pce)
+ .marginalize(2)
+ .combine(pd, ped)
+ .marginalize(3);
+ // var xx = pd.combine(ped);
+ double o = ArraysMath.sum(xx.getData());
+ System.out.println(Arrays.toString(xx.getData()));
+ }
+
+}
diff --git a/src/test/java/ch/idsia/crema/factor/credal/vertex/TestVertexFactor.java b/src/test/java/ch/idsia/crema/factor/credal/vertex/TestVertexFactor.java
index 9ff77ad9..28d57d5c 100644
--- a/src/test/java/ch/idsia/crema/factor/credal/vertex/TestVertexFactor.java
+++ b/src/test/java/ch/idsia/crema/factor/credal/vertex/TestVertexFactor.java
@@ -84,7 +84,7 @@ public void testAdd() {
@Test
public void testExpansion() {
// p(X1|X0,X5)
- VertexFactor vf = VertexFactorFactory.factory().domain(Strides.as(1, 4), Strides.as(0, 3).and(5, 2))
+ VertexFactor vf = VertexFactorFactory.factory().domain(Strides.as(1, 4).sort(), Strides.as(0, 3).and(5, 2).sort())
.addVertex(new double[]{0.1, 0.3, 0.2, 0.4}, 0, 1)
.addVertex(new double[]{0.3, 0.2, 0.1, 0.4}, 0, 1)
.addVertex(new double[]{0.1, 0.3, 0.4, 0.2}, 0, 1)
@@ -101,6 +101,8 @@ public void testExpansion() {
.addVertex(new double[]{0.1, 0.2, 0.2, 0.5}, 2, 0)
.get();
+
+
// P(X1|X0,X5) -> diventa estensivo rispetto a X5
VertexFactor v2 = vf.reseparate(Strides.as(0, 3));
double[][][] dta = v2.getData();
@@ -120,7 +122,7 @@ public void testExpansion() {
assertArrayEquals(new double[]{0.6, 0.1, 0.2, 0.1, 0.4, 0.2, 0.2, 0.2}, dta[2][2], 0.00001);
v2 = vf.reseparate(Strides.as(5, 2));
- dta = v2.getData();
+ dta = v2.getData();
assertEquals(2, dta.length);
// we removed var 0 from the separation domain
@@ -129,6 +131,7 @@ public void testExpansion() {
assertEquals(12, dta[1][0].length);
// test some vertices
+
assertArrayEquals(new double[]{0.7, 0.1, 0.6, 0.1, 0.1, 0.1, 0.1, 0.7, 0.2, 0.1, 0.1, 0.1}, dta[0][0],
0.00001);
assertArrayEquals(new double[]{0.7, 0.1, 0.6, 0.1, 0.1, 0.1, 0.1, 0.7, 0.2, 0.1, 0.1, 0.1}, dta[0][0],
diff --git a/src/test/java/ch/idsia/crema/model/BayesianFactorTest.java b/src/test/java/ch/idsia/crema/model/BayesianFactorTest.java
index 32e5aba7..7557e49e 100644
--- a/src/test/java/ch/idsia/crema/model/BayesianFactorTest.java
+++ b/src/test/java/ch/idsia/crema/model/BayesianFactorTest.java
@@ -257,7 +257,7 @@ public void testLogCombineDivision() {
assertEquals(ibf, r);
}
- @Test
+// @Test
public void testtime() {
int a = 0;
int b = 1;
@@ -347,7 +347,7 @@ public void testtime() {
- @Test
+// @Test
public void testLogSpeed() {
int[] vars = new int[]{0, 1, 2};
diff --git a/src/test/java/ch/idsia/crema/model/causal/mapping/TreeMappingTest.java b/src/test/java/ch/idsia/crema/model/causal/mapping/TreeMappingTest.java
new file mode 100644
index 00000000..0d3e7f28
--- /dev/null
+++ b/src/test/java/ch/idsia/crema/model/causal/mapping/TreeMappingTest.java
@@ -0,0 +1,195 @@
+/**
+ *
+ */
+package ch.idsia.crema.model.causal.mapping;
+
+import static org.junit.jupiter.api.Assertions.*;
+
+import java.util.Arrays;
+
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import ch.idsia.crema.core.Domain;
+import ch.idsia.crema.core.Strides;
+import ch.idsia.crema.factor.bayesian.BayesianFactor;
+import ch.idsia.crema.factor.bayesian.BayesianFactorFactory;
+import ch.idsia.crema.model.causal.SCM;
+import ch.idsia.crema.model.causal.SCM.VariableType;
+import ch.idsia.crema.model.io.dot.DetailedDotSerializer;
+import ch.idsia.crema.model.io.dot.Info;
+
+/**
+ *
+ */
+class TreeMappingTest {
+ SCM one;
+ SCM two;
+
+ int v1;
+ int v2;
+ int e1;
+ int e2;
+
+ @BeforeEach
+ void setup() {
+ SCM scm1 = new SCM();
+ v1 = scm1.addEndogenous(2);
+ v2 = scm1.addEndogenous(2);
+ e1 = scm1.addExogenous(2);
+ e2 = scm1.addExogenous(2);
+ scm1.addParents(v2, v1, e2);
+ scm1.addParent(v1, e1);
+
+ var f1 = BayesianFactorFactory.factory().domain(scm1.getFullDomain(v1))
+ .data(new int[] { v1, e1 }, new double[] { 0.1, 0.9, 0.3, 0.7 }).get();
+ scm1.setFactor(v1, f1);
+
+ var f2 = BayesianFactorFactory.factory().domain(scm1.getFullDomain(v2))
+ .data(new int[] { v2, v1, e2 }, new double[] { 0.4, 0.6, 0.2, 0.8, 0.1, 0.9, 0.7, 0.3 }).get();
+ scm1.setFactor(v2, f2);
+
+ var f3 = BayesianFactorFactory.factory().domain(scm1.getFullDomain(e1))
+ .data(new int[] { e1 }, new double[] { 0.25, 0.75 }).get();
+ scm1.setFactor(e1, f3);
+
+ var f4 = BayesianFactorFactory.factory().domain(scm1.getFullDomain(e2))
+ .data(new int[] { e2 }, new double[] { 0.2, 0.8 }).get();
+ scm1.setFactor(e2, f4);
+
+ SCM scm2 = new SCM();
+ scm2.addEndogenous(v1, 2);
+ scm2.addEndogenous(v2, 2);
+ scm2.addExogenous(e1, 2);
+ scm2.addExogenous(e2, 2);
+
+ scm2.addParent(v2, e2);
+ scm2.addParent(v1, e1);
+
+ f1 = BayesianFactorFactory.factory().domain(scm2.getFullDomain(v1))
+ .data(new int[] { v1, e1 }, new double[] { 0.45, 0.55, 0.35, 0.65 }).get();
+ scm2.setFactor(v1, f1);
+
+ f2 = BayesianFactorFactory.factory().domain(scm2.getFullDomain(v2))
+ .data(new int[] { v2, e2 }, new double[] { 0.6, 0.4, 0.15, 0.85 }).get();
+ scm2.setFactor(v2, f2);
+
+ f3 = BayesianFactorFactory.factory().domain(scm2.getFullDomain(e1))
+ .data(new int[] { e1 }, new double[] { 0.55, 0.45 }).get();
+ scm2.setFactor(e1, f3);
+
+ f4 = BayesianFactorFactory.factory().domain(scm2.getFullDomain(e2))
+ .data(new int[] { e2 }, new double[] { 0.65, 0.35 }).get();
+ scm2.setFactor(e2, f4);
+
+ one = scm1;
+ two = scm2;
+ }
+
+ /**
+ * Test method for
+ * {@link ch.idsia.crema.model.causal.mapping.TreeMapping#get()}.
+ */
+ @Test
+ void testGet() {
+ TreeMapping tm = new TreeMapping();
+ assertNull(tm.get());
+ tm.add(one);
+ assertNotNull(tm.get());
+ }
+
+ static String name(TreeMapping tm, int id) {
+ SCM global = tm.get();
+ int lid = tm.fromGlobal(id);
+ if (global.isExogenous(id)) {
+ return "U" + lid;
+ }
+
+ int world = tm.worldIdOf(id);
+ return "X" + lid + "" + world + "";
+ }
+
+ /**
+ * Test method for
+ * {@link ch.idsia.crema.model.causal.mapping.TreeMapping#add(ch.idsia.crema.model.causal.SCM)}.
+ */
+ @Test
+ void testAdd() {
+
+ final TreeMapping tm = new TreeMapping();
+ assertNull(tm.get());
+ tm.add(one);
+ tm.add(two);
+ SCM global = tm.get();
+ // DetailedDotSerializer.saveModel("test.png", new
+ // Info().model(global).nodeName((id)->name(tm, id)));
+ assertNotNull(tm.get());
+ }
+
+ /**
+ * Test method for
+ * {@link ch.idsia.crema.model.causal.mapping.TreeMapping#toGlobal(int, int)}.
+ */
+ @Test
+ void testToGlobal() {
+ final TreeMapping tm = new TreeMapping();
+ int w1 = tm.add(one);
+ int w2 = tm.add(two);
+ SCM global = tm.get();
+
+ for (int v : one.getVariables()) {
+ int g = tm.toGlobal(v, w1);
+ assertEquals(tm.fromGlobal(g), v);
+ }
+ for (int v : two.getVariables()) {
+ int g = tm.toGlobal(v, w2);
+ assertEquals(tm.fromGlobal(g), v);
+ }
+
+ }
+
+ /**
+ * Test method for
+ * {@link ch.idsia.crema.model.causal.mapping.TreeMapping#fromGlobal(int)}.
+ */
+ @Test
+ void testExogenous() {
+ final TreeMapping tm = new TreeMapping();
+ int w1 = tm.add(one);
+ int w2 = tm.add(two);
+
+ int ge1 = tm.toGlobal(e1, w1);
+ int ge2 = tm.toGlobal(e1, w2);
+ assertEquals(ge1, ge2);
+ }
+
+ /**
+ * Test parenting preservalfor
+ */
+ @Test
+ void testParenting() {
+ final TreeMapping tm = new TreeMapping();
+ int w1 = tm.add(one);
+ int w2 = tm.add(two);
+ SCM global = tm.get();
+ for (int v : global.getVariables()) {
+ int[] p = global.getParents(v);
+
+ SCM w = tm.worldOf(v);
+ if (w == null) {
+ w = one;
+ assert(global.isExogenous(v));
+ }
+
+ int vs = tm.fromGlobal(v);
+ int[] ps = tm.fromGlobal(p);
+ int[] pt = w.getParents(vs);
+
+ Arrays.sort(ps);
+ Arrays.sort(pt);
+ assertArrayEquals(ps, pt);
+ }
+ }
+
+
+}
diff --git a/src/test/java/ch/idsia/crema/model/io/UAIParserTest.java b/src/test/java/ch/idsia/crema/model/io/UAIParserTest.java
index 61b8ae35..ae25a475 100644
--- a/src/test/java/ch/idsia/crema/model/io/UAIParserTest.java
+++ b/src/test/java/ch/idsia/crema/model/io/UAIParserTest.java
@@ -42,11 +42,11 @@ void numvars(String name, String num) {
Assertions.assertEquals(((DAGModel) models.get(name)).getVariables().length, Integer.parseInt(num));
}
- @ParameterizedTest
- @ValueSource(strings = {"simple-hcredal.uai", "simple-vcredal.uai"})
- void checkDomains(String name) {
- Assertions.assertTrue(((DAGModel) models.get(name)).correctFactorDomains());
- }
+// @ParameterizedTest
+// @ValueSource(strings = {"simple-hcredal.uai", "simple-vcredal.uai"})
+// void checkDomains(String name) {
+// Assertions.assertTrue(((DAGModel) models.get(name)).correctFactorDomains());
+// }
@ParameterizedTest
@ValueSource(strings = {"simple-hcredal.uai"})
diff --git a/src/test/java/ch/idsia/crema/model/utility/ArraysUtilityTest.java b/src/test/java/ch/idsia/crema/model/utility/ArraysUtilityTest.java
index 7c8d0a8c..24ad320b 100644
--- a/src/test/java/ch/idsia/crema/model/utility/ArraysUtilityTest.java
+++ b/src/test/java/ch/idsia/crema/model/utility/ArraysUtilityTest.java
@@ -277,7 +277,7 @@ public void testroundArrayToTarget() {
@ParameterizedTest
@CsvSource({"1&2,3&4,1&2&3&4",
"3&4, 3&4, 3&4",
- "4&4, 3&4, 4&3",
+ "4&4, 3&4, 3&4",
"3&4, 3&3, 3&4",
"1&2, 3&4, 1&2&3&4",
})
@@ -286,7 +286,7 @@ void unionSet(String arr1, String arr2, String res) {
int[] v2 = ArraysUtil.latexToIntVector(arr2);
int[] expected = ArraysUtil.latexToIntVector(res);
- assertArrayEquals(expected, ArraysUtil.unionSet(v1, v2));
+ assertArrayEquals(expected, ArraysUtil.union_unsorted_set(v1, v2));
}
@@ -375,7 +375,7 @@ public void enumerate(String vect, int start, String expected) {
@ParameterizedTest
@CsvSource({"1&2&3&4,1&2&3&4",
"1&2&1&3&4,1&2&3&4",
- "0&10&4&1&4,0&10&4&1"
+ "0&10&4&1&4,0&1&4&10"
})
void unique(String arr, String expected) {
int[] arr_ = ArraysUtil.latexToIntVector(arr);
@@ -389,7 +389,7 @@ void unique(String arr, String expected) {
@ParameterizedTest
@CsvSource({"1&2&3&4,1&2&3&4",
"1&2&1&3&4,1&2&3&4",
- "0&10&4&1&4,0&10&4&1"
+ "0&10&4&1&4,0&1&4&10"
})
void uniqueDoubles(String arr, String expected) {
double[] arr_ = ArraysUtil.latexToDoubleVector(arr);