Skip to content

Commit

Permalink
develop --> master for release 1.1.0 #35
Browse files Browse the repository at this point in the history
develop --> master for release 1.1.0
  • Loading branch information
remisultan authored Apr 14, 2021
2 parents bceff65 + 144eb3f commit ae9cc7d
Show file tree
Hide file tree
Showing 16 changed files with 99 additions and 64 deletions.
29 changes: 28 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,31 @@
# java-ml
This repo intends to implement Machine Learning algorithms with Java and https://github.com/deeplearning4j/nd4j purely for understanding

You can check for examples in this package: ``src/main/java/org/rsultan/example``
You can check for examples in this package: ``src/main/java/org/rsultan/example``

## Getting started

Clone the repo and execute this

```bash
$ git fetch --all --tags
$ git checkout tags/<latest-release> -b <your-branch>
$ ./mvnw clean install
```

Then you can import this to your `pom.xml`

```xml
<dependency>
<groupId>org.rsultan</groupId>
<artifactId>java-ml</artifactId>
<version>[latest-version]</version>
</dependency>
```

There is an existing artifactory today but I am not satisfied with how to make it public today.
There will be an artifactory in the future.

Once your are good with this, you can go read the [wiki](https://github.com/remisultan/java-ml/wiki)

Good luck !
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@ public class KMedoidEvaluator {
private final MedoidType medoidType;
private final int epoch;

public KMedoidEvaluator(int minK, int maxK) {
this(minK, maxK, MEAN);
}

public KMedoidEvaluator(int minK, int maxK, MedoidType medoidType) {
this(minK, maxK, medoidType, 100);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public INDArray initialiseCenters(long K, INDArray X) {
var probabilities = squaredDistances.div(squaredDistances.sum(true, 0));
var cumulativeSumTheshold = probabilities.cumsum(0)
.getWhere(nextDouble(0, 1), greaterThanOrEqual());
addNewCenters(X, centers, X.columns() - cumulativeSumTheshold.columns() + 1);
addNewCenters(X, centers, X.columns() - cumulativeSumTheshold.columns());
}
return C.columns() == 1 ? C.transpose() : C;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

public class MeanShift extends MedoidShift {

public MeanShift(long bandwidth, long epoch) {
public MeanShift(double bandwidth, long epoch) {
super(bandwidth, epoch, MEAN);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
package org.rsultan.core.clustering.medoidshift;

import static org.rsultan.core.clustering.type.MedoidType.MEAN;
import static org.rsultan.core.clustering.type.MedoidType.MEDIAN;

import org.rsultan.dataframe.Dataframe;

public class MedianShift extends MedoidShift {

public MedianShift(long bandwidth, long epoch) {
public MedianShift(double bandwidth, long epoch) {
super(bandwidth, epoch, MEDIAN);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package org.rsultan.core.clustering.medoidshift;


import static java.util.stream.Collectors.toList;
import static java.util.stream.LongStream.range;
import static java.util.stream.LongStream.rangeClosed;

import java.util.Arrays;
import java.util.concurrent.atomic.AtomicBoolean;
Expand All @@ -12,17 +12,22 @@
import org.rsultan.core.clustering.type.MedoidType;
import org.rsultan.dataframe.Column;
import org.rsultan.dataframe.Dataframe;
import org.rsultan.dataframe.Dataframes;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class MedoidShift implements Clustering {

private static final Logger LOG = LoggerFactory.getLogger(MedoidShift.class);

private final MedoidType medoidType;
private final long bandwidth;
private final double bandwidth;
private final long epoch;

private INDArray C;
private INDArray centroids;
private INDArray Xt;

protected MedoidShift(long bandwidth, long epoch, MedoidType medoidType) {
protected MedoidShift(double bandwidth, long epoch, MedoidType medoidType) {
this.medoidType = medoidType;
this.epoch = epoch;
this.bandwidth = bandwidth;
Expand All @@ -33,32 +38,30 @@ public MedoidShift train(Dataframe dataframe) {
var medoidFactory = medoidType.getMedoidFactory();
var isTerminated = new AtomicBoolean(false);
Xt = dataframe.toMatrix().transpose();
C = Xt.dup();
centroids = Xt.dup();
range(0, epoch)
.filter(epoch -> !isTerminated.get())
.forEach(epoch -> {
System.out.println("Epoch " + epoch);
var newCentroidList = range(0, C.columns())
LOG.info("Epoch : {}", epoch);
var newCentroidList = range(0, centroids.columns())
.parallel().unordered()
.mapToObj(col -> {
var centroid = C.getColumn(col);
var centroid = centroids.getColumn(col);
var range = epoch == 0 ? range(col, Xt.columns()) : range(0, Xt.columns());
return range.parallel().unordered().mapToObj(Xt::getColumn)
.filter(feature -> medoidFactory.computeNorm(feature.sub(centroid)) < bandwidth)
.collect(toList());
})
.map(features -> Nd4j.create(features, features.size(), C.rows()))
.map(features -> Nd4j.create(features, features.size(), centroids.rows()))
.map(medoidFactory::computeMedoids)
.distinct()
.collect(toList());
var newC = Nd4j.create(newCentroidList, newCentroidList.size(), C.rows())
var newC = Nd4j.create(newCentroidList, newCentroidList.size(), centroids.rows())
.transpose();
if (C.equalShapes(newC) && C.equals(newC) || C.columns() == 1) {
if (centroids.equalShapes(newC) && centroids.equals(newC) || centroids.columns() == 1) {
isTerminated.set(true);
}
C = newC;
System.out.println(C);
System.out.println(Arrays.toString(C.shape()));
centroids = newC;
});

return this;
Expand All @@ -68,14 +71,23 @@ public MedoidShift train(Dataframe dataframe) {
public Dataframe predict(Dataframe dataframe) {
var medoidFactory = medoidType.getMedoidFactory();
var Xpredict = dataframe.toMatrix();
var distances = medoidFactory.computeDistance(Xpredict, C.transpose());
var distances = medoidFactory.computeDistance(Xpredict, centroids.transpose());
var indices = Nd4j.argMin(distances, 1);
var predictions = C.transpose().get(indices);
var predictions = centroids.transpose().get(indices);
return dataframe.addColumn(new Column<>("predictions",
range(0, predictions.rows()).mapToObj(predictions::getRow).collect(toList())));
}

public INDArray getC() {
return C;

public void showMetrics() {
var centroids = range(0, this.centroids.columns()).boxed()
.map(idx -> Arrays.toString(this.centroids.getColumn(idx).toDoubleVector()))
.collect(toList());
var indices = new Column<>("K", rangeClosed(1, this.centroids.columns()).boxed().collect(toList()));
Dataframes.create(indices, new Column<>("centroids", centroids)).tail();
}

public INDArray getCentroids() {
return centroids;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

public class MatrixDataframe implements MatrixTransform {

private static final String NUMBER_REGEX = "^\\d+(\\.\\d+)*$";
private static final String NUMBER_REGEX = "^(-?\\d+(\\.\\d+)*)$";
private final Dataframe dataframe;

public MatrixDataframe(Dataframe dataframe) {
Expand Down
10 changes: 2 additions & 8 deletions src/main/java/org/rsultan/example/MedoidShiftExample.java
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
package org.rsultan.example;

import static java.awt.image.BufferedImage.TYPE_INT_RGB;

import java.awt.Color;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.stream.IntStream;
import javax.imageio.ImageIO;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4j;
import org.rsultan.core.clustering.kmedoids.KMeans;
import org.rsultan.core.clustering.kmedoids.KMedians;
import org.rsultan.core.clustering.medoidshift.MeanShift;
import org.rsultan.core.clustering.medoidshift.MedianShift;
import org.rsultan.dataframe.Column;
Expand Down Expand Up @@ -71,10 +65,10 @@ public static void main(String[] args)
});

var trainedMeanShift = futureMeanShift.get();
System.out.println("MeanShift Centroids : " + meanShift.getC());
System.out.println("MeanShift Centroids : " + meanShift.getCentroids());

var trainedMedianShift = futureMedianShift.get();
System.out.println("MedianShift Centroids : " + medianShift.getC());
System.out.println("MedianShift Centroids : " + medianShift.getCentroids());

}
}
4 changes: 2 additions & 2 deletions src/main/java/org/rsultan/utils/CSVUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

public class CSVUtils {

private static final Pattern DOUBLE_VALUE_REGEX = Pattern.compile("\\d+\\.\\d+");
private static final Pattern LONG_VALUE_REGEX = Pattern.compile("\\d+");
private static final Pattern DOUBLE_VALUE_REGEX = Pattern.compile("-?\\d+\\.\\d+");
private static final Pattern LONG_VALUE_REGEX = Pattern.compile("-?\\d+");
public static final String HEADER_PREFIX = "c";

private static Object getValueWithType(String value) {
Expand Down
20 changes: 9 additions & 11 deletions src/test/java/org/rsultan/core/clustering/MedoidShiftTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4j;
import org.rsultan.core.clustering.kmedoids.KMeans;
import org.rsultan.core.clustering.kmedoids.KMedians;
import org.rsultan.core.clustering.kmedoids.KMedoids;
import org.rsultan.core.clustering.medoidshift.MeanShift;
import org.rsultan.core.clustering.medoidshift.MedianShift;
import org.rsultan.core.clustering.medoidshift.MedoidShift;
Expand All @@ -31,12 +28,12 @@ public class MedoidShiftTest {

private static Stream<Arguments> params_that_must_apply_kmedoids() {
return Stream.of(
Arguments.of(new MeanShift(60, 30)),
Arguments.of(new MeanShift(65, 30)),
Arguments.of(new MeanShift(70, 30)),
Arguments.of(new MedianShift(60, 30)),
Arguments.of(new MedianShift(65, 30)),
Arguments.of(new MedianShift(70, 30))
Arguments.of(new MeanShift(60D, 30)),
Arguments.of(new MeanShift(65D, 30)),
Arguments.of(new MeanShift(70D, 30)),
Arguments.of(new MedianShift(60D, 30)),
Arguments.of(new MedianShift(65D, 30)),
Arguments.of(new MedianShift(70D, 30))
);
}

Expand All @@ -50,9 +47,10 @@ public void must_apply_kmedoids(MedoidShift medoidShift) {
new Column<>("c3", range(0, 20).boxed().map(idx -> nextFloat(0, 10)).collect(toList())),
new Column<>("c4", range(0, 20).boxed().map(idx -> nextInt(0, 10)).collect(toList()))
);
medoidShift.train(dataframe);
medoidShift = medoidShift.train(dataframe);
medoidShift.showMetrics();

assertThat(medoidShift.getC()).isNotNull();
assertThat(medoidShift.getCentroids()).isNotNull();
assertThat(medoidShift.predict(dataframe)).isNotNull();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public void must_apply_linear_regression(
double[] expectedPValues,
double[] expectedPredictions
) throws IOException {
var dataframe = Dataframes.csv(getResourceFileName("org/rsultan/utils/example.csv"));
var dataframe = Dataframes.csv(getResourceFileName("org/rsultan/utils/example-linear-regression.csv"));
var linearRegression = new LinearRegression()
.setPredictorNames(predictors)
.setResponseVariableName(responseVariable)
Expand Down
16 changes: 9 additions & 7 deletions src/test/java/org/rsultan/dataframe/DataframeTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@ private static Stream<Arguments> params_that_must_load_dataframe_correctly() {
of(new Column[]{new Column<>("Doubles", 0D, 1D, 2D, 3D, 4D)}, 5, 1),
of(new Column[]{new Column<>("Floats", 0F, 1F, 2F, 3F, 4F)}, 5, 1),
of(new Column[]{new Column<>("Strings", "1.1", "2.1", "3", "4.4", "5.3")}, 5, 1),
of(new Column[]{new Column<>("Negatives", "-1.1", "-2.1", "-3", "-4.4", "-5.3")}, 5, 1),
of(new Column<?>[]{
new Column<>("Integers", 0, 1, 2, 3, 4),
new Column<>("Longs", 0L, 1L, 2L, 3L, 4L),
new Column<>("Doubles", 0D, 1D, 2D, 3D, 4D),
new Column<>("Floats", 0F, 1F, 2F, 3F, 4F),
new Column<>("Strings", "1.1", "2.1", "3", "4.4", "5.3")
}, 5, 5)
new Column<>("Strings", "1.1", "2.1", "3", "4.4", "5.3"),
new Column<>("Negatives", "-1.1", "-2.1", "-3", "-4.4", "-5.3")
}, 5, 6)
);
}

Expand Down Expand Up @@ -186,11 +188,11 @@ public void must_remove_column() {
@Test
public void must_load_dataframe_from_csv() throws IOException {
var df = Dataframes.csv(getResourceFileName("org/rsultan/utils/example.csv"));
assertThat(df.get("y")).containsExactly(1L, 2L, 3L, 4L, 5L);
assertThat(df.get("x")).containsExactly(1.0D, 2.0D, 3.0D, 4.0D, 5.0D);
assertThat(df.get("x2")).containsExactly(1L, 4L, 9L, 16L, 25L);
assertThat(df.get("x3")).containsExactly(1L, 8L, 27L, 64L, 125L);
assertThat(df.get("strColumn")).containsExactly("a", "b", "c", "d", "e");
assertThat(df.get("y")).containsExactly(1L, 2L, 3L, 4L, 5L, -6L);
assertThat(df.get("x")).containsExactly(1.0D, 2.0D, 3.0D, 4.0D, 5.0D, -5.0D);
assertThat(df.get("x2")).containsExactly(1L, 4L, 9L, 16L, 25L, -25L);
assertThat(df.get("x3")).containsExactly(1L, 8L, 27L, 64L, 125L, -125L);
assertThat(df.get("strColumn")).containsExactly("a", "b", "c", "d", "e", "f");
}

@Test
Expand Down
12 changes: 6 additions & 6 deletions src/test/java/org/rsultan/utils/CSVUtilsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,18 @@ public void must_throw_exception_due_to_wrong_input(String fileName,
public void must_read_csv_and_return_columns() throws IOException {
var columns = CSVUtils.read(getResourceFileName(EXAMPLE_CSV), ",", true);

assertThat(columns).hasSize(5);
assertThat(columns).hasSize((5));

assertThat(columns[0].columnName()).isEqualTo("y");
assertThat((List<Long>) columns[0].values()).hasSize(5).containsExactly(1L, 2L, 3L, 4L, 5L);
assertThat((List<Long>) columns[0].values()).hasSize(6).containsExactly(1L, 2L, 3L, 4L, 5L, -6L);
assertThat(columns[1].columnName()).isEqualTo("x");
assertThat((List<Double>) columns[1].values()).hasSize(5).containsExactly(1.0D, 2.0D, 3.0D, 4.0D, 5.0D);
assertThat((List<Double>) columns[1].values()).hasSize(6).containsExactly(1.0D, 2.0D, 3.0D, 4.0D, 5.0D, -5.0D);
assertThat(columns[2].columnName()).isEqualTo("x2");
assertThat((List<Long>) columns[2].values()).hasSize(5).containsExactly(1L, 4L, 9L, 16L, 25L);
assertThat((List<Long>) columns[2].values()).hasSize(6).containsExactly(1L, 4L, 9L, 16L, 25L, -25L);
assertThat(columns[3].columnName()).isEqualTo("x3");
assertThat((List<Long>) columns[3].values()).hasSize(5).containsExactly(1L, 8L, 27L, 64L, 125L);
assertThat((List<Long>) columns[3].values()).hasSize(6).containsExactly(1L, 8L, 27L, 64L, 125L, -125L);
assertThat(columns[4].columnName()).isEqualTo("strColumn");
assertThat((List<String>) columns[4].values()).hasSize(5).containsExactly("a", "b", "c", "d", "e");
assertThat((List<String>) columns[4].values()).hasSize(6).containsExactly("a", "b", "c", "d", "e", "f");

}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
y,x,x2,x3,strColumn
1,1.0,1,1,a
2,2.0,4,8,b
3,3.0,9,27,c
4,4.0,16,64,d
5,5.0,25,125,e
1 change: 1 addition & 0 deletions src/test/resources/org/rsultan/utils/example.csv
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ y,x,x2,x3,strColumn
3,3.0,9,27,c
4,4.0,16,64,d
5,5.0,25,125,e
-6,-5.0,-25,-125,f
2 changes: 1 addition & 1 deletion version.properties
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version.next=1.0.2
version.next=1.1.0

0 comments on commit ae9cc7d

Please sign in to comment.