From 3372add1b997373605a1727d2d78416da6ee31fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Sultan?= Date: Sun, 18 Apr 2021 18:12:23 +0200 Subject: [PATCH 1/2] fix: removing normalization of data in linear regression (#37) * fix: removing normlization of data in linear regression * fix: test modified accordingly --- .../regression/impl/LinearRegression.java | 2 - .../core/regression/LinearRegressionTest.java | 48 ++++++++----------- 2 files changed, 21 insertions(+), 29 deletions(-) diff --git a/src/main/java/org/rsultan/core/regression/impl/LinearRegression.java b/src/main/java/org/rsultan/core/regression/impl/LinearRegression.java index 766e5fd..38b79b3 100644 --- a/src/main/java/org/rsultan/core/regression/impl/LinearRegression.java +++ b/src/main/java/org/rsultan/core/regression/impl/LinearRegression.java @@ -47,8 +47,6 @@ public LinearRegression setPredictorNames(String... names) { public LinearRegression train(Dataframe dataframe) { var dataframeIntercept = dataframe.map(INTERCEPT, () -> 1); X = dataframeIntercept.toMatrix(predictorNames); - XMean = X.mean(true, 1); - X = X.div(XMean); Xt = X.transpose(); Y = dataframeIntercept.toVector(responseVariableName); diff --git a/src/test/java/org/rsultan/core/regression/LinearRegressionTest.java b/src/test/java/org/rsultan/core/regression/LinearRegressionTest.java index 0bf6f08..f76cdc7 100644 --- a/src/test/java/org/rsultan/core/regression/LinearRegressionTest.java +++ b/src/test/java/org/rsultan/core/regression/LinearRegressionTest.java @@ -30,37 +30,31 @@ private static Stream params_that_must_apply_linear_regression() { return Stream.of( of("y", new String[]{"x"}, - new double[]{-2.5062695924764875, 3.136363636363635}, - 0.19435736677115964, - 0.4408598039866638, - 0.9028213166144202, - new double[]{-3.257172004150948, 9.360368540114695}, - new double[]{0.9763837685298659, 0.0012912173574425312}, - new double[]{0.6300940438871474, 3.7664576802507823, 6.902821316614418, - 10.039184952978053, 13.175548589341687}), + new double[]{1.609823385706477E-15, 1.0}, + 3.0075322011551075E-30, + 1.7342238036525468E-15, + 1.0, + new double[]{0.6855708946307656, 1.4124415416419548E15}, + new double[]{0.27110493295189886, 0.0}, + new double[]{1.0000000000000016, 2.0000000000000018, 3.0000000000000018, 4.000000000000002, 5.000000000000002}), of("y", new String[]{"x", "x2"}, - new double[]{3.6994152781695915, -5.663780713939179, 2.9897427464590947}, - 0.015929603431057288, - 0.12621253278124675, - 0.9920351982844714, - new double[]{4.072386773691408, -5.460955300045582, 13.427152773172018}, - new double[]{0.027670010280800295, 0.9840326600175213, 0.002750470910569236}, - new double[]{1.0253773106895072, 4.330824836127612, 13.615757854483908, - 28.88017636575839, 50.124080369951066}), + new double[]{1.1124434706744069E-13, 0.9999999999999791, 1.3183898417423734E-15 }, + 4.332236737008807E-27, + 6.581972908641304E-14, + 1.0, + new double[]{0.49839415538353254, 5.878981885019946E12, 0.047400374308657484}, + new double[]{0.3338093937142339, 0.0, 0.4832508422900186}, + new double[]{1.0000000000000917, 2.0000000000000746, 3.000000000000061, 4.000000000000049, 5.00000000000004}), of("y", new String[]{"x", "x2", "x3"}, - new double[]{-8.170418390061572, 16.59195036587023, -10.47235921093084, - 3.050349984089662}, - 5.727012408526747E-4, - 0.023931177172313835, - 0.9997136493795736, - new double[]{-5.010158649913984, 6.139000734969269, -8.380896856206233, - 18.413660521274217}, - new double[]{0.93729116849317, 0.051399003206532456, 0.9621983098931299, - 0.017269651919306184}, - new double[]{0.999522748967479, 7.526845370672824, 29.713649379592425, - 85.86203468026427, 194.27410117722633}) + new double[]{3.3661962106634746E-12, 0.9999999999994151, -1.0902390101819037E-13, 1.9165224962591765E-14}, + 2.498960921826301E-24, + 1.5808102105649181E-12, + 1.0, + new double[]{0.19358290785808935, 4.3994833288545204E10, -0.01292277119642834, 0.020574554152200282}, + new double[]{0.43913350496326764, 7.235212429179683E-12, 0.5041132168725599, 0.49345183987744734}, + new double[]{1.0000000000026914, 2.0000000000019136, 3.000000000001148, 4.000000000000509, 5.000000000000112}) ); } From f687d35d468067e4ffbf86f7fcd9235e90eb684a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Sultan?= Date: Wed, 21 Apr 2021 14:43:29 +0200 Subject: [PATCH 2/2] chore: better logging for KMeans / KMedians (#38) * chore: better logging of kmedoids * fix: typo --- README.md | 6 +++++- .../java/org/rsultan/core/clustering/kmedoids/KMedoids.java | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index cc4bf55..6097d07 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,10 @@ This repo intends to implement Machine Learning algorithms with Java and https:/ You can check for examples in this package: ``src/main/java/org/rsultan/example`` +## Requirements + +- JDK15+ + ## Getting started Clone the repo and execute this @@ -23,7 +27,7 @@ Then you can import this to your `pom.xml` ``` -There is an existing artifactory today but I am not satisfied with how to make it public today. +There is an existing artifactory but I am not satisfied with how to make it public today. There will be an artifactory in the future. Once your are good with this, you can go read the [wiki](https://github.com/remisultan/java-ml/wiki) diff --git a/src/main/java/org/rsultan/core/clustering/kmedoids/KMedoids.java b/src/main/java/org/rsultan/core/clustering/kmedoids/KMedoids.java index 74b1425..c741f3d 100644 --- a/src/main/java/org/rsultan/core/clustering/kmedoids/KMedoids.java +++ b/src/main/java/org/rsultan/core/clustering/kmedoids/KMedoids.java @@ -55,12 +55,15 @@ public KMedoids train(Dataframe dataframe) { X = dataframe.toMatrix(); Xt = X.transpose(); var medoidFactory = medoidType.getMedoidFactory(); + + LOG.info("Initialising centroids"); centroids = buildInitialCentroids(medoidFactory); + LOG.info("Centroids initialised"); range(0, numberOfIterations) .filter(epoch -> ofNullable(loss).orElse(-1.0D) != 0.0D) .forEach(epoch -> { - LOG.info("Epoch {}, Loss : {} for {}", epoch, loss, medoidType); + LOG.info("Epoch {} {}", epoch, medoidType); distances = medoidFactory.computeDistance(centroids, X).transpose(); cluster = Nd4j.argMin(distances, 1); @@ -78,6 +81,7 @@ public KMedoids train(Dataframe dataframe) { var newCenters = Nd4j.create(newMedoids, K, Xt.rows()); loss = medoidFactory.computeNorm(centroids.sub(newCenters)); centroids = newCenters; + LOG.info("Epoch {}, Loss : {} for {}", epoch, loss, medoidType); }); this.WCSS = distances.transpose().mmul(distances).sum().div(distances.rows()); return this;