diff --git a/README.md b/README.md index 866fe8e..cc4bf55 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,31 @@ # java-ml This repo intends to implement Machine Learning algorithms with Java and https://github.com/deeplearning4j/nd4j purely for understanding -You can check for examples in this package: ``src/main/java/org/rsultan/example`` \ No newline at end of file +You can check for examples in this package: ``src/main/java/org/rsultan/example`` + +## Getting started + +Clone the repo and execute this + +```bash + $ git fetch --all --tags + $ git checkout tags/ -b + $ ./mvnw clean install +``` + +Then you can import this to your `pom.xml` + +```xml + + org.rsultan + java-ml + [latest-version] + +``` + +There is an existing artifactory today but I am not satisfied with how to make it public today. +There will be an artifactory in the future. + +Once your are good with this, you can go read the [wiki](https://github.com/remisultan/java-ml/wiki) + +Good luck ! \ No newline at end of file diff --git a/src/main/java/org/rsultan/core/clustering/kmedoids/evaluation/KMedoidEvaluator.java b/src/main/java/org/rsultan/core/clustering/kmedoids/evaluation/KMedoidEvaluator.java index b196e81..c1ea444 100644 --- a/src/main/java/org/rsultan/core/clustering/kmedoids/evaluation/KMedoidEvaluator.java +++ b/src/main/java/org/rsultan/core/clustering/kmedoids/evaluation/KMedoidEvaluator.java @@ -21,10 +21,6 @@ public class KMedoidEvaluator { private final MedoidType medoidType; private final int epoch; - public KMedoidEvaluator(int minK, int maxK) { - this(minK, maxK, MEAN); - } - public KMedoidEvaluator(int minK, int maxK, MedoidType medoidType) { this(minK, maxK, medoidType, 100); } diff --git a/src/main/java/org/rsultan/core/clustering/kmedoids/strategy/PlusPlusStrategy.java b/src/main/java/org/rsultan/core/clustering/kmedoids/strategy/PlusPlusStrategy.java index fc21f94..ecf8eb2 100644 --- a/src/main/java/org/rsultan/core/clustering/kmedoids/strategy/PlusPlusStrategy.java +++ b/src/main/java/org/rsultan/core/clustering/kmedoids/strategy/PlusPlusStrategy.java @@ -30,7 +30,7 @@ public INDArray initialiseCenters(long K, INDArray X) { var probabilities = squaredDistances.div(squaredDistances.sum(true, 0)); var cumulativeSumTheshold = probabilities.cumsum(0) .getWhere(nextDouble(0, 1), greaterThanOrEqual()); - addNewCenters(X, centers, X.columns() - cumulativeSumTheshold.columns() + 1); + addNewCenters(X, centers, X.columns() - cumulativeSumTheshold.columns()); } return C.columns() == 1 ? C.transpose() : C; } diff --git a/src/main/java/org/rsultan/core/clustering/medoidshift/MeanShift.java b/src/main/java/org/rsultan/core/clustering/medoidshift/MeanShift.java index 752bf21..836afad 100644 --- a/src/main/java/org/rsultan/core/clustering/medoidshift/MeanShift.java +++ b/src/main/java/org/rsultan/core/clustering/medoidshift/MeanShift.java @@ -6,7 +6,7 @@ public class MeanShift extends MedoidShift { - public MeanShift(long bandwidth, long epoch) { + public MeanShift(double bandwidth, long epoch) { super(bandwidth, epoch, MEAN); } diff --git a/src/main/java/org/rsultan/core/clustering/medoidshift/MedianShift.java b/src/main/java/org/rsultan/core/clustering/medoidshift/MedianShift.java index d69bfe6..8bfd511 100644 --- a/src/main/java/org/rsultan/core/clustering/medoidshift/MedianShift.java +++ b/src/main/java/org/rsultan/core/clustering/medoidshift/MedianShift.java @@ -1,13 +1,12 @@ package org.rsultan.core.clustering.medoidshift; -import static org.rsultan.core.clustering.type.MedoidType.MEAN; import static org.rsultan.core.clustering.type.MedoidType.MEDIAN; import org.rsultan.dataframe.Dataframe; public class MedianShift extends MedoidShift { - public MedianShift(long bandwidth, long epoch) { + public MedianShift(double bandwidth, long epoch) { super(bandwidth, epoch, MEDIAN); } diff --git a/src/main/java/org/rsultan/core/clustering/medoidshift/MedoidShift.java b/src/main/java/org/rsultan/core/clustering/medoidshift/MedoidShift.java index 8362640..3716446 100644 --- a/src/main/java/org/rsultan/core/clustering/medoidshift/MedoidShift.java +++ b/src/main/java/org/rsultan/core/clustering/medoidshift/MedoidShift.java @@ -1,8 +1,8 @@ package org.rsultan.core.clustering.medoidshift; - import static java.util.stream.Collectors.toList; import static java.util.stream.LongStream.range; +import static java.util.stream.LongStream.rangeClosed; import java.util.Arrays; import java.util.concurrent.atomic.AtomicBoolean; @@ -12,17 +12,22 @@ import org.rsultan.core.clustering.type.MedoidType; import org.rsultan.dataframe.Column; import org.rsultan.dataframe.Dataframe; +import org.rsultan.dataframe.Dataframes; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public abstract class MedoidShift implements Clustering { + private static final Logger LOG = LoggerFactory.getLogger(MedoidShift.class); + private final MedoidType medoidType; - private final long bandwidth; + private final double bandwidth; private final long epoch; - private INDArray C; + private INDArray centroids; private INDArray Xt; - protected MedoidShift(long bandwidth, long epoch, MedoidType medoidType) { + protected MedoidShift(double bandwidth, long epoch, MedoidType medoidType) { this.medoidType = medoidType; this.epoch = epoch; this.bandwidth = bandwidth; @@ -33,32 +38,30 @@ public MedoidShift train(Dataframe dataframe) { var medoidFactory = medoidType.getMedoidFactory(); var isTerminated = new AtomicBoolean(false); Xt = dataframe.toMatrix().transpose(); - C = Xt.dup(); + centroids = Xt.dup(); range(0, epoch) .filter(epoch -> !isTerminated.get()) .forEach(epoch -> { - System.out.println("Epoch " + epoch); - var newCentroidList = range(0, C.columns()) + LOG.info("Epoch : {}", epoch); + var newCentroidList = range(0, centroids.columns()) .parallel().unordered() .mapToObj(col -> { - var centroid = C.getColumn(col); + var centroid = centroids.getColumn(col); var range = epoch == 0 ? range(col, Xt.columns()) : range(0, Xt.columns()); return range.parallel().unordered().mapToObj(Xt::getColumn) .filter(feature -> medoidFactory.computeNorm(feature.sub(centroid)) < bandwidth) .collect(toList()); }) - .map(features -> Nd4j.create(features, features.size(), C.rows())) + .map(features -> Nd4j.create(features, features.size(), centroids.rows())) .map(medoidFactory::computeMedoids) .distinct() .collect(toList()); - var newC = Nd4j.create(newCentroidList, newCentroidList.size(), C.rows()) + var newC = Nd4j.create(newCentroidList, newCentroidList.size(), centroids.rows()) .transpose(); - if (C.equalShapes(newC) && C.equals(newC) || C.columns() == 1) { + if (centroids.equalShapes(newC) && centroids.equals(newC) || centroids.columns() == 1) { isTerminated.set(true); } - C = newC; - System.out.println(C); - System.out.println(Arrays.toString(C.shape())); + centroids = newC; }); return this; @@ -68,14 +71,23 @@ public MedoidShift train(Dataframe dataframe) { public Dataframe predict(Dataframe dataframe) { var medoidFactory = medoidType.getMedoidFactory(); var Xpredict = dataframe.toMatrix(); - var distances = medoidFactory.computeDistance(Xpredict, C.transpose()); + var distances = medoidFactory.computeDistance(Xpredict, centroids.transpose()); var indices = Nd4j.argMin(distances, 1); - var predictions = C.transpose().get(indices); + var predictions = centroids.transpose().get(indices); return dataframe.addColumn(new Column<>("predictions", range(0, predictions.rows()).mapToObj(predictions::getRow).collect(toList()))); } - public INDArray getC() { - return C; + + public void showMetrics() { + var centroids = range(0, this.centroids.columns()).boxed() + .map(idx -> Arrays.toString(this.centroids.getColumn(idx).toDoubleVector())) + .collect(toList()); + var indices = new Column<>("K", rangeClosed(1, this.centroids.columns()).boxed().collect(toList())); + Dataframes.create(indices, new Column<>("centroids", centroids)).tail(); + } + + public INDArray getCentroids() { + return centroids; } } diff --git a/src/main/java/org/rsultan/dataframe/transform/matrix/MatrixDataframe.java b/src/main/java/org/rsultan/dataframe/transform/matrix/MatrixDataframe.java index 322b151..c5820a0 100644 --- a/src/main/java/org/rsultan/dataframe/transform/matrix/MatrixDataframe.java +++ b/src/main/java/org/rsultan/dataframe/transform/matrix/MatrixDataframe.java @@ -19,7 +19,7 @@ public class MatrixDataframe implements MatrixTransform { - private static final String NUMBER_REGEX = "^\\d+(\\.\\d+)*$"; + private static final String NUMBER_REGEX = "^(-?\\d+(\\.\\d+)*)$"; private final Dataframe dataframe; public MatrixDataframe(Dataframe dataframe) { diff --git a/src/main/java/org/rsultan/example/MedoidShiftExample.java b/src/main/java/org/rsultan/example/MedoidShiftExample.java index 52e5066..fca75d3 100644 --- a/src/main/java/org/rsultan/example/MedoidShiftExample.java +++ b/src/main/java/org/rsultan/example/MedoidShiftExample.java @@ -1,20 +1,14 @@ package org.rsultan.example; -import static java.awt.image.BufferedImage.TYPE_INT_RGB; - import java.awt.Color; -import java.awt.image.BufferedImage; import java.io.File; import java.io.IOException; import java.util.ArrayList; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; -import java.util.stream.IntStream; import javax.imageio.ImageIO; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.factory.Nd4j; -import org.rsultan.core.clustering.kmedoids.KMeans; -import org.rsultan.core.clustering.kmedoids.KMedians; import org.rsultan.core.clustering.medoidshift.MeanShift; import org.rsultan.core.clustering.medoidshift.MedianShift; import org.rsultan.dataframe.Column; @@ -71,10 +65,10 @@ public static void main(String[] args) }); var trainedMeanShift = futureMeanShift.get(); - System.out.println("MeanShift Centroids : " + meanShift.getC()); + System.out.println("MeanShift Centroids : " + meanShift.getCentroids()); var trainedMedianShift = futureMedianShift.get(); - System.out.println("MedianShift Centroids : " + medianShift.getC()); + System.out.println("MedianShift Centroids : " + medianShift.getCentroids()); } } diff --git a/src/main/java/org/rsultan/utils/CSVUtils.java b/src/main/java/org/rsultan/utils/CSVUtils.java index 9bba709..d3da41d 100644 --- a/src/main/java/org/rsultan/utils/CSVUtils.java +++ b/src/main/java/org/rsultan/utils/CSVUtils.java @@ -15,8 +15,8 @@ public class CSVUtils { - private static final Pattern DOUBLE_VALUE_REGEX = Pattern.compile("\\d+\\.\\d+"); - private static final Pattern LONG_VALUE_REGEX = Pattern.compile("\\d+"); + private static final Pattern DOUBLE_VALUE_REGEX = Pattern.compile("-?\\d+\\.\\d+"); + private static final Pattern LONG_VALUE_REGEX = Pattern.compile("-?\\d+"); public static final String HEADER_PREFIX = "c"; private static Object getValueWithType(String value) { diff --git a/src/test/java/org/rsultan/core/clustering/MedoidShiftTest.java b/src/test/java/org/rsultan/core/clustering/MedoidShiftTest.java index fad9434..9081d6e 100644 --- a/src/test/java/org/rsultan/core/clustering/MedoidShiftTest.java +++ b/src/test/java/org/rsultan/core/clustering/MedoidShiftTest.java @@ -14,9 +14,6 @@ import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.factory.Nd4j; -import org.rsultan.core.clustering.kmedoids.KMeans; -import org.rsultan.core.clustering.kmedoids.KMedians; -import org.rsultan.core.clustering.kmedoids.KMedoids; import org.rsultan.core.clustering.medoidshift.MeanShift; import org.rsultan.core.clustering.medoidshift.MedianShift; import org.rsultan.core.clustering.medoidshift.MedoidShift; @@ -31,12 +28,12 @@ public class MedoidShiftTest { private static Stream params_that_must_apply_kmedoids() { return Stream.of( - Arguments.of(new MeanShift(60, 30)), - Arguments.of(new MeanShift(65, 30)), - Arguments.of(new MeanShift(70, 30)), - Arguments.of(new MedianShift(60, 30)), - Arguments.of(new MedianShift(65, 30)), - Arguments.of(new MedianShift(70, 30)) + Arguments.of(new MeanShift(60D, 30)), + Arguments.of(new MeanShift(65D, 30)), + Arguments.of(new MeanShift(70D, 30)), + Arguments.of(new MedianShift(60D, 30)), + Arguments.of(new MedianShift(65D, 30)), + Arguments.of(new MedianShift(70D, 30)) ); } @@ -50,9 +47,10 @@ public void must_apply_kmedoids(MedoidShift medoidShift) { new Column<>("c3", range(0, 20).boxed().map(idx -> nextFloat(0, 10)).collect(toList())), new Column<>("c4", range(0, 20).boxed().map(idx -> nextInt(0, 10)).collect(toList())) ); - medoidShift.train(dataframe); + medoidShift = medoidShift.train(dataframe); + medoidShift.showMetrics(); - assertThat(medoidShift.getC()).isNotNull(); + assertThat(medoidShift.getCentroids()).isNotNull(); assertThat(medoidShift.predict(dataframe)).isNotNull(); } } diff --git a/src/test/java/org/rsultan/core/regression/LinearRegressionTest.java b/src/test/java/org/rsultan/core/regression/LinearRegressionTest.java index 4bf0142..0bf6f08 100644 --- a/src/test/java/org/rsultan/core/regression/LinearRegressionTest.java +++ b/src/test/java/org/rsultan/core/regression/LinearRegressionTest.java @@ -82,7 +82,7 @@ public void must_apply_linear_regression( double[] expectedPValues, double[] expectedPredictions ) throws IOException { - var dataframe = Dataframes.csv(getResourceFileName("org/rsultan/utils/example.csv")); + var dataframe = Dataframes.csv(getResourceFileName("org/rsultan/utils/example-linear-regression.csv")); var linearRegression = new LinearRegression() .setPredictorNames(predictors) .setResponseVariableName(responseVariable) diff --git a/src/test/java/org/rsultan/dataframe/DataframeTest.java b/src/test/java/org/rsultan/dataframe/DataframeTest.java index 91981f9..1c5e87c 100644 --- a/src/test/java/org/rsultan/dataframe/DataframeTest.java +++ b/src/test/java/org/rsultan/dataframe/DataframeTest.java @@ -24,13 +24,15 @@ private static Stream params_that_must_load_dataframe_correctly() { of(new Column[]{new Column<>("Doubles", 0D, 1D, 2D, 3D, 4D)}, 5, 1), of(new Column[]{new Column<>("Floats", 0F, 1F, 2F, 3F, 4F)}, 5, 1), of(new Column[]{new Column<>("Strings", "1.1", "2.1", "3", "4.4", "5.3")}, 5, 1), + of(new Column[]{new Column<>("Negatives", "-1.1", "-2.1", "-3", "-4.4", "-5.3")}, 5, 1), of(new Column[]{ new Column<>("Integers", 0, 1, 2, 3, 4), new Column<>("Longs", 0L, 1L, 2L, 3L, 4L), new Column<>("Doubles", 0D, 1D, 2D, 3D, 4D), new Column<>("Floats", 0F, 1F, 2F, 3F, 4F), - new Column<>("Strings", "1.1", "2.1", "3", "4.4", "5.3") - }, 5, 5) + new Column<>("Strings", "1.1", "2.1", "3", "4.4", "5.3"), + new Column<>("Negatives", "-1.1", "-2.1", "-3", "-4.4", "-5.3") + }, 5, 6) ); } @@ -186,11 +188,11 @@ public void must_remove_column() { @Test public void must_load_dataframe_from_csv() throws IOException { var df = Dataframes.csv(getResourceFileName("org/rsultan/utils/example.csv")); - assertThat(df.get("y")).containsExactly(1L, 2L, 3L, 4L, 5L); - assertThat(df.get("x")).containsExactly(1.0D, 2.0D, 3.0D, 4.0D, 5.0D); - assertThat(df.get("x2")).containsExactly(1L, 4L, 9L, 16L, 25L); - assertThat(df.get("x3")).containsExactly(1L, 8L, 27L, 64L, 125L); - assertThat(df.get("strColumn")).containsExactly("a", "b", "c", "d", "e"); + assertThat(df.get("y")).containsExactly(1L, 2L, 3L, 4L, 5L, -6L); + assertThat(df.get("x")).containsExactly(1.0D, 2.0D, 3.0D, 4.0D, 5.0D, -5.0D); + assertThat(df.get("x2")).containsExactly(1L, 4L, 9L, 16L, 25L, -25L); + assertThat(df.get("x3")).containsExactly(1L, 8L, 27L, 64L, 125L, -125L); + assertThat(df.get("strColumn")).containsExactly("a", "b", "c", "d", "e", "f"); } @Test diff --git a/src/test/java/org/rsultan/utils/CSVUtilsTest.java b/src/test/java/org/rsultan/utils/CSVUtilsTest.java index e03d99d..904923a 100644 --- a/src/test/java/org/rsultan/utils/CSVUtilsTest.java +++ b/src/test/java/org/rsultan/utils/CSVUtilsTest.java @@ -50,18 +50,18 @@ public void must_throw_exception_due_to_wrong_input(String fileName, public void must_read_csv_and_return_columns() throws IOException { var columns = CSVUtils.read(getResourceFileName(EXAMPLE_CSV), ",", true); - assertThat(columns).hasSize(5); + assertThat(columns).hasSize((5)); assertThat(columns[0].columnName()).isEqualTo("y"); - assertThat((List) columns[0].values()).hasSize(5).containsExactly(1L, 2L, 3L, 4L, 5L); + assertThat((List) columns[0].values()).hasSize(6).containsExactly(1L, 2L, 3L, 4L, 5L, -6L); assertThat(columns[1].columnName()).isEqualTo("x"); - assertThat((List) columns[1].values()).hasSize(5).containsExactly(1.0D, 2.0D, 3.0D, 4.0D, 5.0D); + assertThat((List) columns[1].values()).hasSize(6).containsExactly(1.0D, 2.0D, 3.0D, 4.0D, 5.0D, -5.0D); assertThat(columns[2].columnName()).isEqualTo("x2"); - assertThat((List) columns[2].values()).hasSize(5).containsExactly(1L, 4L, 9L, 16L, 25L); + assertThat((List) columns[2].values()).hasSize(6).containsExactly(1L, 4L, 9L, 16L, 25L, -25L); assertThat(columns[3].columnName()).isEqualTo("x3"); - assertThat((List) columns[3].values()).hasSize(5).containsExactly(1L, 8L, 27L, 64L, 125L); + assertThat((List) columns[3].values()).hasSize(6).containsExactly(1L, 8L, 27L, 64L, 125L, -125L); assertThat(columns[4].columnName()).isEqualTo("strColumn"); - assertThat((List) columns[4].values()).hasSize(5).containsExactly("a", "b", "c", "d", "e"); + assertThat((List) columns[4].values()).hasSize(6).containsExactly("a", "b", "c", "d", "e", "f"); } diff --git a/src/test/resources/org/rsultan/utils/example-linear-regression.csv b/src/test/resources/org/rsultan/utils/example-linear-regression.csv new file mode 100644 index 0000000..aa904d4 --- /dev/null +++ b/src/test/resources/org/rsultan/utils/example-linear-regression.csv @@ -0,0 +1,6 @@ +y,x,x2,x3,strColumn +1,1.0,1,1,a +2,2.0,4,8,b +3,3.0,9,27,c +4,4.0,16,64,d +5,5.0,25,125,e \ No newline at end of file diff --git a/src/test/resources/org/rsultan/utils/example.csv b/src/test/resources/org/rsultan/utils/example.csv index d2ad527..4dcf3d1 100644 --- a/src/test/resources/org/rsultan/utils/example.csv +++ b/src/test/resources/org/rsultan/utils/example.csv @@ -4,3 +4,4 @@ y,x,x2,x3,strColumn 3,3.0,9,27,c 4,4.0,16,64,d 5,5.0,25,125,e +-6,-5.0,-25,-125,f diff --git a/version.properties b/version.properties index d45c622..1d9ac47 100755 --- a/version.properties +++ b/version.properties @@ -1 +1 @@ -version.next=1.0.2 +version.next=1.1.0