Skip to content

Commit

Permalink
Fix a bug when the predicted probability is 0 or 1 in the Bernoulli l…
Browse files Browse the repository at this point in the history
…ikelihood
  • Loading branch information
gkronber committed Sep 20, 2023
1 parent 871e73d commit c75cf04
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions HEAL.NonlinearRegression/Likelihoods/BernoulliLikelihood.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace HEAL.NonlinearRegression {

public class BernoulliLikelihood : LikelihoodBase {

internal BernoulliLikelihood(BernoulliLikelihood original) : base(original) { }
protected BernoulliLikelihood(BernoulliLikelihood original) : base(original) { }
public BernoulliLikelihood(double[,] x, double[] y, Expression<Expr.ParametricFunction> modelExpr) : base(modelExpr, x, y, 0) { }

public override double[,] FisherInformation(double[] p) {
Expand Down Expand Up @@ -35,7 +35,8 @@ public BernoulliLikelihood(double[,] x, double[] y, Expression<Expr.ParametricFu
for (int k = 0; k < n; k++) {
var hessianTerm = (yPred[i] - 1) * yPred[i] * yHess[j, i, k] * (y[i] - yPred[i]);
var gradientTerm = (-2 * y[i] * yPred[i] + yPred[i] * yPred[i] + y[i]) * yJac[i, j] * yJac[i, k];
hess[j, k] += s * (hessianTerm + gradientTerm);
if ((hessianTerm + gradientTerm) > 0)
hess[j, k] += s * (hessianTerm + gradientTerm);
}
}
}
Expand Down Expand Up @@ -66,18 +67,20 @@ public override void NegLogLikelihoodGradient(double[] p, out double nll, double
for (int i = 0; i < m; i++) {
if (y[i] != 0.0 && y[i] != 1.0) throw new ArgumentException("target variable must be binary (0/1) for Bernoulli likelihood");
if (y[i] == 1) {
nll -= Math.Log(yPred[i]);
nll -= Math.Log(yPred[i]); // potential log(0)
if (nll_grad != null) {
for (int j = 0; j < n; j++) {
nll_grad[j] -= yJac[i, j] / yPred[i];
if (yJac[i, j] > 0)
nll_grad[j] -= yJac[i, j] / yPred[i]; // potential division by zero
}
}
} else {
// y[i]==0
nll -= Math.Log(1 - yPred[i]);
nll -= Math.Log(1 - yPred[i]); // potential log(0)
if (nll_grad != null) {
for (int j = 0; j < n; j++) {
nll_grad[j] += yJac[i, j] / (1 - yPred[i]);
if (yJac[i, j] > 0)
nll_grad[j] += yJac[i, j] / (1 - yPred[i]); // potential division by zero
}
}
}
Expand Down

0 comments on commit c75cf04

Please sign in to comment.