-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
develop --> master for release 2.1.0 #72
develop --> master for release 2.1.0
- Loading branch information
Showing
17 changed files
with
580,196 additions
and
2 deletions.
There are no files selected for viewing
30 changes: 30 additions & 0 deletions
30
java-ml-example/src/main/java/org/rsultan/example/IsolationForestExample.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
package org.rsultan.example; | ||
|
||
import java.io.IOException; | ||
import org.nd4j.linalg.api.buffer.DataType; | ||
import org.nd4j.linalg.factory.Nd4j; | ||
import org.rsultan.core.clustering.ensemble.evaluation.TPRThresholdEvaluator; | ||
import org.rsultan.core.clustering.ensemble.isolationforest.IsolationForest; | ||
import org.rsultan.dataframe.Dataframes; | ||
import org.rsultan.dataframe.TrainTestDataframe; | ||
|
||
public class IsolationForestExample { | ||
|
||
/* | ||
You can use the http dataset --> args[0] | ||
*/ | ||
static { | ||
Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE); | ||
} | ||
|
||
public static void main(String[] args) throws IOException { | ||
var df = Dataframes.csv(args[0], ",", "\"", true); | ||
var trainTestDataframe = Dataframes.trainTest(df.getColumns()).setSplitValue(0.5); | ||
|
||
IsolationForest model = new IsolationForest(200).train(df.mapWithout("attack")); | ||
var evaluator = new TPRThresholdEvaluator("attack", "anomalies").setDesiredTPR(0.9).setLearningRate(0.02); | ||
Double threshold = evaluator.evaluate(model, trainTestDataframe); | ||
System.out.println("threshold = " + threshold); | ||
evaluator.showMetrics(); | ||
} | ||
} |
567,499 changes: 567,499 additions & 0 deletions
567,499
java-ml-example/src/main/resources/http/http.csv
Large diffs are not rendered by default.
Oops, something went wrong.
12,212 changes: 12,212 additions & 0 deletions
12,212
java-ml-example/src/main/resources/http/http_reduced.csv
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
package org.rsultan.core; | ||
|
||
import org.rsultan.core.Trainable; | ||
import org.rsultan.dataframe.TrainTestDataframe; | ||
|
||
public interface Evaluator<V, T extends Trainable<T>> { | ||
|
||
V evaluate(T trainable, TrainTestDataframe dataframe); | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
package org.rsultan.core; | ||
|
||
import java.io.Serializable; | ||
import org.nd4j.linalg.api.ndarray.INDArray; | ||
import org.rsultan.dataframe.Dataframe; | ||
|
||
public interface RawTrainable<T> extends Serializable { | ||
|
||
T train(INDArray matrix); | ||
|
||
INDArray predict(INDArray matrix); | ||
|
||
} |
30 changes: 30 additions & 0 deletions
30
java-ml/src/main/java/org/rsultan/core/clustering/ensemble/domain/IsolationNode.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
package org.rsultan.core.clustering.ensemble.domain; | ||
|
||
import static java.util.Objects.nonNull; | ||
|
||
import java.io.Serializable; | ||
import org.nd4j.linalg.api.ndarray.INDArray; | ||
|
||
public record IsolationNode( | ||
int feature, | ||
double featureThreshold, | ||
INDArray data, | ||
IsolationNode left, | ||
IsolationNode right | ||
) implements Serializable { | ||
|
||
public IsolationNode(INDArray data) { | ||
this(-1, -1, data, null, null); | ||
} | ||
public IsolationNode( | ||
int feature, | ||
double featureThreshold, | ||
IsolationNode left, | ||
IsolationNode right) { | ||
this(feature, featureThreshold, null, left, right); | ||
} | ||
|
||
public boolean isLeaf() { | ||
return nonNull(data); | ||
} | ||
} |
74 changes: 74 additions & 0 deletions
74
.../src/main/java/org/rsultan/core/clustering/ensemble/evaluation/TPRThresholdEvaluator.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
package org.rsultan.core.clustering.ensemble.evaluation; | ||
|
||
import org.rsultan.core.Evaluator; | ||
import org.rsultan.core.clustering.ensemble.isolationforest.IsolationForest; | ||
import org.rsultan.dataframe.Dataframes; | ||
import org.rsultan.dataframe.Row; | ||
import org.rsultan.dataframe.TrainTestDataframe; | ||
|
||
public class TPRThresholdEvaluator implements Evaluator<Double, IsolationForest> { | ||
|
||
private final String responseVariable; | ||
private final String predictionColumn; | ||
private double desiredTPR = 0.9; | ||
private double learningRate = 0.01; | ||
private double TPR = 0; | ||
private double FPR = 0; | ||
|
||
public TPRThresholdEvaluator(String responseVariable, String predictionColumn) { | ||
this.responseVariable = responseVariable; | ||
this.predictionColumn = predictionColumn; | ||
} | ||
|
||
@Override | ||
public Double evaluate(IsolationForest trainable, TrainTestDataframe dataframe) { | ||
var dfSplit = dataframe.shuffle().split(); | ||
double threshold = 1; | ||
while (threshold > 0 && TPR <= desiredTPR) { | ||
threshold -= learningRate; | ||
var trained = trainable.setAnomalyThreshold(threshold) | ||
.train(dfSplit.train().mapWithout(responseVariable)); | ||
var responses = dfSplit.test().<Long>get(responseVariable); | ||
var predictions = trained.predict(dfSplit.test().mapWithout(responseVariable)) | ||
.<Long>get(predictionColumn); | ||
|
||
double truePositives = 0; | ||
double trueNegatives = 0; | ||
double falsePositives = 0; | ||
double falseNegative = 0; | ||
|
||
for (int i = 0; i < responses.size(); i++) { | ||
var response = responses.get(i); | ||
var prediction = predictions.get(i); | ||
truePositives += response == 1L && prediction == 1L ? 1L : 0L; | ||
trueNegatives += response == 0L && prediction == 0L ? 1L : 0L; | ||
falsePositives += response == 0L && prediction == 1L ? 1L : 0L; | ||
falseNegative += response == 1L && prediction == 0L ? 1L : 0L; | ||
} | ||
TPR = truePositives / (truePositives + falseNegative); | ||
TPR = Double.isNaN(TPR) ? 0 : TPR; | ||
|
||
FPR = falsePositives / (falsePositives + trueNegatives); | ||
FPR = Double.isNaN(FPR) ? 0L : FPR; | ||
} | ||
if (threshold < 0) { | ||
throw new IllegalArgumentException("Cannot have desired TPR"); | ||
} | ||
|
||
return threshold; | ||
} | ||
|
||
public TPRThresholdEvaluator setDesiredTPR(double desiredTPR) { | ||
this.desiredTPR = desiredTPR; | ||
return this; | ||
} | ||
|
||
public TPRThresholdEvaluator setLearningRate(double learningRate) { | ||
this.learningRate = learningRate; | ||
return this; | ||
} | ||
|
||
public void showMetrics() { | ||
Dataframes.create(new String[]{"TPR", "FPR"}, new Row(TPR, FPR)).tail(); | ||
} | ||
} |
75 changes: 75 additions & 0 deletions
75
...l/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/IsolationForest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
package org.rsultan.core.clustering.ensemble.isolationforest; | ||
|
||
import static java.util.stream.IntStream.range; | ||
import static org.rsultan.core.clustering.ensemble.isolationforest.utils.ScoreUtils.averagePathLength; | ||
|
||
import java.util.List; | ||
import java.util.stream.DoubleStream; | ||
import org.apache.commons.lang3.RandomUtils; | ||
import org.nd4j.linalg.api.ndarray.INDArray; | ||
import org.nd4j.linalg.factory.Nd4j; | ||
import org.nd4j.linalg.ops.transforms.Transforms; | ||
import org.rsultan.core.Trainable; | ||
import org.rsultan.dataframe.Column; | ||
import org.rsultan.dataframe.Dataframe; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
|
||
public class IsolationForest implements Trainable<IsolationForest> { | ||
|
||
private static final Logger LOG = LoggerFactory.getLogger(IsolationTree.class); | ||
private final int nbTrees; | ||
private double anomalyThreshold = 0.5; | ||
private List<IsolationTree> isolationTrees; | ||
private int sampleSize = 256; | ||
|
||
public IsolationForest(int nbTrees) { | ||
this.nbTrees = nbTrees; | ||
} | ||
|
||
public IsolationForest setSampleSize(int sampleSize) { | ||
this.sampleSize = sampleSize; | ||
return this; | ||
} | ||
|
||
public IsolationForest setAnomalyThreshold(double anomalyThreshold) { | ||
this.anomalyThreshold = anomalyThreshold; | ||
return this; | ||
} | ||
|
||
@Override | ||
public IsolationForest train(Dataframe dataframe) { | ||
var matrix = dataframe.toMatrix(); | ||
int realSample = sampleSize >= matrix.rows() ? sampleSize / 10 : sampleSize; | ||
int treeDepth = (int) Math.ceil(Math.log(realSample) / Math.log(2)); | ||
isolationTrees = range(0, nbTrees).parallel() | ||
.peek(i -> LOG.info("Tree number: {}", i)) | ||
.mapToObj(i -> range(0, realSample) | ||
.map(idx -> RandomUtils.nextInt(0, matrix.rows())) | ||
.toArray()).map(matrix::getRows) | ||
.map(m -> new IsolationTree(treeDepth).train(m)) | ||
.toList(); | ||
return this; | ||
} | ||
|
||
@Override | ||
public Dataframe predict(Dataframe dataframe) { | ||
var matrix = dataframe.toMatrix(); | ||
var anomalyScores = computeAnomalyScore(matrix); | ||
var isAnomaly = new Column<>("anomalies", DoubleStream.of( | ||
anomalyScores.toDoubleVector() | ||
).mapToObj(score -> score >= anomalyThreshold ? 1L : 0L).toArray()); | ||
return dataframe.addColumn(isAnomaly); | ||
} | ||
|
||
private INDArray computeAnomalyScore(INDArray matrix) { | ||
var pathLengths = isolationTrees.stream().parallel().map(tree -> { | ||
LOG.info("Compute paths for tree {}", isolationTrees.indexOf(tree) + 1); | ||
return tree.predict(matrix); | ||
}).toList(); | ||
int[] shape = {pathLengths.size(), pathLengths.get(0).columns()}; | ||
var avgLength = Nd4j.create(pathLengths, shape).mean(true, 0); | ||
var twos = Nd4j.ones(avgLength.shape()).mul(2D); | ||
return Transforms.pow(twos, avgLength.neg().div(averagePathLength(sampleSize))); | ||
} | ||
} |
88 changes: 88 additions & 0 deletions
88
...-ml/src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/IsolationTree.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
package org.rsultan.core.clustering.ensemble.isolationforest; | ||
|
||
import static java.util.stream.IntStream.range; | ||
import static org.apache.commons.lang3.RandomUtils.nextDouble; | ||
import static org.apache.commons.lang3.RandomUtils.nextInt; | ||
import static org.rsultan.core.clustering.ensemble.isolationforest.utils.ScoreUtils.averagePathLength; | ||
|
||
import org.nd4j.linalg.api.ndarray.INDArray; | ||
import org.nd4j.linalg.factory.Nd4j; | ||
import org.rsultan.core.RawTrainable; | ||
import org.rsultan.core.clustering.ensemble.domain.IsolationNode; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
|
||
public class IsolationTree implements RawTrainable<IsolationTree> { | ||
|
||
public static final Logger LOG = LoggerFactory.getLogger(IsolationTree.class); | ||
private final int treeDepthLimit; | ||
private IsolationNode tree; | ||
|
||
public IsolationTree(int treeDepthLimit) { | ||
this.treeDepthLimit = treeDepthLimit; | ||
} | ||
|
||
@Override | ||
public IsolationTree train(INDArray matrix) { | ||
this.tree = buildTree(matrix, treeDepthLimit); | ||
return this; | ||
} | ||
|
||
private IsolationNode buildTree(INDArray matrix, int currentDepth) { | ||
LOG.info("Tree Depth {}", currentDepth); | ||
if (currentDepth <= 0 || matrix.rows() <= 2) { | ||
return new IsolationNode(matrix); | ||
} | ||
int numberOfFeatures = matrix.columns(); | ||
int splitFeature = nextInt(0, numberOfFeatures); | ||
var feature = matrix.getColumn(splitFeature); | ||
double startInclusive = feature.minNumber().doubleValue(); | ||
double endInclusive = feature.maxNumber().doubleValue(); | ||
|
||
double valueSplit = | ||
getValueSplit(startInclusive, endInclusive); | ||
|
||
var leftIndices = range(0, feature.columns()).parallel() | ||
.filter(idx -> feature.getDouble(idx) < valueSplit) | ||
.toArray(); | ||
var left = matrix.getRows(leftIndices); | ||
|
||
var rightIndices = range(0, feature.columns()).parallel() | ||
.filter(idx -> feature.getDouble(idx) > valueSplit) | ||
.toArray(); | ||
var right = matrix.getRows(rightIndices); | ||
|
||
return new IsolationNode( | ||
splitFeature, | ||
valueSplit, | ||
buildTree(left, currentDepth - 1), | ||
buildTree(right, currentDepth - 1) | ||
); | ||
} | ||
|
||
private double getValueSplit(double startInclusive, double endInclusive) { | ||
if (startInclusive < 0 && endInclusive < 0) { | ||
return -nextDouble(endInclusive * -1, startInclusive * -1); | ||
} else if (startInclusive < 0 && endInclusive >= 0) { | ||
return nextDouble(0, endInclusive + startInclusive * -1) + startInclusive; | ||
} | ||
return nextDouble(startInclusive, endInclusive); | ||
} | ||
|
||
@Override | ||
public INDArray predict(INDArray matrix) { | ||
var pathLengths = Nd4j.zeros(1, matrix.rows()); | ||
for (int i = 0; i < matrix.rows(); i++) { | ||
var row = matrix.getRow(i); | ||
var node = tree; | ||
int length = 0; | ||
while (!node.isLeaf()) { | ||
node = row.getDouble(node.feature()) < node.featureThreshold() ? node.left() : node.right(); | ||
length++; | ||
} | ||
int leafSize = node.data().rows(); | ||
pathLengths.put(0, i, length + averagePathLength(leafSize)); | ||
} | ||
return pathLengths; | ||
} | ||
} |
20 changes: 20 additions & 0 deletions
20
.../src/main/java/org/rsultan/core/clustering/ensemble/isolationforest/utils/ScoreUtils.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
package org.rsultan.core.clustering.ensemble.isolationforest.utils; | ||
|
||
public class ScoreUtils { | ||
|
||
private static final double EULER_CONSTANT = 0.5772156649; | ||
|
||
public static double averagePathLength(double leafSize) { | ||
if (leafSize > 2) { | ||
return 2 * harmonicNumber(leafSize) - (2 * (leafSize - 1) / leafSize); | ||
} | ||
if (leafSize == 2) { | ||
return 1; | ||
} | ||
return 0; | ||
} | ||
|
||
private static double harmonicNumber(double leafSize) { | ||
return Math.log(leafSize - 1) + EULER_CONSTANT; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.