-
Notifications
You must be signed in to change notification settings - Fork 108
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
249 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,249 @@ | ||
' # Neural Networks | ||
|
||
include "plot.dx" | ||
|
||
' ## NN Prelude | ||
|
||
|
||
def relu (input : Float) : Float = | ||
select (input > 0.0) input 0.0 | ||
|
||
instance [Add a, Add b] Add (a & b) | ||
add = \(a, b) (c, d). ( (a + c), (b + d)) | ||
sub = \(a, b) (c, d). ( (a - c), (b - d)) | ||
zero = (zero, zero) | ||
|
||
instance [VSpace a, VSpace b] VSpace (a & b) | ||
scaleVec = \ s (a, b) . (scaleVec s a, scaleVec s b) | ||
|
||
data Layer inp:Type out:Type params:Type = | ||
AsLayer {forward:(params -> inp -> out) & init:(Key -> params)} | ||
|
||
|
||
def forward (l:Layer i o p) (p : p) (x : i): o = | ||
(AsLayer l' ) = l | ||
(getAt #forward l') p x | ||
|
||
def init (l:Layer i o p) (k:Key) : p = | ||
(AsLayer l') = l | ||
(getAt #init l') k | ||
|
||
|
||
' ## Layers | ||
|
||
' Dense layer | ||
|
||
def DenseParams (a:Type) (b:Type) : Type = | ||
((a=>b=>Float) & (b=>Float)) | ||
|
||
def dense (a:Type) (b:Type) : Layer (a=>Float) (b=>Float) (DenseParams a b) = | ||
AsLayer { | ||
forward = (\ ((weight, bias)) x . | ||
for j. (bias.j + sum for i. weight.i.j * x.i)), | ||
init = arb | ||
} | ||
|
||
|
||
' CNN layer | ||
|
||
def CNNParams (inc:Type) (outc:Type) (kw:Int) (kh:Int) : Type = | ||
((outc=>inc=>Fin kh=>Fin kw=>Float) & | ||
(outc=>Float)) | ||
|
||
def conv2d (x:inc=>(Fin h)=>(Fin w)=>Float) | ||
(kernel:outc=>inc=>(Fin kh)=>(Fin kw)=>Float) : | ||
outc=>(Fin h)=>(Fin w)=>Float = | ||
for o i j. | ||
(i', j') = (ordinal i, ordinal j) | ||
case (i' + kh) < h && (j' + kw) < w of | ||
True -> | ||
sum for (ki, kj, inp). | ||
(di, dj) = (fromOrdinal (Fin h) (i' + (ordinal ki)), | ||
fromOrdinal (Fin w) (j' + (ordinal kj))) | ||
x.inp.di.dj * kernel.o.inp.ki.kj | ||
False -> zero | ||
|
||
def cnn (h:Int) ?-> (w:Int) ?-> (inc:Type) (outc:Type) (kw:Int) (kh:Int) : | ||
Layer (inc=>(Fin h)=>(Fin w)=>Float) | ||
(outc=>(Fin h)=>(Fin w)=>Float) | ||
(CNNParams inc outc kw kh) = | ||
AsLayer { | ||
forward = (\ (weight, bias) x. for o i j . (conv2d x weight).o.i.j + bias.o), | ||
init = arb | ||
} | ||
|
||
' Pooling | ||
|
||
def split (x: m=>v) : n=>o=>v = | ||
for i j. x.((ordinal (i,j))@m) | ||
|
||
def imtile (x: a=>b=>v) : n=>o=>p=>q=>v = | ||
for kw kh w h. (split (split x).w.kw).h.kh | ||
|
||
def meanpool (kh: Type) (kw: Type) (x : m=>n=> Float) : ( h=>w=> Float) = | ||
out : (kh => kw => h => w => Float) = imtile x | ||
mean for (i,j). out.i.j | ||
|
||
' ## Simple point classifier | ||
|
||
[k1, k2] = splitKey $ newKey 1 | ||
x1 : Fin 100 => Float = arb k1 | ||
x2 : Fin 100 => Float = arb k2 | ||
y = for i. case ((x1.i > 0.0) && (x2.i > 0.0)) || ((x1.i < 0.0) && (x2.i < 0.0)) of | ||
True -> 1 | ||
False -> 0 | ||
xs = for i. [x1.i, x2.i] | ||
|
||
|
||
:html showPlot $ xycPlot x1 x2 $ for i. IToF y.i | ||
|
||
simple = \h1. | ||
ndense1 = dense (Fin 2) h1 | ||
ndense2 = dense h1 (Fin 2) | ||
AsLayer { | ||
forward = (\ (dense1, dense2) x. | ||
x1' = forward ndense1 dense1 x | ||
x1 = for i. relu x1'.i | ||
logsoftmax $ forward ndense2 dense2 x1), | ||
init = (\key. | ||
[k1, k2] = splitKey key | ||
(init ndense1 k1, init ndense2 k2)) | ||
} | ||
|
||
:t simple | ||
|
||
' Train a multiclass classifier with minibatch SGD | ||
' `minibatch * minibatches = batch` | ||
|
||
def trainClass [VSpace p] (model: Layer a (b=>Float) p) | ||
(x: batch=>a) | ||
(y: batch=>b) | ||
(epochs : Type) | ||
(minibatch : Type) | ||
(minibatches : Type) : | ||
(epochs => p & epochs => Float ) = | ||
xs : minibatches => minibatch => a = split x | ||
ys : minibatches => minibatch => b = split y | ||
unzip $ withState (init model $ newKey 0) $ \params . | ||
for _ : epochs. | ||
loss = sum $ for b : minibatches. | ||
(loss, gradfn) = vjp (\ params. | ||
-sum for j. | ||
result = forward model params xs.b.j | ||
result.(ys.b.j)) (get params) | ||
gparams = gradfn 1.0 | ||
params := (get params) - scaleVec (0.05 / (IToF 100)) gparams | ||
loss | ||
(get params, loss) | ||
|
||
-- todo : Do I have to give minibatches as a param? | ||
simple_model = simple (Fin 10) | ||
(all_params,losses) = trainClass simple_model xs (for i. (y.i @ (Fin 2))) (Fin 500) (Fin 100) (Fin 1) | ||
|
||
span = linspace (Fin 10) (-1.0) (1.0) | ||
tests = for h : (Fin 50). for i . for j. | ||
r = forward simple_model all_params.((ordinal h * 10)@_) [span.i, span.j] | ||
[exp r.(1@_), exp r.(0@_), 0.0] | ||
|
||
|
||
:html imseqshow tests | ||
|
||
' ## LeNet for image classification | ||
|
||
H = 28 | ||
W = 28 | ||
Image = Fin 1 => Fin H => Fin W => Float | ||
Class = Fin 10 | ||
|
||
lenet = \h1 h2 h3 . | ||
ncnn1 = cnn (Fin 1) h1 3 3 | ||
ncnn2 = cnn h1 h2 3 3 | ||
Pooled = (h2 & Fin 7 & Fin 7) | ||
ndense1 = dense Pooled h3 | ||
ndense2 = dense h3 Class | ||
AsLayer { | ||
forward = (\ (cnn1, cnn2, dense1, dense2) inp. | ||
x:Image = inp | ||
x1' = forward ncnn1 cnn1 x | ||
x1 = for i j k. relu x1'.i.j.k | ||
x2' = forward ncnn2 cnn2 x1 | ||
x2 = for i j k. relu x2'.i.j.k | ||
x3 : (h2 => Fin 7 => Fin 7 => Float) = for c. meanpool (Fin 4) (Fin 4) x2.c | ||
x4' = forward ndense1 dense1 for (i,j,k). x3.i.j.k | ||
x4 = for i. relu x4'.i | ||
logsoftmax $ forward ndense2 dense2 x4), | ||
init = (\key. | ||
[k1, k2, k3, k4] = splitKey key | ||
(init ncnn1 k1, init ncnn2 k2, | ||
init ndense1 k3, init ndense2 k4)) | ||
} | ||
|
||
:t lenet | ||
|
||
|
||
' ## Data Loading | ||
|
||
|
||
|
||
|
||
|
||
Batch = Fin 5000 | ||
Full = Fin ((size Batch) * H * W) | ||
|
||
def pixel (x:Char) : Float32 = | ||
r = W8ToI x | ||
IToF case r < 0 of | ||
True -> (abs r) + 128 | ||
False -> r | ||
|
||
def getIm : Batch => Image = | ||
(AsList _ im) = unsafeIO do readFile "examples/mnist.bin" | ||
raw = unsafeCastTable Full im | ||
for b: Batch c: (Fin 1) i:(Fin W) j:(Fin H). | ||
pixel raw.((ordinal (b, i, j)) @ Full) | ||
|
||
def getLabel : Batch => Class = | ||
(AsList _ im2) = unsafeIO do readFile "examples/labels.bin" | ||
r = unsafeCastTable Batch im2 | ||
for i. (W8ToI r.i @ Class) | ||
|
||
|
||
' ## Training loop | ||
|
||
|
||
' Get binary files from: | ||
|
||
' `wget https://github.com/srush/learns-dex/raw/main/mnist.bin` | ||
|
||
' `wget https://github.com/srush/learns-dex/raw/main/labels.bin` | ||
|
||
' Comment out these lines | ||
|
||
ims = getIm | ||
labels = getLabel | ||
|
||
small_ims = for i: (Fin 10). ims.((ordinal i)@_) | ||
small_labels = for i: (Fin 10). labels.((ordinal i)@_) | ||
|
||
:p small_labels | ||
|
||
Epochs = (Fin 5) | ||
Minibatches = (Fin 1) | ||
Minibatch = (Fin 10) | ||
|
||
:t ims.(2@_) | ||
|
||
model = lenet (Fin 1) (Fin 1) (Fin 20) | ||
init_param = (init model $ newKey 0) | ||
:p forward model init_param (ims.(2@Batch)) | ||
|
||
' Sanity check | ||
|
||
:t (grad ((\x param. sum (forward model param x)) (ims.(2@_)))) init_param | ||
|
||
(all_params', losses') = trainClass model small_ims small_labels Epochs Minibatch Minibatches | ||
|
||
:p losses' | ||
|
||
|
||
|