Skip to content

Commit

Permalink
for scalac RBM.scala
Browse files Browse the repository at this point in the history
  • Loading branch information
Yusuke Sugomori committed Mar 21, 2013
1 parent 1a4c59d commit eafd9bd
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 43 deletions.
6 changes: 5 additions & 1 deletion java/DBN/src/DBN.java
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ public void predict(int[] x, double[] y) {
log_layer.softmax(y);
}

public static void main(String[] arg) {
private static void test_dbn() {
Random rng = new Random(123);

double pretrain_lr = 0.1;
Expand Down Expand Up @@ -215,4 +215,8 @@ public static void main(String[] arg) {
System.out.println();
}
}

public static void main(String[] args) {
test_dbn();
}
}
6 changes: 5 additions & 1 deletion java/LogisticRegression/src/LogisticRegression.java
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public void predict(int[] x, double[] y) {
softmax(y);
}

public static void main(String[] arg) {
private static void test_lr() {
double learning_rate = 0.1;
double n_epochs = 500;

Expand Down Expand Up @@ -121,4 +121,8 @@ public static void main(String[] arg) {
System.out.println();
}
}

public static void main(String[] args) {
test_lr();
}
}
7 changes: 6 additions & 1 deletion java/RBM/src/RBM.java
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ public void reconstruct(int[] v, double[] reconstructed_v) {



public static void main(String[] arg) {
private static void test_rbm() {
Random rng = new Random(123);

double learning_rate = 0.1;
Expand Down Expand Up @@ -212,4 +212,9 @@ public static void main(String[] arg) {
System.out.println();
}
}

public static void main(String[] args) {
test_rbm();
}

}
87 changes: 47 additions & 40 deletions scala/RBM.scala
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
// $ scalac RBM.scala
// $ scala RBM

import scala.util.Random
import scala.math

Expand Down Expand Up @@ -159,57 +162,61 @@ class RBM(val N: Int, val n_visible: Int, val n_hidden: Int,
}


def test_rbm() {
val rng: Random = new Random(123)
object RBM {
def test_rbm() {
val rng: Random = new Random(123)

var learning_rate: Double = 0.1
val training_epochs: Int = 1000
val k: Int = 1
var learning_rate: Double = 0.1
val training_epochs: Int = 1000
val k: Int = 1

val train_N: Int = 6;
val test_N: Int = 2
val n_visible: Int = 6
val n_hidden: Int = 3
val train_N: Int = 6;
val test_N: Int = 2
val n_visible: Int = 6
val n_hidden: Int = 3

val train_X: Array[Array[Int]] = Array(
Array(1, 1, 1, 0, 0, 0),
Array(1, 0, 1, 0, 0, 0),
Array(1, 1, 1, 0, 0, 0),
Array(0, 0, 1, 1, 1, 0),
Array(0, 0, 1, 0, 1, 0),
Array(0, 0, 1, 1, 1, 0)
)
val train_X: Array[Array[Int]] = Array(
Array(1, 1, 1, 0, 0, 0),
Array(1, 0, 1, 0, 0, 0),
Array(1, 1, 1, 0, 0, 0),
Array(0, 0, 1, 1, 1, 0),
Array(0, 0, 1, 0, 1, 0),
Array(0, 0, 1, 1, 1, 0)
)


val rbm: RBM = new RBM(train_N, n_visible, n_hidden, rng=rng)
val rbm: RBM = new RBM(train_N, n_visible, n_hidden, rng=rng)

var i: Int = 0
var j: Int = 0
var i: Int = 0
var j: Int = 0

// train
var epoch: Int = 0
for(epoch <- 0 until training_epochs) {
for(i <- 0 until train_N) {
rbm.contrastive_divergence(train_X(i), learning_rate, k)
// train
var epoch: Int = 0
for(epoch <- 0 until training_epochs) {
for(i <- 0 until train_N) {
rbm.contrastive_divergence(train_X(i), learning_rate, k)
}
}
}

// test data
val test_X: Array[Array[Int]] = Array(
Array(1, 1, 0, 0, 0, 0),
Array(0, 0, 0, 1, 1, 0)
)
// test data
val test_X: Array[Array[Int]] = Array(
Array(1, 1, 0, 0, 0, 0),
Array(0, 0, 0, 1, 1, 0)
)

val reconstructed_X: Array[Array[Double]] = Array.ofDim[Double](test_N, n_visible)
for(i <- 0 until test_N) {
rbm.reconstruct(test_X(i), reconstructed_X(i))
for(j <- 0 until n_visible) {
printf("%.5f", reconstructed_X(i)(j))
val reconstructed_X: Array[Array[Double]] = Array.ofDim[Double](test_N, n_visible)
for(i <- 0 until test_N) {
rbm.reconstruct(test_X(i), reconstructed_X(i))
for(j <- 0 until n_visible) {
printf("%.5f ", reconstructed_X(i)(j))
}
println()
}
println()
}

}
}

test_rbm()
def main(args: Array[String]) {
test_rbm()
}

}

0 comments on commit eafd9bd

Please sign in to comment.