diff --git a/java-ml-example/pom.xml b/java-ml-example/pom.xml
index 1329f0e..ea34dd7 100644
--- a/java-ml-example/pom.xml
+++ b/java-ml-example/pom.xml
@@ -10,7 +10,6 @@
java-ml-example
- 2.1.4-SNAPSHOT
java-ml-example
diff --git a/java-ml-example/src/main/java/org/rsultan/example/ExtendedIsolationForestExample.java b/java-ml-example/src/main/java/org/rsultan/example/ExtendedIsolationForestExample.java
new file mode 100644
index 0000000..942404a
--- /dev/null
+++ b/java-ml-example/src/main/java/org/rsultan/example/ExtendedIsolationForestExample.java
@@ -0,0 +1,42 @@
+package org.rsultan.example;
+
+import org.nd4j.linalg.api.buffer.DataType;
+import org.nd4j.linalg.factory.Nd4j;
+import org.rsultan.core.clustering.ensemble.evaluation.TPRThresholdEvaluator;
+import org.rsultan.core.clustering.ensemble.isolationforest.ExtendedIsolationForest;
+import org.rsultan.dataframe.Dataframes;
+
+import java.io.IOException;
+
+public class ExtendedIsolationForestExample {
+
+ /*
+ You can use the http dataset --> args[0]
+ You can use the http_reduced.csv dataset for testing --> args[1]
+
+ threshold = 0.7000000000000001
+ ===========================
+ TPR ║ FPR
+ ===========================
+ 0.9963817277250113 ║ 0.0031
+ ===========================
+ */
+ static {
+ Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE);
+ }
+
+ public static void main(String[] args) throws IOException {
+ var df = Dataframes.csv(args[0], ",", "\"", true);
+ var testDf = Dataframes.csv(args[1], ",", "\"", true);
+ var trainTestDataframe = Dataframes.trainTest(df.getColumns()).setSplitValue(0.99);
+
+ var model = new ExtendedIsolationForest(200, 2);
+ var evaluator = new TPRThresholdEvaluator("attack", "anomalies")
+ .setDesiredTPR(0.7)
+ .setExternalTestDataframe(testDf)
+ .setLearningRate(0.1);
+ Double threshold = evaluator.evaluate(model, trainTestDataframe);
+ System.out.println("threshold = " + threshold);
+ evaluator.showMetrics();
+ }
+}
diff --git a/java-ml/pom.xml b/java-ml/pom.xml
index 78d8a98..8364719 100644
--- a/java-ml/pom.xml
+++ b/java-ml/pom.xml
@@ -12,7 +12,6 @@
java-ml
- 2.1.4-SNAPSHOT
java-ml
diff --git a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/domain/IsolationNode.java b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/domain/IsolationNode.java
index 0cbdd9c..e64c721 100644
--- a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/domain/IsolationNode.java
+++ b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/domain/IsolationNode.java
@@ -5,23 +5,21 @@
import java.io.Serializable;
import org.nd4j.linalg.api.ndarray.INDArray;
-public record IsolationNode(
- int feature,
- double featureThreshold,
+public record IsolationNode(
+ T nodeData,
INDArray data,
- IsolationNode left,
- IsolationNode right
+ IsolationNode left,
+ IsolationNode right
) implements Serializable {
public IsolationNode(INDArray data) {
- this(-1, -1, data, null, null);
+ this(null, data, null, null);
}
public IsolationNode(
- int feature,
- double featureThreshold,
- IsolationNode left,
- IsolationNode right) {
- this(feature, featureThreshold, null, left, right);
+ T nodeData,
+ IsolationNode left,
+ IsolationNode right) {
+ this(nodeData, null, left, right);
}
public boolean isLeaf() {
diff --git a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/evaluation/TPRThresholdEvaluator.java b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/evaluation/TPRThresholdEvaluator.java
index 9024390..de0a801 100644
--- a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/evaluation/TPRThresholdEvaluator.java
+++ b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/evaluation/TPRThresholdEvaluator.java
@@ -2,73 +2,90 @@
import org.rsultan.core.Evaluator;
import org.rsultan.core.clustering.ensemble.isolationforest.IsolationForest;
+import org.rsultan.dataframe.Dataframe;
import org.rsultan.dataframe.Dataframes;
import org.rsultan.dataframe.Row;
import org.rsultan.dataframe.TrainTestDataframe;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
public class TPRThresholdEvaluator implements Evaluator {
- private final String responseVariable;
- private final String predictionColumn;
- private double desiredTPR = 0.9;
- private double learningRate = 0.01;
- private double TPR = 0;
- private double FPR = 0;
+ private static final Logger LOG = LoggerFactory.getLogger(TPRThresholdEvaluator.class);
+ private static final String IS_AN_ANOMALY = "isAnAnomaly";
+ private final String responseVariable;
+ private final String predictionColumn;
+ private double desiredTPR = 0.9;
+ private double learningRate = 0.01;
+ private double TPR = 0;
+ private double FPR = 0;
+ private Dataframe externalTestDataframe;
- public TPRThresholdEvaluator(String responseVariable, String predictionColumn) {
- this.responseVariable = responseVariable;
- this.predictionColumn = predictionColumn;
- }
+ public TPRThresholdEvaluator(String responseVariable, String predictionColumn) {
+ this.responseVariable = responseVariable;
+ this.predictionColumn = predictionColumn;
+ }
- @Override
- public Double evaluate(IsolationForest trainable, TrainTestDataframe dataframe) {
- var dfSplit = dataframe.shuffle().split();
- double threshold = 1;
- while (threshold > 0 && TPR <= desiredTPR) {
- threshold -= learningRate;
- var trained = trainable.setAnomalyThreshold(threshold)
- .train(dfSplit.train().mapWithout(responseVariable));
- var responses = dfSplit.test().get(responseVariable);
- var predictions = trained.predict(dfSplit.test().mapWithout(responseVariable))
- .get(predictionColumn);
+ @Override
+ public Double evaluate(IsolationForest trainable, TrainTestDataframe dataframe) {
+ var dfSplit = dataframe.shuffle().split();
+ var trained = trainable.setUseAnomalyScoresOnly(true)
+ .train(dfSplit.train().mapWithout(responseVariable));
+ final Dataframe testDf = externalTestDataframe != null ? externalTestDataframe : dfSplit.test();
+ final Dataframe predict = trained.predict(testDf.mapWithout(responseVariable));
- double truePositives = 0;
- double trueNegatives = 0;
- double falsePositives = 0;
- double falseNegative = 0;
+ double threshold = 1;
+ while (threshold > 0 && TPR <= desiredTPR) {
+ LOG.info("Evaluating isolation forest with threshold {}", threshold);
+ threshold -= learningRate;
+ final double finalThreshold = threshold;
+ var responses = testDf.get(responseVariable);
+ var predictions = predict.map(IS_AN_ANOMALY,
+ (Double score) -> (score >= finalThreshold ? 1L : 0L), predictionColumn)
+ .get(IS_AN_ANOMALY);
- for (int i = 0; i < responses.size(); i++) {
- var response = responses.get(i);
- var prediction = predictions.get(i);
- truePositives += response == 1L && prediction == 1L ? 1L : 0L;
- trueNegatives += response == 0L && prediction == 0L ? 1L : 0L;
- falsePositives += response == 0L && prediction == 1L ? 1L : 0L;
- falseNegative += response == 1L && prediction == 0L ? 1L : 0L;
- }
- TPR = truePositives / (truePositives + falseNegative);
- TPR = Double.isNaN(TPR) ? 0 : TPR;
+ double truePositives = 0;
+ double trueNegatives = 0;
+ double falsePositives = 0;
+ double falseNegative = 0;
- FPR = falsePositives / (falsePositives + trueNegatives);
- FPR = Double.isNaN(FPR) ? 0L : FPR;
- }
- if (threshold < 0) {
- throw new IllegalArgumentException("Cannot have desired TPR");
+ for (int i = 0; i < responses.size(); i++) {
+ var response = responses.get(i);
+ var prediction = predictions.get(i);
+ truePositives += response == 1L && prediction == 1L ? 1L : 0L;
+ trueNegatives += response == 0L && prediction == 0L ? 1L : 0L;
+ falsePositives += response == 0L && prediction == 1L ? 1L : 0L;
+ falseNegative += response == 1L && prediction == 0L ? 1L : 0L;
+ }
+ TPR = truePositives / (truePositives + falseNegative);
+ TPR = Double.isNaN(TPR) ? 0 : TPR;
+
+ FPR = falsePositives / (falsePositives + trueNegatives);
+ FPR = Double.isNaN(FPR) ? 0L : FPR;
+ }
+ if (threshold < 0) {
+ throw new IllegalArgumentException("Cannot have desired TPR");
+ }
+
+ return threshold;
}
- return threshold;
- }
+ public TPRThresholdEvaluator setDesiredTPR(double desiredTPR) {
+ this.desiredTPR = desiredTPR;
+ return this;
+ }
- public TPRThresholdEvaluator setDesiredTPR(double desiredTPR) {
- this.desiredTPR = desiredTPR;
- return this;
- }
+ public TPRThresholdEvaluator setLearningRate(double learningRate) {
+ this.learningRate = learningRate;
+ return this;
+ }
- public TPRThresholdEvaluator setLearningRate(double learningRate) {
- this.learningRate = learningRate;
- return this;
- }
+ public TPRThresholdEvaluator setExternalTestDataframe(Dataframe dataframe) {
+ this.externalTestDataframe = dataframe;
+ return this;
+ }
- public void showMetrics() {
- Dataframes.create(new String[]{"TPR", "FPR"}, new Row(TPR, FPR)).tail();
- }
+ public void showMetrics() {
+ Dataframes.create(new String[]{"TPR", "FPR"}, new Row(TPR, FPR)).tail();
+ }
}
diff --git a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/ExtendedIsolationForest.java b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/ExtendedIsolationForest.java
new file mode 100644
index 0000000..e161fe1
--- /dev/null
+++ b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/ExtendedIsolationForest.java
@@ -0,0 +1,56 @@
+package org.rsultan.core.clustering.ensemble.isolationforest;
+
+import org.apache.commons.lang3.RandomUtils;
+import org.nd4j.linalg.indexing.NDArrayIndex;
+import org.rsultan.core.clustering.ensemble.isolationforest.tree.ExtendedIsolationTree;
+import org.rsultan.dataframe.Dataframe;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.stream.LongStream;
+
+import static java.lang.String.format;
+import static java.util.stream.IntStream.range;
+import static org.rsultan.core.clustering.ensemble.isolationforest.utils.ScoreUtils.averagePathLength;
+
+public class ExtendedIsolationForest extends IsolationForest {
+
+ private static final Logger LOG = LoggerFactory.getLogger(ExtendedIsolationForest.class);
+
+ private final int extensionLevel;
+
+ public ExtendedIsolationForest(int nbTrees, int extensionLevel) {
+ super(nbTrees);
+ this.extensionLevel = extensionLevel;
+ }
+
+ public ExtendedIsolationForest setSampleSize(int sampleSize) {
+ super.setSampleSize(sampleSize);
+ return this;
+ }
+
+ public ExtendedIsolationForest setAnomalyThreshold(double anomalyThreshold) {
+ super.setAnomalyThreshold(anomalyThreshold);
+ return this;
+ }
+
+ @Override
+ public ExtendedIsolationForest train(Dataframe dataframe) {
+ var matrix = dataframe.toMatrix();
+ if (extensionLevel > matrix.columns() - 1 || extensionLevel < 0) {
+ throw new IllegalArgumentException(format("extensionLevel must be between 0 and %d, current is [%d]", matrix.columns() - 1, extensionLevel));
+ }
+ int realSample = sampleSize >= matrix.rows() ? sampleSize / 10 : sampleSize;
+ int treeDepth = (int) Math.ceil(Math.log(realSample) / Math.log(2));
+ isolationTrees = range(0, nbTrees).parallel()
+ .peek(i -> LOG.info("Tree number: {}", i))
+ .mapToObj(i -> LongStream.range(0, realSample)
+ .map(idx -> RandomUtils.nextLong(0, matrix.rows()))
+ .toArray())
+ .map(NDArrayIndex::indices)
+ .map(matrix::get)
+ .map(m -> new ExtendedIsolationTree(treeDepth, extensionLevel).train(m))
+ .toList();
+ return this;
+ }
+}
diff --git a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/IsolationForest.java b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/IsolationForest.java
index 9d9fa2a..69b1bf6 100644
--- a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/IsolationForest.java
+++ b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/IsolationForest.java
@@ -9,10 +9,11 @@
import org.apache.commons.lang3.RandomUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
-import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.transforms.Transforms;
+import org.rsultan.core.RawTrainable;
import org.rsultan.core.Trainable;
+import org.rsultan.core.clustering.ensemble.isolationforest.tree.IsolationTree;
import org.rsultan.dataframe.Column;
import org.rsultan.dataframe.Dataframe;
import org.slf4j.Logger;
@@ -23,8 +24,9 @@ public class IsolationForest implements Trainable {
private static final Logger LOG = LoggerFactory.getLogger(IsolationTree.class);
protected final int nbTrees;
protected double anomalyThreshold = 0.5;
- protected List isolationTrees;
+ protected List extends RawTrainable>> isolationTrees;
protected int sampleSize = 256;
+ private boolean useAnomalyScoresOnly;
public IsolationForest(int nbTrees) {
this.nbTrees = nbTrees;
@@ -40,6 +42,11 @@ public IsolationForest setAnomalyThreshold(double anomalyThreshold) {
return this;
}
+ public IsolationForest setUseAnomalyScoresOnly(boolean useAnomalyScoresOnly) {
+ this.useAnomalyScoresOnly = useAnomalyScoresOnly;
+ return this;
+ }
+
@Override
public IsolationForest train(Dataframe dataframe) {
var matrix = dataframe.toMatrix();
@@ -61,9 +68,11 @@ public IsolationForest train(Dataframe dataframe) {
public Dataframe predict(Dataframe dataframe) {
var matrix = dataframe.toMatrix();
var anomalyScores = computeAnomalyScore(matrix);
- var isAnomaly = new Column<>("anomalies", DoubleStream.of(
- anomalyScores.toDoubleVector()
- ).mapToObj(score -> score >= anomalyThreshold ? 1L : 0L).toArray());
+ final DoubleStream doubleStream = DoubleStream.of(
+ anomalyScores.toDoubleVector()
+ );
+ var isAnomaly = new Column<>("anomalies",
+ useAnomalyScoresOnly ? doubleStream.boxed().toArray() : doubleStream.mapToObj(score -> score >= anomalyThreshold ? 1L : 0L).toArray());
return dataframe.addColumn(isAnomaly);
}
diff --git a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/IsolationTree.java b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/IsolationTree.java
deleted file mode 100644
index b1801c8..0000000
--- a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/IsolationTree.java
+++ /dev/null
@@ -1,93 +0,0 @@
-package org.rsultan.core.clustering.ensemble.isolationforest;
-
-import static java.util.stream.LongStream.range;
-import static org.apache.commons.lang3.RandomUtils.nextDouble;
-import static org.apache.commons.lang3.RandomUtils.nextInt;
-import static org.rsultan.core.clustering.ensemble.isolationforest.utils.ScoreUtils.averagePathLength;
-
-import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.factory.Nd4j;
-import org.nd4j.linalg.indexing.NDArrayIndex;
-import org.rsultan.core.RawTrainable;
-import org.rsultan.core.clustering.ensemble.domain.IsolationNode;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-public class IsolationTree implements RawTrainable {
-
- public static final Logger LOG = LoggerFactory.getLogger(IsolationTree.class);
- private final int treeDepthLimit;
- private IsolationNode tree;
-
- public IsolationTree(int treeDepthLimit) {
- this.treeDepthLimit = treeDepthLimit;
- }
-
- @Override
- public IsolationTree train(INDArray matrix) {
- this.tree = buildTree(matrix, treeDepthLimit);
- return this;
- }
-
- private IsolationNode buildTree(INDArray matrix, int currentDepth) {
- LOG.info("Tree Depth {}", currentDepth);
- if (currentDepth <= 0 || matrix.rows() <= 2) {
- return new IsolationNode(matrix);
- }
- int numberOfFeatures = matrix.columns();
- int splitFeature = nextInt(0, numberOfFeatures);
- var feature = matrix.getColumn(splitFeature);
- double startInclusive = feature.minNumber().doubleValue();
- double endInclusive = feature.maxNumber().doubleValue();
-
- double valueSplit =
- getValueSplit(startInclusive, endInclusive);
-
- var leftIndices = range(0, feature.columns()).parallel()
- .filter(idx -> feature.getDouble(idx) < valueSplit)
- .toArray();
- var left = getVector(matrix, leftIndices);
-
- var rightIndices = range(0, feature.columns()).parallel()
- .filter(idx -> feature.getDouble(idx) > valueSplit)
- .toArray();
- var right = getVector(matrix, rightIndices);
-
- return new IsolationNode(
- splitFeature,
- valueSplit,
- buildTree(left, currentDepth - 1),
- buildTree(right, currentDepth - 1)
- );
- }
-
- private INDArray getVector(INDArray matrix, long[] indices) {
- return matrix.get(NDArrayIndex.indices(indices));
- }
-
- private double getValueSplit(double startInclusive, double endInclusive) {
- if (startInclusive < 0 && endInclusive < 0) {
- return -nextDouble(endInclusive * -1, startInclusive * -1);
- } else if (startInclusive < 0 && endInclusive >= 0) {
- return nextDouble(0, endInclusive + startInclusive * -1) + startInclusive;
- }
- return nextDouble(startInclusive, endInclusive);
- }
-
- @Override
- public INDArray predict(INDArray matrix) {
- var pathLengths = Nd4j.zeros(1, matrix.rows());
- for (int i = 0; i < matrix.rows(); i++) {
- var row = matrix.getRow(i);
- var node = tree;
- int length = 0;
- while (!node.isLeaf()) {
- node = row.getDouble(node.feature()) < node.featureThreshold() ? node.left() : node.right();
- length++;
- }
- int leafSize = node.data().rows();
- pathLengths.put(0, i, length + averagePathLength(leafSize));
- }
- return pathLengths;
- }
-}
diff --git a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/tree/AbstractTree.java b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/tree/AbstractTree.java
new file mode 100644
index 0000000..7c54b0c
--- /dev/null
+++ b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/tree/AbstractTree.java
@@ -0,0 +1,58 @@
+package org.rsultan.core.clustering.ensemble.isolationforest.tree;
+
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.indexing.NDArrayIndex;
+import org.rsultan.core.clustering.ensemble.domain.IsolationNode;
+import org.rsultan.core.clustering.ensemble.isolationforest.tree.IsolationTree.Feature;
+
+import java.util.function.LongPredicate;
+
+import static java.util.stream.LongStream.range;
+import static org.rsultan.core.clustering.ensemble.isolationforest.utils.ScoreUtils.averagePathLength;
+
+public abstract class AbstractTree {
+
+ protected final int treeDepthLimit;
+ protected IsolationNode tree;
+
+ public AbstractTree(int treeDepthLimit) {
+ this.treeDepthLimit = treeDepthLimit;
+ }
+
+ protected IsolationNode buildTree(INDArray X, int currentDepth){
+ if (currentDepth <= 0 || X.rows() <= 2) {
+ return new IsolationNode<>(X);
+ }
+ return buildNode(X, currentDepth);
+ }
+
+ protected abstract IsolationNode buildNode(INDArray X, int currentDepth);
+
+ protected abstract boolean chooseLeftNode(INDArray row, NODE_DATA slope);
+
+ protected INDArray getVector(INDArray matrix, long[] indices) {
+ return matrix.get(NDArrayIndex.indices(indices));
+ }
+
+ protected long[] getIndices(INDArray ndArray, LongPredicate predicate){
+ return range(0, ndArray.rows()).parallel().filter(predicate).toArray();
+ }
+
+ public INDArray predict(INDArray matrix) {
+ var pathLengths = Nd4j.zeros(1, matrix.rows());
+ for (int i = 0; i < matrix.rows(); i++) {
+ var row = matrix.getRow(i);
+ var node = tree;
+ int length = 0;
+ while (!node.isLeaf()) {
+ var slope = node.nodeData();
+ node = chooseLeftNode(row, slope) ? node.left() : node.right();
+ length++;
+ }
+ int leafSize = node.data().rows();
+ pathLengths.put(0, i, length + averagePathLength(leafSize));
+ }
+ return pathLengths;
+ }
+}
diff --git a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/tree/ExtendedIsolationTree.java b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/tree/ExtendedIsolationTree.java
new file mode 100644
index 0000000..21b8ed8
--- /dev/null
+++ b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/tree/ExtendedIsolationTree.java
@@ -0,0 +1,87 @@
+package org.rsultan.core.clustering.ensemble.isolationforest.tree;
+
+import org.apache.commons.math3.distribution.NormalDistribution;
+import org.nd4j.common.util.MathUtils;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+import org.rsultan.core.RawTrainable;
+import org.rsultan.core.clustering.ensemble.domain.IsolationNode;
+import org.rsultan.core.clustering.ensemble.isolationforest.tree.ExtendedIsolationTree.Slope;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+import static java.util.stream.LongStream.range;
+
+public class ExtendedIsolationTree extends AbstractTree implements RawTrainable {
+
+ public static final Logger LOG = LoggerFactory.getLogger(ExtendedIsolationTree.class);
+ private final int extensionLevel;
+
+ public ExtendedIsolationTree(int treeDepthLimit, int extensionLevel) {
+ super(treeDepthLimit);
+ this.extensionLevel = extensionLevel;
+ }
+
+ @Override
+ public ExtendedIsolationTree train(INDArray matrix) {
+ this.tree = buildTree(matrix, treeDepthLimit);
+ return this;
+ }
+
+ @Override
+ protected IsolationNode buildNode(INDArray X, int currentDepth) {
+ LOG.info("Tree Depth {}", currentDepth);
+ if (currentDepth <= 0 || X.rows() <= 2) {
+ return new IsolationNode<>(X);
+ }
+ int numberOfFeatures = X.columns();
+ var mins = X.min(true, 0);
+ var maxs = X.max(true, 0);
+ var n = getNormalVector(mins.shape());
+ var p = getIntercept(numberOfFeatures, mins, maxs);
+ var w = X.sub(p).mmul(n);
+
+ var left = getVector(X, getIndices(w, idx -> w.getDouble(idx) < 0));
+ var right = getVector(X, getIndices(w, idx -> w.getDouble(idx) >= 0));
+
+ return new IsolationNode<>(
+ new Slope(n, p),
+ buildTree(left, currentDepth - 1),
+ buildTree(right, currentDepth - 1)
+ );
+ }
+
+ private INDArray getIntercept(int numberOfFeatures, INDArray mins, INDArray maxs) {
+ var p = Nd4j.zeros(mins.shape());
+ for (int i = 0; i < numberOfFeatures; i++) {
+ p.put(0, i, MathUtils.randomDoubleBetween(mins.getDouble(0, i), maxs.getDouble(0, i)));
+ }
+ return p;
+ }
+
+ private INDArray getNormalVector(long[] shape) {
+ var distribution = new NormalDistribution();
+ var n = Nd4j.zeros(shape);
+ for (int i = 0; i < n.columns(); i++) {
+ n.put(0, i, distribution.sample());
+ }
+ final int numSamples = n.columns() - this.extensionLevel - 1;
+ if (numSamples > 0) {
+ final INDArray arange = Nd4j.create(range(0, n.columns()).boxed().toList());
+ int[] indices = Nd4j.choice(arange, Nd4j.rand(n.columns()), numSamples).toIntVector();
+ for (int index : indices) {
+ n.putScalar(index, 0.0D);
+ }
+ }
+ return n.transpose();
+ }
+
+ @Override
+ protected boolean chooseLeftNode(INDArray row, Slope slope) {
+ return row.sub(slope.p).mmul(slope.n).getDouble(0) < 0;
+ }
+
+ static record Slope(INDArray n, INDArray p){}
+
+}
diff --git a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/tree/IsolationTree.java b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/tree/IsolationTree.java
new file mode 100644
index 0000000..bc0e9d3
--- /dev/null
+++ b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/tree/IsolationTree.java
@@ -0,0 +1,54 @@
+package org.rsultan.core.clustering.ensemble.isolationforest.tree;
+
+import static org.apache.commons.lang3.RandomUtils.nextInt;
+
+import org.nd4j.common.util.MathUtils;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.rsultan.core.RawTrainable;
+import org.rsultan.core.clustering.ensemble.domain.IsolationNode;
+import org.rsultan.core.clustering.ensemble.isolationforest.tree.IsolationTree.Feature;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class IsolationTree extends AbstractTree implements RawTrainable {
+
+ public static final Logger LOG = LoggerFactory.getLogger(IsolationTree.class);
+
+ public IsolationTree(int treeDepthLimit) {
+ super(treeDepthLimit);
+ }
+
+ @Override
+ public IsolationTree train(INDArray matrix) {
+ this.tree = buildTree(matrix, treeDepthLimit);
+ return this;
+ }
+
+ @Override
+ protected IsolationNode buildNode(INDArray X, int currentDepth) {
+ LOG.info("Tree Depth {}", currentDepth);
+ int numberOfFeatures = X.columns();
+ int splitFeature = nextInt(0, numberOfFeatures);
+ var feature = X.getColumn(splitFeature);
+ double startInclusive = feature.minNumber().doubleValue();
+ double endInclusive = feature.maxNumber().doubleValue();
+ double valueSplit = MathUtils.randomDoubleBetween(startInclusive, endInclusive);
+
+ var left = getVector(X, getIndices(feature, idx -> feature.getDouble(idx) < valueSplit));
+ var right = getVector(X, getIndices(feature, idx -> feature.getDouble(idx) >= valueSplit));
+
+ return new IsolationNode<>(
+ new Feature(splitFeature, valueSplit),
+ buildTree(left, currentDepth - 1),
+ buildTree(right, currentDepth - 1)
+ );
+ }
+
+ @Override
+ protected boolean chooseLeftNode(INDArray row, Feature feature) {
+ return row.getDouble(feature.feature) < feature.threshold ;
+ }
+
+ static record Feature(int feature, double threshold){}
+
+}
diff --git a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/utils/ScoreUtils.java b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/utils/ScoreUtils.java
index fd4f95d..0d5a5e2 100644
--- a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/utils/ScoreUtils.java
+++ b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/utils/ScoreUtils.java
@@ -17,4 +17,5 @@ public static double averagePathLength(double leafSize) {
private static double harmonicNumber(double leafSize) {
return Math.log(leafSize - 1) + EULER_CONSTANT;
}
+
}
diff --git a/java-ml/src/test/java/org/rsultan/core/clustering/IsolationForestTest.java b/java-ml/src/test/java/org/rsultan/core/clustering/IsolationForestTest.java
index c57f9f4..33fbe93 100644
--- a/java-ml/src/test/java/org/rsultan/core/clustering/IsolationForestTest.java
+++ b/java-ml/src/test/java/org/rsultan/core/clustering/IsolationForestTest.java
@@ -11,6 +11,7 @@
import java.util.List;
import org.junit.jupiter.api.Test;
import org.rsultan.core.clustering.ensemble.evaluation.TPRThresholdEvaluator;
+import org.rsultan.core.clustering.ensemble.isolationforest.ExtendedIsolationForest;
import org.rsultan.core.clustering.ensemble.isolationforest.IsolationForest;
import org.rsultan.dataframe.Dataframes;
import org.rsultan.dataframe.Row;
@@ -55,5 +56,35 @@ public void should_evaluate_tpr() {
assertThat(threshold).isGreaterThan(0.4);
}
+ @Test
+ public void must_perform_extended_isolation() {
+ var df = Dataframes.create(new String[]{"x", "y", "response"}, rows);
+ var predict = new ExtendedIsolationForest(10, 1).setAnomalyThreshold(0.56).setSampleSize(15).train(df).predict(df);
+ int anomalies = predict.filter("anomalies", obj -> obj.equals(1L)).getRowSize();
+ int nonAnomalies = predict.filter("anomalies", obj -> obj.equals(0L)).getRowSize();
+
+ assertThat(anomalies).isBetween(1, predict.getRowSize());
+ assertThat(nonAnomalies).isEqualTo(predict.getRowSize() - anomalies);
+ }
+
+ @Test
+ public void should_evaluate_extended_tpr() {
+ var df = Dataframes.trainTest(new String[]{"x", "y", "response"}, rows);
+ var model = new ExtendedIsolationForest(10, 1).setSampleSize(15);
+ var evaluator = new TPRThresholdEvaluator("response", "anomalies").setDesiredTPR(0.9).setLearningRate(0.01);
+ var threshold = evaluator.evaluate(model, df);
+ evaluator.showMetrics();
+ assertThat(threshold).isGreaterThan(0.4);
+ }
+ @Test
+ public void must_perform_isolation_and_retrieve_scores() {
+ var df = Dataframes.create(new String[]{"x", "y", "response"}, rows);
+ var predict = new IsolationForest(10).setAnomalyThreshold(0.56).setSampleSize(15).setUseAnomalyScoresOnly(true).train(df).predict(df);
+ int anomalies = predict.filter("anomalies", (Number obj) -> obj.doubleValue() > 0.56).getRowSize();
+ int nonAnomalies = predict.filter("anomalies", (Number obj) -> obj.doubleValue() <= 0.56).getRowSize();
+
+ assertThat(anomalies).isBetween(1, predict.getRowSize());
+ assertThat(nonAnomalies).isEqualTo(predict.getRowSize() - anomalies);
+ }
}
diff --git a/version.properties b/version.properties
index 82f2bfb..1edc995 100755
--- a/version.properties
+++ b/version.properties
@@ -1 +1 @@
-version.next=2.1.4
+version.next=2.2.0
\ No newline at end of file