Skip to content

Commit

Permalink
Maxpool and convolution fixes (#56)
Browse files Browse the repository at this point in the history
* maxpool and convolution fixes

* softmax working

* dropout and softmax loss

* fix relu. whoops

* mnist

* big stability changes and QOL

* conv fallback

* fallback

* use matrix multiplyication
  • Loading branch information
christopherzimmerman authored Nov 1, 2020
1 parent 0464462 commit 6ba53ea
Show file tree
Hide file tree
Showing 43 changed files with 1,564 additions and 87 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -222,12 +222,12 @@ y_train = [[0.0], [1.0], [1.0], [0.0]].to_tensor
x = ctx.variable(x_train)
net = Num::NN::Network.new(ctx) do
input [2]
# A basic network with a single hidden layer using
# a ReLU activation function
linear(2, 3)
linear 3
relu
linear(3, 1)
linear 1
# SGD Optimizer
sgd 0.7
Expand Down
7 changes: 5 additions & 2 deletions examples/basic_xor_classifier/xor.cr
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

require "../../src/num"

Num::Rand.set_seed(2)

ctx = Num::Grad::Context(Tensor(Float64)).new

bsz = 32
Expand All @@ -35,9 +37,10 @@ x_train = ctx.variable(x_train_bool.as_type(Float64))
y = y_bool.as_type(Float64)

net = Num::NN::Network.new(ctx) do
linear 2, 3
input [2]
linear 3
relu
linear 3, 1
linear 1
sgd 0.7
sigmoid_cross_entropy_loss
end
Expand Down
7 changes: 5 additions & 2 deletions examples/iris_dataset/iris.cr
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

require "../../src/num"

Num::Rand.set_seed(2)

ctx = Num::Grad::Context(Tensor(Float64)).new

labels, x_train, y_train = Num::NN.load_iris_dataset
Expand All @@ -31,9 +33,10 @@ x_train = (x_train - x_train.mean(axis: 0)) / x_train.std(axis: 0)
x_train = ctx.variable(x_train)

net = Num::NN::Network.new(ctx) do
linear 4, 3
input [4]
linear 3
relu
linear 3, 3
linear 3
sgd 0.9
sigmoid_cross_entropy_loss
end
Expand Down
34 changes: 34 additions & 0 deletions examples/mnist/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
## MNIST

This is an example of how to use convolutional layers
and max pooling to learn the MNIST dataset. Using this approach,
our network achieves > 97% accuracy on the training dataset.

```crystal
net = Num::NN::Network.new(ctx) do
input [1, 28, 28]
conv2d 20, 5, 5
relu
maxpool({2, 2}, {0, 0}, {2, 2})
conv2d 20, 5, 5
maxpool({2, 2}, {0, 0}, {2, 2})
flatten
linear 10
relu
linear 10
softmax_cross_entropy_loss
sgd 0.01
end
```

```
Epoch: 0 | Accuracy: 0.8644276947705443
Epoch: 1 | Accuracy: 0.9558931430096052
Epoch: 2 | Accuracy: 0.9677494663820705
Epoch: 3 | Accuracy: 0.9735358858057631
Epoch: 4 | Accuracy: 0.9770711045891142
```

### Accuracy over time

![mnist](mnist.png)
59 changes: 59 additions & 0 deletions examples/mnist/mnist.cr
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
require "../../src/num"

dataset = Num::NN.load_mnist_dataset
ctx = Num::Grad::Context(Tensor(Float32)).new

batch_size = 32

net = Num::NN::Network.new(ctx) do
input [1, 28, 28]
conv2d 20, 5, 5
relu
maxpool({2, 2}, {0, 0}, {2, 2})
conv2d 20, 5, 5
maxpool({2, 2}, {0, 0}, {2, 2})
flatten
linear 10
relu
linear 10
softmax_cross_entropy_loss
sgd 0.01
end

x_train = ctx.variable((dataset.features / 255_f32).reshape(-1, 1, 28, 28))
y_train = dataset.labels

losses = [] of Float32

5.times do |epoch|
y_trues = [] of Int32
y_preds = [] of Int32

(x_train.value.shape[0] // batch_size).times do |batch_id|
offset = batch_id * batch_size
x = x_train[offset...offset + batch_size]
target = y_train[offset...offset + batch_size]

output = net.forward(x)

loss = net.loss(output, target)
losses << loss.value.value

y_trues += target.argmax(axis: 1).to_a
y_preds += output.value.argmax(axis: 1).to_a

loss.backprop
net.optimizer.update
end

accuracy = y_trues.zip(y_preds).map { |t, p| (t == p).to_unsafe }.sum / y_trues.size

puts "Epoch: #{epoch} | Accuracy: #{accuracy}"
end

Num::Plot::Plot.plot do
scatter (0...losses.size), losses
x_label "Epochs"
y_label "Loss"
label "MNIST Accuracy"
end
Binary file added examples/mnist/mnist.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 6 additions & 0 deletions src/api.cr
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,17 @@ require "./libs/nnpack"
require "./libs/plplot"

require "./grad/primitives/*"
require "./grad/extensions/float"
require "./grad/gates_arithmetic"
require "./grad/gates_blas"
require "./grad/gates_trigonometric"
require "./grad/gates_exp"
require "./grad/variable_ops"
require "./grad/infer"
require "./grad/utils"

require "./nn/primitives/*"
require "./nn/initialization"
require "./nn/layers/*"
require "./nn/gates/*"
require "./nn/optimizer"
Expand Down
54 changes: 54 additions & 0 deletions src/grad/extensions/float.cr
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) 2020 Crystal Data Contributors
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

struct Float32
macro add_operator(name, operator)
def {{operator.id}}(other : Num::Grad::Variable(Tensor(Float32)))
other.context.variable(self) {{operator.id}} other
end
end

add_operator add, :+
add_operator subtract, :-
add_operator multiply, :*
add_operator divide, :/
add_operator power, :**
end

struct Float64
macro add_operator(name, operator)
def {{operator.id}}(other : Num::Grad::Variable(Tensor(Float64)))
other.context.variable(self) {{operator.id}} other
end
end

add_operator add, :+
add_operator subtract, :-
add_operator multiply, :*
add_operator divide, :/
add_operator power, :**

def exp
Math.exp(self)
end
end
56 changes: 15 additions & 41 deletions src/grad/gates_arithmetic.cr
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

# :nodoc:
class Num::Grad::AddGate(T) < Num::Grad::Gate(T)
# :nodoc:
def backward(payload : Num::Grad::Payload(T)) : Array(T)
Expand All @@ -40,7 +39,6 @@ class Num::Grad::AddGate(T) < Num::Grad::Gate(T)
end
end

# :nodoc:
class Num::Grad::SubtractGate(T) < Num::Grad::Gate(T)
# :nodoc:
def backward(payload : Num::Grad::Payload(T)) : Array(T)
Expand All @@ -58,20 +56,17 @@ class Num::Grad::SubtractGate(T) < Num::Grad::Gate(T)
end
end

# :nodoc:
class Num::Grad::MultiplyGate(T) < Num::Grad::Gate(T)
class Num::Grad::TwoOpGate(T) < Num::Grad::Gate(T)
getter a : Num::Grad::Variable(T)
getter b : Num::Grad::Variable(T)
@@name = "TwoOp"

# :nodoc:
def initialize(@a : Num::Grad::Variable(T), @b : Num::Grad::Variable(T))
end

# :nodoc:
def backward(payload : Num::Grad::Payload(T)) : Array(T)
gradient = payload.variable.grad

[gradient * @b.value, @a.value * gradient]
[] of T
end

# :nodoc:
Expand All @@ -80,47 +75,34 @@ class Num::Grad::MultiplyGate(T) < Num::Grad::Gate(T)
result.grad = T.zeros_like(result.value)
result.requires_grad = true

Num::Grad.register("Mul", self, result, a, b)
Num::Grad.register(@@name, self, result, a, b)
end
end

# :nodoc:
class Num::Grad::DivideGate(T) < Num::Grad::Gate(T)
getter a : Num::Grad::Variable(T)
getter b : Num::Grad::Variable(T)
class Num::Grad::MultiplyGate(T) < Num::Grad::TwoOpGate(T)
@@name = "Multiply"

# :nodoc:
def initialize(@a : Num::Grad::Variable(T), @b : Num::Grad::Variable(T))
def backward(payload : Num::Grad::Payload(T)) : Array(T)
gradient = payload.variable.grad
[gradient * @b.value, @a.value * gradient]
end
end

class Num::Grad::DivideGate(T) < Num::Grad::TwoOpGate(T)
@@name = "Divide"

# :nodoc:
def backward(payload : Num::Grad::Payload(T)) : Array(T)
gradient = payload.variable.grad

r0 = gradient.map(@b.value) { |i, j| i / j }
r1 = gradient.map(@a.value, @b.value) { |i, j, k| -i * j / (k ** 2) }
[r0, r1]
end

# :nodoc:
def cache(result : Num::Grad::Variable(T), *args)
a, b = args
result.grad = T.zeros_like(result.value)
result.requires_grad = true
Num::Grad.register("Div", self, result, a, b)
end
end

# :nodoc:
class Num::Grad::PowerGate(T) < Num::Grad::Gate(T)
getter a : Num::Grad::Variable(T)
getter b : Num::Grad::Variable(T)

# :nodoc:
def initialize(@a : Num::Grad::Variable(T), @b : Num::Grad::Variable(T))
end
class Num::Grad::PowerGate(T) < Num::Grad::TwoOpGate(T)
@@name = "Power"

# :nodoc:
def backward(payload : Num::Grad::Payload(T)) : Array(T)
gradient = payload.variable.grad

Expand All @@ -134,12 +116,4 @@ class Num::Grad::PowerGate(T) < Num::Grad::Gate(T)

[r0, r1]
end

# :nodoc:
def cache(result : Num::Grad::Variable(T), *args)
a, b = args
result.grad = T.zeros_like(result.value)
result.requires_grad = true
Num::Grad.register("Pow", self, result, a, b)
end
end
44 changes: 44 additions & 0 deletions src/grad/gates_exp.cr
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) 2020 Crystal Data Contributors
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

class Num::Grad::ExpGate(T) < Num::Grad::Gate(T)
getter a : Num::Grad::Variable(T)

def initialize(@a : Num::Grad::Variable(T))
end

def backward(payload : Num::Grad::Payload(T)) : Array(T)
gradient = payload.variable.grad
r0 = gradient.map(a.value) do |i, j|
i * Math.exp(j)
end
[r0]
end

def cache(result : Num::Grad::Variable(T), *args)
a = args[0]
result.grad = T.zeros_like(result.value)
result.requires_grad = true
Num::Grad.register("Exp", self, result, a)
end
end
Loading

0 comments on commit 6ba53ea

Please sign in to comment.