Skip to content

Commit

Permalink
develop --> master for release 1.1.1 #39
Browse files Browse the repository at this point in the history
develop --> master for release 1.1.1
  • Loading branch information
remisultan authored Apr 21, 2021
2 parents 33f29a0 + f687d35 commit ebd5c61
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 31 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ This repo intends to implement Machine Learning algorithms with Java and https:/

You can check for examples in this package: ``src/main/java/org/rsultan/example``

## Requirements

- JDK15+

## Getting started

Clone the repo and execute this
Expand All @@ -23,7 +27,7 @@ Then you can import this to your `pom.xml`
</dependency>
```

There is an existing artifactory today but I am not satisfied with how to make it public today.
There is an existing artifactory but I am not satisfied with how to make it public today.
There will be an artifactory in the future.

Once your are good with this, you can go read the [wiki](https://github.com/remisultan/java-ml/wiki)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,15 @@ public KMedoids train(Dataframe dataframe) {
X = dataframe.toMatrix();
Xt = X.transpose();
var medoidFactory = medoidType.getMedoidFactory();

LOG.info("Initialising centroids");
centroids = buildInitialCentroids(medoidFactory);
LOG.info("Centroids initialised");

range(0, numberOfIterations)
.filter(epoch -> ofNullable(loss).orElse(-1.0D) != 0.0D)
.forEach(epoch -> {
LOG.info("Epoch {}, Loss : {} for {}", epoch, loss, medoidType);
LOG.info("Epoch {} {}", epoch, medoidType);
distances = medoidFactory.computeDistance(centroids, X).transpose();
cluster = Nd4j.argMin(distances, 1);

Expand All @@ -78,6 +81,7 @@ public KMedoids train(Dataframe dataframe) {
var newCenters = Nd4j.create(newMedoids, K, Xt.rows());
loss = medoidFactory.computeNorm(centroids.sub(newCenters));
centroids = newCenters;
LOG.info("Epoch {}, Loss : {} for {}", epoch, loss, medoidType);
});
this.WCSS = distances.transpose().mmul(distances).sum().div(distances.rows());
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ public LinearRegression setPredictorNames(String... names) {
public LinearRegression train(Dataframe dataframe) {
var dataframeIntercept = dataframe.map(INTERCEPT, () -> 1);
X = dataframeIntercept.toMatrix(predictorNames);
XMean = X.mean(true, 1);
X = X.div(XMean);
Xt = X.transpose();
Y = dataframeIntercept.toVector(responseVariableName);

Expand Down
48 changes: 21 additions & 27 deletions src/test/java/org/rsultan/core/regression/LinearRegressionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,37 +30,31 @@ private static Stream<Arguments> params_that_must_apply_linear_regression() {
return Stream.of(
of("y",
new String[]{"x"},
new double[]{-2.5062695924764875, 3.136363636363635},
0.19435736677115964,
0.4408598039866638,
0.9028213166144202,
new double[]{-3.257172004150948, 9.360368540114695},
new double[]{0.9763837685298659, 0.0012912173574425312},
new double[]{0.6300940438871474, 3.7664576802507823, 6.902821316614418,
10.039184952978053, 13.175548589341687}),
new double[]{1.609823385706477E-15, 1.0},
3.0075322011551075E-30,
1.7342238036525468E-15,
1.0,
new double[]{0.6855708946307656, 1.4124415416419548E15},
new double[]{0.27110493295189886, 0.0},
new double[]{1.0000000000000016, 2.0000000000000018, 3.0000000000000018, 4.000000000000002, 5.000000000000002}),
of("y",
new String[]{"x", "x2"},
new double[]{3.6994152781695915, -5.663780713939179, 2.9897427464590947},
0.015929603431057288,
0.12621253278124675,
0.9920351982844714,
new double[]{4.072386773691408, -5.460955300045582, 13.427152773172018},
new double[]{0.027670010280800295, 0.9840326600175213, 0.002750470910569236},
new double[]{1.0253773106895072, 4.330824836127612, 13.615757854483908,
28.88017636575839, 50.124080369951066}),
new double[]{1.1124434706744069E-13, 0.9999999999999791, 1.3183898417423734E-15 },
4.332236737008807E-27,
6.581972908641304E-14,
1.0,
new double[]{0.49839415538353254, 5.878981885019946E12, 0.047400374308657484},
new double[]{0.3338093937142339, 0.0, 0.4832508422900186},
new double[]{1.0000000000000917, 2.0000000000000746, 3.000000000000061, 4.000000000000049, 5.00000000000004}),
of("y",
new String[]{"x", "x2", "x3"},
new double[]{-8.170418390061572, 16.59195036587023, -10.47235921093084,
3.050349984089662},
5.727012408526747E-4,
0.023931177172313835,
0.9997136493795736,
new double[]{-5.010158649913984, 6.139000734969269, -8.380896856206233,
18.413660521274217},
new double[]{0.93729116849317, 0.051399003206532456, 0.9621983098931299,
0.017269651919306184},
new double[]{0.999522748967479, 7.526845370672824, 29.713649379592425,
85.86203468026427, 194.27410117722633})
new double[]{3.3661962106634746E-12, 0.9999999999994151, -1.0902390101819037E-13, 1.9165224962591765E-14},
2.498960921826301E-24,
1.5808102105649181E-12,
1.0,
new double[]{0.19358290785808935, 4.3994833288545204E10, -0.01292277119642834, 0.020574554152200282},
new double[]{0.43913350496326764, 7.235212429179683E-12, 0.5041132168725599, 0.49345183987744734},
new double[]{1.0000000000026914, 2.0000000000019136, 3.000000000001148, 4.000000000000509, 5.000000000000112})
);
}

Expand Down

0 comments on commit ebd5c61

Please sign in to comment.