From 9483aa58a9a5f177a8591645c3d8aca1f53e1b7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Sultan?= Date: Sun, 17 Oct 2021 15:16:02 +0200 Subject: [PATCH 1/4] chore: provide sources in build (#76) --- pom.xml | 266 +++++++++++++++++++++++++++++--------------------------- 1 file changed, 139 insertions(+), 127 deletions(-) diff --git a/pom.xml b/pom.xml index 93a04e5..807b1f5 100644 --- a/pom.xml +++ b/pom.xml @@ -1,128 +1,140 @@ - - - - 4.0.0 - org.rsultan - java-ml-parent - 2.1.3-SNAPSHOT - java-ml-parent - pom - - - - release.archiva.rsultan.org - https://archiva.rsultan.org/repository/internal - - - - - 3.9.1 - 1.0 - 16 - 5.7.0 - ${java.version} - ${java.version} - UTF-8 - 2.0.0-alpha1 - 2.9.0 - - - - java-ml - java-ml-example - - - - - org.slf4j - slf4j-log4j12 - ${slf4j-log4j12.version} - - - org.slf4j - slf4j-api - ${slf4j-log4j12.version} - - - org.junit.jupiter - junit-jupiter - ${junit-jupiter.version} - test - - - org.assertj - assertj-core - ${assertj-core.version} - test - - - org.junit.jupiter - junit-jupiter-params - 5.7.0 - compile - - - - - java-ml - - - - maven-clean-plugin - 3.1.0 - - - maven-resources-plugin - 3.0.2 - - - maven-compiler-plugin - 3.8.0 - - - maven-surefire-plugin - 3.0.0-M5 - - - 1 - 0 - - - - - maven-jar-plugin - 3.0.2 - - - maven-install-plugin - 2.5.2 - - - maven-deploy-plugin - 2.8.2 - - - maven-site-plugin - 3.7.1 - - - maven-project-info-reports-plugin - 3.0.0 - - - - - - org.apache.maven.plugins - maven-compiler-plugin - - ${java.version} - ${java.version} - - - - - + + + + 4.0.0 + org.rsultan + java-ml-parent + 2.1.3-SNAPSHOT + java-ml-parent + pom + + + + release.archiva.rsultan.org + https://archiva.rsultan.org/repository/internal + + + + + 3.9.1 + 1.0 + 16 + 5.7.0 + ${java.version} + ${java.version} + UTF-8 + 2.0.0-alpha1 + 2.9.0 + + + + java-ml + java-ml-example + + + + + org.slf4j + slf4j-log4j12 + ${slf4j-log4j12.version} + + + org.slf4j + slf4j-api + ${slf4j-log4j12.version} + + + org.junit.jupiter + junit-jupiter + ${junit-jupiter.version} + test + + + org.assertj + assertj-core + ${assertj-core.version} + test + + + org.junit.jupiter + junit-jupiter-params + 5.7.0 + compile + + + + + java-ml + + + + maven-clean-plugin + 3.1.0 + + + maven-resources-plugin + 3.0.2 + + + maven-compiler-plugin + 3.8.0 + + + maven-surefire-plugin + 3.0.0-M5 + + + 1 + 0 + + + + + maven-jar-plugin + 3.0.2 + + + maven-install-plugin + 2.5.2 + + + maven-deploy-plugin + 2.8.2 + + + maven-site-plugin + 3.7.1 + + + maven-project-info-reports-plugin + 3.0.0 + + + + + + org.apache.maven.plugins + maven-compiler-plugin + + ${java.version} + ${java.version} + + + + org.apache.maven.plugins + maven-source-plugin + + + attach-sources + + jar + + + + + + + \ No newline at end of file From 48b790f5270fd3451ca1c2c728f3fc752ce3b9f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Sultan?= Date: Sun, 17 Oct 2021 15:27:09 +0200 Subject: [PATCH 2/4] refactor: protected fields for isolation forests (#77) --- .../ensemble/isolationforest/IsolationForest.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/IsolationForest.java b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/IsolationForest.java index 6be020a..b4b9bf7 100644 --- a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/IsolationForest.java +++ b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/IsolationForest.java @@ -18,10 +18,10 @@ public class IsolationForest implements Trainable { private static final Logger LOG = LoggerFactory.getLogger(IsolationTree.class); - private final int nbTrees; - private double anomalyThreshold = 0.5; - private List isolationTrees; - private int sampleSize = 256; + protected final int nbTrees; + protected double anomalyThreshold = 0.5; + protected List isolationTrees; + protected int sampleSize = 256; public IsolationForest(int nbTrees) { this.nbTrees = nbTrees; From a39031dc4daaa05acb00fe24d19daa4403c619e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Sultan?= Date: Sun, 31 Oct 2021 18:21:52 +0100 Subject: [PATCH 3/4] refactor: make PCA raw trainable (#78) --- .../isolationforest/IsolationForest.java | 11 +++- .../isolationforest/IsolationTree.java | 11 +++- .../dimred/PrincipalComponentAnalysis.java | 56 ++++++++++++------- 3 files changed, 53 insertions(+), 25 deletions(-) diff --git a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/IsolationForest.java b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/IsolationForest.java index b4b9bf7..9d9fa2a 100644 --- a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/IsolationForest.java +++ b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/IsolationForest.java @@ -5,9 +5,12 @@ import java.util.List; import java.util.stream.DoubleStream; +import java.util.stream.LongStream; import org.apache.commons.lang3.RandomUtils; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.INDArrayIndex; +import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.ops.transforms.Transforms; import org.rsultan.core.Trainable; import org.rsultan.dataframe.Column; @@ -44,9 +47,11 @@ public IsolationForest train(Dataframe dataframe) { int treeDepth = (int) Math.ceil(Math.log(realSample) / Math.log(2)); isolationTrees = range(0, nbTrees).parallel() .peek(i -> LOG.info("Tree number: {}", i)) - .mapToObj(i -> range(0, realSample) - .map(idx -> RandomUtils.nextInt(0, matrix.rows())) - .toArray()).map(matrix::getRows) + .mapToObj(i -> LongStream.range(0, realSample) + .map(idx -> RandomUtils.nextLong(0, matrix.rows())) + .toArray()) + .map(NDArrayIndex::indices) + .map(matrix::get) .map(m -> new IsolationTree(treeDepth).train(m)) .toList(); return this; diff --git a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/IsolationTree.java b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/IsolationTree.java index 4882ec6..b1801c8 100644 --- a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/IsolationTree.java +++ b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/IsolationTree.java @@ -1,12 +1,13 @@ package org.rsultan.core.clustering.ensemble.isolationforest; -import static java.util.stream.IntStream.range; +import static java.util.stream.LongStream.range; import static org.apache.commons.lang3.RandomUtils.nextDouble; import static org.apache.commons.lang3.RandomUtils.nextInt; import static org.rsultan.core.clustering.ensemble.isolationforest.utils.ScoreUtils.averagePathLength; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.NDArrayIndex; import org.rsultan.core.RawTrainable; import org.rsultan.core.clustering.ensemble.domain.IsolationNode; import org.slf4j.Logger; @@ -45,12 +46,12 @@ private IsolationNode buildTree(INDArray matrix, int currentDepth) { var leftIndices = range(0, feature.columns()).parallel() .filter(idx -> feature.getDouble(idx) < valueSplit) .toArray(); - var left = matrix.getRows(leftIndices); + var left = getVector(matrix, leftIndices); var rightIndices = range(0, feature.columns()).parallel() .filter(idx -> feature.getDouble(idx) > valueSplit) .toArray(); - var right = matrix.getRows(rightIndices); + var right = getVector(matrix, rightIndices); return new IsolationNode( splitFeature, @@ -60,6 +61,10 @@ private IsolationNode buildTree(INDArray matrix, int currentDepth) { ); } + private INDArray getVector(INDArray matrix, long[] indices) { + return matrix.get(NDArrayIndex.indices(indices)); + } + private double getValueSplit(double startInclusive, double endInclusive) { if (startInclusive < 0 && endInclusive < 0) { return -nextDouble(endInclusive * -1, startInclusive * -1); diff --git a/java-ml/src/main/java/org/rsultan/core/dimred/PrincipalComponentAnalysis.java b/java-ml/src/main/java/org/rsultan/core/dimred/PrincipalComponentAnalysis.java index ab47a04..8b183d8 100644 --- a/java-ml/src/main/java/org/rsultan/core/dimred/PrincipalComponentAnalysis.java +++ b/java-ml/src/main/java/org/rsultan/core/dimred/PrincipalComponentAnalysis.java @@ -6,8 +6,8 @@ import static org.nd4j.linalg.eigen.Eigen.symmetricGeneralizedEigenvalues; import java.util.List; -import org.nd4j.common.util.ArrayUtil; import org.nd4j.linalg.api.ndarray.INDArray; +import org.rsultan.core.RawTrainable; import org.rsultan.core.Trainable; import org.rsultan.dataframe.Column; import org.rsultan.dataframe.Dataframe; @@ -16,7 +16,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class PrincipalComponentAnalysis implements Trainable { +public class PrincipalComponentAnalysis implements + Trainable, RawTrainable { private static final Logger LOG = LoggerFactory.getLogger(PrincipalComponentAnalysis.class); @@ -35,26 +36,14 @@ public PrincipalComponentAnalysis(int numberOfComponents) { public PrincipalComponentAnalysis train(Dataframe dataframe) { var X = dataframe.mapWithout(responseVariable).toMatrix(); this.responseVariableData = dataframe.get(responseVariable); - int components = Math.min(numberOfComponent, X.columns()); - Xmean = X.mean(0); - X = X.sub(Xmean); - LOG.info("computing covariance matrix"); - LOG.info("computing eighenvectors"); - eighenVectors = Matrices.covariance(X); - var eighenValuesArgSort = argsort( - symmetricGeneralizedEigenvalues(eighenVectors, true).toIntVector(), false - ); - eighenVectors = eighenVectors - .getColumns(eighenValuesArgSort) - .getColumns(range(0, components).toArray()); - return this; + return this.train(X); } @Override public Dataframe predict(Dataframe dataframe) { - var Xpredict = dataframe.mapWithout(responseVariable).toMatrix().sub(Xmean); + var Xpredict = dataframe.mapWithout(responseVariable).toMatrix(); LOG.info("computing predictions"); - predictions = eighenVectors.transpose().mmul(Xpredict.transpose()).transpose(); + this.predict(Xpredict); List> columns = range(0, predictions.columns()) .mapToObj(colIdx -> new Column<>("c" + colIdx, range(0, predictions.rows()) .mapToObj(rowIdx -> predictions.getDouble(rowIdx, colIdx)) @@ -64,9 +53,27 @@ public Dataframe predict(Dataframe dataframe) { return Dataframes.create(columns.toArray(Column[]::new)); } + @Override + public PrincipalComponentAnalysis train(INDArray X) { + int components = Math.min(numberOfComponent, X.columns()); + Xmean = X.mean(0); + X = X.sub(Xmean); + LOG.info("computing covariance matrix"); + eighenVectors = Matrices.covariance(X); + LOG.info("computing eighenvectors"); + var eighenValuesArgSort = argsort( + symmetricGeneralizedEigenvalues(eighenVectors, true).toIntVector(), false + ); + eighenVectors = eighenVectors + .getColumns(eighenValuesArgSort) + .getColumns(range(0, components).toArray()); + LOG.info("eighenvectors computed"); + return this; + } + public Dataframe reconstruct() { - LOG.info("reconstructing original matrix"); - var XreBuilt = predictions.mmul(eighenVectors.transpose()).add(Xmean); + LOG.info("trying to reconstruct original matrix"); + var XreBuilt = rawReconstruct(); List> columns = range(0, XreBuilt.columns()) .mapToObj(colIdx -> new Column<>("c" + colIdx, range(0, XreBuilt.rows()) .mapToObj(rowIdx -> XreBuilt.getDouble(rowIdx, colIdx)) @@ -76,8 +83,19 @@ public Dataframe reconstruct() { return Dataframes.create(columns.toArray(Column[]::new)); } + public INDArray rawReconstruct() { + return predictions.mmul(eighenVectors.transpose()).add(Xmean); + } + + @Override + public INDArray predict(INDArray matrix) { + predictions = eighenVectors.transpose().mmul(matrix.sub(Xmean).transpose()).transpose(); + return predictions; + } + public PrincipalComponentAnalysis setResponseVariable(String responseVariable) { this.responseVariable = responseVariable; return this; } + } From 84641641976967fd6ac9450ff1b3adbe3839e9e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Sultan?= Date: Wed, 8 Dec 2021 21:13:09 +0100 Subject: [PATCH 4/4] feat: extended isolation forests (#80) --- .../ExtendedIsolationForestExample.java | 42 ++++++ .../ensemble/domain/IsolationNode.java | 20 ++- .../evaluation/TPRThresholdEvaluator.java | 123 ++++++++++-------- .../ExtendedIsolationForest.java | 56 ++++++++ .../isolationforest/IsolationForest.java | 19 ++- .../isolationforest/IsolationTree.java | 93 ------------- .../isolationforest/tree/AbstractTree.java | 58 +++++++++ .../tree/ExtendedIsolationTree.java | 87 +++++++++++++ .../isolationforest/tree/IsolationTree.java | 54 ++++++++ .../isolationforest/utils/ScoreUtils.java | 1 + .../core/clustering/IsolationForestTest.java | 31 +++++ version.properties | 2 +- 12 files changed, 423 insertions(+), 163 deletions(-) create mode 100644 java-ml-example/src/main/java/org/rsultan/example/ExtendedIsolationForestExample.java create mode 100644 java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/ExtendedIsolationForest.java delete mode 100644 java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/IsolationTree.java create mode 100644 java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/tree/AbstractTree.java create mode 100644 java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/tree/ExtendedIsolationTree.java create mode 100644 java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/tree/IsolationTree.java diff --git a/java-ml-example/src/main/java/org/rsultan/example/ExtendedIsolationForestExample.java b/java-ml-example/src/main/java/org/rsultan/example/ExtendedIsolationForestExample.java new file mode 100644 index 0000000..942404a --- /dev/null +++ b/java-ml-example/src/main/java/org/rsultan/example/ExtendedIsolationForestExample.java @@ -0,0 +1,42 @@ +package org.rsultan.example; + +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.factory.Nd4j; +import org.rsultan.core.clustering.ensemble.evaluation.TPRThresholdEvaluator; +import org.rsultan.core.clustering.ensemble.isolationforest.ExtendedIsolationForest; +import org.rsultan.dataframe.Dataframes; + +import java.io.IOException; + +public class ExtendedIsolationForestExample { + + /* + You can use the http dataset --> args[0] + You can use the http_reduced.csv dataset for testing --> args[1] + + threshold = 0.7000000000000001 + =========================== + TPR ║ FPR + =========================== + 0.9963817277250113 ║ 0.0031 + =========================== + */ + static { + Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE); + } + + public static void main(String[] args) throws IOException { + var df = Dataframes.csv(args[0], ",", "\"", true); + var testDf = Dataframes.csv(args[1], ",", "\"", true); + var trainTestDataframe = Dataframes.trainTest(df.getColumns()).setSplitValue(0.99); + + var model = new ExtendedIsolationForest(200, 2); + var evaluator = new TPRThresholdEvaluator("attack", "anomalies") + .setDesiredTPR(0.7) + .setExternalTestDataframe(testDf) + .setLearningRate(0.1); + Double threshold = evaluator.evaluate(model, trainTestDataframe); + System.out.println("threshold = " + threshold); + evaluator.showMetrics(); + } +} diff --git a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/domain/IsolationNode.java b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/domain/IsolationNode.java index 0cbdd9c..e64c721 100644 --- a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/domain/IsolationNode.java +++ b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/domain/IsolationNode.java @@ -5,23 +5,21 @@ import java.io.Serializable; import org.nd4j.linalg.api.ndarray.INDArray; -public record IsolationNode( - int feature, - double featureThreshold, +public record IsolationNode( + T nodeData, INDArray data, - IsolationNode left, - IsolationNode right + IsolationNode left, + IsolationNode right ) implements Serializable { public IsolationNode(INDArray data) { - this(-1, -1, data, null, null); + this(null, data, null, null); } public IsolationNode( - int feature, - double featureThreshold, - IsolationNode left, - IsolationNode right) { - this(feature, featureThreshold, null, left, right); + T nodeData, + IsolationNode left, + IsolationNode right) { + this(nodeData, null, left, right); } public boolean isLeaf() { diff --git a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/evaluation/TPRThresholdEvaluator.java b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/evaluation/TPRThresholdEvaluator.java index 9024390..de0a801 100644 --- a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/evaluation/TPRThresholdEvaluator.java +++ b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/evaluation/TPRThresholdEvaluator.java @@ -2,73 +2,90 @@ import org.rsultan.core.Evaluator; import org.rsultan.core.clustering.ensemble.isolationforest.IsolationForest; +import org.rsultan.dataframe.Dataframe; import org.rsultan.dataframe.Dataframes; import org.rsultan.dataframe.Row; import org.rsultan.dataframe.TrainTestDataframe; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class TPRThresholdEvaluator implements Evaluator { - private final String responseVariable; - private final String predictionColumn; - private double desiredTPR = 0.9; - private double learningRate = 0.01; - private double TPR = 0; - private double FPR = 0; + private static final Logger LOG = LoggerFactory.getLogger(TPRThresholdEvaluator.class); + private static final String IS_AN_ANOMALY = "isAnAnomaly"; + private final String responseVariable; + private final String predictionColumn; + private double desiredTPR = 0.9; + private double learningRate = 0.01; + private double TPR = 0; + private double FPR = 0; + private Dataframe externalTestDataframe; - public TPRThresholdEvaluator(String responseVariable, String predictionColumn) { - this.responseVariable = responseVariable; - this.predictionColumn = predictionColumn; - } + public TPRThresholdEvaluator(String responseVariable, String predictionColumn) { + this.responseVariable = responseVariable; + this.predictionColumn = predictionColumn; + } - @Override - public Double evaluate(IsolationForest trainable, TrainTestDataframe dataframe) { - var dfSplit = dataframe.shuffle().split(); - double threshold = 1; - while (threshold > 0 && TPR <= desiredTPR) { - threshold -= learningRate; - var trained = trainable.setAnomalyThreshold(threshold) - .train(dfSplit.train().mapWithout(responseVariable)); - var responses = dfSplit.test().get(responseVariable); - var predictions = trained.predict(dfSplit.test().mapWithout(responseVariable)) - .get(predictionColumn); + @Override + public Double evaluate(IsolationForest trainable, TrainTestDataframe dataframe) { + var dfSplit = dataframe.shuffle().split(); + var trained = trainable.setUseAnomalyScoresOnly(true) + .train(dfSplit.train().mapWithout(responseVariable)); + final Dataframe testDf = externalTestDataframe != null ? externalTestDataframe : dfSplit.test(); + final Dataframe predict = trained.predict(testDf.mapWithout(responseVariable)); - double truePositives = 0; - double trueNegatives = 0; - double falsePositives = 0; - double falseNegative = 0; + double threshold = 1; + while (threshold > 0 && TPR <= desiredTPR) { + LOG.info("Evaluating isolation forest with threshold {}", threshold); + threshold -= learningRate; + final double finalThreshold = threshold; + var responses = testDf.get(responseVariable); + var predictions = predict.map(IS_AN_ANOMALY, + (Double score) -> (score >= finalThreshold ? 1L : 0L), predictionColumn) + .get(IS_AN_ANOMALY); - for (int i = 0; i < responses.size(); i++) { - var response = responses.get(i); - var prediction = predictions.get(i); - truePositives += response == 1L && prediction == 1L ? 1L : 0L; - trueNegatives += response == 0L && prediction == 0L ? 1L : 0L; - falsePositives += response == 0L && prediction == 1L ? 1L : 0L; - falseNegative += response == 1L && prediction == 0L ? 1L : 0L; - } - TPR = truePositives / (truePositives + falseNegative); - TPR = Double.isNaN(TPR) ? 0 : TPR; + double truePositives = 0; + double trueNegatives = 0; + double falsePositives = 0; + double falseNegative = 0; - FPR = falsePositives / (falsePositives + trueNegatives); - FPR = Double.isNaN(FPR) ? 0L : FPR; - } - if (threshold < 0) { - throw new IllegalArgumentException("Cannot have desired TPR"); + for (int i = 0; i < responses.size(); i++) { + var response = responses.get(i); + var prediction = predictions.get(i); + truePositives += response == 1L && prediction == 1L ? 1L : 0L; + trueNegatives += response == 0L && prediction == 0L ? 1L : 0L; + falsePositives += response == 0L && prediction == 1L ? 1L : 0L; + falseNegative += response == 1L && prediction == 0L ? 1L : 0L; + } + TPR = truePositives / (truePositives + falseNegative); + TPR = Double.isNaN(TPR) ? 0 : TPR; + + FPR = falsePositives / (falsePositives + trueNegatives); + FPR = Double.isNaN(FPR) ? 0L : FPR; + } + if (threshold < 0) { + throw new IllegalArgumentException("Cannot have desired TPR"); + } + + return threshold; } - return threshold; - } + public TPRThresholdEvaluator setDesiredTPR(double desiredTPR) { + this.desiredTPR = desiredTPR; + return this; + } - public TPRThresholdEvaluator setDesiredTPR(double desiredTPR) { - this.desiredTPR = desiredTPR; - return this; - } + public TPRThresholdEvaluator setLearningRate(double learningRate) { + this.learningRate = learningRate; + return this; + } - public TPRThresholdEvaluator setLearningRate(double learningRate) { - this.learningRate = learningRate; - return this; - } + public TPRThresholdEvaluator setExternalTestDataframe(Dataframe dataframe) { + this.externalTestDataframe = dataframe; + return this; + } - public void showMetrics() { - Dataframes.create(new String[]{"TPR", "FPR"}, new Row(TPR, FPR)).tail(); - } + public void showMetrics() { + Dataframes.create(new String[]{"TPR", "FPR"}, new Row(TPR, FPR)).tail(); + } } diff --git a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/ExtendedIsolationForest.java b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/ExtendedIsolationForest.java new file mode 100644 index 0000000..e161fe1 --- /dev/null +++ b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/ExtendedIsolationForest.java @@ -0,0 +1,56 @@ +package org.rsultan.core.clustering.ensemble.isolationforest; + +import org.apache.commons.lang3.RandomUtils; +import org.nd4j.linalg.indexing.NDArrayIndex; +import org.rsultan.core.clustering.ensemble.isolationforest.tree.ExtendedIsolationTree; +import org.rsultan.dataframe.Dataframe; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.stream.LongStream; + +import static java.lang.String.format; +import static java.util.stream.IntStream.range; +import static org.rsultan.core.clustering.ensemble.isolationforest.utils.ScoreUtils.averagePathLength; + +public class ExtendedIsolationForest extends IsolationForest { + + private static final Logger LOG = LoggerFactory.getLogger(ExtendedIsolationForest.class); + + private final int extensionLevel; + + public ExtendedIsolationForest(int nbTrees, int extensionLevel) { + super(nbTrees); + this.extensionLevel = extensionLevel; + } + + public ExtendedIsolationForest setSampleSize(int sampleSize) { + super.setSampleSize(sampleSize); + return this; + } + + public ExtendedIsolationForest setAnomalyThreshold(double anomalyThreshold) { + super.setAnomalyThreshold(anomalyThreshold); + return this; + } + + @Override + public ExtendedIsolationForest train(Dataframe dataframe) { + var matrix = dataframe.toMatrix(); + if (extensionLevel > matrix.columns() - 1 || extensionLevel < 0) { + throw new IllegalArgumentException(format("extensionLevel must be between 0 and %d, current is [%d]", matrix.columns() - 1, extensionLevel)); + } + int realSample = sampleSize >= matrix.rows() ? sampleSize / 10 : sampleSize; + int treeDepth = (int) Math.ceil(Math.log(realSample) / Math.log(2)); + isolationTrees = range(0, nbTrees).parallel() + .peek(i -> LOG.info("Tree number: {}", i)) + .mapToObj(i -> LongStream.range(0, realSample) + .map(idx -> RandomUtils.nextLong(0, matrix.rows())) + .toArray()) + .map(NDArrayIndex::indices) + .map(matrix::get) + .map(m -> new ExtendedIsolationTree(treeDepth, extensionLevel).train(m)) + .toList(); + return this; + } +} diff --git a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/IsolationForest.java b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/IsolationForest.java index 9d9fa2a..69b1bf6 100644 --- a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/IsolationForest.java +++ b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/IsolationForest.java @@ -9,10 +9,11 @@ import org.apache.commons.lang3.RandomUtils; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.ops.transforms.Transforms; +import org.rsultan.core.RawTrainable; import org.rsultan.core.Trainable; +import org.rsultan.core.clustering.ensemble.isolationforest.tree.IsolationTree; import org.rsultan.dataframe.Column; import org.rsultan.dataframe.Dataframe; import org.slf4j.Logger; @@ -23,8 +24,9 @@ public class IsolationForest implements Trainable { private static final Logger LOG = LoggerFactory.getLogger(IsolationTree.class); protected final int nbTrees; protected double anomalyThreshold = 0.5; - protected List isolationTrees; + protected List> isolationTrees; protected int sampleSize = 256; + private boolean useAnomalyScoresOnly; public IsolationForest(int nbTrees) { this.nbTrees = nbTrees; @@ -40,6 +42,11 @@ public IsolationForest setAnomalyThreshold(double anomalyThreshold) { return this; } + public IsolationForest setUseAnomalyScoresOnly(boolean useAnomalyScoresOnly) { + this.useAnomalyScoresOnly = useAnomalyScoresOnly; + return this; + } + @Override public IsolationForest train(Dataframe dataframe) { var matrix = dataframe.toMatrix(); @@ -61,9 +68,11 @@ public IsolationForest train(Dataframe dataframe) { public Dataframe predict(Dataframe dataframe) { var matrix = dataframe.toMatrix(); var anomalyScores = computeAnomalyScore(matrix); - var isAnomaly = new Column<>("anomalies", DoubleStream.of( - anomalyScores.toDoubleVector() - ).mapToObj(score -> score >= anomalyThreshold ? 1L : 0L).toArray()); + final DoubleStream doubleStream = DoubleStream.of( + anomalyScores.toDoubleVector() + ); + var isAnomaly = new Column<>("anomalies", + useAnomalyScoresOnly ? doubleStream.boxed().toArray() : doubleStream.mapToObj(score -> score >= anomalyThreshold ? 1L : 0L).toArray()); return dataframe.addColumn(isAnomaly); } diff --git a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/IsolationTree.java b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/IsolationTree.java deleted file mode 100644 index b1801c8..0000000 --- a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/IsolationTree.java +++ /dev/null @@ -1,93 +0,0 @@ -package org.rsultan.core.clustering.ensemble.isolationforest; - -import static java.util.stream.LongStream.range; -import static org.apache.commons.lang3.RandomUtils.nextDouble; -import static org.apache.commons.lang3.RandomUtils.nextInt; -import static org.rsultan.core.clustering.ensemble.isolationforest.utils.ScoreUtils.averagePathLength; - -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.indexing.NDArrayIndex; -import org.rsultan.core.RawTrainable; -import org.rsultan.core.clustering.ensemble.domain.IsolationNode; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class IsolationTree implements RawTrainable { - - public static final Logger LOG = LoggerFactory.getLogger(IsolationTree.class); - private final int treeDepthLimit; - private IsolationNode tree; - - public IsolationTree(int treeDepthLimit) { - this.treeDepthLimit = treeDepthLimit; - } - - @Override - public IsolationTree train(INDArray matrix) { - this.tree = buildTree(matrix, treeDepthLimit); - return this; - } - - private IsolationNode buildTree(INDArray matrix, int currentDepth) { - LOG.info("Tree Depth {}", currentDepth); - if (currentDepth <= 0 || matrix.rows() <= 2) { - return new IsolationNode(matrix); - } - int numberOfFeatures = matrix.columns(); - int splitFeature = nextInt(0, numberOfFeatures); - var feature = matrix.getColumn(splitFeature); - double startInclusive = feature.minNumber().doubleValue(); - double endInclusive = feature.maxNumber().doubleValue(); - - double valueSplit = - getValueSplit(startInclusive, endInclusive); - - var leftIndices = range(0, feature.columns()).parallel() - .filter(idx -> feature.getDouble(idx) < valueSplit) - .toArray(); - var left = getVector(matrix, leftIndices); - - var rightIndices = range(0, feature.columns()).parallel() - .filter(idx -> feature.getDouble(idx) > valueSplit) - .toArray(); - var right = getVector(matrix, rightIndices); - - return new IsolationNode( - splitFeature, - valueSplit, - buildTree(left, currentDepth - 1), - buildTree(right, currentDepth - 1) - ); - } - - private INDArray getVector(INDArray matrix, long[] indices) { - return matrix.get(NDArrayIndex.indices(indices)); - } - - private double getValueSplit(double startInclusive, double endInclusive) { - if (startInclusive < 0 && endInclusive < 0) { - return -nextDouble(endInclusive * -1, startInclusive * -1); - } else if (startInclusive < 0 && endInclusive >= 0) { - return nextDouble(0, endInclusive + startInclusive * -1) + startInclusive; - } - return nextDouble(startInclusive, endInclusive); - } - - @Override - public INDArray predict(INDArray matrix) { - var pathLengths = Nd4j.zeros(1, matrix.rows()); - for (int i = 0; i < matrix.rows(); i++) { - var row = matrix.getRow(i); - var node = tree; - int length = 0; - while (!node.isLeaf()) { - node = row.getDouble(node.feature()) < node.featureThreshold() ? node.left() : node.right(); - length++; - } - int leafSize = node.data().rows(); - pathLengths.put(0, i, length + averagePathLength(leafSize)); - } - return pathLengths; - } -} diff --git a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/tree/AbstractTree.java b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/tree/AbstractTree.java new file mode 100644 index 0000000..7c54b0c --- /dev/null +++ b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/tree/AbstractTree.java @@ -0,0 +1,58 @@ +package org.rsultan.core.clustering.ensemble.isolationforest.tree; + +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.NDArrayIndex; +import org.rsultan.core.clustering.ensemble.domain.IsolationNode; +import org.rsultan.core.clustering.ensemble.isolationforest.tree.IsolationTree.Feature; + +import java.util.function.LongPredicate; + +import static java.util.stream.LongStream.range; +import static org.rsultan.core.clustering.ensemble.isolationforest.utils.ScoreUtils.averagePathLength; + +public abstract class AbstractTree { + + protected final int treeDepthLimit; + protected IsolationNode tree; + + public AbstractTree(int treeDepthLimit) { + this.treeDepthLimit = treeDepthLimit; + } + + protected IsolationNode buildTree(INDArray X, int currentDepth){ + if (currentDepth <= 0 || X.rows() <= 2) { + return new IsolationNode<>(X); + } + return buildNode(X, currentDepth); + } + + protected abstract IsolationNode buildNode(INDArray X, int currentDepth); + + protected abstract boolean chooseLeftNode(INDArray row, NODE_DATA slope); + + protected INDArray getVector(INDArray matrix, long[] indices) { + return matrix.get(NDArrayIndex.indices(indices)); + } + + protected long[] getIndices(INDArray ndArray, LongPredicate predicate){ + return range(0, ndArray.rows()).parallel().filter(predicate).toArray(); + } + + public INDArray predict(INDArray matrix) { + var pathLengths = Nd4j.zeros(1, matrix.rows()); + for (int i = 0; i < matrix.rows(); i++) { + var row = matrix.getRow(i); + var node = tree; + int length = 0; + while (!node.isLeaf()) { + var slope = node.nodeData(); + node = chooseLeftNode(row, slope) ? node.left() : node.right(); + length++; + } + int leafSize = node.data().rows(); + pathLengths.put(0, i, length + averagePathLength(leafSize)); + } + return pathLengths; + } +} diff --git a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/tree/ExtendedIsolationTree.java b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/tree/ExtendedIsolationTree.java new file mode 100644 index 0000000..21b8ed8 --- /dev/null +++ b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/tree/ExtendedIsolationTree.java @@ -0,0 +1,87 @@ +package org.rsultan.core.clustering.ensemble.isolationforest.tree; + +import org.apache.commons.math3.distribution.NormalDistribution; +import org.nd4j.common.util.MathUtils; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.rsultan.core.RawTrainable; +import org.rsultan.core.clustering.ensemble.domain.IsolationNode; +import org.rsultan.core.clustering.ensemble.isolationforest.tree.ExtendedIsolationTree.Slope; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + + +import static java.util.stream.LongStream.range; + +public class ExtendedIsolationTree extends AbstractTree implements RawTrainable { + + public static final Logger LOG = LoggerFactory.getLogger(ExtendedIsolationTree.class); + private final int extensionLevel; + + public ExtendedIsolationTree(int treeDepthLimit, int extensionLevel) { + super(treeDepthLimit); + this.extensionLevel = extensionLevel; + } + + @Override + public ExtendedIsolationTree train(INDArray matrix) { + this.tree = buildTree(matrix, treeDepthLimit); + return this; + } + + @Override + protected IsolationNode buildNode(INDArray X, int currentDepth) { + LOG.info("Tree Depth {}", currentDepth); + if (currentDepth <= 0 || X.rows() <= 2) { + return new IsolationNode<>(X); + } + int numberOfFeatures = X.columns(); + var mins = X.min(true, 0); + var maxs = X.max(true, 0); + var n = getNormalVector(mins.shape()); + var p = getIntercept(numberOfFeatures, mins, maxs); + var w = X.sub(p).mmul(n); + + var left = getVector(X, getIndices(w, idx -> w.getDouble(idx) < 0)); + var right = getVector(X, getIndices(w, idx -> w.getDouble(idx) >= 0)); + + return new IsolationNode<>( + new Slope(n, p), + buildTree(left, currentDepth - 1), + buildTree(right, currentDepth - 1) + ); + } + + private INDArray getIntercept(int numberOfFeatures, INDArray mins, INDArray maxs) { + var p = Nd4j.zeros(mins.shape()); + for (int i = 0; i < numberOfFeatures; i++) { + p.put(0, i, MathUtils.randomDoubleBetween(mins.getDouble(0, i), maxs.getDouble(0, i))); + } + return p; + } + + private INDArray getNormalVector(long[] shape) { + var distribution = new NormalDistribution(); + var n = Nd4j.zeros(shape); + for (int i = 0; i < n.columns(); i++) { + n.put(0, i, distribution.sample()); + } + final int numSamples = n.columns() - this.extensionLevel - 1; + if (numSamples > 0) { + final INDArray arange = Nd4j.create(range(0, n.columns()).boxed().toList()); + int[] indices = Nd4j.choice(arange, Nd4j.rand(n.columns()), numSamples).toIntVector(); + for (int index : indices) { + n.putScalar(index, 0.0D); + } + } + return n.transpose(); + } + + @Override + protected boolean chooseLeftNode(INDArray row, Slope slope) { + return row.sub(slope.p).mmul(slope.n).getDouble(0) < 0; + } + + static record Slope(INDArray n, INDArray p){} + +} diff --git a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/tree/IsolationTree.java b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/tree/IsolationTree.java new file mode 100644 index 0000000..bc0e9d3 --- /dev/null +++ b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/tree/IsolationTree.java @@ -0,0 +1,54 @@ +package org.rsultan.core.clustering.ensemble.isolationforest.tree; + +import static org.apache.commons.lang3.RandomUtils.nextInt; + +import org.nd4j.common.util.MathUtils; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.rsultan.core.RawTrainable; +import org.rsultan.core.clustering.ensemble.domain.IsolationNode; +import org.rsultan.core.clustering.ensemble.isolationforest.tree.IsolationTree.Feature; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class IsolationTree extends AbstractTree implements RawTrainable { + + public static final Logger LOG = LoggerFactory.getLogger(IsolationTree.class); + + public IsolationTree(int treeDepthLimit) { + super(treeDepthLimit); + } + + @Override + public IsolationTree train(INDArray matrix) { + this.tree = buildTree(matrix, treeDepthLimit); + return this; + } + + @Override + protected IsolationNode buildNode(INDArray X, int currentDepth) { + LOG.info("Tree Depth {}", currentDepth); + int numberOfFeatures = X.columns(); + int splitFeature = nextInt(0, numberOfFeatures); + var feature = X.getColumn(splitFeature); + double startInclusive = feature.minNumber().doubleValue(); + double endInclusive = feature.maxNumber().doubleValue(); + double valueSplit = MathUtils.randomDoubleBetween(startInclusive, endInclusive); + + var left = getVector(X, getIndices(feature, idx -> feature.getDouble(idx) < valueSplit)); + var right = getVector(X, getIndices(feature, idx -> feature.getDouble(idx) >= valueSplit)); + + return new IsolationNode<>( + new Feature(splitFeature, valueSplit), + buildTree(left, currentDepth - 1), + buildTree(right, currentDepth - 1) + ); + } + + @Override + protected boolean chooseLeftNode(INDArray row, Feature feature) { + return row.getDouble(feature.feature) < feature.threshold ; + } + + static record Feature(int feature, double threshold){} + +} diff --git a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/utils/ScoreUtils.java b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/utils/ScoreUtils.java index fd4f95d..0d5a5e2 100644 --- a/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/utils/ScoreUtils.java +++ b/java-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/utils/ScoreUtils.java @@ -17,4 +17,5 @@ public static double averagePathLength(double leafSize) { private static double harmonicNumber(double leafSize) { return Math.log(leafSize - 1) + EULER_CONSTANT; } + } diff --git a/java-ml/src/test/java/org/rsultan/core/clustering/IsolationForestTest.java b/java-ml/src/test/java/org/rsultan/core/clustering/IsolationForestTest.java index c57f9f4..33fbe93 100644 --- a/java-ml/src/test/java/org/rsultan/core/clustering/IsolationForestTest.java +++ b/java-ml/src/test/java/org/rsultan/core/clustering/IsolationForestTest.java @@ -11,6 +11,7 @@ import java.util.List; import org.junit.jupiter.api.Test; import org.rsultan.core.clustering.ensemble.evaluation.TPRThresholdEvaluator; +import org.rsultan.core.clustering.ensemble.isolationforest.ExtendedIsolationForest; import org.rsultan.core.clustering.ensemble.isolationforest.IsolationForest; import org.rsultan.dataframe.Dataframes; import org.rsultan.dataframe.Row; @@ -55,5 +56,35 @@ public void should_evaluate_tpr() { assertThat(threshold).isGreaterThan(0.4); } + @Test + public void must_perform_extended_isolation() { + var df = Dataframes.create(new String[]{"x", "y", "response"}, rows); + var predict = new ExtendedIsolationForest(10, 1).setAnomalyThreshold(0.56).setSampleSize(15).train(df).predict(df); + int anomalies = predict.filter("anomalies", obj -> obj.equals(1L)).getRowSize(); + int nonAnomalies = predict.filter("anomalies", obj -> obj.equals(0L)).getRowSize(); + + assertThat(anomalies).isBetween(1, predict.getRowSize()); + assertThat(nonAnomalies).isEqualTo(predict.getRowSize() - anomalies); + } + + @Test + public void should_evaluate_extended_tpr() { + var df = Dataframes.trainTest(new String[]{"x", "y", "response"}, rows); + var model = new ExtendedIsolationForest(10, 1).setSampleSize(15); + var evaluator = new TPRThresholdEvaluator("response", "anomalies").setDesiredTPR(0.9).setLearningRate(0.01); + var threshold = evaluator.evaluate(model, df); + evaluator.showMetrics(); + assertThat(threshold).isGreaterThan(0.4); + } + @Test + public void must_perform_isolation_and_retrieve_scores() { + var df = Dataframes.create(new String[]{"x", "y", "response"}, rows); + var predict = new IsolationForest(10).setAnomalyThreshold(0.56).setSampleSize(15).setUseAnomalyScoresOnly(true).train(df).predict(df); + int anomalies = predict.filter("anomalies", (Number obj) -> obj.doubleValue() > 0.56).getRowSize(); + int nonAnomalies = predict.filter("anomalies", (Number obj) -> obj.doubleValue() <= 0.56).getRowSize(); + + assertThat(anomalies).isBetween(1, predict.getRowSize()); + assertThat(nonAnomalies).isEqualTo(predict.getRowSize() - anomalies); + } } diff --git a/version.properties b/version.properties index 7665a6f..7c407e8 100755 --- a/version.properties +++ b/version.properties @@ -1 +1 @@ -version.next=2.1.3 +version.next=2.2.0