Skip to content

Commit

Permalink
Fixed Logspace Marginal
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhuber committed Apr 12, 2023
1 parent 1a25066 commit 32d39a3
Show file tree
Hide file tree
Showing 10 changed files with 230 additions and 35 deletions.
22 changes: 22 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,13 @@
<artifactId>commons-math3</artifactId>
<version>3.6.1</version>
</dependency>

<!-- <dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math4-core</artifactId>
<version>4.0-beta1</version>
</dependency> -->

<dependency>
<groupId>net.sf.lpsolve</groupId>
<artifactId>lp_solve</artifactId>
Expand All @@ -186,6 +193,21 @@
<artifactId>commons-lang3</artifactId>
<version>3.12.0</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-rng-sampling</artifactId>
<version>1.5</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-rng-core</artifactId>
<version>1.5</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-rng-simple</artifactId>
<version>1.5</version>
</dependency>
<dependency>
<groupId>org.eclipse.persistence</groupId>
<artifactId>org.eclipse.persistence.moxy</artifactId>
Expand Down
11 changes: 8 additions & 3 deletions src/main/java/ch/idsia/crema/factor/algebra/OperationUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ private OperationUtils() {


public static double logSum(double first, double second) {
final double min, max;
double min;
double max;

if (first < second) {
min = first;
max = second;
Expand All @@ -18,7 +20,10 @@ public static double logSum(double first, double second) {
max = first;
}

return max + FastMath.log1p(FastMath.exp(min - max));
return max + Math.log1p(Math.exp(min - max));
}

public static double logSumLowPrecision(double first, double second) {
return second + FastMath.log1p(FastMath.exp(first - second));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,5 @@
* Date: 21.04.2021 21:14
*/
public interface BayesianCollector {

double collect(BayesianFactor factor, int source);

}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import org.apache.commons.math3.util.FastMath;

/**
* A collector to marginalize a variable out of a domain in logspace.
*
* Author: Claudio "Dna" Bonesana
* Project: crema
* Date: 21.04.2021 21:23
Expand All @@ -17,25 +19,56 @@ public class LogBayesianMarginal implements BayesianCollector {
private final int[] offsets;
private final int size;

/**
* Construct the collector that summs all values of the variable.
* This will compute the set of offsets defined by the strides.
*
* @param size the size of the variable to be collected
* @param stride the stride of the variable.
*/
public LogBayesianMarginal(int size, int stride) {
this.size = size;
offsets = new int[size];
for (int i = 0; i < size; ++i) {

// we can safely start from 1 as the index 0 is always 0.
for (int i = 1; i < size; ++i) {
offsets[i] = i * stride;
}
}

@Override
public final double collect(BayesianFactor factor, final int source) {

return Arrays.stream(offsets)
.mapToDouble(factor::getLogValueAt)
.reduce(0, OperationUtils::logSum);
public final double collect(BayesianFactor factor, int source) {
// 270 slowest!!
// return Arrays.stream(offsets).map(v->v+source)
// .mapToDouble(factor::getLogValueAt)
// .reduce(Double.NEGATIVE_INFINITY, OperationUtils::logSum);

/** double value = 0;
for (int i = 0; i < size; ++i) {
value += factor.getValueAt(source + offsets[i]); // TODO: try with ch.idsia.crema.factor.algebra.OperationUtils.logSum()
// // 130
double value = factor.getLogValueAt(source + offsets[0]);
for (int i = 1; i < size; ++i) {
double v = factor.getLogValueAt(source + offsets[i]);

if (v > value) {
value = v + Math.log1p(FastMath.exp(value - v));
} else {
value += Math.log1p(FastMath.exp(v - value));
}
}
return FastMath.log(value);*/
return value;

// // 226
// double value = Double.NEGATIVE_INFINITY;
// for (int i = 0; i < size; ++i) {
// double v = factor.getLogValueAt(source + offsets[i]);
// value = OperationUtils.logSum(value, v);
// }
// return value;

// 95 but wrong
// double value = factor.getValueAt(source + offsets[0]);
// for (int i = 1; i < size; ++i) {
// value += factor.getValueAt(source + offsets[i]); // TODO: try with ch.idsia.crema.factor.algebra.OperationUtils.logSum()
// }
// return FastMath.log(value);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ public SimpleBayesianMarginal(int size, int stride) {

@Override
public final double collect(BayesianFactor factor, final int source) {
double value = 0;
for (int i = 0; i < size; ++i) {
double value = factor.getValueAt(source + offsets[0]);
for (int i = 1; i < size; ++i) {
value += factor.getValueAt(source + offsets[i]);
}
return value;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,17 @@ public final class LogMarginal implements Collector {
private final int size;

@Override
public final double collect(final double[] data, final int source) {
double value = 0;
public final double collect(double[] data, int source) {
double value = data[source + offsets[0]];
for (int i = 0; i < size; ++i) {
value += FastMath.exp(data[source + offsets[i]]);
double v = data[source + offsets[i]];
if (v > value) {
value = v + Math.log1p(FastMath.exp(value - v));
} else {
value += Math.log1p(FastMath.exp(v - value));
}
}
return FastMath.log(value);
return value;
}

public LogMarginal(int size, int stride) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ public Marginal(int size, int stride) {

@Override
public final double collect(final double[] data, final int source) {
double value = 0;
for (int i = 0; i < size; ++i) {
double value = data[source + offsets[0]];
for (int i = 1; i < size; ++i) {
value += data[source + offsets[i]];
}
return value;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,18 @@ public double getLogValue(int... states) {
return getLogValueAt(domain.getOffset(states));
}

/**
* The collector method will collect data from this factor with the purpouse of eliminating the variable
* at the index specified in 'offset'. The method will use the builder to generate the final factor and the
* collector as the strategy to collect values of the removed variable.
*
* @param <F> The type of factor. Mostly identified by the builder
* @param offset the offset in the domain of the variable to be removed
* @param builder the method to generate the new factor
* @param collector the strategy to aggregate or filter the data related to the removed vsariable.
*
* @return a new Bayesian Factor with a smaller domain.
*/
protected <F extends BayesianAbstractFactor> F collect(int offset, BayesianFactorBuilder<F> builder, BayesianCollector collector) {
final int[] new_variables = new int[domain.getSize() - 1];
final int[] new_sizes = new int[domain.getSize() - 1];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import ch.idsia.crema.core.Strides;
import ch.idsia.crema.factor.algebra.collectors.Collector;
import ch.idsia.crema.factor.algebra.vertex.VertexOperation;

import org.apache.commons.lang3.NotImplementedException;

import java.util.ArrayList;
Expand Down
143 changes: 131 additions & 12 deletions src/test/java/ch/idsia/crema/model/BayesianFactorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,23 @@
import ch.idsia.crema.factor.bayesian.BayesianFactor;
import ch.idsia.crema.factor.bayesian.BayesianFactorFactory;
import ch.idsia.crema.factor.bayesian.BayesianLogFactor;
import ch.idsia.crema.utility.RandomUtil;

import org.junit.jupiter.api.Test;

import com.google.common.primitives.Doubles;

import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;

import java.time.Clock;
import java.util.Arrays;

import org.apache.commons.math3.stat.StatUtils;
import org.apache.commons.rng.sampling.distribution.DirichletSampler;
import org.apache.commons.rng.simple.RandomSource;


public class BayesianFactorTest {

@Test
Expand Down Expand Up @@ -56,11 +67,15 @@ public void testMarginalize() {
assertArrayEquals(new double[]{0.6, 0.4}, ibf.marginalize(3).marginalize(2).getData(), 1e-7);
}




@Test
public void testLogMarginalize() {
int[] vars = new int[]{1, 2};
int[] size = new int[]{3, 3};


double[] vals = new double[]{0.15, 0.05, 0.25, 0.10, 0.05, 0.05, 0.15, 0.10, 0.10};

BayesianFactor ibf = BayesianFactorFactory.factory().domain(new Strides(vars, size))
Expand Down Expand Up @@ -248,22 +263,126 @@ public void testtime() {
int b = 1;
int c = 2;

Strides d1 = new Strides(new int[]{a,b,c}, new int[]{2,3,2});


Strides d1 = new Strides(new int[]{a,c,b }, new int[]{2,2,3});
Strides d2 = new Strides(new int[]{b,c}, new int[]{3,2});
int tests = 50000;
int reps = 100;
double div = 1000.0;
long tot = 0;
double x = 0;

long nt = System.nanoTime();
for (int rep = -50; rep < reps; ++ rep){
var rs = RandomSource.JDK.create(0xb00b);
DirichletSampler sampler2 = DirichletSampler.symmetric(rs, 2, 1);
DirichletSampler sampler_prior = DirichletSampler.symmetric(rs, 6, 1);


double[][] arrpa = sampler2.samples(tests*6).toArray(len->new double[len][]);
double[][] arrpb = sampler_prior.samples(tests).toArray(len->new double[len][]);


for (int i = 0; i < tests; ++i) {

double[] parr = Doubles.concat(arrpa[i*6], arrpa[i*6+1], arrpa[i*6+2], arrpa[i*6+3], arrpa[i*6+4], arrpa[i*6+5]);
BayesianFactor pa = BayesianFactorFactory.factory().domain(d1).data(parr).log();
BayesianFactor pb = BayesianFactorFactory.factory().domain(d2).data(arrpb[i]).log();

final long nt = System.nanoTime();
var fact = pa.combine(pb)
// .marginalize(c)
// .marginalize(b)
;
long diff = (System.nanoTime() - nt);
double[] data = fact.getData();
if (rep >=0) {
tot += diff;
x += Arrays.stream(data).sum();
}
}

double x = 0;
BayesianLogFactor pa = BayesianFactorFactory.factory().domain(d1).data(new double[]{0.15, 0.85, 0.25, 0.75, 0.95, 0.05, 0.5, 0.5, 0.4, 0.6, 0.7,0.3}).log();
BayesianLogFactor pb = BayesianFactorFactory.factory().domain(d2).data(new double[]{0.15, 0.05, 0.25,0.20, 0.25, 0.1}).log();
for (int i= 0; i < 100000; ++i) {
double[] data = pa.combine(pb).marginalize(b).marginalize(c).getData();
x += data[0];
}

double time = (System.nanoTime() - nt) / 1000000000.0;
System.out.println(time + " " + x);
System.out.println("log:" + (tot/div/reps) + "\t" + (x/reps));

double x2 = 0;
tot=0;

for (int rep = -10; rep < reps; ++ rep){
var rs = RandomSource.JDK.create(0xb00b);
DirichletSampler sampler2 = DirichletSampler.symmetric(rs, 2, 1);
DirichletSampler sampler_prior = DirichletSampler.symmetric(rs, 6, 1);


double[][] arrpa = sampler2.samples(tests*6).toArray(len->new double[len][]);
double[][] arrpb = sampler_prior.samples(tests).toArray(len->new double[len][]);

for (int i = 0; i < tests; ++i) {

double[] parr = Doubles.concat(arrpa[i*6], arrpa[i*6+1], arrpa[i*6+2], arrpa[i*6+3], arrpa[i*6+4], arrpa[i*6+5]);
// BayesianFactor pa = BayesianFactorFactory.factory().domain(d1).data(parr).get();//.log();
// BayesianFactor pb = BayesianFactorFactory.factory().domain(d2).data(arrpb[i]).get();//.log();

BayesianFactor pa = BayesianFactorFactory.factory().domain(d1).data(parr).get();
BayesianFactor pb = BayesianFactorFactory.factory().domain(d2).data(arrpb[i]).get();

long nt = System.nanoTime();
var fact = pa.combine(pb)
// .marginalize(c)
// .marginalize(b)
;
long diff = (System.nanoTime() - nt) ;
double[] data = fact.getData();
if (rep >=0) {
tot += diff;
x2 += Arrays.stream(data).sum();
}
}
}

System.out.println("get:" + (tot/div/reps) + "\t" + (x/reps));
assertEquals(x,x2, 0.000000001);


}



@Test
public void testLogSpeed() {

int[] vars = new int[]{0, 1, 2};
int[] size = new int[]{2, 3, 3};

int tests = 100000;
long total = 0;
int reps = 20;
var rs = RandomSource.JDK.create(0xb00b);
for (int rep = -10; rep < reps; ++ rep){
DirichletSampler sampler = DirichletSampler.symmetric(rs, 18, 1);

long time = System.nanoTime();
double[][] arrs = sampler.samples(tests).toArray(len->new double[len][]);
for (int i = 0; i < tests; ++i) {

double[] vals = arrs[i];//new double[]{0.15, 0.05, 0.25, 0.10, 0.05, 0.05, 0.15, 0.10, 0.10};

BayesianFactor ibf = BayesianFactorFactory.factory().domain(new Strides(vars, size))
.data(vals)
.log();

double[] dta = ibf.marginalize(1).getData();
for (int j = 0; j > 3; ++j) {
assertEquals(dta[j], vals[j*6+0] + vals[j*6+2] + vals[j*6+4], 0.00000001);
assertEquals(dta[j+1], vals[j*6+1] + vals[j*6+3] + vals[j*6+5], 0.00000001);
}
}
long delta = System.nanoTime() - time;

if (rep >= 0) {
total += delta;
System.out.println(delta);
}
}
System.out.println(total/1000000.0/reps);
}
}

0 comments on commit 32d39a3

Please sign in to comment.