Skip to content

Commit

Permalink
NN / MNist Example (#456)
Browse files Browse the repository at this point in the history
  • Loading branch information
apaszke authored Jan 20, 2021
2 parents 8aeaf90 + c06efae commit d178dc0
Showing 1 changed file with 249 additions and 0 deletions.
249 changes: 249 additions & 0 deletions examples/nn.dx
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
' # Neural Networks

include "plot.dx"

' ## NN Prelude


def relu (input : Float) : Float =
select (input > 0.0) input 0.0

instance [Add a, Add b] Add (a & b)
add = \(a, b) (c, d). ( (a + c), (b + d))
sub = \(a, b) (c, d). ( (a - c), (b - d))
zero = (zero, zero)

instance [VSpace a, VSpace b] VSpace (a & b)
scaleVec = \ s (a, b) . (scaleVec s a, scaleVec s b)

data Layer inp:Type out:Type params:Type =
AsLayer {forward:(params -> inp -> out) & init:(Key -> params)}


def forward (l:Layer i o p) (p : p) (x : i): o =
(AsLayer l' ) = l
(getAt #forward l') p x

def init (l:Layer i o p) (k:Key) : p =
(AsLayer l') = l
(getAt #init l') k


' ## Layers

' Dense layer

def DenseParams (a:Type) (b:Type) : Type =
((a=>b=>Float) & (b=>Float))

def dense (a:Type) (b:Type) : Layer (a=>Float) (b=>Float) (DenseParams a b) =
AsLayer {
forward = (\ ((weight, bias)) x .
for j. (bias.j + sum for i. weight.i.j * x.i)),
init = arb
}


' CNN layer

def CNNParams (inc:Type) (outc:Type) (kw:Int) (kh:Int) : Type =
((outc=>inc=>Fin kh=>Fin kw=>Float) &
(outc=>Float))

def conv2d (x:inc=>(Fin h)=>(Fin w)=>Float)
(kernel:outc=>inc=>(Fin kh)=>(Fin kw)=>Float) :
outc=>(Fin h)=>(Fin w)=>Float =
for o i j.
(i', j') = (ordinal i, ordinal j)
case (i' + kh) < h && (j' + kw) < w of
True ->
sum for (ki, kj, inp).
(di, dj) = (fromOrdinal (Fin h) (i' + (ordinal ki)),
fromOrdinal (Fin w) (j' + (ordinal kj)))
x.inp.di.dj * kernel.o.inp.ki.kj
False -> zero

def cnn (h:Int) ?-> (w:Int) ?-> (inc:Type) (outc:Type) (kw:Int) (kh:Int) :
Layer (inc=>(Fin h)=>(Fin w)=>Float)
(outc=>(Fin h)=>(Fin w)=>Float)
(CNNParams inc outc kw kh) =
AsLayer {
forward = (\ (weight, bias) x. for o i j . (conv2d x weight).o.i.j + bias.o),
init = arb
}

' Pooling

def split (x: m=>v) : n=>o=>v =
for i j. x.((ordinal (i,j))@m)

def imtile (x: a=>b=>v) : n=>o=>p=>q=>v =
for kw kh w h. (split (split x).w.kw).h.kh

def meanpool (kh: Type) (kw: Type) (x : m=>n=> Float) : ( h=>w=> Float) =
out : (kh => kw => h => w => Float) = imtile x
mean for (i,j). out.i.j

' ## Simple point classifier

[k1, k2] = splitKey $ newKey 1
x1 : Fin 100 => Float = arb k1
x2 : Fin 100 => Float = arb k2
y = for i. case ((x1.i > 0.0) && (x2.i > 0.0)) || ((x1.i < 0.0) && (x2.i < 0.0)) of
True -> 1
False -> 0
xs = for i. [x1.i, x2.i]


:html showPlot $ xycPlot x1 x2 $ for i. IToF y.i

simple = \h1.
ndense1 = dense (Fin 2) h1
ndense2 = dense h1 (Fin 2)
AsLayer {
forward = (\ (dense1, dense2) x.
x1' = forward ndense1 dense1 x
x1 = for i. relu x1'.i
logsoftmax $ forward ndense2 dense2 x1),
init = (\key.
[k1, k2] = splitKey key
(init ndense1 k1, init ndense2 k2))
}

:t simple

' Train a multiclass classifier with minibatch SGD
' `minibatch * minibatches = batch`

def trainClass [VSpace p] (model: Layer a (b=>Float) p)
(x: batch=>a)
(y: batch=>b)
(epochs : Type)
(minibatch : Type)
(minibatches : Type) :
(epochs => p & epochs => Float ) =
xs : minibatches => minibatch => a = split x
ys : minibatches => minibatch => b = split y
unzip $ withState (init model $ newKey 0) $ \params .
for _ : epochs.
loss = sum $ for b : minibatches.
(loss, gradfn) = vjp (\ params.
-sum for j.
result = forward model params xs.b.j
result.(ys.b.j)) (get params)
gparams = gradfn 1.0
params := (get params) - scaleVec (0.05 / (IToF 100)) gparams
loss
(get params, loss)

-- todo : Do I have to give minibatches as a param?
simple_model = simple (Fin 10)
(all_params,losses) = trainClass simple_model xs (for i. (y.i @ (Fin 2))) (Fin 500) (Fin 100) (Fin 1)

span = linspace (Fin 10) (-1.0) (1.0)
tests = for h : (Fin 50). for i . for j.
r = forward simple_model all_params.((ordinal h * 10)@_) [span.i, span.j]
[exp r.(1@_), exp r.(0@_), 0.0]


:html imseqshow tests

' ## LeNet for image classification

H = 28
W = 28
Image = Fin 1 => Fin H => Fin W => Float
Class = Fin 10

lenet = \h1 h2 h3 .
ncnn1 = cnn (Fin 1) h1 3 3
ncnn2 = cnn h1 h2 3 3
Pooled = (h2 & Fin 7 & Fin 7)
ndense1 = dense Pooled h3
ndense2 = dense h3 Class
AsLayer {
forward = (\ (cnn1, cnn2, dense1, dense2) inp.
x:Image = inp
x1' = forward ncnn1 cnn1 x
x1 = for i j k. relu x1'.i.j.k
x2' = forward ncnn2 cnn2 x1
x2 = for i j k. relu x2'.i.j.k
x3 : (h2 => Fin 7 => Fin 7 => Float) = for c. meanpool (Fin 4) (Fin 4) x2.c
x4' = forward ndense1 dense1 for (i,j,k). x3.i.j.k
x4 = for i. relu x4'.i
logsoftmax $ forward ndense2 dense2 x4),
init = (\key.
[k1, k2, k3, k4] = splitKey key
(init ncnn1 k1, init ncnn2 k2,
init ndense1 k3, init ndense2 k4))
}

:t lenet


' ## Data Loading





Batch = Fin 5000
Full = Fin ((size Batch) * H * W)

def pixel (x:Char) : Float32 =
r = W8ToI x
IToF case r < 0 of
True -> (abs r) + 128
False -> r

def getIm : Batch => Image =
(AsList _ im) = unsafeIO do readFile "examples/mnist.bin"
raw = unsafeCastTable Full im
for b: Batch c: (Fin 1) i:(Fin W) j:(Fin H).
pixel raw.((ordinal (b, i, j)) @ Full)

def getLabel : Batch => Class =
(AsList _ im2) = unsafeIO do readFile "examples/labels.bin"
r = unsafeCastTable Batch im2
for i. (W8ToI r.i @ Class)


' ## Training loop


' Get binary files from:

' `wget https://github.com/srush/learns-dex/raw/main/mnist.bin`

' `wget https://github.com/srush/learns-dex/raw/main/labels.bin`

' Comment out these lines

ims = getIm
labels = getLabel

small_ims = for i: (Fin 10). ims.((ordinal i)@_)
small_labels = for i: (Fin 10). labels.((ordinal i)@_)

:p small_labels

Epochs = (Fin 5)
Minibatches = (Fin 1)
Minibatch = (Fin 10)

:t ims.(2@_)

model = lenet (Fin 1) (Fin 1) (Fin 20)
init_param = (init model $ newKey 0)
:p forward model init_param (ims.(2@Batch))

' Sanity check

:t (grad ((\x param. sum (forward model param x)) (ims.(2@_)))) init_param

(all_params', losses') = trainClass model small_ims small_labels Epochs Minibatch Minibatches

:p losses'



0 comments on commit d178dc0

Please sign in to comment.