diff --git a/java-ml-example/pom.xml b/java-ml-example/pom.xml index 1329f0e..ea34dd7 100644 --- a/java-ml-example/pom.xml +++ b/java-ml-example/pom.xml @@ -10,7 +10,6 @@ java-ml-example - 2.1.4-SNAPSHOT java-ml-example diff --git a/java-ml-example/src/main/java/org/rsultan/example/ExtendedIsolationForestExample.java b/java-ml-example/src/main/java/org/rsultan/example/ExtendedIsolationForestExample.java new file mode 100644 index 0000000..942404a --- /dev/null +++ b/java-ml-example/src/main/java/org/rsultan/example/ExtendedIsolationForestExample.java @@ -0,0 +1,42 @@ +package org.rsultan.example; + +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.factory.Nd4j; +import org.rsultan.core.clustering.ensemble.evaluation.TPRThresholdEvaluator; +import org.rsultan.core.clustering.ensemble.isolationforest.ExtendedIsolationForest; +import org.rsultan.dataframe.Dataframes; + +import java.io.IOException; + +public class ExtendedIsolationForestExample { + + /* + You can use the http dataset --> args[0] + You can use the http_reduced.csv dataset for testing --> args[1] + + threshold = 0.7000000000000001 + =========================== + TPR ║ FPR + =========================== + 0.9963817277250113 ║ 0.0031 + =========================== + */ + static { + Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE); + } + + public static void main(String[] args) throws IOException { + var df = Dataframes.csv(args[0], ",", "\"", true); + var testDf = Dataframes.csv(args[1], ",", "\"", true); + var trainTestDataframe = Dataframes.trainTest(df.getColumns()).setSplitValue(0.99); + + var model = new ExtendedIsolationForest(200, 2); + var evaluator = new TPRThresholdEvaluator("attack", "anomalies") + .setDesiredTPR(0.7) + .setExternalTestDataframe(testDf) + .setLearningRate(0.1); + Double threshold = evaluator.evaluate(model, trainTestDataframe); + System.out.println("threshold = " + threshold); + evaluator.showMetrics(); + } +} diff --git a/java-ml/pom.xml b/java-ml/pom.xml index 78d8a98..8364719 100644 --- a/java-ml/pom.xml +++ b/java-ml/pom.xml @@ -12,7 +12,6 @@ java-ml - 2.1.4-SNAPSHOT java-ml diff --git a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/domain/IsolationNode.java b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/domain/IsolationNode.java index 0cbdd9c..e64c721 100644 --- a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/domain/IsolationNode.java +++ b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/domain/IsolationNode.java @@ -5,23 +5,21 @@ import java.io.Serializable; import org.nd4j.linalg.api.ndarray.INDArray; -public record IsolationNode( - int feature, - double featureThreshold, +public record IsolationNode( + T nodeData, INDArray data, - IsolationNode left, - IsolationNode right + IsolationNode left, + IsolationNode right ) implements Serializable { public IsolationNode(INDArray data) { - this(-1, -1, data, null, null); + this(null, data, null, null); } public IsolationNode( - int feature, - double featureThreshold, - IsolationNode left, - IsolationNode right) { - this(feature, featureThreshold, null, left, right); + T nodeData, + IsolationNode left, + IsolationNode right) { + this(nodeData, null, left, right); } public boolean isLeaf() { diff --git a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/evaluation/TPRThresholdEvaluator.java b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/evaluation/TPRThresholdEvaluator.java index 9024390..de0a801 100644 --- a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/evaluation/TPRThresholdEvaluator.java +++ b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/evaluation/TPRThresholdEvaluator.java @@ -2,73 +2,90 @@ import org.rsultan.core.Evaluator; import org.rsultan.core.clustering.ensemble.isolationforest.IsolationForest; +import org.rsultan.dataframe.Dataframe; import org.rsultan.dataframe.Dataframes; import org.rsultan.dataframe.Row; import org.rsultan.dataframe.TrainTestDataframe; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class TPRThresholdEvaluator implements Evaluator { - private final String responseVariable; - private final String predictionColumn; - private double desiredTPR = 0.9; - private double learningRate = 0.01; - private double TPR = 0; - private double FPR = 0; + private static final Logger LOG = LoggerFactory.getLogger(TPRThresholdEvaluator.class); + private static final String IS_AN_ANOMALY = "isAnAnomaly"; + private final String responseVariable; + private final String predictionColumn; + private double desiredTPR = 0.9; + private double learningRate = 0.01; + private double TPR = 0; + private double FPR = 0; + private Dataframe externalTestDataframe; - public TPRThresholdEvaluator(String responseVariable, String predictionColumn) { - this.responseVariable = responseVariable; - this.predictionColumn = predictionColumn; - } + public TPRThresholdEvaluator(String responseVariable, String predictionColumn) { + this.responseVariable = responseVariable; + this.predictionColumn = predictionColumn; + } - @Override - public Double evaluate(IsolationForest trainable, TrainTestDataframe dataframe) { - var dfSplit = dataframe.shuffle().split(); - double threshold = 1; - while (threshold > 0 && TPR <= desiredTPR) { - threshold -= learningRate; - var trained = trainable.setAnomalyThreshold(threshold) - .train(dfSplit.train().mapWithout(responseVariable)); - var responses = dfSplit.test().get(responseVariable); - var predictions = trained.predict(dfSplit.test().mapWithout(responseVariable)) - .get(predictionColumn); + @Override + public Double evaluate(IsolationForest trainable, TrainTestDataframe dataframe) { + var dfSplit = dataframe.shuffle().split(); + var trained = trainable.setUseAnomalyScoresOnly(true) + .train(dfSplit.train().mapWithout(responseVariable)); + final Dataframe testDf = externalTestDataframe != null ? externalTestDataframe : dfSplit.test(); + final Dataframe predict = trained.predict(testDf.mapWithout(responseVariable)); - double truePositives = 0; - double trueNegatives = 0; - double falsePositives = 0; - double falseNegative = 0; + double threshold = 1; + while (threshold > 0 && TPR <= desiredTPR) { + LOG.info("Evaluating isolation forest with threshold {}", threshold); + threshold -= learningRate; + final double finalThreshold = threshold; + var responses = testDf.get(responseVariable); + var predictions = predict.map(IS_AN_ANOMALY, + (Double score) -> (score >= finalThreshold ? 1L : 0L), predictionColumn) + .get(IS_AN_ANOMALY); - for (int i = 0; i < responses.size(); i++) { - var response = responses.get(i); - var prediction = predictions.get(i); - truePositives += response == 1L && prediction == 1L ? 1L : 0L; - trueNegatives += response == 0L && prediction == 0L ? 1L : 0L; - falsePositives += response == 0L && prediction == 1L ? 1L : 0L; - falseNegative += response == 1L && prediction == 0L ? 1L : 0L; - } - TPR = truePositives / (truePositives + falseNegative); - TPR = Double.isNaN(TPR) ? 0 : TPR; + double truePositives = 0; + double trueNegatives = 0; + double falsePositives = 0; + double falseNegative = 0; - FPR = falsePositives / (falsePositives + trueNegatives); - FPR = Double.isNaN(FPR) ? 0L : FPR; - } - if (threshold < 0) { - throw new IllegalArgumentException("Cannot have desired TPR"); + for (int i = 0; i < responses.size(); i++) { + var response = responses.get(i); + var prediction = predictions.get(i); + truePositives += response == 1L && prediction == 1L ? 1L : 0L; + trueNegatives += response == 0L && prediction == 0L ? 1L : 0L; + falsePositives += response == 0L && prediction == 1L ? 1L : 0L; + falseNegative += response == 1L && prediction == 0L ? 1L : 0L; + } + TPR = truePositives / (truePositives + falseNegative); + TPR = Double.isNaN(TPR) ? 0 : TPR; + + FPR = falsePositives / (falsePositives + trueNegatives); + FPR = Double.isNaN(FPR) ? 0L : FPR; + } + if (threshold < 0) { + throw new IllegalArgumentException("Cannot have desired TPR"); + } + + return threshold; } - return threshold; - } + public TPRThresholdEvaluator setDesiredTPR(double desiredTPR) { + this.desiredTPR = desiredTPR; + return this; + } - public TPRThresholdEvaluator setDesiredTPR(double desiredTPR) { - this.desiredTPR = desiredTPR; - return this; - } + public TPRThresholdEvaluator setLearningRate(double learningRate) { + this.learningRate = learningRate; + return this; + } - public TPRThresholdEvaluator setLearningRate(double learningRate) { - this.learningRate = learningRate; - return this; - } + public TPRThresholdEvaluator setExternalTestDataframe(Dataframe dataframe) { + this.externalTestDataframe = dataframe; + return this; + } - public void showMetrics() { - Dataframes.create(new String[]{"TPR", "FPR"}, new Row(TPR, FPR)).tail(); - } + public void showMetrics() { + Dataframes.create(new String[]{"TPR", "FPR"}, new Row(TPR, FPR)).tail(); + } } diff --git a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/ExtendedIsolationForest.java b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/ExtendedIsolationForest.java new file mode 100644 index 0000000..e161fe1 --- /dev/null +++ b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/ExtendedIsolationForest.java @@ -0,0 +1,56 @@ +package org.rsultan.core.clustering.ensemble.isolationforest; + +import org.apache.commons.lang3.RandomUtils; +import org.nd4j.linalg.indexing.NDArrayIndex; +import org.rsultan.core.clustering.ensemble.isolationforest.tree.ExtendedIsolationTree; +import org.rsultan.dataframe.Dataframe; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.stream.LongStream; + +import static java.lang.String.format; +import static java.util.stream.IntStream.range; +import static org.rsultan.core.clustering.ensemble.isolationforest.utils.ScoreUtils.averagePathLength; + +public class ExtendedIsolationForest extends IsolationForest { + + private static final Logger LOG = LoggerFactory.getLogger(ExtendedIsolationForest.class); + + private final int extensionLevel; + + public ExtendedIsolationForest(int nbTrees, int extensionLevel) { + super(nbTrees); + this.extensionLevel = extensionLevel; + } + + public ExtendedIsolationForest setSampleSize(int sampleSize) { + super.setSampleSize(sampleSize); + return this; + } + + public ExtendedIsolationForest setAnomalyThreshold(double anomalyThreshold) { + super.setAnomalyThreshold(anomalyThreshold); + return this; + } + + @Override + public ExtendedIsolationForest train(Dataframe dataframe) { + var matrix = dataframe.toMatrix(); + if (extensionLevel > matrix.columns() - 1 || extensionLevel < 0) { + throw new IllegalArgumentException(format("extensionLevel must be between 0 and %d, current is [%d]", matrix.columns() - 1, extensionLevel)); + } + int realSample = sampleSize >= matrix.rows() ? sampleSize / 10 : sampleSize; + int treeDepth = (int) Math.ceil(Math.log(realSample) / Math.log(2)); + isolationTrees = range(0, nbTrees).parallel() + .peek(i -> LOG.info("Tree number: {}", i)) + .mapToObj(i -> LongStream.range(0, realSample) + .map(idx -> RandomUtils.nextLong(0, matrix.rows())) + .toArray()) + .map(NDArrayIndex::indices) + .map(matrix::get) + .map(m -> new ExtendedIsolationTree(treeDepth, extensionLevel).train(m)) + .toList(); + return this; + } +} diff --git a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/IsolationForest.java b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/IsolationForest.java index 9d9fa2a..69b1bf6 100644 --- a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/IsolationForest.java +++ b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/IsolationForest.java @@ -9,10 +9,11 @@ import org.apache.commons.lang3.RandomUtils; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.ops.transforms.Transforms; +import org.rsultan.core.RawTrainable; import org.rsultan.core.Trainable; +import org.rsultan.core.clustering.ensemble.isolationforest.tree.IsolationTree; import org.rsultan.dataframe.Column; import org.rsultan.dataframe.Dataframe; import org.slf4j.Logger; @@ -23,8 +24,9 @@ public class IsolationForest implements Trainable { private static final Logger LOG = LoggerFactory.getLogger(IsolationTree.class); protected final int nbTrees; protected double anomalyThreshold = 0.5; - protected List isolationTrees; + protected List> isolationTrees; protected int sampleSize = 256; + private boolean useAnomalyScoresOnly; public IsolationForest(int nbTrees) { this.nbTrees = nbTrees; @@ -40,6 +42,11 @@ public IsolationForest setAnomalyThreshold(double anomalyThreshold) { return this; } + public IsolationForest setUseAnomalyScoresOnly(boolean useAnomalyScoresOnly) { + this.useAnomalyScoresOnly = useAnomalyScoresOnly; + return this; + } + @Override public IsolationForest train(Dataframe dataframe) { var matrix = dataframe.toMatrix(); @@ -61,9 +68,11 @@ public IsolationForest train(Dataframe dataframe) { public Dataframe predict(Dataframe dataframe) { var matrix = dataframe.toMatrix(); var anomalyScores = computeAnomalyScore(matrix); - var isAnomaly = new Column<>("anomalies", DoubleStream.of( - anomalyScores.toDoubleVector() - ).mapToObj(score -> score >= anomalyThreshold ? 1L : 0L).toArray()); + final DoubleStream doubleStream = DoubleStream.of( + anomalyScores.toDoubleVector() + ); + var isAnomaly = new Column<>("anomalies", + useAnomalyScoresOnly ? doubleStream.boxed().toArray() : doubleStream.mapToObj(score -> score >= anomalyThreshold ? 1L : 0L).toArray()); return dataframe.addColumn(isAnomaly); } diff --git a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/IsolationTree.java b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/IsolationTree.java deleted file mode 100644 index b1801c8..0000000 --- a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/IsolationTree.java +++ /dev/null @@ -1,93 +0,0 @@ -package org.rsultan.core.clustering.ensemble.isolationforest; - -import static java.util.stream.LongStream.range; -import static org.apache.commons.lang3.RandomUtils.nextDouble; -import static org.apache.commons.lang3.RandomUtils.nextInt; -import static org.rsultan.core.clustering.ensemble.isolationforest.utils.ScoreUtils.averagePathLength; - -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.indexing.NDArrayIndex; -import org.rsultan.core.RawTrainable; -import org.rsultan.core.clustering.ensemble.domain.IsolationNode; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class IsolationTree implements RawTrainable { - - public static final Logger LOG = LoggerFactory.getLogger(IsolationTree.class); - private final int treeDepthLimit; - private IsolationNode tree; - - public IsolationTree(int treeDepthLimit) { - this.treeDepthLimit = treeDepthLimit; - } - - @Override - public IsolationTree train(INDArray matrix) { - this.tree = buildTree(matrix, treeDepthLimit); - return this; - } - - private IsolationNode buildTree(INDArray matrix, int currentDepth) { - LOG.info("Tree Depth {}", currentDepth); - if (currentDepth <= 0 || matrix.rows() <= 2) { - return new IsolationNode(matrix); - } - int numberOfFeatures = matrix.columns(); - int splitFeature = nextInt(0, numberOfFeatures); - var feature = matrix.getColumn(splitFeature); - double startInclusive = feature.minNumber().doubleValue(); - double endInclusive = feature.maxNumber().doubleValue(); - - double valueSplit = - getValueSplit(startInclusive, endInclusive); - - var leftIndices = range(0, feature.columns()).parallel() - .filter(idx -> feature.getDouble(idx) < valueSplit) - .toArray(); - var left = getVector(matrix, leftIndices); - - var rightIndices = range(0, feature.columns()).parallel() - .filter(idx -> feature.getDouble(idx) > valueSplit) - .toArray(); - var right = getVector(matrix, rightIndices); - - return new IsolationNode( - splitFeature, - valueSplit, - buildTree(left, currentDepth - 1), - buildTree(right, currentDepth - 1) - ); - } - - private INDArray getVector(INDArray matrix, long[] indices) { - return matrix.get(NDArrayIndex.indices(indices)); - } - - private double getValueSplit(double startInclusive, double endInclusive) { - if (startInclusive < 0 && endInclusive < 0) { - return -nextDouble(endInclusive * -1, startInclusive * -1); - } else if (startInclusive < 0 && endInclusive >= 0) { - return nextDouble(0, endInclusive + startInclusive * -1) + startInclusive; - } - return nextDouble(startInclusive, endInclusive); - } - - @Override - public INDArray predict(INDArray matrix) { - var pathLengths = Nd4j.zeros(1, matrix.rows()); - for (int i = 0; i < matrix.rows(); i++) { - var row = matrix.getRow(i); - var node = tree; - int length = 0; - while (!node.isLeaf()) { - node = row.getDouble(node.feature()) < node.featureThreshold() ? node.left() : node.right(); - length++; - } - int leafSize = node.data().rows(); - pathLengths.put(0, i, length + averagePathLength(leafSize)); - } - return pathLengths; - } -} diff --git a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/tree/AbstractTree.java b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/tree/AbstractTree.java new file mode 100644 index 0000000..7c54b0c --- /dev/null +++ b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/tree/AbstractTree.java @@ -0,0 +1,58 @@ +package org.rsultan.core.clustering.ensemble.isolationforest.tree; + +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.NDArrayIndex; +import org.rsultan.core.clustering.ensemble.domain.IsolationNode; +import org.rsultan.core.clustering.ensemble.isolationforest.tree.IsolationTree.Feature; + +import java.util.function.LongPredicate; + +import static java.util.stream.LongStream.range; +import static org.rsultan.core.clustering.ensemble.isolationforest.utils.ScoreUtils.averagePathLength; + +public abstract class AbstractTree { + + protected final int treeDepthLimit; + protected IsolationNode tree; + + public AbstractTree(int treeDepthLimit) { + this.treeDepthLimit = treeDepthLimit; + } + + protected IsolationNode buildTree(INDArray X, int currentDepth){ + if (currentDepth <= 0 || X.rows() <= 2) { + return new IsolationNode<>(X); + } + return buildNode(X, currentDepth); + } + + protected abstract IsolationNode buildNode(INDArray X, int currentDepth); + + protected abstract boolean chooseLeftNode(INDArray row, NODE_DATA slope); + + protected INDArray getVector(INDArray matrix, long[] indices) { + return matrix.get(NDArrayIndex.indices(indices)); + } + + protected long[] getIndices(INDArray ndArray, LongPredicate predicate){ + return range(0, ndArray.rows()).parallel().filter(predicate).toArray(); + } + + public INDArray predict(INDArray matrix) { + var pathLengths = Nd4j.zeros(1, matrix.rows()); + for (int i = 0; i < matrix.rows(); i++) { + var row = matrix.getRow(i); + var node = tree; + int length = 0; + while (!node.isLeaf()) { + var slope = node.nodeData(); + node = chooseLeftNode(row, slope) ? node.left() : node.right(); + length++; + } + int leafSize = node.data().rows(); + pathLengths.put(0, i, length + averagePathLength(leafSize)); + } + return pathLengths; + } +} diff --git a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/tree/ExtendedIsolationTree.java b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/tree/ExtendedIsolationTree.java new file mode 100644 index 0000000..21b8ed8 --- /dev/null +++ b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/tree/ExtendedIsolationTree.java @@ -0,0 +1,87 @@ +package org.rsultan.core.clustering.ensemble.isolationforest.tree; + +import org.apache.commons.math3.distribution.NormalDistribution; +import org.nd4j.common.util.MathUtils; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.rsultan.core.RawTrainable; +import org.rsultan.core.clustering.ensemble.domain.IsolationNode; +import org.rsultan.core.clustering.ensemble.isolationforest.tree.ExtendedIsolationTree.Slope; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + + +import static java.util.stream.LongStream.range; + +public class ExtendedIsolationTree extends AbstractTree implements RawTrainable { + + public static final Logger LOG = LoggerFactory.getLogger(ExtendedIsolationTree.class); + private final int extensionLevel; + + public ExtendedIsolationTree(int treeDepthLimit, int extensionLevel) { + super(treeDepthLimit); + this.extensionLevel = extensionLevel; + } + + @Override + public ExtendedIsolationTree train(INDArray matrix) { + this.tree = buildTree(matrix, treeDepthLimit); + return this; + } + + @Override + protected IsolationNode buildNode(INDArray X, int currentDepth) { + LOG.info("Tree Depth {}", currentDepth); + if (currentDepth <= 0 || X.rows() <= 2) { + return new IsolationNode<>(X); + } + int numberOfFeatures = X.columns(); + var mins = X.min(true, 0); + var maxs = X.max(true, 0); + var n = getNormalVector(mins.shape()); + var p = getIntercept(numberOfFeatures, mins, maxs); + var w = X.sub(p).mmul(n); + + var left = getVector(X, getIndices(w, idx -> w.getDouble(idx) < 0)); + var right = getVector(X, getIndices(w, idx -> w.getDouble(idx) >= 0)); + + return new IsolationNode<>( + new Slope(n, p), + buildTree(left, currentDepth - 1), + buildTree(right, currentDepth - 1) + ); + } + + private INDArray getIntercept(int numberOfFeatures, INDArray mins, INDArray maxs) { + var p = Nd4j.zeros(mins.shape()); + for (int i = 0; i < numberOfFeatures; i++) { + p.put(0, i, MathUtils.randomDoubleBetween(mins.getDouble(0, i), maxs.getDouble(0, i))); + } + return p; + } + + private INDArray getNormalVector(long[] shape) { + var distribution = new NormalDistribution(); + var n = Nd4j.zeros(shape); + for (int i = 0; i < n.columns(); i++) { + n.put(0, i, distribution.sample()); + } + final int numSamples = n.columns() - this.extensionLevel - 1; + if (numSamples > 0) { + final INDArray arange = Nd4j.create(range(0, n.columns()).boxed().toList()); + int[] indices = Nd4j.choice(arange, Nd4j.rand(n.columns()), numSamples).toIntVector(); + for (int index : indices) { + n.putScalar(index, 0.0D); + } + } + return n.transpose(); + } + + @Override + protected boolean chooseLeftNode(INDArray row, Slope slope) { + return row.sub(slope.p).mmul(slope.n).getDouble(0) < 0; + } + + static record Slope(INDArray n, INDArray p){} + +} diff --git a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/tree/IsolationTree.java b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/tree/IsolationTree.java new file mode 100644 index 0000000..bc0e9d3 --- /dev/null +++ b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/tree/IsolationTree.java @@ -0,0 +1,54 @@ +package org.rsultan.core.clustering.ensemble.isolationforest.tree; + +import static org.apache.commons.lang3.RandomUtils.nextInt; + +import org.nd4j.common.util.MathUtils; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.rsultan.core.RawTrainable; +import org.rsultan.core.clustering.ensemble.domain.IsolationNode; +import org.rsultan.core.clustering.ensemble.isolationforest.tree.IsolationTree.Feature; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class IsolationTree extends AbstractTree implements RawTrainable { + + public static final Logger LOG = LoggerFactory.getLogger(IsolationTree.class); + + public IsolationTree(int treeDepthLimit) { + super(treeDepthLimit); + } + + @Override + public IsolationTree train(INDArray matrix) { + this.tree = buildTree(matrix, treeDepthLimit); + return this; + } + + @Override + protected IsolationNode buildNode(INDArray X, int currentDepth) { + LOG.info("Tree Depth {}", currentDepth); + int numberOfFeatures = X.columns(); + int splitFeature = nextInt(0, numberOfFeatures); + var feature = X.getColumn(splitFeature); + double startInclusive = feature.minNumber().doubleValue(); + double endInclusive = feature.maxNumber().doubleValue(); + double valueSplit = MathUtils.randomDoubleBetween(startInclusive, endInclusive); + + var left = getVector(X, getIndices(feature, idx -> feature.getDouble(idx) < valueSplit)); + var right = getVector(X, getIndices(feature, idx -> feature.getDouble(idx) >= valueSplit)); + + return new IsolationNode<>( + new Feature(splitFeature, valueSplit), + buildTree(left, currentDepth - 1), + buildTree(right, currentDepth - 1) + ); + } + + @Override + protected boolean chooseLeftNode(INDArray row, Feature feature) { + return row.getDouble(feature.feature) < feature.threshold ; + } + + static record Feature(int feature, double threshold){} + +} diff --git a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/utils/ScoreUtils.java b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/utils/ScoreUtils.java index fd4f95d..0d5a5e2 100644 --- a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/utils/ScoreUtils.java +++ b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/utils/ScoreUtils.java @@ -17,4 +17,5 @@ public static double averagePathLength(double leafSize) { private static double harmonicNumber(double leafSize) { return Math.log(leafSize - 1) + EULER_CONSTANT; } + } diff --git a/java-ml/src/test/java/org/rsultan/core/clustering/IsolationForestTest.java b/java-ml/src/test/java/org/rsultan/core/clustering/IsolationForestTest.java index c57f9f4..33fbe93 100644 --- a/java-ml/src/test/java/org/rsultan/core/clustering/IsolationForestTest.java +++ b/java-ml/src/test/java/org/rsultan/core/clustering/IsolationForestTest.java @@ -11,6 +11,7 @@ import java.util.List; import org.junit.jupiter.api.Test; import org.rsultan.core.clustering.ensemble.evaluation.TPRThresholdEvaluator; +import org.rsultan.core.clustering.ensemble.isolationforest.ExtendedIsolationForest; import org.rsultan.core.clustering.ensemble.isolationforest.IsolationForest; import org.rsultan.dataframe.Dataframes; import org.rsultan.dataframe.Row; @@ -55,5 +56,35 @@ public void should_evaluate_tpr() { assertThat(threshold).isGreaterThan(0.4); } + @Test + public void must_perform_extended_isolation() { + var df = Dataframes.create(new String[]{"x", "y", "response"}, rows); + var predict = new ExtendedIsolationForest(10, 1).setAnomalyThreshold(0.56).setSampleSize(15).train(df).predict(df); + int anomalies = predict.filter("anomalies", obj -> obj.equals(1L)).getRowSize(); + int nonAnomalies = predict.filter("anomalies", obj -> obj.equals(0L)).getRowSize(); + + assertThat(anomalies).isBetween(1, predict.getRowSize()); + assertThat(nonAnomalies).isEqualTo(predict.getRowSize() - anomalies); + } + + @Test + public void should_evaluate_extended_tpr() { + var df = Dataframes.trainTest(new String[]{"x", "y", "response"}, rows); + var model = new ExtendedIsolationForest(10, 1).setSampleSize(15); + var evaluator = new TPRThresholdEvaluator("response", "anomalies").setDesiredTPR(0.9).setLearningRate(0.01); + var threshold = evaluator.evaluate(model, df); + evaluator.showMetrics(); + assertThat(threshold).isGreaterThan(0.4); + } + @Test + public void must_perform_isolation_and_retrieve_scores() { + var df = Dataframes.create(new String[]{"x", "y", "response"}, rows); + var predict = new IsolationForest(10).setAnomalyThreshold(0.56).setSampleSize(15).setUseAnomalyScoresOnly(true).train(df).predict(df); + int anomalies = predict.filter("anomalies", (Number obj) -> obj.doubleValue() > 0.56).getRowSize(); + int nonAnomalies = predict.filter("anomalies", (Number obj) -> obj.doubleValue() <= 0.56).getRowSize(); + + assertThat(anomalies).isBetween(1, predict.getRowSize()); + assertThat(nonAnomalies).isEqualTo(predict.getRowSize() - anomalies); + } } diff --git a/version.properties b/version.properties index 82f2bfb..1edc995 100755 --- a/version.properties +++ b/version.properties @@ -1 +1 @@ -version.next=2.1.4 +version.next=2.2.0 \ No newline at end of file