From 945aa65e8c5d934378de1b3a634d0cc76ae1ab65 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Wed, 26 Aug 2020 10:32:10 -0700 Subject: [PATCH 001/105] sketching out a particle filter example Co-authored-by: Dougal Maclaurin --- examples/particle-filter.dx | 65 +++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 examples/particle-filter.dx diff --git a/examples/particle-filter.dx b/examples/particle-filter.dx new file mode 100644 index 000000000..197b3e47e --- /dev/null +++ b/examples/particle-filter.dx @@ -0,0 +1,65 @@ +def Distribution (range:Type) : Type = + ( Key -> range + & range -> Float) + +def Model (state:Type) (observation:Type) : Type = + ( Distribution state -- initial state + & state -> Distribution state -- dynamics + & state -> Distribution observation) -- observations + +def sample (d: Distribution a) (k: Key) : a = + (sampler, _) = d + sampler k + +def simulate (model: Model s v) (t: Int) (key: Key) : Fin t=>(s & v) = + (init, dynamics, observe) = model + (key, subkey) = splitKey key + s0 = sample init subkey + fst $ withState s0 \s_ref . + for i. + (k1, k2) = splitKey (ixkey key i) + s = get s_ref + s_next = sample (dynamics s) k1 + v = sample (observe s) k2 + s_ref := s_next + (s, v) + +def categorical (ps: n=>Float) (key: Key) : n = todo + +def filter (model: Model s v) (num_particles: Int) (num_timesteps: Int) + (obs: Fin num_timesteps=>v) (key: Key) : Fin num_timesteps => Fin num_particles => s = + (init, dynamics, observe) = model + (key, init_key) = splitKey key + particles = for i: (Fin num_particles). sample init (ixkey init_key i) + fst $ withState particles \p_ref . + for t: (Fin num_timesteps). + p: (Fin num_particles)=>s = get p_ref + likelihoods = for i. (snd (observe p.i)) obs.t + (resample_key, dynamics_key) = splitKey (ixkey key t) + p_resampled = for i. particles.(categorical likelihoods (ixkey resample_key i)) + p_next = for i. (fst (dynamics p_resampled.i)) (ixkey dynamics_key i) + p_ref := p_next + p_resampled + + +def normalDistn (mean: Float) (var: Float) : Distribution Float = + ( \k. (randn k) * (sqrt var) + mean + , \v. -0.5 * (sq (v - mean)) / var - 0.5 * log (2.0 * pi * var) + ) + +gaussModel : Model Float Float = + ( normalDistn 0.1 0.1 + , \s. normalDistn s 1.0 + , \s. normalDistn s 0.1 + ) + +timesteps = 5 +num_particles = 3 + +truth = simulate gaussModel timesteps (newKey 0) +:p truth + +obs = for i. snd truth.i +-- samples = filter gaussModel num_particles obs (newKey 0) +-- estimated = for i. mean samples.i +-- :p estimated From b6d3636966be4d714c96246481ca8b8dea7e8def Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Wed, 26 Aug 2020 21:31:20 -0400 Subject: [PATCH 002/105] Finish particle filter and move some shared utilities to the prelude. --- examples/particle-filter.dx | 57 +++++++++------ examples/raytrace.dx | 27 ------- makefile | 3 +- prelude.dx | 137 ++++++++++++++++++++++++++++-------- 4 files changed, 143 insertions(+), 81 deletions(-) diff --git a/examples/particle-filter.dx b/examples/particle-filter.dx index 197b3e47e..f983e58ec 100644 --- a/examples/particle-filter.dx +++ b/examples/particle-filter.dx @@ -1,6 +1,6 @@ def Distribution (range:Type) : Type = ( Key -> range - & range -> Float) + & range -> Float) -- log prob def Model (state:Type) (observation:Type) : Type = ( Distribution state -- initial state @@ -24,23 +24,26 @@ def simulate (model: Model s v) (t: Int) (key: Key) : Fin t=>(s & v) = s_ref := s_next (s, v) -def categorical (ps: n=>Float) (key: Key) : n = todo - -def filter (model: Model s v) (num_particles: Int) (num_timesteps: Int) - (obs: Fin num_timesteps=>v) (key: Key) : Fin num_timesteps => Fin num_particles => s = +def filter + (num_particles: Int) (num_timesteps: Int) + (model: Model s v) + (summarize: (Fin num_particles => s) -> a) + (obs: Fin num_timesteps=>v) + (key: Key) + : Fin num_timesteps => a = (init, dynamics, observe) = model (key, init_key) = splitKey key - particles = for i: (Fin num_particles). sample init (ixkey init_key i) - fst $ withState particles \p_ref . + init_particles = for i: (Fin num_particles). sample init (ixkey init_key i) + fst $ withState init_particles \p_ref . for t: (Fin num_timesteps). - p: (Fin num_particles)=>s = get p_ref - likelihoods = for i. (snd (observe p.i)) obs.t + p_prev = get p_ref + logLikelihoods = for i. snd (observe p_prev.i) obs.t (resample_key, dynamics_key) = splitKey (ixkey key t) - p_resampled = for i. particles.(categorical likelihoods (ixkey resample_key i)) - p_next = for i. (fst (dynamics p_resampled.i)) (ixkey dynamics_key i) + resampled_idxs = categoricalBatch logLikelihoods resample_key + p_resampled = for i. p_prev.(resampled_idxs.i) + p_next = for i. fst (dynamics p_resampled.i) (ixkey dynamics_key i) p_ref := p_next - p_resampled - + summarize p_resampled def normalDistn (mean: Float) (var: Float) : Distribution Float = ( \k. (randn k) * (sqrt var) + mean @@ -50,16 +53,26 @@ def normalDistn (mean: Float) (var: Float) : Distribution Float = gaussModel : Model Float Float = ( normalDistn 0.1 0.1 , \s. normalDistn s 1.0 - , \s. normalDistn s 0.1 + , \s. normalDistn s 1.0 ) -timesteps = 5 -num_particles = 3 +timesteps = 10 +num_particles = 10000 + +truth = for i:(Fin timesteps). + s = IToF (ordinal i) + (s, sample (normalDistn s 1.0) $ ixkey (newKey 0) i) -truth = simulate gaussModel timesteps (newKey 0) -:p truth +filtered = filter num_particles _ gaussModel mean (map snd truth) (newKey 0) -obs = for i. snd truth.i --- samples = filter gaussModel num_particles obs (newKey 0) --- estimated = for i. mean samples.i --- :p estimated +:p for i. (truth.i, filtered.i) +> [ ((0.0, -0.27877414), 5.959407e-2) +> , ((1.0, -0.55018175), -0.2716055) +> , ((2.0, 2.4248328), 1.3502709) +> , ((3.0, 4.800973), 3.4571552) +> , ((4.0, 3.747046), 3.633845) +> , ((5.0, 5.4751964), 4.754587) +> , ((6.0, 5.0827684), 4.9687166) +> , ((7.0, 5.6351852), 5.3734612) +> , ((8.0, 8.845968), 7.475342) +> , ((9.0, 8.416575), 8.059247) ] diff --git a/examples/raytrace.dx b/examples/raytrace.dx index 9212388a7..41cd4227f 100644 --- a/examples/raytrace.dx +++ b/examples/raytrace.dx @@ -4,8 +4,6 @@ [JAX implementation](https://github.com/ericjang/pt-jax/blob/master/jaxpt_vmap.ipynb), described [here](https://blog.evjang.com/2019/11/jaxpt.html). -Specifically, it's based on his unrolled ```lax.scan``` version. - ' ### Generic Helper Functions Some of these should probably go in prelude. @@ -25,10 +23,6 @@ def dot (_:VSpace v) ?=> (s:d=>Float) (vs:d=>v) : v = sum for j. s.j .* vs.j def randuniform (lower:Float) (upper:Float) (k:Key) : Float = lower + (rand k) * (upper - lower) -def reverse (x:n=>a) : n=>a = - s = size n - for i. x.((s - 1 - ordinal i)@_) - def sampleAveraged (_:VSpace a) ?=> (sample:Key -> a) (n:Int) (k:Key) : a = snd $ withState zero \total. for i:(Fin n). @@ -45,27 +39,6 @@ def gradNumerical (n:Int) ?-> (f:Vec n -> Float) (xs:Vec n) : Vec n = eps = 0.0001 for i. (f (addAt xs i eps) - f (addAt xs i (neg eps))) / (2.0 * eps) -data IterResult a:Type b:Type = - Continue a - Done b - --- A little iteration combinator --- TODO: allow effects (currently there's some type inference bug preventing it) -def iter (init:a) (body: Int -> a -> IterResult a b) : b = - result = snd $ withState Nothing \resultRef. - withState init \carryRef. - withState 0 \i. - while (\(). isNothing (get resultRef)) \(). - case body (get i) (get carryRef) of - Continue carry -> - i := get i + 1 - carryRef := carry - Done result -> - resultRef := Just result - case result of - Just ans -> ans - Nothing -> todo -- should be unreachable - ' ### 3D Helper Functions -- TODO: implement table unpacking diff --git a/makefile b/makefile index ec366b7ff..a3ef475dc 100644 --- a/makefile +++ b/makefile @@ -61,7 +61,8 @@ example-names = uexpr-tests adt-tests type-tests eval-tests \ ad-tests mandelbrot pi sierpinsky \ regression brownian_motion particle-swarm-optimizer \ ode-integrator parser-tests serialize-tests \ - mcmc record-variant-tests simple-include-test ctc raytrace + mcmc record-variant-tests simple-include-test ctc \ + raytrace particle-filter quine-test-targets = $(example-names:%=run-%) diff --git a/prelude.dx b/prelude.dx index 07d40f1f9..6364bbfec 100644 --- a/prelude.dx +++ b/prelude.dx @@ -288,6 +288,10 @@ def abs (x:Float) : Float = select (x > 0.0) x (-x) def mod (x:Int) (y:Int) : Int = rem (y + rem x y) y def compose (f:b->c) (g:a->b) (x:a) : c = f (g x) +def reverse (x:n=>a) : n=>a = + s = size n + for i. x.((s - 1 - ordinal i)@_) + def slice (xs:n=>a) (start:Int) (m:Type) : m=>a = for i. xs.(fromOrdinal _ (ordinal i + start)) @@ -316,6 +320,51 @@ def std (xs:n=>Float) : Float = sqrt $ mean (map sq xs) - sq (mean xs) def any (xs:n=>Bool) : Bool = reduce False (||) xs def all (xs:n=>Bool) : Bool = reduce True (&&) xs +def while + (eff:Effects) ?-> + (cond: Unit -> {|eff} Bool) + (body: Unit -> {|eff} Unit) + : {|eff} Unit = + cond' : Unit -> {|eff} Int8 = \_. BToI8 $ cond () + %while cond' body + +data IterResult a:Type b:Type = + Continue a + Done b + +-- A little iteration combinator +-- TODO: allow effects (currently there's some type inference bug preventing it) +def iter (init:a) (body: Int -> a -> IterResult a b) : b = + result = snd $ withState Nothing \resultRef. + withState init \carryRef. + withState 0 \i. + while (\(). isNothing (get resultRef)) \(). + case body (get i) (get carryRef) of + Continue carry -> + i := get i + 1 + carryRef := carry + Done result -> + resultRef := Just result + case result of + Just ans -> ans + Nothing -> todo -- should be unreachable + +-- returns the highest index `i` such that `xs.i <= x` +def searchSorted (_:Ord a) ?=> (xs:n=>a) (x:a) : Maybe n = + case size n == 0 of + True -> Nothing + False -> case x < xs.(fromOrdinal _ 0) of + True -> Nothing + False -> + iter (0, size n) \_ (low, high). + numLeft = high - low + case numLeft == 1 of + True -> Done $ Just $ fromOrdinal _ low + False -> + centerIx = low + idiv (high - low) 2 + case x < xs.(fromOrdinal _ centerIx) of + True -> Continue (low, centerIx) + False -> Continue (centerIx, high) def applyN (n:Int) (x:a) (f:a -> a) : a = snd $ withState x \ref. for _:(Fin n). @@ -339,6 +388,29 @@ def vdot (x:n=>Float) (y:n=>Float) : Float = fsum \i. x.i * y.i def inner (x:n=>Float) (mat:n=>m=>Float) (y:m=>Float) : Float = fsum \(i,j). x.i * mat.i.j * y.j +'min / max etc + +def minBy (_:Ord o) ?=> (f:a->o) (x:a) (y:a) : a = select (f x < f y) x y +def maxBy (_:Ord o) ?=> (f:a->o) (x:a) (y:a) : a = select (f x > f y) x y + +def min (_:Ord o) ?=> (x1: o) -> (x2: o) : o = minBy id x1 x2 +def max (_:Ord o) ?=> (x1: o) -> (x2: o) : o = maxBy id x1 x2 + +def minimumBy (_:Ord o) ?=> (f:a->o) (xs:n=>a) : a = + reduce xs.(0@_) (minBy f) xs +def maximumBy (_:Ord o) ?=> (f:a->o) (xs:n=>a) : a = + reduce xs.(0@_) (maxBy f) xs + +def minimum (_:Ord o) ?=> (xs:n=>o) : o = minimumBy id xs +def maximum (_:Ord o) ?=> (xs:n=>o) : o = maximumBy id xs + +def argmin (_:Ord o) ?=> (xs:n=>o) : n = + zeroth = (0@_, xs.(0@_)) + compare = \(idx1, x1) (idx2, x2). + select (x1 < x2) (idx1, x1) (idx2, x2) + zipped = for i. (i, xs.i) + fst $ reduce zeroth compare zipped + 'Functions for working with the pseudorandom number generator -- TODO: newtype @@ -376,29 +448,40 @@ def bern (p:Float) (k:Key) : Bool = rand k < p def randnVec (n:Type) ?-> (k:Key) : n=>Float = for i. randn (ixkey k i) -'min / max etc - -def minBy (_:Ord o) ?=> (f:a->o) (x:a) (y:a) : a = select (f x < f y) x y -def maxBy (_:Ord o) ?=> (f:a->o) (x:a) (y:a) : a = select (f x > f y) x y - -def min (_:Ord o) ?=> (x1: o) -> (x2: o) : o = minBy id x1 x2 -def max (_:Ord o) ?=> (x1: o) -> (x2: o) : o = maxBy id x1 x2 - -def minimumBy (_:Ord o) ?=> (f:a->o) (xs:n=>a) : a = - reduce xs.(0@_) (minBy f) xs -def maximumBy (_:Ord o) ?=> (f:a->o) (xs:n=>a) : a = - reduce xs.(0@_) (maxBy f) xs - -def minimum (_:Ord o) ?=> (xs:n=>o) : o = minimumBy id xs -def maximum (_:Ord o) ?=> (xs:n=>o) : o = maximumBy id xs - -def argmin (_:Ord o) ?=> (xs:n=>o) : n = - zeroth = (0@_, xs.(0@_)) - compare = \(idx1, x1) (idx2, x2). - select (x1 < x2) (idx1, x1) (idx2, x2) - zipped = for i. (i, xs.i) - fst $ reduce zeroth compare zipped - +def cumSum (xs: n=>Float) : n=>Float = + fst $ withState 0.0 \total. + for i. + newTotal = get total + xs.i + total := newTotal + newTotal + +def cumSumLow (xs: n=>Float) : n=>Float = + fst $ withState 0.0 \total. + for i. + oldTotal = get total + total := oldTotal + xs.i + oldTotal + +-- cdf should include 0.0 but not 1.0 +def categoricalFromCDF (cdf: n=>Float) (key: Key) : n = + r = rand key + case searchSorted cdf r of + Just i -> i + +def normalizePdf (xs: d=>Float) : d=>Float = xs / sum xs + +def cdfForCategorical (logprobs: n=>Float) : n=>Float = + maxLogProb = maximum logprobs + cumSumLow $ normalizePdf $ map exp $ for i. logprobs.i - maxLogProb + +def categorical (logprobs: n=>Float) (key: Key) : n = + categoricalFromCDF (cdfForCategorical logprobs) key + +-- batch variant to share the work of forming the cumsum +-- (alternatively we could rely on hoisting of loop constants) +def categoricalBatch (logprobs: n=>Float) (key: Key) : m=>n = + cdf = cdfForCategorical logprobs + for i. categoricalFromCDF cdf $ ixkey key i 'Automatic differentiation @@ -429,14 +512,6 @@ def checkDerivBase (f:Float->Float) (x:Float) : Bool = def checkDeriv (f:Float->Float) (x:Float) : Bool = checkDerivBase f x && checkDerivBase (deriv f) x -def while - (eff:Effects) ?-> - (cond: Unit -> {|eff} Bool) - (body: Unit -> {|eff} Unit) - : {|eff} Unit = - cond' : Unit -> {|eff} Int8 = \_. BToI8 $ cond () - %while cond' body - 'Vector support -- TODO: Reenable vector suport once fixed-width types are supported. From 673b5fb74effa809c989be6bbfc50d5ebf9803e3 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Fri, 4 Dec 2020 16:47:24 -0500 Subject: [PATCH 003/105] Strengthen warning about experimental status --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index b8d2c1229..310b7dca9 100644 --- a/README.md +++ b/README.md @@ -21,8 +21,8 @@ or these example programs: * [Basis function regression](https://google-research.github.io/dex-lang/regression.html) * [Brownian bridge](https://google-research.github.io/dex-lang/brownian_motion.html) -Please note that Dex is an experimental research project at an early stage of -development. Contributions welcome! +⚠️ Dex is an experimental research project at an early stage of +development. Expect monstrous bugs and razor-sharp edges. Contributions welcome! ⚠️ ## Dependencies From 16a698d8dae6105a701f403748eee421da1bd3c3 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Fri, 18 Dec 2020 15:02:40 -0500 Subject: [PATCH 004/105] Make separate directories for tests and libraries. (Rather than lumping them all together under `examples/`.) --- examples/bad-binary-file.dxbo | 3 -- examples/dxbo-example.dxbo | Bin 1104 -> 0 bytes examples/fluidsim.dx | 2 +- examples/mandelbrot.dx | 2 +- examples/raytrace.dx | 2 +- examples/regression.dx | 2 +- examples/sierpinski.dx | 2 +- examples/simple-include-test.dx | 7 ---- {examples => lib}/diagram.dx | 0 {examples => lib}/plot.dx | 4 +-- {examples => lib}/png.dx | 0 prelude.dx => lib/prelude.dx | 0 makefile | 34 ++++++++++-------- src/lib/TopLevel.hs | 7 +++- src/resources/Resources.hs | 2 +- {examples => tests}/ad-tests-interp.dx | 0 {examples => tests}/ad-tests.dx | 0 {examples => tests}/adt-tests.dx | 0 {examples => tests}/complex-tests.dx | 0 {examples => tests}/eval-tests.dx | 0 {examples => tests}/flop-tests.dx | 0 {examples => tests}/gpu-tests.dx | 0 {examples => tests}/include-test.dx | 6 ++-- {examples => tests}/included.dx | 0 {examples => tests}/interp-tests.dx | 0 {examples => tests}/jax-tests.dx | 0 {examples => tests}/linear-tests.dx | 0 {examples => tests}/loopy-ad-tests.dx | 0 {examples => tests}/monad-tests.dx | 0 {examples => tests}/parser-tests.dx | 0 {examples => tests}/record-variant-tests.dx | 0 .../repl-multiline-test-expected-output | 0 {examples => tests}/repl-multiline-test.dx | 0 {examples => tests}/serialize-tests.dx | 0 {examples => tests}/shadow-tests.dx | 0 {examples => tests}/show-tests.dx | 0 tests/simple-include-test.dx | 7 ++++ {examples => tests}/somedata.dxo | 0 {examples => tests}/trig-tests.dx | 0 {examples => tests}/type-tests.dx | 0 {examples => tests}/typeclass-tests.dx | 0 {examples => tests}/uexpr-tests.dx | 0 {examples => tests}/web-tests.dx | 0 43 files changed, 44 insertions(+), 36 deletions(-) delete mode 100644 examples/bad-binary-file.dxbo delete mode 100644 examples/dxbo-example.dxbo delete mode 100644 examples/simple-include-test.dx rename {examples => lib}/diagram.dx (100%) rename {examples => lib}/plot.dx (98%) rename {examples => lib}/png.dx (100%) rename prelude.dx => lib/prelude.dx (100%) rename {examples => tests}/ad-tests-interp.dx (100%) rename {examples => tests}/ad-tests.dx (100%) rename {examples => tests}/adt-tests.dx (100%) rename {examples => tests}/complex-tests.dx (100%) rename {examples => tests}/eval-tests.dx (100%) rename {examples => tests}/flop-tests.dx (100%) rename {examples => tests}/gpu-tests.dx (100%) rename {examples => tests}/include-test.dx (93%) rename {examples => tests}/included.dx (100%) rename {examples => tests}/interp-tests.dx (100%) rename {examples => tests}/jax-tests.dx (100%) rename {examples => tests}/linear-tests.dx (100%) rename {examples => tests}/loopy-ad-tests.dx (100%) rename {examples => tests}/monad-tests.dx (100%) rename {examples => tests}/parser-tests.dx (100%) rename {examples => tests}/record-variant-tests.dx (100%) rename {examples => tests}/repl-multiline-test-expected-output (100%) rename {examples => tests}/repl-multiline-test.dx (100%) rename {examples => tests}/serialize-tests.dx (100%) rename {examples => tests}/shadow-tests.dx (100%) rename {examples => tests}/show-tests.dx (100%) create mode 100644 tests/simple-include-test.dx rename {examples => tests}/somedata.dxo (100%) rename {examples => tests}/trig-tests.dx (100%) rename {examples => tests}/type-tests.dx (100%) rename {examples => tests}/typeclass-tests.dx (100%) rename {examples => tests}/uexpr-tests.dx (100%) rename {examples => tests}/web-tests.dx (100%) diff --git a/examples/bad-binary-file.dxbo b/examples/bad-binary-file.dxbo deleted file mode 100644 index 0fb628742..000000000 --- a/examples/bad-binary-file.dxbo +++ /dev/null @@ -1,3 +0,0 @@ --- dex-object-file-v0.0.1 num-header-bytes 128 --------------------------------- -type: (2=>1=>Real, Int) -bufferSizes: [8, 8, 8] diff --git a/examples/dxbo-example.dxbo b/examples/dxbo-example.dxbo deleted file mode 100644 index 656783b3ed89a76d684c56c10007a534fa22a921..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1104 zcmah|+iu%14Bh>b{RH1?qf!v3S(gQ9ffnmSUe~^50V3OUtU{IoNlu!6{U}-8iUGw8 z!?bCNJUoZw9BO>vro6|>^F}J<`^k7RE<$r95Tu#tQ{bJ|ajB0|e8= zUjr=2q^?w7OTCTWza{f06c(v;Iuv^Qr^F9@0Ds&e2wle%W#2StV^Ko*R5qoPj=G{8 zmC;^EJ)jIJ_Coato2+0kWw3(mjICzu>pEwUowL<9Hjl60*{YbZqFCp!&NrioF%&0! zweVsTkSriuYR@1(&Mq#8uj$1lIr2Q7xqgoE3F0`A?uct~3{)$lx73Gkg$9jSBeiEh zRF&yEL&b>Q76X7#7@P@_KKPDM(GkiZl-%m!ScNw98zq3!^tu){h$NeoJL-2A<%$F; z)LY;F(S0+>bsLmC1=n{l&>S3YGkW&1fYqD<-P0L^*?j%1Uq`fCl5=N*R0eweD0e%o z88k|Gj{$3D&4G8~-q;*mC6urrtQAKWI=zh!M;@92Z!Epd+3>RJ*T=QCbP56hdI8>mfRs>mD!jo`&-sXdYeL*8p+Ue22$A@eOwSP+o5gybVx zj2 (n:Type) ?-> (m:Type) ?-> (x: n=>m=>a) : n=>m=>a = -- Todo: update in place without starting with a copy. diff --git a/examples/mandelbrot.dx b/examples/mandelbrot.dx index 700ede927..b384c5086 100644 --- a/examples/mandelbrot.dx +++ b/examples/mandelbrot.dx @@ -1,6 +1,6 @@ '# Mandelbrot set -include "examples/plot.dx" +include "plot.dx" 'Escape time algorithm diff --git a/examples/raytrace.dx b/examples/raytrace.dx index 9b51c3d4e..c02fda15a 100644 --- a/examples/raytrace.dx +++ b/examples/raytrace.dx @@ -6,7 +6,7 @@ described [here](https://blog.evjang.com/2019/11/jaxpt.html). Specifically, it's based on his unrolled ```lax.scan``` version. -include "examples/plot.dx" +include "plot.dx" ' ### Generic Helper Functions Some of these should probably go in prelude. diff --git a/examples/regression.dx b/examples/regression.dx index f33a84aec..61f99909e 100644 --- a/examples/regression.dx +++ b/examples/regression.dx @@ -1,6 +1,6 @@ '# Basis function regression -include "examples/plot.dx" +include "plot.dx" -- Conjugate gradients solver def solve (m:Type)?-> : m=>m=>Float -> m=>Float -> m=>Float = diff --git a/examples/sierpinski.dx b/examples/sierpinski.dx index cdfde70c1..815020c2c 100644 --- a/examples/sierpinski.dx +++ b/examples/sierpinski.dx @@ -1,6 +1,6 @@ '# Sierpinski triangle ("chaos game") -include "examples/plot.dx" +include "plot.dx" update : n=>Point -> Key -> Point -> Point = \points key (x,y). diff --git a/examples/simple-include-test.dx b/examples/simple-include-test.dx deleted file mode 100644 index b251c0378..000000000 --- a/examples/simple-include-test.dx +++ /dev/null @@ -1,7 +0,0 @@ - -include "examples/included.dx" -> 30 -> 40 - -:p x -> 10 diff --git a/examples/diagram.dx b/lib/diagram.dx similarity index 100% rename from examples/diagram.dx rename to lib/diagram.dx diff --git a/examples/plot.dx b/lib/plot.dx similarity index 98% rename from examples/plot.dx rename to lib/plot.dx index fde6d9754..0b32cde28 100644 --- a/examples/plot.dx +++ b/lib/plot.dx @@ -1,7 +1,7 @@ '# Plotting library -include "examples/diagram.dx" -include "examples/png.dx" +include "diagram.dx" +include "png.dx" data CompactSet a:Type = Interval a a diff --git a/examples/png.dx b/lib/png.dx similarity index 100% rename from examples/png.dx rename to lib/png.dx diff --git a/prelude.dx b/lib/prelude.dx similarity index 100% rename from prelude.dx rename to lib/prelude.dx diff --git a/makefile b/makefile index 3fb4fb955..1b93918b8 100644 --- a/makefile +++ b/makefile @@ -80,16 +80,20 @@ build-python: build # --- running tests --- # TODO: re-enable linear-tests ad-tests include-test chol -example-names = uexpr-tests adt-tests type-tests eval-tests show-tests \ - shadow-tests monad-tests \ - ad-tests mandelbrot pi sierpinski \ +example-names = mandelbrot pi sierpinski \ regression brownian_motion particle-swarm-optimizer \ - ode-integrator parser-tests serialize-tests \ - mcmc record-variant-tests simple-include-test ctc raytrace \ - isomorphisms typeclass-tests complex-tests trig-tests \ - ode-integrator linear_algebra fluidsim + ode-integrator mcmc ctc raytrace \ + isomorphisms ode-integrator linear_algebra fluidsim -quine-test-targets = $(example-names:%=run-%) +test-names = uexpr-tests adt-tests type-tests eval-tests show-tests \ + shadow-tests monad-tests \ + ad-tests parser-tests serialize-tests \ + record-variant-tests simple-include-test \ + typeclass-tests complex-tests trig-tests + +all-names = $(test-names:%=tests/%) $(example-names:%=examples/%) + +quine-test-targets = $(all-names:%=run-%) update-targets = $(example-names:%=update-%) @@ -102,7 +106,9 @@ quine-tests: $(quine-test-targets) quine-tests-interp: runinterp-eval-tests runinterp-ad-tests-interp runinterp-interp-tests run-%: export DEX_ALLOW_CONTRACTIONS=0 -run-%: examples/%.dx build +run-tests/%: tests/%.dx build + misc/check-quine $< $(dex) script --allow-errors +run-examples/%: examples/%.dx build misc/check-quine $< $(dex) script --allow-errors # Run these with profiling on while they're catching lots of crashes @@ -112,16 +118,16 @@ prop-tests: cbits/libdex.so update-all: $(update-targets) update-%: export DEX_ALLOW_CONTRACTIONS=0 -update-%: examples/%.dx build +update-%: tests/%.dx build $(dex) script --allow-errors $< > $<.tmp mv $<.tmp $< run-gpu-tests: export DEX_ALLOC_CONTRACTIONS=0 -run-gpu-tests: examples/gpu-tests.dx build +run-gpu-tests: tests/gpu-tests.dx build misc/check-quine $< $(dex) --backend LLVM-CUDA script --allow-errors update-gpu-tests: export DEX_ALLOW_CONTRACTIONS=0 -update-gpu-tests: examples/gpu-tests.dx build +update-gpu-tests: tests/gpu-tests.dx build $(dex) --backend LLVM-CUDA script --allow-errors $< > $<.tmp mv $<.tmp $< @@ -140,8 +146,8 @@ uexpr-tests: repl-test: misc/check-no-diff \ - examples/repl-multiline-test-expected-output \ - <($(dex) repl < examples/repl-multiline-test.dx) + tests/repl-multiline-test-expected-output \ + <($(dex) repl < tests/repl-multiline-test.dx) # --- running and querying benchmarks --- diff --git a/src/lib/TopLevel.hs b/src/lib/TopLevel.hs index 07e9a8e1e..b2ee837bb 100644 --- a/src/lib/TopLevel.hs +++ b/src/lib/TopLevel.hs @@ -112,7 +112,8 @@ evalSourceBlockM env block = case sbContents block of Just (ty, _) -> logTop (TextOut $ pprint ty) >> return mempty _ -> liftEitherIO $ throw UnboundVarErr $ pprint v IncludeSourceFile fname -> do - source <- liftIO $ readFile fname + fullPath <- liftIO $ findSourceFile fname + source <- liftIO $ readFile fullPath evalSourceBlocks env $ parseProg source UnParseable _ s -> liftEitherIO $ throw ParseErr s _ -> return mempty @@ -373,3 +374,7 @@ traverseLiterals block f = traverseAtomLiterals atom = case atom of Con (Lit x) -> lift $ lift $ f x _ -> traverseAtom def atom + +-- TODO: use something like a `DEXPATH` env var for finding source files +findSourceFile :: FilePath -> IO FilePath +findSourceFile fpath = return $ "lib/" ++ fpath diff --git a/src/resources/Resources.hs b/src/resources/Resources.hs index cb9c8cd3b..a21e98b29 100644 --- a/src/resources/Resources.hs +++ b/src/resources/Resources.hs @@ -12,4 +12,4 @@ dexrtBC :: B.ByteString dexrtBC = $(embedFile "src/lib/dexrt.bc") preludeSource :: String -preludeSource = B.unpack $ $(embedFile "prelude.dx") +preludeSource = B.unpack $ $(embedFile "lib/prelude.dx") diff --git a/examples/ad-tests-interp.dx b/tests/ad-tests-interp.dx similarity index 100% rename from examples/ad-tests-interp.dx rename to tests/ad-tests-interp.dx diff --git a/examples/ad-tests.dx b/tests/ad-tests.dx similarity index 100% rename from examples/ad-tests.dx rename to tests/ad-tests.dx diff --git a/examples/adt-tests.dx b/tests/adt-tests.dx similarity index 100% rename from examples/adt-tests.dx rename to tests/adt-tests.dx diff --git a/examples/complex-tests.dx b/tests/complex-tests.dx similarity index 100% rename from examples/complex-tests.dx rename to tests/complex-tests.dx diff --git a/examples/eval-tests.dx b/tests/eval-tests.dx similarity index 100% rename from examples/eval-tests.dx rename to tests/eval-tests.dx diff --git a/examples/flop-tests.dx b/tests/flop-tests.dx similarity index 100% rename from examples/flop-tests.dx rename to tests/flop-tests.dx diff --git a/examples/gpu-tests.dx b/tests/gpu-tests.dx similarity index 100% rename from examples/gpu-tests.dx rename to tests/gpu-tests.dx diff --git a/examples/include-test.dx b/tests/include-test.dx similarity index 93% rename from examples/include-test.dx rename to tests/include-test.dx index 87a74a4bd..c5127ef2f 100644 --- a/examples/include-test.dx +++ b/tests/include-test.dx @@ -1,12 +1,12 @@ -include "examples/included.dx" +include "included.dx" > 30 > 40 :p x > 10 -load dxo "examples/somedata.dxo" as dat +load dxo "somedata.dxo" as dat :t dat > (Float, 2, (2=>(3=>Float)), (2=>(Int, Bool))) @@ -27,7 +27,7 @@ load dxbo "test-scratch/bin-data-dump.dxbo" as dat2 load dxbo "not-a-file" as notData > IO error: not-a-file: openFile: does not exist (No such file or directory) -load dxbo "examples/bad-binary-file.dxbo" as badData +load dxbo "bad-binary-file.dxbo" as badData > IO error: unexpected number of buffers: [16,8] vs [8,8,8] > Validation error > Claimed header length: 128 diff --git a/examples/included.dx b/tests/included.dx similarity index 100% rename from examples/included.dx rename to tests/included.dx diff --git a/examples/interp-tests.dx b/tests/interp-tests.dx similarity index 100% rename from examples/interp-tests.dx rename to tests/interp-tests.dx diff --git a/examples/jax-tests.dx b/tests/jax-tests.dx similarity index 100% rename from examples/jax-tests.dx rename to tests/jax-tests.dx diff --git a/examples/linear-tests.dx b/tests/linear-tests.dx similarity index 100% rename from examples/linear-tests.dx rename to tests/linear-tests.dx diff --git a/examples/loopy-ad-tests.dx b/tests/loopy-ad-tests.dx similarity index 100% rename from examples/loopy-ad-tests.dx rename to tests/loopy-ad-tests.dx diff --git a/examples/monad-tests.dx b/tests/monad-tests.dx similarity index 100% rename from examples/monad-tests.dx rename to tests/monad-tests.dx diff --git a/examples/parser-tests.dx b/tests/parser-tests.dx similarity index 100% rename from examples/parser-tests.dx rename to tests/parser-tests.dx diff --git a/examples/record-variant-tests.dx b/tests/record-variant-tests.dx similarity index 100% rename from examples/record-variant-tests.dx rename to tests/record-variant-tests.dx diff --git a/examples/repl-multiline-test-expected-output b/tests/repl-multiline-test-expected-output similarity index 100% rename from examples/repl-multiline-test-expected-output rename to tests/repl-multiline-test-expected-output diff --git a/examples/repl-multiline-test.dx b/tests/repl-multiline-test.dx similarity index 100% rename from examples/repl-multiline-test.dx rename to tests/repl-multiline-test.dx diff --git a/examples/serialize-tests.dx b/tests/serialize-tests.dx similarity index 100% rename from examples/serialize-tests.dx rename to tests/serialize-tests.dx diff --git a/examples/shadow-tests.dx b/tests/shadow-tests.dx similarity index 100% rename from examples/shadow-tests.dx rename to tests/shadow-tests.dx diff --git a/examples/show-tests.dx b/tests/show-tests.dx similarity index 100% rename from examples/show-tests.dx rename to tests/show-tests.dx diff --git a/tests/simple-include-test.dx b/tests/simple-include-test.dx new file mode 100644 index 000000000..93cff77e6 --- /dev/null +++ b/tests/simple-include-test.dx @@ -0,0 +1,7 @@ + +include "../tests/included.dx" +> 30 +> 40 + +:p x +> 10 diff --git a/examples/somedata.dxo b/tests/somedata.dxo similarity index 100% rename from examples/somedata.dxo rename to tests/somedata.dxo diff --git a/examples/trig-tests.dx b/tests/trig-tests.dx similarity index 100% rename from examples/trig-tests.dx rename to tests/trig-tests.dx diff --git a/examples/type-tests.dx b/tests/type-tests.dx similarity index 100% rename from examples/type-tests.dx rename to tests/type-tests.dx diff --git a/examples/typeclass-tests.dx b/tests/typeclass-tests.dx similarity index 100% rename from examples/typeclass-tests.dx rename to tests/typeclass-tests.dx diff --git a/examples/uexpr-tests.dx b/tests/uexpr-tests.dx similarity index 100% rename from examples/uexpr-tests.dx rename to tests/uexpr-tests.dx diff --git a/examples/web-tests.dx b/tests/web-tests.dx similarity index 100% rename from examples/web-tests.dx rename to tests/web-tests.dx From 45e5454a1f9d5aa8f30b9d0fc29c982bcd8b93f1 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Fri, 18 Dec 2020 15:05:20 -0500 Subject: [PATCH 005/105] Fix GHC warnings. --- src/lib/Cat.hs | 1 - src/lib/Embed.hs | 1 - src/lib/JIT.hs | 1 - src/lib/Serialize.hs | 1 - src/lib/Syntax.hs | 1 - 5 files changed, 5 deletions(-) diff --git a/src/lib/Cat.hs b/src/lib/Cat.hs index 3edb9df41..aa6d703fa 100644 --- a/src/lib/Cat.hs +++ b/src/lib/Cat.hs @@ -18,7 +18,6 @@ module Cat (CatT, MonadCat, runCatT, look, extend, scoped, looks, extendLocal, -- Monad for tracking monoidal state import Control.Applicative -import Control.Monad.Fail import Control.Monad.State.Strict import Control.Monad.Reader import Control.Monad.Writer diff --git a/src/lib/Embed.hs b/src/lib/Embed.hs index 12349af6e..f75e89c67 100644 --- a/src/lib/Embed.hs +++ b/src/lib/Embed.hs @@ -34,7 +34,6 @@ module Embed (emit, emitTo, emitAnn, emitOp, buildDepEffLam, buildLamAux, buildP import Control.Applicative import Control.Monad -import Control.Monad.Fail import Control.Monad.Except hiding (Except) import Control.Monad.Reader import Control.Monad.Writer hiding (Alt) diff --git a/src/lib/JIT.hs b/src/lib/JIT.hs index 8ddf1e67e..a7776b1da 100644 --- a/src/lib/JIT.hs +++ b/src/lib/JIT.hs @@ -33,7 +33,6 @@ import Control.Monad.State.Strict import Control.Monad.Reader import Data.ByteString.Short (toShort) import qualified Data.ByteString.Char8 as B -import Data.List (concat) import Data.String import Data.Foldable import Data.Text.Prettyprint.Doc diff --git a/src/lib/Serialize.hs b/src/lib/Serialize.hs index 7275c6968..11b552be6 100644 --- a/src/lib/Serialize.hs +++ b/src/lib/Serialize.hs @@ -22,7 +22,6 @@ import Interpreter import Syntax import Type import PPrint -import Interpreter (indices) pprintVal :: Val -> IO String pprintVal val = asStr <$> prettyVal val diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index 0120cc14c..04052e457 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -66,7 +66,6 @@ module Syntax ( import qualified Data.Map.Strict as M import Control.Exception hiding (throw) -import Control.Monad.Fail import Control.Monad.Identity import Control.Monad.Writer hiding (Alt) import Control.Monad.Except hiding (Except) From d53771b26033e8d4e99ff4272d1799fb4f202a59 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Fri, 18 Dec 2020 16:31:33 -0500 Subject: [PATCH 006/105] Update examples and prelude to "modern Dex". --- examples/brownian_motion.dx | 26 ++++++------ examples/mandelbrot.dx | 14 +++---- examples/mcmc.dx | 20 +++++---- examples/ode-integrator.dx | 8 ++-- examples/pi.dx | 11 +++-- examples/raytrace.dx | 82 +++++++++++++++++++------------------ examples/regression.dx | 69 ++++++++++++++++--------------- examples/sierpinski.dx | 16 ++++---- lib/plot.dx | 4 ++ 9 files changed, 127 insertions(+), 123 deletions(-) diff --git a/examples/brownian_motion.dx b/examples/brownian_motion.dx index f97d0f759..4461e76ad 100644 --- a/examples/brownian_motion.dx +++ b/examples/brownian_motion.dx @@ -1,22 +1,22 @@ +include "plot.dx" UnitInterval = Float -bmIter : (Key & Float & Float & UnitInterval) -> (Key & Float & Float & UnitInterval) = - \(key, y, sigma, t). - (kDraw, kL, kR) = splitKey3 key - t' = abs (t - 0.5) - y' = sigma * randn kDraw * (0.5 - t') - key' = select (t > 0.5) kL kR - (key', y + y', sigma / sqrt 2.0, t' * 2.0) +def bmIter ((key, y, sigma, t):(Key & Float & Float & UnitInterval)) : + (Key & Float & Float & UnitInterval) = + (kDraw, kL, kR) = splitKey3 key + t' = abs (t - 0.5) + y' = sigma * randn kDraw * (0.5 - t') + key' = select (t > 0.5) kL kR + (key', y + y', sigma / sqrt 2.0, t' * 2.0) -sampleBM : Key -> UnitInterval -> Float = - \key t. - (_, y, _, _) = fold (key, 0.0, 1.0, t) \i:(Fin 10). bmIter - y +def sampleBM (key:Key) (t:UnitInterval) : Float = + (_, y, _, _) = fold (key, 0.0, 1.0, t) \i:(Fin 10). bmIter + y xs = linspace (Fin 1000) 0.0 1.0 ys = map (sampleBM (newKey 0)) xs --- :plot zip xs ys --- > +:html showPlot $ xyPlot xs ys +> diff --git a/examples/mandelbrot.dx b/examples/mandelbrot.dx index b384c5086..468ef17dd 100644 --- a/examples/mandelbrot.dx +++ b/examples/mandelbrot.dx @@ -4,17 +4,15 @@ include "plot.dx" 'Escape time algorithm -update : Complex -> Complex -> Complex = - \c z. c + (z * z) +def update (c:Complex) (z:Complex) : Complex = c + (z * z) tol = 2.0 -inBounds : Complex -> Bool = - \z. complex_abs z < tol +def inBounds (z:Complex) : Bool = complex_abs z < tol -escapeTime : Complex -> Float = - \c. fst $ fold (0.0, zero) $ \i:(Fin 1000) (n, z). - z' = update c z - (n + BToF (inBounds z'), z') +def escapeTime (c:Complex) : Float = + fst $ fold (0.0, zero) $ \i:(Fin 1000) (n, z). + z' = update c z + (n + BToF (inBounds z'), z') 'Evaluate on a grid and plot the results diff --git a/examples/mcmc.dx b/examples/mcmc.dx index 9bf7159ed..012370ff1 100644 --- a/examples/mcmc.dx +++ b/examples/mcmc.dx @@ -1,6 +1,8 @@ -- === General MCMC utilities === +include "plot.dx" + LogProb : Type = Float def runChain @@ -88,17 +90,17 @@ mhParams = 0.1 numSamples = 500 k0 = newKey 1 -hmcSamples = runChain randnVec (hmcStep hmcParams myLogProb) numSamples k0 mhSamples = runChain randnVec (mhStep mhParams myLogProb) numSamples k0 - -:p meanAndCovariance hmcSamples -> ([1.4468338, 2.4944723], [[1.065676, 2.047594e-2], [2.047594e-2, 5.288498e-2]]) - --- :plot for i. (IToF (ordinal i), hmcSamples.i.(0@_)) --- > +hmcSamples = runChain randnVec (hmcStep hmcParams myLogProb) numSamples k0 :p meanAndCovariance mhSamples > ([0.64555484, 2.4140575], [[0.38236195, 0.17941256], [0.17941256, 0.22895703]]) --- :plot for i. (IToF (ordinal i), mhSamples.i.(0@_)) --- > +:html showPlot $ yPlot (for i. mhSamples.i.(0@_)) +> + +:p meanAndCovariance hmcSamples +> ([1.4468338, 2.4944723], [[1.065676, 2.047594e-2], [2.047594e-2, 5.288498e-2]]) + +:html showPlot $ yPlot (for i. hmcSamples.i.(0@_)) +> diff --git a/examples/ode-integrator.dx b/examples/ode-integrator.dx index 0b87184e2..dc784b177 100644 --- a/examples/ode-integrator.dx +++ b/examples/ode-integrator.dx @@ -4,6 +4,8 @@ This version is a port of the [Jax implementation](https://github.com/google/jax One difference is that it uses a lower-triangular matrix type for the Butcher tableau, and so avoids zero-padding everywhere. +include "plot.dx" + Time = Float -- Should this go in the prelude? @@ -168,7 +170,5 @@ exact_e = [[exp 1.0]] times = linspace (Fin 100) 0.00001 1.0 ys = odeint myDyn z0 t0 times --- :plot --- ys' = for i. ys.i.(fromOrdinal _ 0) --- zip times ys' --- > +:html showPlot $ yPlot for i. ys.i.(fromOrdinal _ 0) +> diff --git a/examples/pi.dx b/examples/pi.dx index 46de500ac..6b2625f9c 100644 --- a/examples/pi.dx +++ b/examples/pi.dx @@ -1,20 +1,19 @@ '# Monte Carlo estimates of pi -estimatePiArea : Key -> Float = \key. +def estimatePiArea (key:Key) : Float = (k1, k2) = splitKey key x = rand k1 y = rand k2 inBounds = (sq x + sq y) < 1.0 4.0 * BToF inBounds -estimatePiAvgVal : Key -> Float = \key. +def estimatePiAvgVal (key:Key) : Float = x = rand key 4.0 * sqrt (1.0 - sq x) -meanAndStdDev : Int -> (Key -> Float) -> Key -> (Float & Float) = - \n f key. - samps = for i:(Fin n). many f key i - (mean samps, std samps) +def meanAndStdDev (n:Int) (f: Key -> Float) (key:Key) : (Float & Float) = + samps = for i:(Fin n). many f key i + (mean samps, std samps) numSamps = 1000000 diff --git a/examples/raytrace.dx b/examples/raytrace.dx index c02fda15a..bd9f693a2 100644 --- a/examples/raytrace.dx +++ b/examples/raytrace.dx @@ -4,7 +4,6 @@ [JAX implementation](https://github.com/ericjang/pt-jax/blob/master/jaxpt_vmap.ipynb), described [here](https://blog.evjang.com/2019/11/jaxpt.html). -Specifically, it's based on his unrolled ```lax.scan``` version. include "plot.dx" @@ -41,7 +40,7 @@ data IterResult a:Type b:Type = Done b -- A little iteration combinator --- TODO: allow effects (currently there's some type inference bug preventing it) +-- TODO: allow effects (bug #267) def iter (init:a) (body: Int -> a -> IterResult a b) : b = result = snd $ withState Nothing \resultRef. withState init \carryRef. @@ -59,18 +58,14 @@ def iter (init:a) (body: Int -> a -> IterResult a b) : b = ' ### 3D Helper Functions --- TODO: implement table unpacking -def unpackvec3 (p:Vec 3) : (Float & Float & Float) = - (p.(0@(Fin 3)), p.(1@(Fin 3)), p.(2@(Fin 3))) - def cross (a:Vec 3) (b:Vec 3) : Vec 3 = - (a1, a2, a3) = unpackvec3 a - (b1, b2, b3) = unpackvec3 b + [a1, a2, a3] = a + [b1, b2, b3] = b [a2 * b3 - a3 * b2, a3 * b1 - a1 * b3, a1 * b2 - a2 * b1] -- TODO: Use `data Color = Red | Green | Blue` and ADTs for index sets -def ColorImage (height:Int) (width:Int) : Type = Fin height => Fin width => Color -def GrayScaleImage (height:Int) (width:Int) : Type = Fin height => Fin width => Float +data Image = + MkImage height:Int width:Int (Fin height => Fin width => Color) xHat : Vec 3 = [1., 0., 0.] yHat : Vec 3 = [0., 1., 0.] @@ -81,19 +76,19 @@ Angle = Float -- angle in radians def rotateX (p:Vec 3) (angle:Angle) : Vec 3 = c = cos angle s = sin angle - (px, py, pz) = unpackvec3(p) + [px, py, pz] = p [px, c*py - s*pz, s*py + c*pz] def rotateY (p:Vec 3) (angle:Angle) : Vec 3 = c = cos angle s = sin angle - (px, py, pz) = unpackvec3(p) + [px, py, pz] = p [c*px + s*pz, py, - s*px+ c*pz] def rotateZ (p:Vec 3) (angle:Angle) : Vec 3 = c = cos angle s = sin angle - (px, py, pz) = unpackvec3(p) + [px, py, pz] = p [c*px - s*py, s*px+c*py, pz] def sampleCosineWeightedHemisphere (normal: Vec 3) (k:Key) : Vec 3 = @@ -112,7 +107,6 @@ def sampleCosineWeightedHemisphere (normal: Vec 3) (k:Key) : Vec 3 = ' ### Raytracer Distance = Float -def Image (n:Int) :Type = Fin n => Fin n => Color -- TODO: hide the size Position = Vec 3 Direction = Vec 3 -- Should be normalized. TODO: use a newtype wrapper @@ -142,7 +136,9 @@ Filter = Color -- TODO: use a record -- num samples, num bounces, share seed? -Params = (Int & Int & Bool) +Params = { numSamples : Int + & maxBounces : Int + & shareSeed : Bool } -- TODO: use a list instead, once they work data Scene n:Type = MkScene (n=>Object) @@ -258,12 +254,11 @@ def sampleLightRadiance radiance += coeff .* rayDirectRadiance scene outRay def trace (params:Params) (scene:Scene n) (init_ray:Ray) (k:Key) : Color = - (_, max_bounces, _) = params -- TODO: we ought to be able to use an accumulator here, but there's a bug noFilter = [1.0, 1.0, 1.0] iter (noFilter, zero, init_ray) $ \i (filter, radiance, ray). - case i >= max_bounces of + case i >= getAt #maxBounces params of True -> Done radiance False -> case raymarch scene ray of HitNothing -> Done radiance @@ -281,11 +276,16 @@ def trace (params:Params) (scene:Scene n) (init_ray:Ray) (k:Key) : Color = -- TODO: add number of pixels once we can hide sizes -- sensor half-width, pinhole-sensor distance, pinhole position -- (Assumes we're looking towards -z.) -Camera = (Position & Float & Float) +Camera = + { numPix : Int + & pos : Position + & halfWidth : Float + & sensorDist : Float } +-- TODO: might be better with an anonymous dependent pair for the result def cameraRays (n:Int) (camera:Camera) : Fin n => Fin n => (Key -> Ray) = -- images indexed from top-left - (pos, halfWidth, sensorDist) = camera + halfWidth = getAt #halfWidth camera pixHalfWidth = halfWidth / IToF n ys = reverse $ linspace (Fin n) (neg halfWidth) halfWidth xs = linspace (Fin n) (neg halfWidth) halfWidth @@ -293,23 +293,21 @@ def cameraRays (n:Int) (camera:Camera) : Fin n => Fin n => (Key -> Ray) = (kx, ky) = splitKey key x = xs.j + randuniform (-pixHalfWidth) pixHalfWidth kx y = ys.i + randuniform (-pixHalfWidth) pixHalfWidth ky - (pos, normalize [x, y, neg sensorDist]) + (getAt #pos camera, normalize [x, y, neg (getAt #sensorDist camera)]) -def takePicture - (params:Params) (scene:Scene m) (n:Int) (camera:Camera) - : ColorImage n n = - (numSamples, _, shareSeed) = params +def takePicture (params:Params) (scene:Scene m) (camera:Camera) : Image = + n = getAt #numPix camera rays = cameraRays n camera rootKey = newKey 0 image = for i j. - pixKey = case shareSeed of - True -> rootKey - False -> ixkey (ixkey rootKey i) j + pixKey = if getAt #shareSeed params + then rootKey + else ixkey (ixkey rootKey i) j sampleRayColor : Key -> Color = \k. (k1, k2) = splitKey k trace params scene (rays.i.j k1) k2 - sampleAveraged sampleRayColor numSamples pixKey - image / mean (for (i,j,k). image.i.j.k) + sampleAveraged sampleRayColor (getAt #numSamples params) pixKey + MkImage _ _ $ image / mean (for (i,j,k). image.i.j.k) ' ### Define the scene and render it @@ -331,25 +329,29 @@ theScene = MkScene $ , PassiveObject (Sphere [ 2.0, 2.0, -2.0] 1.5) (Mirror) ] -camera = (10.0 .* zHat, 0.3, 1.0) +defaultParams = { numSamples = 50 + , maxBounces = 10 + , shareSeed = True } --- num_pix = 250 -num_pix = 10 -num_samples = 50 -num_bounces = 10 -share_prng = True -params = (num_samples, num_bounces, share_prng) +defaultCamera = { numPix = 250 + , pos = 10.0 .* zHat + , halfWidth = 0.3 + , sensorDist = 1.0 } +-- We change to a small num pix here to reduce the compute needed for tests +camera = defaultCamera |> setAt #numPix 10 +params = defaultParams -- %time -image = takePicture params theScene num_pix camera - +(MkImage _ _ image) = takePicture params theScene camera :html imshow image > 'Just for fun, here's what we get with a single sample (sharing the PRNG key among pixels) -:html imshow $ - takePicture (1, num_bounces, share_prng) theScene num_pix camera +params2 = defaultParams |> setAt #numSamples 1 +(MkImage _ _ image2) = takePicture params2 theScene camera + +:html imshow image2 > diff --git a/examples/regression.dx b/examples/regression.dx index 61f99909e..0fa9e54eb 100644 --- a/examples/regression.dx +++ b/examples/regression.dx @@ -3,21 +3,20 @@ include "plot.dx" -- Conjugate gradients solver -def solve (m:Type)?-> : m=>m=>Float -> m=>Float -> m=>Float = - \mat b. - x0 = for i:m. 0.0 - ax = mat **. x0 - r0 = b - ax - (xOut, _, _) = fold (x0, r0, r0) $ - \s:m (x, r, p). - ap = mat **. p - alpha = vdot r r / vdot p ap - x' = x + alpha .* p - r' = r - alpha .* ap - beta = vdot r' r' / (vdot r r + 0.000001) - p' = r' + beta .* p - (x', r', p') - xOut +def solve (mat:m=>m=>Float) (b:m=>Float) : m=>Float = + x0 = for i:m. 0.0 + ax = mat **. x0 + r0 = b - ax + (xOut, _, _) = fold (x0, r0, r0) $ + \s:m (x, r, p). + ap = mat **. p + alpha = vdot r r / vdot p ap + x' = x + alpha .* p + r' = r - alpha .* ap + beta = vdot r' r' / (vdot r r + 0.000001) + p' = r' + beta .* p + (x', r', p') + xOut 'Make some synthetic data @@ -25,41 +24,42 @@ Nx = Fin 100 noise = 0.1 (k1, k2) = splitKey (newKey 0) -trueFun : Float -> Float = - \x. x + sin (5.0 * x) +def trueFun (x:Float) : Float = + x + sin (5.0 * x) xs : Nx=>Float = for i. rand (ixkey k1 i) ys : Nx=>Float = for i. trueFun xs.i + noise * randn (ixkey k2 i) --- :html showPlot $ xyPlot xs ys +:html showPlot $ xyPlot xs ys +> 'Implement basis function regression -regress : (Float -> d=>Float) -> n=>Float -> n=>Float -> d=>Float = - \featurize xRaw y. - x = map featurize xRaw - xT = transpose x - solve (xT ** x) (xT **. y) +def regress (featurize: Float -> d=>Float) (xRaw:n=>Float) (y:n=>Float) : d=>Float = + x = map featurize xRaw + xT = transpose x + solve (xT ** x) (xT **. y) 'Fit a third-order polynomial -poly : Float -> d=>Float = - \x. for i. pow x (IToF (ordinal i)) +def poly (x:Float) : d=>Float = + for i. pow x (IToF (ordinal i)) params : (Fin 4)=>Float = regress poly xs ys -predict : Float -> Float = - \x. vdot params (poly x) +def predict (x:Float) : Float = + vdot params (poly x) xsTest = linspace (Fin 200) 0.0 1.0 --- :html showPlot $ xyPlot xsTest (map predict xsTest) +:html showPlot $ xyPlot xsTest (map predict xsTest) +> 'RMS error -rmsErr : n=>Float -> n=>Float -> Float = - \truth pred. sqrt $ mean for i. sq (pred.i - truth.i) +def rmsErr (truth:n=>Float) (pred:n=>Float) : Float = + sqrt $ mean for i. sq (pred.i - truth.i) :p rmsErr ys (map predict xs) > 0.25269496 @@ -73,7 +73,8 @@ def tabCat (xs:n=>a) (ys:m=>a) : ({left:n|right:m})=>a = xsPlot = tabCat xs xsTest ysPlot = tabCat ys $ map predict xsTest --- :html showPlot $ xycPlot xsPlot ysPlot $ --- for i. case i of --- {| left = _ |} -> 0.0 --- {| right = _ |} -> 1.0 +:html showPlot $ xycPlot xsPlot ysPlot $ + for i. case i of + {| left = _ |} -> 0.0 + {| right = _ |} -> 1.0 +> diff --git a/examples/sierpinski.dx b/examples/sierpinski.dx index 815020c2c..64a8f8aea 100644 --- a/examples/sierpinski.dx +++ b/examples/sierpinski.dx @@ -2,18 +2,16 @@ include "plot.dx" -update : n=>Point -> Key -> Point -> Point = - \points key (x,y). - (x', y') = points.(randIdx key) - (0.5 * (x + x'), 0.5 * (y + y')) +def update (points:n=>Point) (key:Key) ((x,y):Point) : Point = + (x', y') = points.(randIdx key) + (0.5 * (x + x'), 0.5 * (y + y')) -runChain : n:Int -> (Key -> a -> a) -> Key -> a -> (Fin n)=>a = - \n f key x0. scan' x0 (many f key) +def runChain (n:Int) (f:Key -> a -> a) (key:Key) (x0:a) : Fin n => a = + scan' x0 (many f key) trianglePoints : (Fin 3)=>Point = [(0.0, 0.0), (1.0, 0.0), (0.5, sqrt 0.75)] (xs, ys) = unzip $ runChain 3000 (update trianglePoints) (newKey 0) (0.0, 0.0) --- Disabling this for now because the plotting function is too slow to compile --- :html showPlot $ xyPlot xs ys --- > +:html showPlot $ xyPlot xs ys +> diff --git a/lib/plot.dx b/lib/plot.dx index 0b32cde28..396bd7498 100644 --- a/lib/plot.dx +++ b/lib/plot.dx @@ -109,6 +109,10 @@ def xycPlot (xs:n=>Float) (ys:n=>Float) (cs:n=>Float) : Plot n Float Float Float setYData (autoScale ys) |> setCData (autoScale cs) +def yPlot (ys:n=>Float) : Plot n Float Float Unit = + xs = for i. IToF $ ordinal i + xyPlot xs ys + -- xs = linspace (Fin 100) 0. 1.0 -- :html showPlot $ xycPlot xs xs xs From 7ce56f72003106d263dcd1bddb013463c715c767 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Fri, 18 Dec 2020 22:17:52 -0500 Subject: [PATCH 007/105] Delete duplicate prelude left over from merge --- prelude.dx | 610 ----------------------------------------------------- 1 file changed, 610 deletions(-) delete mode 100644 prelude.dx diff --git a/prelude.dx b/prelude.dx deleted file mode 100644 index 6364bbfec..000000000 --- a/prelude.dx +++ /dev/null @@ -1,610 +0,0 @@ - -'## Dex prelude - -'Runs before every Dex program unless an alternative is provided with `--prelude`. - -'Wrappers around primitives - -Unit = %UnitType -Type = %TyKind -Effects = %EffKind -Fields = %LabeledRowKind - -Int64 = %Int64 -Int32 = %Int32 -Int8 = %Int8 -Float64 = %Float64 -Float32 = %Float32 - -Int = Int64 -Float = Float64 - -def (&) (a:Type) (b:Type) : Type = %PairType a b -def (,) (x:a) (y:b) : (a & b) = %pair x y -def fst (p: (a & b)) : a = %fst p -def snd (p: (a & b)) : b = %snd p - -def idiv (x:Int) (y:Int) : Int = %idiv x y -def rem (x:Int) (y:Int) : Int = %irem x y -def ipow (x:Int) (y:Int) : Int = %ipow x y - -def fdiv (x:Float) (y:Float) : Float = %fdiv x y - -def internalCast (b:Type) (x:a) : b = %cast b x - -def F64ToF (x : Float64) : Float = internalCast _ x -def F32ToF (x : Float32) : Float = internalCast _ x -def FToF64 (x : Float) : Float64 = internalCast _ x -def FToF32 (x : Float) : Float32 = internalCast _ x -def I64ToI (x : Int64) : Int = internalCast _ x -def I32ToI (x : Int32) : Int = internalCast _ x -def I8ToI (x : Int8 ) : Int = internalCast _ x -def IToI64 (x : Int) : Int64 = internalCast _ x -def IToI32 (x : Int) : Int32 = internalCast _ x -def IToI8 (x : Int) : Int8 = internalCast _ x - -data Add a:Type = - MkAdd (a->a->a) (a->a->a) a -- add, sub, zero - -def (+) (d:Add a) ?=> : a -> a -> a = case d of MkAdd add _ _ -> add -def (-) (d:Add a) ?=> : a -> a -> a = case d of MkAdd _ sub _ -> sub -def zero (d:Add a) ?=> : a = case d of MkAdd _ _ zero -> zero - -@instance float64Add : Add Float64 = MkAdd (\x:Float64 y:Float64. %fadd x y) (\x y. %fsub x y) (FToF64 0.0) -@instance float32Add : Add Float32 = MkAdd (\x:Float32 y:Float32. %fadd x y) (\x y. %fsub x y) (FToF32 0.0) -@instance int64Add : Add Int64 = MkAdd (\x:Int64 y:Int64. %iadd x y) (\x y. %isub x y) (IToI64 0) -@instance int32Add : Add Int32 = MkAdd (\x:Int32 y:Int32. %iadd x y) (\x y. %isub x y) (IToI32 0) -@instance int8Add : Add Int8 = MkAdd (\x:Int8 y:Int8. %iadd x y) (\x y. %isub x y) (IToI8 0) -@instance unitAdd : Add Unit = MkAdd (\x y. ()) (\x y. ()) () - -@instance tabAdd : Add a ?=> Add (n=>a) = - (MkAdd ( \xs ys. for i. xs.i + ys.i ) - ( \xs ys. for i. xs.i - ys.i ) - ( for _. zero )) - -data Mul a:Type = MkMul (a->a->a) a -- multiply, one - -def (*) (d:Mul a) ?=> : a -> a -> a = case d of MkMul mul _ -> mul -def one (d:Mul a) ?=> : a = case d of MkMul _ one -> one - -@instance floatMul : Mul Float = MkMul (\x:Float y:Float. %fmul x y) 1.0 -@instance intMul : Mul Int = MkMul (\x:Int y:Int. %imul x y) 1 -@instance unitMul : Mul Unit = MkMul (\x y. ()) () - -data VSpace a:Type = MkVSpace (Add a) (Float -> a -> a) - -@superclass -def addFromVSpace (d:VSpace a) : Add a = case d of MkVSpace addDict _ -> addDict - -flip : (a -> b -> c) -> (b -> a -> c) = \f x y. f y x -uncurry : (a -> b -> c) -> (a & b) -> c = \f (x,y). f x y -const : a -> b -> a = \x _. x - -def (.*) (d:VSpace a) ?=> : Float -> a -> a = case d of MkVSpace _ scale -> scale -(*.) : VSpace a ?=> a -> Float -> a = flip (.*) -def (/) (_:VSpace a) ?=> (v:a) (s:Float) : a = (fdiv 1.0 s) .* v -def neg (_:VSpace a) ?=> (v:a) : a = (-1.0) .* v - -@instance floatVS : VSpace Float = MkVSpace float64Add (*) -@instance tabVS : VSpace a ?=> VSpace (n=>a) = MkVSpace tabAdd \s xs. for i. s .* xs.i -@instance unitVS : VSpace Unit = MkVSpace unitAdd \s u. () - -data Bool = - False - True - -def BToI8 (x : Bool) : Int8 = - case x of - False -> (IToI8 0) - True -> (IToI8 1) - -def I8ToB (x : Int8) : Bool = - t = True - f = False - %select x t f - -def (&&) (x:Bool) (y:Bool) : Bool = - x' = BToI8 x - y' = BToI8 y - I8ToB $ %and x' y' - -def (||) (x:Bool) (y:Bool) : Bool = - x' = BToI8 x - y' = BToI8 y - I8ToB $ %or x' y' - -def not (x:Bool) : Bool = - x' = BToI8 x - I8ToB $ %not x' - -'Sum types - -data Maybe a:Type = - Nothing - Just a - -def isNothing (x:Maybe a) : Bool = case x of - Nothing -> True - Just _ -> False - -data (|) a:Type b:Type = - Left a - Right b - -def select (p:Bool) (x:a) (y:a) : a = case p of - True -> x - False -> y - -def b2i (x:Bool) : Int = - case x of - False -> 0 - True -> 1 - -def IToF (x:Int) : Float = internalCast _ x -def FToI (x:Float) : Int = internalCast _ x -def b2r (x:Bool) : Float = IToF (b2i x) -def todo (a:Type) ?-> : a = %throwError a - -'Effects - -def Ref (r:Type) (a:Type) : Type = %Ref r a -def get (ref:Ref h s) : {State h} s = %get ref -def (:=) (ref:Ref h s) (x:s) : {State h} Unit = %put ref x -def ask (ref:Ref h r) : {Read h} r = %ask ref -def (+=) (ref:Ref h w) (x:w) : {Accum h} Unit = %tell ref x -def (!) (ref:Ref h (n=>a)) (i:n) : Ref h a = %indexRef ref i -def fstRef (ref: Ref h (a & b)) : Ref h a = %fstRef ref -def sndRef (ref: Ref h (a & b)) : Ref h b = %sndRef ref - -def withReader - (eff:Effects) ?-> (a:Type) ?-> (r:Type) ?-> - (init:r) (action: (h:Type ?-> Ref h r -> {Read h|eff} a)) - : {|eff} a = - def explicitAction (h':Type) (ref:Ref h' r) : {Read h'|eff} a = action ref - %runReader init explicitAction - -def withAccum - (eff:Effects) ?-> (a:Type) ?-> (w:Type) ?-> - (action: (h:Type ?-> Ref h w -> {Accum h|eff} a)) - : {|eff} (a & w) = - def explicitAction (h':Type) (ref:Ref h' w) : {Accum h'|eff} a = action ref - %runWriter explicitAction - -def withState - (eff:Effects) ?-> (a:Type) ?-> (s:Type) ?-> - (init:s) - (action: (h:Type ?-> Ref h s -> {State h |eff} a)) - : {|eff} (a & s) = - def explicitAction (h':Type) (ref:Ref h' s) : {State h'|eff} a = action ref - %runState init explicitAction - -'Type classes - -data Eq a:Type = MkEq (a -> a -> Bool) -data Ord a:Type = MkOrd (Eq a) (a -> a -> Bool) (a -> a -> Bool) -- eq, gt, lt - -@superclass -def eqFromOrd (d:Ord a) : Eq a = case d of MkOrd eq _ _ -> eq - -def (==) (d:Eq a) ?=> (x:a) (y:a) : Bool = case d of MkEq eq -> eq x y -def (/=) (d:Eq a) ?=> (x:a) (y:a) : Bool = not $ x == y - -def (>) (d:Ord a) ?=> (x:a) (y:a) : Bool = case d of MkOrd _ gt _ -> gt x y -def (<) (d:Ord a) ?=> (x:a) (y:a) : Bool = case d of MkOrd _ _ lt -> lt x y -def (<=) (d:Ord a) ?=> (x:a) (y:a) : Bool = x=) (d:Ord a) ?=> (x:a) (y:a) : Bool = x>y || x==y - -@instance intEq : Eq Int = MkEq \x:Int y:Int. I8ToB $ %ieq x y -@instance floatEq : Eq Float = MkEq \x:Float y:Float. I8ToB $ %feq x y -@instance unitEq : Eq Unit = MkEq \x y. True - -@instance intOrd : Ord Int = (MkOrd intEq (\x y. I8ToB $ %igt x y) - (\x y. I8ToB $ %ilt x y)) -@instance floatOrd : Ord Float = (MkOrd floatEq (\x y. I8ToB $ %fgt x y) - (\x y. I8ToB $ %flt x y)) -@instance unitOrd : Ord Unit = MkOrd unitEq (\x y. False) (\x y. False) - -@instance -def pairEq (eqA: Eq a)?=> (eqB: Eq b)?=> : Eq (a & b) = MkEq $ - \(x1,x2) (y1,y2). x1 == y1 && x2 == y2 - -@instance -def pairOrd (ordA: Ord a)?=> (ordB: Ord b)?=> : Ord (a & b) = - pairGt = \(x1,x2) (y1,y2). x1 > y1 || (x1 == y1 && x2 > y2) - pairLt = \(x1,x2) (y1,y2). x1 < y1 || (x1 == y1 && x2 < y2) - MkOrd pairEq pairGt pairLt - --- TODO: accumulate using the True/&& monoid -@instance -def tabEq (n:Type) ?-> (eqA: Eq a) ?=> : Eq (n=>a) = MkEq $ - \xs ys. - numDifferent : Float = - snd $ withAccum \ref. for i. - ref += (IToF (b2i (xs.i /= ys.i))) - numDifferent == 0.0 - -'Wrappers around C library functions - -def exp (x:Float) : Float = %exp x -def exp2 (x:Float) : Float = %exp2 x -def log (x:Float) : Float = %log x -def log2 (x:Float) : Float = %log2 x -def log10 (x:Float) : Float = %log10 x - -def sin (x:Float) : Float = %sin x -def cos (x:Float) : Float = %cos x -def tan (x:Float) : Float = %tan x - -def floor (x:Float) : Float = %floor x -def ceil (x:Float) : Float = %ceil x -def round (x:Float) : Float = %round x - -def sqrt (x:Float) : Float = %sqrt x -def pow (x:Float) (y:Float) : Float = %fpow x y - -def lgamma (x:Float) : Float = - x64 = FToF64 x - F64ToF $ %ffi lgamma Float64 x64 -def log1p (x:Float) : Float = - x64 = FToF64 x - F64ToF $ %ffi log1p Float64 x64 -def lbeta (x:Float) (y:Float) : Float = lgamma x + lgamma y - lgamma (x + y) - -'Working with index sets - -def Range (low:Int) (high:Int) : Type = %IntRange low high -def Fin (n:Int) : Type = Range 0 n -def ordinal (i:a) : Int = %asint i -def size (n:Type) : Int = %idxSetSize n -def fromOrdinal (n:Type) (i:Int) : n = %asidx n i -def asidx (n:Type) ?-> (i:Int) : n = fromOrdinal n i -def (@) (i:Int) (n:Type) : n = fromOrdinal n i -def ixadd (n:Type) ?-> (i:n) (x:Int) : n = fromOrdinal n $ ordinal i + x -def ixsub (n:Type) ?-> (i:n) (x:Int) : n = fromOrdinal n $ ordinal i - x -def iota (n:Type) : n=>Int = for i. ordinal i - --- TODO: we want Eq and Ord for all index sets, not just `Fin n` -@instance -def finEq (n:Int) ?-> : Eq (Fin n) = MkEq \x y. ordinal x == ordinal y - -@instance -def finOrd (n:Int) ?-> : Ord (Fin n) = - MkOrd finEq (\x y. ordinal x > ordinal y) (\x y. ordinal x < ordinal y) - -'Misc - -pi : Float = 3.141592653589793 - -def id (x:a) : a = x -def dup (x:a) : (a & a) = (x, x) --- TODO: unpack pair in args once we fix the bug -def swap (p:(a&b)) : (b&a) = (snd p, fst p) -def map (f:a->{|eff} b) (xs: n=>a) : {|eff} (n=>b) = for i. f xs.i -def zip (xs:n=>a) (ys:n=>b) : (n=>(a&b)) = for i. (xs.i, ys.i) -def unzip (xys:n=>(a&b)) : (n=>a & n=>b) = (map fst xys, map snd xys) -def fanout (n:Type) (x:a) : n=>a = for i. x -def sq (d:Mul a) ?=> (x:a) : a = x * x -def abs (x:Float) : Float = select (x > 0.0) x (-x) -def mod (x:Int) (y:Int) : Int = rem (y + rem x y) y -def compose (f:b->c) (g:a->b) (x:a) : c = f (g x) - -def reverse (x:n=>a) : n=>a = - s = size n - for i. x.((s - 1 - ordinal i)@_) - -def slice (xs:n=>a) (start:Int) (m:Type) : m=>a = - for i. xs.(fromOrdinal _ (ordinal i + start)) - -def scan (init:a) (body:n->a->(a&b)) : (a & n=>b) = - swap $ withState init \s. for i. - c = get s - (c', y) = body i c - s := c' - y - -def fold (init:a) (body:(n->a->a)) : a = fst $ scan init \i x. (body i x, ()) -def reduce (identity:a) (binop:(a->a->a)) (xs:n=>a) : a = - -- `binop` should be a commutative and associative, and form a - -- commutative monoid with `identity` - -- TODO: implement with a parallelizable monoid-parameterized writer - fold identity (\i c. binop c xs.i) - --- TODO: call this `scan` and call the current `scan` something else -def scan' (init:a) (body:n->a->a) : n=>a = snd $ scan init \i x. dup (body i x) --- TODO: allow tables-via-lambda and get rid of this -def fsum (xs:n->Float) : Float = snd $ withAccum \ref. for i. ref += xs i -def sum (_: Add v) ?=> (xs:n=>v) : v = reduce zero (+) xs -def prod (_: Mul v) ?=> (xs:n=>v) : v = reduce one (*) xs -def mean (n:Type) ?-> (xs:n=>Float) : Float = sum xs / IToF (size n) -def std (xs:n=>Float) : Float = sqrt $ mean (map sq xs) - sq (mean xs) -def any (xs:n=>Bool) : Bool = reduce False (||) xs -def all (xs:n=>Bool) : Bool = reduce True (&&) xs - -def while - (eff:Effects) ?-> - (cond: Unit -> {|eff} Bool) - (body: Unit -> {|eff} Unit) - : {|eff} Unit = - cond' : Unit -> {|eff} Int8 = \_. BToI8 $ cond () - %while cond' body - -data IterResult a:Type b:Type = - Continue a - Done b - --- A little iteration combinator --- TODO: allow effects (currently there's some type inference bug preventing it) -def iter (init:a) (body: Int -> a -> IterResult a b) : b = - result = snd $ withState Nothing \resultRef. - withState init \carryRef. - withState 0 \i. - while (\(). isNothing (get resultRef)) \(). - case body (get i) (get carryRef) of - Continue carry -> - i := get i + 1 - carryRef := carry - Done result -> - resultRef := Just result - case result of - Just ans -> ans - Nothing -> todo -- should be unreachable - --- returns the highest index `i` such that `xs.i <= x` -def searchSorted (_:Ord a) ?=> (xs:n=>a) (x:a) : Maybe n = - case size n == 0 of - True -> Nothing - False -> case x < xs.(fromOrdinal _ 0) of - True -> Nothing - False -> - iter (0, size n) \_ (low, high). - numLeft = high - low - case numLeft == 1 of - True -> Done $ Just $ fromOrdinal _ low - False -> - centerIx = low + idiv (high - low) 2 - case x < xs.(fromOrdinal _ centerIx) of - True -> Continue (low, centerIx) - False -> Continue (centerIx, high) - -def applyN (n:Int) (x:a) (f:a -> a) : a = - snd $ withState x \ref. for _:(Fin n). - ref := f (get ref) - -def linspace (n:Type) (low:Float) (high:Float) : n=>Float = - dx = (high - low) / IToF (size n) - for i:n. low + IToF (ordinal i) * dx - -def transpose (x:n=>m=>Float) : m=>n=>Float = for i j. x.j.i -def vdot (x:n=>Float) (y:n=>Float) : Float = fsum \i. x.i * y.i - --- matmul. Better symbol to use? `@`? -(**) : (l=>m=>Float) -> (m=>n=>Float) -> (l=>n=>Float) = \x y. - y' = transpose y - for i k. fsum \j. x.i.j * y'.k.j - -(**.) : (n=>m=>Float) -> (m=>Float) -> (n=>Float) = \mat v. for i. vdot mat.i v -(.**) : (m=>Float) -> (n=>m=>Float) -> (n=>Float) = flip (**.) - -def inner (x:n=>Float) (mat:n=>m=>Float) (y:m=>Float) : Float = - fsum \(i,j). x.i * mat.i.j * y.j - -'min / max etc - -def minBy (_:Ord o) ?=> (f:a->o) (x:a) (y:a) : a = select (f x < f y) x y -def maxBy (_:Ord o) ?=> (f:a->o) (x:a) (y:a) : a = select (f x > f y) x y - -def min (_:Ord o) ?=> (x1: o) -> (x2: o) : o = minBy id x1 x2 -def max (_:Ord o) ?=> (x1: o) -> (x2: o) : o = maxBy id x1 x2 - -def minimumBy (_:Ord o) ?=> (f:a->o) (xs:n=>a) : a = - reduce xs.(0@_) (minBy f) xs -def maximumBy (_:Ord o) ?=> (f:a->o) (xs:n=>a) : a = - reduce xs.(0@_) (maxBy f) xs - -def minimum (_:Ord o) ?=> (xs:n=>o) : o = minimumBy id xs -def maximum (_:Ord o) ?=> (xs:n=>o) : o = maximumBy id xs - -def argmin (_:Ord o) ?=> (xs:n=>o) : n = - zeroth = (0@_, xs.(0@_)) - compare = \(idx1, x1) (idx2, x2). - select (x1 < x2) (idx1, x1) (idx2, x2) - zipped = for i. (i, xs.i) - fst $ reduce zeroth compare zipped - -'Functions for working with the pseudorandom number generator - --- TODO: newtype -Key = Int64 - -def hash (x:Key) (y:Int) : Key = - y64 = IToI64 y - %ffi threefry2x32 Int64 x y64 -def newKey (x:Int) : Key = hash (IToI64 0) x -def splitKey (k:Key) : (Key & Key) = (hash k 0, hash k 1) -def splitKey3 (k:Key) : (Key & Key & Key) = - (k1, k') = splitKey k - (k2, k3) = splitKey k' - (k1, k2, k3) - -def many (f:Key->a) (k:Key) (i:n) : a = f (hash k (ordinal i)) -def ixkey (k:Key) (i:n) : Key = hash k (ordinal i) -def ixkey2 (k:Key) (i:n) (j:m) : Key = hash (hash k (ordinal i)) (ordinal j) -def rand (k:Key) : Float = F64ToF $ %ffi randunif Float64 k -def randVec (n:Int) (f: Key -> a) (k: Key) : Fin n => a = - for i:(Fin n). f (ixkey k i) - -def randn (k:Key) : Float = - (k1, k2) = splitKey k - u1 = rand k1 - u2 = rand k2 - sqrt ((-2.0) * log u1) * cos (2.0 * pi * u2) - -def randIdx (n:Type) ?-> (k:Key) : n = - unif = rand k - fromOrdinal n $ FToI $ floor $ unif * IToF (size n) - -def bern (p:Float) (k:Key) : Bool = rand k < p - -def randnVec (n:Type) ?-> (k:Key) : n=>Float = - for i. randn (ixkey k i) - -def cumSum (xs: n=>Float) : n=>Float = - fst $ withState 0.0 \total. - for i. - newTotal = get total + xs.i - total := newTotal - newTotal - -def cumSumLow (xs: n=>Float) : n=>Float = - fst $ withState 0.0 \total. - for i. - oldTotal = get total - total := oldTotal + xs.i - oldTotal - --- cdf should include 0.0 but not 1.0 -def categoricalFromCDF (cdf: n=>Float) (key: Key) : n = - r = rand key - case searchSorted cdf r of - Just i -> i - -def normalizePdf (xs: d=>Float) : d=>Float = xs / sum xs - -def cdfForCategorical (logprobs: n=>Float) : n=>Float = - maxLogProb = maximum logprobs - cumSumLow $ normalizePdf $ map exp $ for i. logprobs.i - maxLogProb - -def categorical (logprobs: n=>Float) (key: Key) : n = - categoricalFromCDF (cdfForCategorical logprobs) key - --- batch variant to share the work of forming the cumsum --- (alternatively we could rely on hoisting of loop constants) -def categoricalBatch (logprobs: n=>Float) (key: Key) : m=>n = - cdf = cdfForCategorical logprobs - for i. categoricalFromCDF cdf $ ixkey key i - -'Automatic differentiation - --- TODO: add vector space constraints -def linearize (f:a->b) (x:a) : (b & a --o b) = %linearize f x -def jvp (f:a->b) (x:a) : a --o b = snd (linearize f x) -def transposeLinear (f:a --o b) : b --o a = %linearTranspose f - -def vjp (f:a->b) (x:a) : (b & b --o a) = - (y, df) = linearize f x - (y, transposeLinear df) - -def grad (f:a->Float) (x:a) : a = snd (vjp f x) 1.0 - -def deriv (f:Float->Float) (x:Float) : Float = jvp f x 1.0 - -def derivRev (f:Float->Float) (x:Float) : Float = snd (vjp f x) 1.0 - -def checkDerivBase (f:Float->Float) (x:Float) : Bool = - -- TODO: parse 1e-5 - eps = 0.00005 - ansFwd = deriv f x - ansRev = derivRev f x - ansNumeric = (f (x + eps) - f (x - eps)) / (2. * eps) - isClose = \a b. abs (a - b) < 0.001 - isClose ansFwd ansNumeric && isClose ansRev ansNumeric - -def checkDeriv (f:Float->Float) (x:Float) : Bool = - checkDerivBase f x && checkDerivBase (deriv f) x - -'Vector support - --- TODO: Reenable vector suport once fixed-width types are supported. --- def UNSAFEFromOrdinal (n : Type) (i : Int) : n = %unsafeAsIndex n i --- --- VectorWidth = 4 -- XXX: Keep this synced with the constant defined in Array.hs --- VectorFloat = todo --- --- def packVector (a : Float) (b : Float) (c : Float) (d : Float) : VectorFloat = %vectorPack a b c d --- def indexVector (v : VectorFloat) (i : Fin VectorWidth) : Float = %vectorIndex v i --- --- -- NB: Backends should be smart enough to optimize this to a vector load from v --- def loadVector (v : (Fin VectorWidth)=>Float) : VectorFloat = --- idx = Fin VectorWidth --- (packVector v.(UNSAFEFromOrdinal idx 0) --- v.(UNSAFEFromOrdinal idx 1) --- v.(UNSAFEFromOrdinal idx 2) --- v.(UNSAFEFromOrdinal idx 3)) --- def storeVector (v : VectorFloat) : (Fin VectorWidth)=>Float = --- idx = Fin VectorWidth --- [ indexVector v (UNSAFEFromOrdinal idx 0) --- , indexVector v (UNSAFEFromOrdinal idx 1) --- , indexVector v (UNSAFEFromOrdinal idx 2) --- , indexVector v (UNSAFEFromOrdinal idx 3) ] --- --- def broadcastVector (v : Float) : VectorFloat = packVector v v v v --- --- @instance vectorFloatAdd : Add VectorFloat = --- (MkAdd ( \x y. %vfadd x y ) --- ( \x y. %vfsub x y ) --- ( broadcastVector zero )) --- @instance vectorFloatMul : Mul VectorFloat = --- MkMul (\x y. %vfmul x y) $ packVector 1.0 1.0 1.0 1.0 --- @instance vectorFloatVSpace : VSpace VectorFloat = --- MkVSpace vectorFloatAdd \x v. broadcastVector x * v - -'Tiling - -def Tile (n : Type) (m : Type) : Type = %IndexSlice n m - --- One can think of instances of `Tile n m` as injective functions `m -> n`, --- with the special property that consecutive elements of m map to consecutive --- elements of n. In this view (+>) is just function application, while ++> --- is currying followed by function application. We cannot represent currying --- in isolation, because `Tile n (Tile u v)` does not make sense, unlike `Tile n (u & v)`. -def (+>) (l : Type) ?-> (t:Tile n l) (i : l) : n = %sliceOffset t i -def (++>) (t : Tile n (u & v)) (i : u) : Tile n v = %sliceCurry t i - -def tile (l : Type) ?-> - (fTile : (t:(Tile n l) -> {|eff} l=>a)) - (fScalar : n -> {|eff} a) : {|eff} n=>a = %tiled fTile fScalar -def tile1 (n : Type) ?-> (l : Type) ?-> (m : Type) ?-> - (fTile : (t:(Tile n l) -> {|eff} m=>l=>a)) - (fScalar : n -> {|eff} m=>a) : {|eff} m=>n=>a = %tiledd fTile fScalar - --- TODO: This should become just `loadVector $ for i. arr.(t +> i)` --- once we are able to eliminate temporary arrays. Until then, we inline for performance... ---def loadTile (t : Tile n (Fin VectorWidth)) (arr : n=>Float) : VectorFloat = --- idx = Fin VectorWidth --- (packVector arr.(t +> UNSAFEFromOrdinal idx 0) --- arr.(t +> UNSAFEFromOrdinal idx 1) --- arr.(t +> UNSAFEFromOrdinal idx 2) --- arr.(t +> UNSAFEFromOrdinal idx 3)) - -'Numerical utilities - -def logsumexp (x: n=>Float) : Float = - m = maximum x - m + (log $ sum for i. exp (x.i - m)) - -def logsoftmax (x: n=>Float) : n=>Float = - lse = logsumexp x - for i. x.i - lse - -def softmax (x: n=>Float) : n=>Float = - m = maximum x - e = for i. exp (x.i - m) - s = sum e - for i. e.i / s - -def evalpoly (_:VSpace v) ?=> (coefficients:n=>v) (x:Float) : v = - -- Evaluate a polynomial at x. Same as Numpy's polyval. - fold zero \i c. coefficients.i + x .* c - -data List a:Type = - AsList n:Int foo:(Fin n => a) - -def (<>) (x:List a) (y:List a) : List a = - (AsList nx xs) = x - (AsList ny ys) = y - nz = nx + ny - AsList _ $ for i:(Fin nz). - i' = ordinal i - case i' < nx of - True -> xs.(fromOrdinal _ i') - False -> ys.(fromOrdinal _ (i' - nx)) From b367ddcc86d6240b0a223289fc90bcd69e51aeb3 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Fri, 18 Dec 2020 22:54:17 -0500 Subject: [PATCH 008/105] Make `splitKey` produce a table instead of a pair, subsuming `splitKey3` etc. Now that we have pattern matching on table constructors, this is just as convenient to use, because the pattern tells the function how many keys to produce. Also add an `Arbitrary` type class for creating dummy data at any type. --- examples/brownian_motion.dx | 2 +- examples/ctc.dx | 16 ++++++++-------- examples/mcmc.dx | 6 +++--- examples/particle-filter.dx | 8 ++++---- examples/particle-swarm-optimizer.dx | 10 +++++----- examples/pi.dx | 2 +- examples/raytrace.dx | 10 +++++----- examples/regression.dx | 2 +- lib/prelude.dx | 24 +++++++++++++++++------- 9 files changed, 45 insertions(+), 35 deletions(-) diff --git a/examples/brownian_motion.dx b/examples/brownian_motion.dx index 4461e76ad..9f9456291 100644 --- a/examples/brownian_motion.dx +++ b/examples/brownian_motion.dx @@ -5,7 +5,7 @@ UnitInterval = Float def bmIter ((key, y, sigma, t):(Key & Float & Float & UnitInterval)) : (Key & Float & Float & UnitInterval) = - (kDraw, kL, kR) = splitKey3 key + [kDraw, kL, kR] = splitKey key t' = abs (t - 0.5) y' = sigma * randn kDraw * (0.5 - t') key' = select (t > 0.5) kL kR diff --git a/examples/ctc.dx b/examples/ctc.dx index 6b63fa852..4e6544eb2 100644 --- a/examples/ctc.dx +++ b/examples/ctc.dx @@ -103,16 +103,16 @@ def randIdxNoZero (n:Type) -> (k:Key) : n = unif = rand k fromOrdinal n $ (1 + (FToI (floor ( unif * IToF ((size n) - 1))))) -vocab = Fin 6 +Vocab = Fin 6 position = Fin 3 -blank = 0@vocab +blank = 0@Vocab -- Create random logits -time = Fin 4 -logits = for i:time j:vocab. (randn $ ixkey2 (newKey 0) i j) +Time = Fin 4 +logits : Time => Vocab => Float = arb $ newKey 0 -- Create random labels -labels = for i:position. randIdxNoZero vocab (newKey (ordinal i)) +labels = for i:position. randIdxNoZero Vocab (newKey (ordinal i)) :p labels > [(1@Fin 6), (5@Fin 6), (5@Fin 6)] @@ -130,14 +130,14 @@ labels = for i:position. randIdxNoZero vocab (newKey (ordinal i)) -- e.g. the summed-over labels should include blanks. -:p sum for i:vocab. +:p sum for i:Vocab. exp $ ctc blank logits [i] > 0.14146839 -:p sum for (i, j):(vocab & vocab). +:p sum for (i, j):(Vocab & Vocab). exp $ ctc blank logits [i, j] > 0.7091234 -:p sum for (i, j, k):(vocab & vocab & vocab). +:p sum for (i, j, k):(Vocab & Vocab & Vocab). exp $ ctc blank logits [i, j, k] > 0.9251011 diff --git a/examples/mcmc.dx b/examples/mcmc.dx index 012370ff1..a33dfc3ca 100644 --- a/examples/mcmc.dx +++ b/examples/mcmc.dx @@ -11,7 +11,7 @@ def runChain (numSamples: Int) (k:Key) : Fin numSamples => a = - (k1, k2) = splitKey k + [k1, k2] = splitKey k fst $ withState (initialize k1) \s. for i:(Fin numSamples). x = step (ixkey k2 i) (get s) @@ -45,7 +45,7 @@ def mhStep (k:Key) (x:d=>Float) : d=>Float = - (k1, k2) = splitKey k + [k1, k2] = splitKey k proposal = x + stepSize .* randnVec k1 propose logProb x proposal k2 @@ -74,7 +74,7 @@ def hmcStep (x:d=>Float) : d=>Float = hamiltonian = \(x, p). logProb x - 0.5 * vdot p p - (k1, k2) = splitKey k + [k1, k2] = splitKey k p = randnVec k1 proposal = leapfrogIntegrate params logProb (x, p) fst $ propose hamiltonian (x, p) proposal k2 diff --git a/examples/particle-filter.dx b/examples/particle-filter.dx index 2bcff1040..f88d8e541 100644 --- a/examples/particle-filter.dx +++ b/examples/particle-filter.dx @@ -13,11 +13,11 @@ def sample (d: Distribution a) (k: Key) : a = def simulate (model: Model s v) (t: Int) (key: Key) : Fin t=>(s & v) = (init, dynamics, observe) = model - (key, subkey) = splitKey key + [key, subkey] = splitKey key s0 = sample init subkey fst $ withState s0 \s_ref . for i. - (k1, k2) = splitKey (ixkey key i) + [k1, k2] = splitKey (ixkey key i) s = get s_ref s_next = sample (dynamics s) k1 v = sample (observe s) k2 @@ -32,13 +32,13 @@ def filter (key: Key) : Fin num_timesteps => a = (init, dynamics, observe) = model - (key, init_key) = splitKey key + [key, init_key] = splitKey key init_particles = for i: (Fin num_particles). sample init (ixkey init_key i) fst $ withState init_particles \p_ref . for t: (Fin num_timesteps). p_prev = get p_ref logLikelihoods = for i. snd (observe p_prev.i) obs.t - (resample_key, dynamics_key) = splitKey (ixkey key t) + [resample_key, dynamics_key] = splitKey (ixkey key t) resampled_idxs = categoricalBatch logLikelihoods resample_key p_resampled = for i. p_prev.(resampled_idxs.i) p_next = for i. fst (dynamics p_resampled.i) (ixkey dynamics_key i) diff --git a/examples/particle-swarm-optimizer.dx b/examples/particle-swarm-optimizer.dx index dc5126e5d..3677c15e7 100644 --- a/examples/particle-swarm-optimizer.dx +++ b/examples/particle-swarm-optimizer.dx @@ -72,7 +72,7 @@ def optimize minbyfst pbests.p (f newPositions.p, newPositions.p) newGbest:(Float & d=>Float) = minbyfst gbest (minimumbyfst newPbests) - (keyG, keyP, keyNext) = splitKey3 keyL + [keyG, keyP, keyNext] = splitKey keyL (gscore, gloc) = gbest plocs = map snd pbests gVel:(np=>d=>Float) = for p i. @@ -87,7 +87,7 @@ def optimize (keyNext,newGbest,newPbests,newPositions,newVelocities) randInit1 = \keyI1. randBounded keyI1 lb ub randInit = \keyI. for p:np. randInit1 $ ixkey keyI p - (keyPos, keyVel, keyLoop) = splitKey3 key + [keyPos, keyVel, keyLoop] = splitKey key initPositions:(np=>d=>Float) = randInit keyPos initVelocities:(np=>d=>Float) = randInit keyVel initPs:(np=>(Float & d=>Float)) = for p. (f initPositions.p, initPositions.p) @@ -103,13 +103,13 @@ Run it for more iterations and result should improve. Which it indeed does. :p optimize 50 10 (newKey 0) rosenbrock2 ([-10.0, -10.0],[20.0, 20.0]) (0.5,0.3,0.4) -> [3.7902741, 14.911411] +> [7.698643e-2, 0.23281813] :p optimize 50 20 (newKey 0) rosenbrock2 ([-10.0, -10.0],[20.0, 20.0]) (0.5,0.3,0.4) -> [1.737732, 3.1227117] +> [0.90125036, 0.75044703] :p optimize 50 100 (newKey 0) rosenbrock2 ([-10.0, -10.0],[20.0, 20.0]) (0.5,0.3,0.4) -> [1.0062296, 1.0128789] +> [0.9990686, 0.9981924] :p optimize 50 1000 (newKey 0) rosenbrock2 ([-10.0, -10.0],[20.0, 20.0]) (0.5,0.3,0.4) > [1.0, 1.0] diff --git a/examples/pi.dx b/examples/pi.dx index 6b2625f9c..6c05e4b86 100644 --- a/examples/pi.dx +++ b/examples/pi.dx @@ -1,7 +1,7 @@ '# Monte Carlo estimates of pi def estimatePiArea (key:Key) : Float = - (k1, k2) = splitKey key + [k1, k2] = splitKey key x = rand k1 y = rand k2 inBounds = (sq x + sq y) < 1.0 diff --git a/examples/raytrace.dx b/examples/raytrace.dx index 40c5a43d1..08d6758b7 100644 --- a/examples/raytrace.dx +++ b/examples/raytrace.dx @@ -67,7 +67,7 @@ def rotateZ (p:Vec 3) (angle:Angle) : Vec 3 = [c*px - s*py, s*px+c*py, pz] def sampleCosineWeightedHemisphere (normal: Vec 3) (k:Key) : Vec 3 = - (k1, k2) = splitKey k + [k1, k2] = splitKey k u1 = rand k1 u2 = rand k2 uu = normalize $ cross normal [0.0, 1.1, 1.1] @@ -204,7 +204,7 @@ def rayDirectRadiance (scene:Scene n) (ray:Ray) : Radiance = HitObj _ _ -> zero def sampleSquare (hw:Float) (k:Key) : Position = - (kx, kz) = splitKey k + [kx, kz] = splitKey k x = randuniform (- hw) hw kx z = randuniform (- hw) hw kz [x, 0.0, z] @@ -241,7 +241,7 @@ def trace (params:Params) (scene:Scene n) (init_ray:Ray) (k:Key) : Color = True -> Done intensity -- TODO: scale etc False -> Done radiance HitObj incidentRay osurf -> - (k1, k2) = splitKey $ hash k i + [k1, k2] = splitKey $ hash k i lightRadiance = sampleLightRadiance scene osurf incidentRay k1 outRayHemisphere = sampleReflection osurf incidentRay k2 newFilter = surfaceFilter filter (snd osurf) @@ -265,7 +265,7 @@ def cameraRays (n:Int) (camera:Camera) : Fin n => Fin n => (Key -> Ray) = ys = reverse $ linspace (Fin n) (neg halfWidth) halfWidth xs = linspace (Fin n) (neg halfWidth) halfWidth for i j. \key. - (kx, ky) = splitKey key + [kx, ky] = splitKey key x = xs.j + randuniform (-pixHalfWidth) pixHalfWidth kx y = ys.i + randuniform (-pixHalfWidth) pixHalfWidth ky (getAt #pos camera, normalize [x, y, neg (getAt #sensorDist camera)]) @@ -279,7 +279,7 @@ def takePicture (params:Params) (scene:Scene m) (camera:Camera) : Image = then rootKey else ixkey (ixkey rootKey i) j sampleRayColor : Key -> Color = \k. - (k1, k2) = splitKey k + [k1, k2] = splitKey k trace params scene (rays.i.j k1) k2 sampleAveraged sampleRayColor (getAt #numSamples params) pixKey MkImage _ _ $ image / mean (for (i,j,k). image.i.j.k) diff --git a/examples/regression.dx b/examples/regression.dx index 0fa9e54eb..fbf484aa1 100644 --- a/examples/regression.dx +++ b/examples/regression.dx @@ -22,7 +22,7 @@ def solve (mat:m=>m=>Float) (b:m=>Float) : m=>Float = Nx = Fin 100 noise = 0.1 -(k1, k2) = splitKey (newKey 0) +[k1, k2] = splitKey (newKey 0) def trueFun (x:Float) : Float = x + sin (5.0 * x) diff --git a/lib/prelude.dx b/lib/prelude.dx index d2d41bd2e..2b05011bc 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -483,15 +483,10 @@ def hash (x:Key) (y:Int32) : Key = y64 = IToI64 y %ffi threefry2x32 Int64 x y64 def newKey (x:Int) : Key = hash (IToI64 0) x -def splitKey (k:Key) : (Key & Key) = (hash k 0, hash k 1) -def splitKey3 (k:Key) : (Key & Key & Key) = - (k1, k') = splitKey k - (k2, k3) = splitKey k' - (k1, k2, k3) - def many (f:Key->a) (k:Key) (i:n) : a = f (hash k (ordinal i)) def ixkey (k:Key) (i:n) : Key = hash k (ordinal i) def ixkey2 (k:Key) (i:n) (j:m) : Key = hash (hash k (ordinal i)) (ordinal j) +def splitKey (n:Int) ?-> (k:Key) : Fin n => Key = for i. ixkey k i def rand (k:Key) : Float = F64ToF $ %ffi randunif Float64 k def randVec (n:Int) (f: Key -> a) (k: Key) : Fin n => a = for i:(Fin n). f (ixkey k i) @@ -500,7 +495,7 @@ def randMat (n:Int) (m:Int) (f: Key -> a) (k: Key) : Fin n => Fin m => a = for i j. f (ixkey2 k i j) def randn (k:Key) : Float = - (k1, k2) = splitKey k + [k1, k2] = splitKey k u1 = rand k1 u2 = rand k2 sqrt ((-2.0) * log u1) * cos (2.0 * pi * u2) @@ -524,6 +519,21 @@ def cumSum (xs: n=>Float) : n=>Float = total := newTotal newTotal +interface Arbitrary a:Type where + arb : Key -> a + +instance float32Arb : Arbitrary Float32 where + arb = randn + +instance in32Arb : Arbitrary Int32 where + arb = \key. FToI $ randn key * 5.0 + +instance tabArb : Arbitrary a ?=> Arbitrary (n=>a) where + arb = \key. for i. arb $ ixkey key i + +instance finArb : n:Int ?-> Arbitrary (Fin n) where + arb = randIdx + 'min / max etc def minBy (_:Ord o) ?=> (f:a->o) (x:a) (y:a) : a = select (f x < f y) x y From d8540257c8e00c9d9e86f4b53190052b5a145b68 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sat, 19 Dec 2020 16:57:04 -0500 Subject: [PATCH 009/105] [README] Edit build dependency instructions. (#360) Use consistent formatting and style. --- README.md | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 2c2213c61..0b6e6fac3 100644 --- a/README.md +++ b/README.md @@ -28,9 +28,12 @@ development. Expect monstrous bugs and razor-sharp edges. Contributions welcome! * Install [stack](https://www.haskellstack.org) * Install LLVM 9 - * `apt-get install llvm-9-dev` on Ubuntu/Debian, - * `brew install llvm@9` on macOS, and ensure it is on your `PATH` e.g. via `export PATH="$(brew --prefix llvm@9)/bin:$PATH"` before building. - * Install libpng (often included by default in *nix) + * Ubuntu/Debian: `apt-get install llvm-9-dev` + * macOS: `brew install llvm@9` + * Make sure `llvm@9` is on your `PATH` before building. Example: `export PATH="$(brew --prefix llvm@9)/bin:$PATH"` + * Install libpng (often included by default in *nix platforms) + * Ubuntu/Debian: `apt-get install libpng-dev` + * macOS: `brew install libpng` ## Building From 5dc6a6f7b89d2ad371994b1e56fe96d680ffde0b Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Sat, 19 Dec 2020 23:44:20 -0500 Subject: [PATCH 010/105] Add an IO effect, modeled as `{State World}` with a special `World` token. The goal is to let us sequence effectful FFI calls, without worrying about reordering and DCE, and then wrap the sequence in `unsafePerformIO` to expose it as a pure function. --- lib/prelude.dx | 12 ++++++++++++ src/dex.hs | 2 +- src/lib/Autodiff.hs | 4 ++-- src/lib/Imp.hs | 11 +++++++---- src/lib/Inference.hs | 6 ++++-- src/lib/Interpreter.hs | 2 +- src/lib/Parser.hs | 5 +++-- src/lib/Simplify.hs | 3 +++ src/lib/Syntax.hs | 13 +++++++++++-- src/lib/Type.hs | 11 ++++++++++- src/lib/dexrt.cpp | 8 ++++++++ 11 files changed, 62 insertions(+), 15 deletions(-) diff --git a/lib/prelude.dx b/lib/prelude.dx index 2b05011bc..7edb0f6c8 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -249,6 +249,9 @@ def withState def explicitAction (h':Type) (ref:Ref h' s) : {State h'|eff} a = action ref %runState init explicitAction +def unsafeIO (f: Unit -> {State World|eff} a) : {|eff} a = + %runIO f + 'Type classes data Eq a:Type = MkEq (a -> a -> Bool) @@ -893,6 +896,15 @@ instance showFloat64 : Show Float64 where AsList n $ for i:(Fin n). %ptrLoad (%ptrOffset ptr (ordinal i)) +def writeStdErr (s:String) : {State World} Unit = + (AsList n cs) = s + %ffi IO writeToStdErr Int n (%getPtr cs) + () + +def throwMsg (s:String) : a = unsafeIO \(). + writeStdErr s + %throwError a + -- pipe-like reverse function application def (|>) (x:a) (f: a -> b) : b = f x diff --git a/src/dex.hs b/src/dex.hs index 9c36c9ae8..c0b73f6df 100644 --- a/src/dex.hs +++ b/src/dex.hs @@ -62,7 +62,7 @@ runMode evalMode preludeFile opts = do exportFunctions objPath exportedFuns env opts evalPrelude :: EvalConfig -> Maybe FilePath -> IO TopEnv -evalPrelude opts fname = flip execStateT mempty $ do +evalPrelude opts fname = flip execStateT initTopEnv $ do source <- case fname of Nothing -> return $ preludeSource Just path -> liftIO $ readFile path diff --git a/src/lib/Autodiff.hs b/src/lib/Autodiff.hs index f17115bfd..2d5ae353c 100644 --- a/src/lib/Autodiff.hs +++ b/src/lib/Autodiff.hs @@ -167,7 +167,7 @@ linearizeOp op = case op of RecordSplit vs r -> (RecordSplit <$> traverse la vs <*> la r) `bindLin` emitOp VariantLift ts v -> (VariantLift ts <$> la v) `bindLin` emitOp VariantSplit ts v -> (VariantSplit ts <$> la v) `bindLin` emitOp - FFICall _ _ _ -> error $ "Can't differentiate through an FFI call" + FFICall _ _ _ _ -> error $ "Can't differentiate through an FFI call" where emitDiscrete = if isTrivialForAD (Op op) then LinA $ withZeroTangent <$> emitOp op @@ -625,7 +625,7 @@ transposeOp op ct = case op of ToOrdinal _ -> notLinear IdxSetSize _ -> notLinear ThrowError _ -> notLinear - FFICall _ _ _ -> notLinear + FFICall _ _ _ _ -> notLinear DataConTag _ -> notLinear ToEnum _ _ -> notLinear where diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index a9c967ed0..77b5ecb59 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -98,9 +98,10 @@ toImpModule env backend cc entryName argBinders maybeDest block = do requiredFunctions :: HasVars a => Scope -> a -> [(Name, Atom)] requiredFunctions scope expr = - for (transitiveClosure getParents immediateParents) $ \fname -> do - let (_, LetBound _ (Atom f)) = scope ! fname - (fname, f) + flip foldMap (transitiveClosure getParents immediateParents) $ \fname -> + case scope ! fname of + (_, LetBound _ (Atom f)) -> [(fname, f)] + _ -> [] where getParents :: Name -> [Name] getParents fname = envNames $ freeVars $ scope ! fname @@ -314,7 +315,7 @@ toImpOp (maybeDest, op) = case op of _ -> error $ "Not a data constructor: " ++ pprint con ToEnum ~ty@(TypeCon (DataDef _ _ cons) _) i -> returnVal $ Con $ SumAsProd ty i (map (const []) cons) - FFICall name returnTy xs -> do + FFICall _ name returnTy xs -> do let returnTys = fromScalarOrPairType returnTy let xTys = map (fromScalarType . getType) xs f <- emitFFIFunction name xTys returnTys @@ -507,6 +508,8 @@ toImpHof env (maybeDest, hof) = do copyAtom sDest =<< impSubst env s void $ translateBlock (env <> ref @> sDest) (Just aDest, body) PairVal <$> destToAtom aDest <*> destToAtom sDest + RunIO ~(Lam (Abs _ (_, body))) -> do + translateBlock env (maybeDest, body) Linearize _ -> error "Unexpected Linearize" Transpose _ -> error "Unexpected Transpose" diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index bd398e8c3..9912b3b3d 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -308,7 +308,7 @@ lookupSourceVar v = do Nothing -> do scope <- getScope let v' = asGlobal $ varName v - case envLookup scope (v':>()) of + case envLookup scope v' of Just (_, DataBoundTypeCon def ) -> return $ TypeCon def [] Just (_, DataBoundDataCon def con) -> return $ DataCon def [] con [] Just (ty, _) -> return $ Var $ v':>ty @@ -394,7 +394,9 @@ checkULam (p, ann) body piTy = do checkUEff :: EffectRow -> UInferM EffectRow checkUEff (EffectRow effs t) = do - effs' <- forM effs $ \(effName, region) -> (effName,) <$> lookupVarName TyKind region + effs' <- forM effs $ \(effName, region) -> do + (Var (v:>TyKind)) <- lookupSourceVar (region:>()) + return (effName, v) t' <- forM t $ \tv -> lookupVarName EffKind tv return $ EffectRow effs' t' where diff --git a/src/lib/Interpreter.hs b/src/lib/Interpreter.hs index f9f43fa18..db8b1df7f 100644 --- a/src/lib/Interpreter.hs +++ b/src/lib/Interpreter.hs @@ -90,7 +90,7 @@ evalOp expr = case expr of ScalarUnOp op x -> return $ case op of FNeg -> applyFloatUnOp (0-) x _ -> error $ "Not implemented: " ++ pprint expr - FFICall name _ args -> return $ case name of + FFICall _ name _ args -> return $ case name of "randunif" -> Float64Val $ c_unif x where [Int64Val x] = args "threefry2x32" -> Int64Val $ c_threefry x y where [Int64Val x, Int64Val y] = args _ -> error $ "FFI function not recognized: " ++ name diff --git a/src/lib/Parser.hs b/src/lib/Parser.hs index c253cc25b..7df13f440 100644 --- a/src/lib/Parser.hs +++ b/src/lib/Parser.hs @@ -487,7 +487,7 @@ effects :: Parser EffectRow effects = braces someEffects <|> return Pure where someEffects = do - effs <- liftM2 (,) effectName lowerName `sepBy` sym "," + effs <- liftM2 (,) effectName (lowerName <|> upperName) `sepBy` sym "," v <- optional $ symbol "|" >> lowerName return $ EffectRow effs v @@ -678,10 +678,11 @@ uPrim = withSrc $ do s <- primName case s of "ffi" -> do + mayDoIO <- (symbol "IO" $> True) <|> return False f <- lexeme $ some nameTailChar retTy <- leafExpr args <- some leafExpr - return $ UPrimExpr $ OpExpr $ FFICall f retTy args + return $ UPrimExpr $ OpExpr $ FFICall mayDoIO f retTy args _ -> case strToPrimName s of Just prim -> UPrimExpr <$> traverse (const leafExpr) prim Nothing -> fail $ "Unrecognized primitive: " ++ s diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index 05a87b191..f9175dc8c 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -475,6 +475,9 @@ simplifyHof hof = case hof of (ans, sOut) <- fromPair =<< (emit $ Hof $ RunState s' lam') ans' <- applyRecon recon ans return $ PairVal ans' sOut + RunIO lam -> do + ~(lam', Nothing) <- simplifyLam lam + emit $ Hof $ RunIO lam' where applyRecon Nothing x = return x applyRecon (Just f) x = f x diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index 04052e457..29ae13eb5 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -46,7 +46,7 @@ module Syntax ( subst, deShadow, scopelessSubst, absArgType, applyAbs, makeAbs, applyNaryAbs, applyDataDefParams, freshSkolemVar, IndexStructure, mkConsList, mkConsListTy, fromConsList, fromConsListTy, extendEffRow, - getProjection, + getProjection, theWorld, initTopEnv, varType, binderType, isTabTy, LogLevel (..), IRVariant (..), applyIntBinOp, applyIntCmpOp, applyFloatBinOp, applyFloatUnOp, getIntLit, getFloatLit, sizeOf, vectorWidth, @@ -319,7 +319,7 @@ data PrimOp e = | IndexRef e e | FstRef e | SndRef e - | FFICall String e [e] + | FFICall Bool String e [e] -- bool indicates it may do IO | Inject e | PtrOffset e e | PtrLoad e @@ -363,6 +363,7 @@ data PrimHof e = | RunReader e e | RunWriter e | RunState e e + | RunIO e | Linearize e | Transpose e | PTileReduce e e -- index set, thread body @@ -445,6 +446,13 @@ instance Eq EffectRow where EffectRow effs t == EffectRow effs' t' = sort effs == sort effs' && t == t' +theWorld :: Name +theWorld = GlobalName "World" + +initTopEnv :: TopEnv +initTopEnv = + (theWorld:>TyKind) @> (TyKind, LamBound ImplicitArrow) + -- === top-level constructs === data SourceBlock = SourceBlock @@ -1505,6 +1513,7 @@ builtinNames = M.fromList , ("runReader" , HofExpr $ RunReader () ()) , ("runWriter" , HofExpr $ RunWriter ()) , ("runState" , HofExpr $ RunState () ()) + , ("runIO" , HofExpr $ RunIO ()) , ("tiled" , HofExpr $ Tile 0 () ()) , ("tiledd" , HofExpr $ Tile 1 () ()) , ("TyKind" , TCExpr $ TypeKind) diff --git a/src/lib/Type.hs b/src/lib/Type.hs index c2c4faeff..4d4af8b53 100644 --- a/src/lib/Type.hs +++ b/src/lib/Type.hs @@ -276,6 +276,7 @@ exprEffs expr = case expr of MAsk -> S.singleton (Reader, h) MTell _ -> S.singleton (Writer, h) where RefTy (Var (h:>_)) _ = getType ref + FFICall True _ _ _ -> S.singleton (State, theWorld) _ -> NoEffects Hof hof -> case hof of For _ f -> functionEffs f @@ -287,6 +288,8 @@ exprEffs expr = case expr of RunWriter f -> handleRunner Writer f RunState _ f -> handleRunner State f PTileReduce _ _ -> mempty + RunIO ~(Lam (Abs _ (PlainArrow (EffectRow effs Nothing), _))) -> + S.delete (State, theWorld) $ S.fromList effs Case _ alts _ -> foldMap (\(Abs _ block) -> blockEffs block) alts where handleRunner effName ~(BinaryFunVal (Bind (h:>_)) _ (EffectRow effs Nothing) _) = @@ -436,6 +439,7 @@ instance CoreVariant (PrimHof a) where RunReader _ _ -> alwaysAllowed RunWriter _ -> alwaysAllowed RunState _ _ -> alwaysAllowed + RunIO _ -> alwaysAllowed Linearize _ -> goneBy Simp Transpose _ -> goneBy Simp Tile _ _ _ -> alwaysAllowed @@ -665,13 +669,14 @@ typeCheckOp op = case op of UnsafeFromOrdinal ty i -> ty|:TyKind >> i|:IdxRepTy $> ty ToOrdinal i -> typeCheck i $> IdxRepTy IdxSetSize i -> typeCheck i $> IdxRepTy - FFICall _ ansTy args -> do + FFICall mayDoIO _ ansTy args -> do forM_ args $ \arg -> do argTy <- typeCheck arg case argTy of BaseTy _ -> return () _ -> throw TypeErr $ "All arguments of FFI calls have to be " ++ "fixed-width base types, but got: " ++ pprint argTy + when mayDoIO $ declareEff (State, Just theWorld) return ansTy Inject i -> do TC tc <- typeCheck i @@ -856,6 +861,10 @@ typeCheckHof hof = case hof of (resultTy, stateTy) <- checkAction State f s |: stateTy return $ PairTy resultTy stateTy + RunIO f -> do + FunTy _ eff resultTy <- typeCheck f + extendAllowedEffect (State, theWorld) $ declareEffs eff + return resultTy checkAction :: EffectName -> Atom -> TypeM (Type, Type) checkAction effName f = do diff --git a/src/lib/dexrt.cpp b/src/lib/dexrt.cpp index c3ef5780b..d99e189df 100644 --- a/src/lib/dexrt.cpp +++ b/src/lib/dexrt.cpp @@ -184,6 +184,14 @@ void doubleVec(char **resultPtr, int32_t n, float* xs) { *result2Ptr = p2; } +int32_t writeToStdErr(int32_t numBytes, char* bytes) { + fwrite(bytes, 1, (size_t) numBytes, stderr); + fprintf(stderr, "\n"); + fflush(stderr); + return 0; +} + + void encodePNG(char **resultPtr, int8_t* pixels, int32_t width, int32_t height) { png_image img; memset(&img, 0, sizeof(img)); From 6acdf92f3f50aa38af83cbdff2e2c32337ec99d7 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Sun, 20 Dec 2020 21:05:53 -0500 Subject: [PATCH 011/105] Expose some unix file operations using the IO effect. There's still a lot to do around errors and thread safety. --- lib/io.dx | 79 +++++++++++++++++++++++++++++++++++++++++++++++ makefile | 2 +- tests/io-tests.dx | 12 +++++++ 3 files changed, 92 insertions(+), 1 deletion(-) create mode 100644 lib/io.dx create mode 100644 tests/io-tests.dx diff --git a/lib/io.dx b/lib/io.dx new file mode 100644 index 000000000..84b7439d2 --- /dev/null +++ b/lib/io.dx @@ -0,0 +1,79 @@ + +'File system operations + +FilePath : Type = String +data CString = MkCString CharPtr + +data StreamMode = + ReadMode + WriteMode + +data Stream mode:StreamMode = MkStream CharPtr + +-- TODO: check the string contains no nulls +def asCString (s:String) : CString = + (AsList _ s') = s <> (AsList _ "\NUL") + MkCString %getPtr s' + +def fopen (path:String) (mode:StreamMode) : {State World} (Stream mode) = + modeStr = AsList _ case mode of + ReadMode -> "r" + WriteMode -> "w" + (MkCString path') = asCString path + (MkCString mode') = asCString modeStr + MkStream $ %ffi IO fopen CharPtr path' mode' + +def fclose (mode:StreamMode) ?-> (stream:Stream mode) : {State World} Unit = + (MkStream stream') = stream + %ffi IO fclose Int64 stream' + () + +def fwrite (stream:Stream WriteMode) (s:String) : {State World} Unit = + (MkStream stream') = stream + (AsList n s') = s + ptr = %getPtr s' + ans = %ffi IO fwrite Int64 ptr (IToI64 1) (IToI64 n) stream' + () + +def fread (stream:Stream ReadMode) : {State World} String = + (MkStream stream') = stream + -- TODO: do the malloc and pointer reads/writes in the {IO World} effect + -- TODO: allow reading longer files! + n = 4096 + buffer = for i:(Fin n). '\NUL' + ptr = %getPtr buffer + numRead = I64ToI $ %ffi IO fread Int64 ptr (IToI64 1) (IToI64 n) stream' + AsList numRead (for i. %ptrLoad (%ptrOffset ptr (ordinal i))) + +def deleteFile (f:FilePath) : {State World} Unit = + (MkCString f') = asCString f + %ffi IO remove Int64 f' + () + +def withFile (f:FilePath) (mode:StreamMode) + (action: Stream mode -> {State World} a) + : {State World} a = + stream = fopen f mode + result = action stream + fclose stream + result + +def writeFile (f:FilePath) (s:String) : {State World} Unit = + withFile f WriteMode \stream. fwrite stream s + +def readFile (f:FilePath) : {State World} String = + withFile f ReadMode \stream. fread stream + +def writeTemp (s:String) : {State World} FilePath = + -- TODO: Make this properly atomic. It can fail if another process creates a + -- file with same name after we ask for the name and before we create it. + template = "/tmp/dex-XXXXXX\NUL" + ptr = %getPtr template + %ffi IO mktemp CharPtr ptr + AsList 15 for i. %ptrLoad (%ptrOffset ptr (ordinal i)) + +def withTempFile (action: FilePath -> {State World} a) : {State World} a = + tmpFile = writeTemp (AsList _ []) + result = action tmpFile + deleteFile tmpFile + result diff --git a/makefile b/makefile index a2ff6b532..ef8adb0f2 100644 --- a/makefile +++ b/makefile @@ -86,7 +86,7 @@ example-names = mandelbrot pi sierpinski \ isomorphisms ode-integrator linear_algebra fluidsim test-names = uexpr-tests adt-tests type-tests eval-tests show-tests \ - shadow-tests monad-tests \ + shadow-tests monad-tests io-tests \ ad-tests parser-tests serialize-tests \ record-variant-tests simple-include-test \ typeclass-tests complex-tests trig-tests diff --git a/tests/io-tests.dx b/tests/io-tests.dx new file mode 100644 index 000000000..db26bae5a --- /dev/null +++ b/tests/io-tests.dx @@ -0,0 +1,12 @@ + +include "io.dx" + +:p unsafeIO \(). + withTempFile \fname. + withFile fname WriteMode \stream. + fwrite stream $ AsList _ "lorem ipsum\n" + fwrite stream $ AsList _ "dolor sit amet\n" + readFile fname +> (AsList 27 "lorem ipsum +> dolor sit amet +> ") From 433916d25c13b28a80d9db6644888f5ac3883e90 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Sun, 20 Dec 2020 22:10:16 -0500 Subject: [PATCH 012/105] Treat all FFI functions as potentially IO-causing. This just forces us to explicitly wrap with `unsafeIO` when that's what we mean. Previously it was too easy to forget the IO tag. --- lib/diagram.dx | 2 +- lib/io.dx | 12 ++++++------ lib/png.dx | 2 +- lib/prelude.dx | 14 +++++++------- src/lib/Autodiff.hs | 4 ++-- src/lib/Imp.hs | 2 +- src/lib/Interpreter.hs | 2 +- src/lib/Parser.hs | 3 +-- src/lib/Syntax.hs | 2 +- src/lib/Type.hs | 6 +++--- 10 files changed, 24 insertions(+), 25 deletions(-) diff --git a/lib/diagram.dx b/lib/diagram.dx index e19fd325b..abcec481c 100644 --- a/lib/diagram.dx +++ b/lib/diagram.dx @@ -12,7 +12,7 @@ data Geom = -- TODO: replace with `Fin 3 => Word8` when we fix #348 HtmlColor : Type = (Word8 & Word8 & Word8) -def showHex (x:Int32) : String = +def showHex (x:Int32) : String = unsafeIO \(). (n, ptr) = %ffi showHex (Int32 & CharPtr) x AsList n $ for i:(Fin n). %ptrLoad (%ptrOffset ptr (ordinal i)) diff --git a/lib/io.dx b/lib/io.dx index 84b7439d2..78391bd90 100644 --- a/lib/io.dx +++ b/lib/io.dx @@ -21,18 +21,18 @@ def fopen (path:String) (mode:StreamMode) : {State World} (Stream mode) = WriteMode -> "w" (MkCString path') = asCString path (MkCString mode') = asCString modeStr - MkStream $ %ffi IO fopen CharPtr path' mode' + MkStream $ %ffi fopen CharPtr path' mode' def fclose (mode:StreamMode) ?-> (stream:Stream mode) : {State World} Unit = (MkStream stream') = stream - %ffi IO fclose Int64 stream' + %ffi fclose Int64 stream' () def fwrite (stream:Stream WriteMode) (s:String) : {State World} Unit = (MkStream stream') = stream (AsList n s') = s ptr = %getPtr s' - ans = %ffi IO fwrite Int64 ptr (IToI64 1) (IToI64 n) stream' + ans = %ffi fwrite Int64 ptr (IToI64 1) (IToI64 n) stream' () def fread (stream:Stream ReadMode) : {State World} String = @@ -42,12 +42,12 @@ def fread (stream:Stream ReadMode) : {State World} String = n = 4096 buffer = for i:(Fin n). '\NUL' ptr = %getPtr buffer - numRead = I64ToI $ %ffi IO fread Int64 ptr (IToI64 1) (IToI64 n) stream' + numRead = I64ToI $ %ffi fread Int64 ptr (IToI64 1) (IToI64 n) stream' AsList numRead (for i. %ptrLoad (%ptrOffset ptr (ordinal i))) def deleteFile (f:FilePath) : {State World} Unit = (MkCString f') = asCString f - %ffi IO remove Int64 f' + %ffi remove Int64 f' () def withFile (f:FilePath) (mode:StreamMode) @@ -69,7 +69,7 @@ def writeTemp (s:String) : {State World} FilePath = -- file with same name after we ask for the name and before we create it. template = "/tmp/dex-XXXXXX\NUL" ptr = %getPtr template - %ffi IO mktemp CharPtr ptr + %ffi mktemp CharPtr ptr AsList 15 for i. %ptrLoad (%ptrOffset ptr (ordinal i)) def withTempFile (action: FilePath -> {State World} a) : {State World} a = diff --git a/lib/png.dx b/lib/png.dx index 8fe4dbd2f..f4898cd6c 100644 --- a/lib/png.dx +++ b/lib/png.dx @@ -93,7 +93,7 @@ def base64Decode (s:String) : Maybe String = Html : Type = List Char -def makePNG (img:n=>m=>(Fin 3)=>Word8) : List Byte = +def makePNG (img:n=>m=>(Fin 3)=>Word8) : List Byte = unsafeIO \(). (AsList _ imgFlat) = toList for (i,j,k). img.i.j.k (n, ptr) = (%ffi encodePNG (Int & CharPtr) (%getPtr imgFlat) (size m) (size n)) diff --git a/lib/prelude.dx b/lib/prelude.dx index 7edb0f6c8..82e19e130 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -482,7 +482,7 @@ def inner (x:n=>Float) (mat:n=>m=>Float) (y:m=>Float) : Float = -- TODO: newtype Key = Int64 -def hash (x:Key) (y:Int32) : Key = +def hash (x:Key) (y:Int32) : Key = unsafeIO \(). y64 = IToI64 y %ffi threefry2x32 Int64 x y64 def newKey (x:Int) : Key = hash (IToI64 0) x @@ -490,7 +490,7 @@ def many (f:Key->a) (k:Key) (i:n) : a = f (hash k (ordinal i)) def ixkey (k:Key) (i:n) : Key = hash k (ordinal i) def ixkey2 (k:Key) (i:n) (j:m) : Key = hash (hash k (ordinal i)) (ordinal j) def splitKey (n:Int) ?-> (k:Key) : Fin n => Key = for i. ixkey k i -def rand (k:Key) : Float = F64ToF $ %ffi randunif Float64 k +def rand (k:Key) : Float = unsafeIO \(). F64ToF $ %ffi randunif Float64 k def randVec (n:Int) (f: Key -> a) (k: Key) : Fin n => a = for i:(Fin n). f (ixkey k i) @@ -873,32 +873,32 @@ interface Show a:Type where show : a -> String instance showInt32 : Show Int32 where - show = \x: Int32. + show = \x: Int32. unsafeIO \(). (n, ptr) = %ffi showInt32 (Int32 & CharPtr) x AsList n $ for i:(Fin n). %ptrLoad (%ptrOffset ptr (ordinal i)) instance showInt64 : Show Int64 where - show = \x: Int64. + show = \x: Int64. unsafeIO \(). (n, ptr) = %ffi showInt64 (Int32 & CharPtr) x AsList n $ for i:(Fin n). %ptrLoad (%ptrOffset ptr (ordinal i)) instance showFloat32 : Show Float32 where - show = \x: Float32. + show = \x: Float32.unsafeIO \(). (n, ptr) = %ffi showFloat32 (Int32 & CharPtr) x AsList n $ for i:(Fin n). %ptrLoad (%ptrOffset ptr (ordinal i)) instance showFloat64 : Show Float64 where - show = \x: Float64. + show = \x: Float64.unsafeIO \(). (n, ptr) = %ffi showFloat64 (Int32 & CharPtr) x AsList n $ for i:(Fin n). %ptrLoad (%ptrOffset ptr (ordinal i)) def writeStdErr (s:String) : {State World} Unit = (AsList n cs) = s - %ffi IO writeToStdErr Int n (%getPtr cs) + %ffi writeToStdErr Int n (%getPtr cs) () def throwMsg (s:String) : a = unsafeIO \(). diff --git a/src/lib/Autodiff.hs b/src/lib/Autodiff.hs index 2d5ae353c..e1b9b4578 100644 --- a/src/lib/Autodiff.hs +++ b/src/lib/Autodiff.hs @@ -167,7 +167,7 @@ linearizeOp op = case op of RecordSplit vs r -> (RecordSplit <$> traverse la vs <*> la r) `bindLin` emitOp VariantLift ts v -> (VariantLift ts <$> la v) `bindLin` emitOp VariantSplit ts v -> (VariantSplit ts <$> la v) `bindLin` emitOp - FFICall _ _ _ _ -> error $ "Can't differentiate through an FFI call" + FFICall _ _ _ -> error $ "Can't differentiate through an FFI call" where emitDiscrete = if isTrivialForAD (Op op) then LinA $ withZeroTangent <$> emitOp op @@ -625,7 +625,7 @@ transposeOp op ct = case op of ToOrdinal _ -> notLinear IdxSetSize _ -> notLinear ThrowError _ -> notLinear - FFICall _ _ _ _ -> notLinear + FFICall _ _ _ -> notLinear DataConTag _ -> notLinear ToEnum _ _ -> notLinear where diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index 77b5ecb59..8337bbbd5 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -315,7 +315,7 @@ toImpOp (maybeDest, op) = case op of _ -> error $ "Not a data constructor: " ++ pprint con ToEnum ~ty@(TypeCon (DataDef _ _ cons) _) i -> returnVal $ Con $ SumAsProd ty i (map (const []) cons) - FFICall _ name returnTy xs -> do + FFICall name returnTy xs -> do let returnTys = fromScalarOrPairType returnTy let xTys = map (fromScalarType . getType) xs f <- emitFFIFunction name xTys returnTys diff --git a/src/lib/Interpreter.hs b/src/lib/Interpreter.hs index db8b1df7f..f9f43fa18 100644 --- a/src/lib/Interpreter.hs +++ b/src/lib/Interpreter.hs @@ -90,7 +90,7 @@ evalOp expr = case expr of ScalarUnOp op x -> return $ case op of FNeg -> applyFloatUnOp (0-) x _ -> error $ "Not implemented: " ++ pprint expr - FFICall _ name _ args -> return $ case name of + FFICall name _ args -> return $ case name of "randunif" -> Float64Val $ c_unif x where [Int64Val x] = args "threefry2x32" -> Int64Val $ c_threefry x y where [Int64Val x, Int64Val y] = args _ -> error $ "FFI function not recognized: " ++ name diff --git a/src/lib/Parser.hs b/src/lib/Parser.hs index 7df13f440..001df0cee 100644 --- a/src/lib/Parser.hs +++ b/src/lib/Parser.hs @@ -678,11 +678,10 @@ uPrim = withSrc $ do s <- primName case s of "ffi" -> do - mayDoIO <- (symbol "IO" $> True) <|> return False f <- lexeme $ some nameTailChar retTy <- leafExpr args <- some leafExpr - return $ UPrimExpr $ OpExpr $ FFICall mayDoIO f retTy args + return $ UPrimExpr $ OpExpr $ FFICall f retTy args _ -> case strToPrimName s of Just prim -> UPrimExpr <$> traverse (const leafExpr) prim Nothing -> fail $ "Unrecognized primitive: " ++ s diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index 29ae13eb5..b8b75945c 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -319,7 +319,7 @@ data PrimOp e = | IndexRef e e | FstRef e | SndRef e - | FFICall Bool String e [e] -- bool indicates it may do IO + | FFICall String e [e] | Inject e | PtrOffset e e | PtrLoad e diff --git a/src/lib/Type.hs b/src/lib/Type.hs index 4d4af8b53..11b0de085 100644 --- a/src/lib/Type.hs +++ b/src/lib/Type.hs @@ -276,7 +276,7 @@ exprEffs expr = case expr of MAsk -> S.singleton (Reader, h) MTell _ -> S.singleton (Writer, h) where RefTy (Var (h:>_)) _ = getType ref - FFICall True _ _ _ -> S.singleton (State, theWorld) + FFICall _ _ _ -> S.singleton (State, theWorld) _ -> NoEffects Hof hof -> case hof of For _ f -> functionEffs f @@ -669,14 +669,14 @@ typeCheckOp op = case op of UnsafeFromOrdinal ty i -> ty|:TyKind >> i|:IdxRepTy $> ty ToOrdinal i -> typeCheck i $> IdxRepTy IdxSetSize i -> typeCheck i $> IdxRepTy - FFICall mayDoIO _ ansTy args -> do + FFICall _ ansTy args -> do forM_ args $ \arg -> do argTy <- typeCheck arg case argTy of BaseTy _ -> return () _ -> throw TypeErr $ "All arguments of FFI calls have to be " ++ "fixed-width base types, but got: " ++ pprint argTy - when mayDoIO $ declareEff (State, Just theWorld) + declareEff (State, Just theWorld) return ansTy Inject i -> do TC tc <- typeCheck i From 1d811352be66f632904d76430660de1057b82b47 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 11 Dec 2020 16:18:27 +0000 Subject: [PATCH 013/105] Make it possible to export JITed Dex functions to Python The API is still far from ergonomic, because it returns raw function pointers, with signatures derived from pretty obscure Haskell code. The next step will be to clean up the expected signatures and handle all the destination allocation and dereferencing internally. This is a big-ish change, mostly because it required a further refactor of our compilation pipeline, and a few additions to the llvm-hs APIs (non-bracketed context and module management). Until those changes are merged upstream, I changed `stack.yaml` to point to the PR branch in my fork. --- dex.cabal | 9 +- python/dex/__init__.py | 100 ++++++++++-------- python/tests/api_test.py | 22 ++++ src/Dex/Foreign/API.hs | 41 +++++++ .../API.hs => Dex/Foreign/Context.hs} | 87 ++------------- src/Dex/Foreign/JIT.hs | 80 ++++++++++++++ src/Dex/Foreign/Serialize.hs | 78 ++++++++++++++ src/Dex/Foreign/Util.hs | 16 +++ src/{foreign => Dex/Foreign}/rts.c | 0 src/lib/LLVM/JIT.hs | 40 ++++--- src/lib/LLVM/Shims.hs | 16 +-- src/lib/LLVMExec.hs | 75 ++++++------- src/lib/TopLevel.hs | 59 ++++++----- stack-macos.yaml | 7 +- stack.yaml | 9 +- 15 files changed, 411 insertions(+), 228 deletions(-) create mode 100644 src/Dex/Foreign/API.hs rename src/{foreign/API.hs => Dex/Foreign/Context.hs} (54%) create mode 100644 src/Dex/Foreign/JIT.hs create mode 100644 src/Dex/Foreign/Serialize.hs create mode 100644 src/Dex/Foreign/Util.hs rename src/{foreign => Dex/Foreign}/rts.c (100%) diff --git a/dex.cabal b/dex.cabal index 8fe9d6835..8ecdef23b 100644 --- a/dex.cabal +++ b/dex.cabal @@ -84,10 +84,11 @@ executable dex foreign-library Dex type: native-shared - other-modules: API - build-depends: base, dex, dex-resources, mtl - hs-source-dirs: src/foreign - c-sources: src/foreign/rts.c + other-modules: Dex.Foreign.API, Dex.Foreign.Util, Dex.Foreign.JIT + , Dex.Foreign.Context, Dex.Foreign.Serialize + build-depends: base, dex, dex-resources, mtl, llvm-hs, containers + hs-source-dirs: src/ + c-sources: src/Dex/Foreign/rts.c cc-options: -std=c11 -fPIC ghc-options: -Wall -fPIC default-language: Haskell2010 diff --git a/python/dex/__init__.py b/python/dex/__init__.py index 6850c436d..e60ffaea6 100644 --- a/python/dex/__init__.py +++ b/python/dex/__init__.py @@ -36,56 +36,49 @@ class CRectArray(ctypes.Structure): class HsAtom(ctypes.Structure): pass class HsContext(ctypes.Structure): pass - -_init = lib.dexInit -_init.restype = None -_init.argtypes = [] - -_fini = lib.dexFini -_fini.restype = None -_fini.argtypes = [] - -_create_context = lib.dexCreateContext -_create_context.restype = ctypes.POINTER(HsContext) -_create_context.argtypes = [] - -_destroy_context = lib.dexDestroyContext -_destroy_context.restype = None -_destroy_context.argtypes = [ctypes.POINTER(HsContext)] - -_print = lib.dexPrint -_print.restype = ctypes.c_char_p -_print.argtypes = [ctypes.POINTER(HsAtom)] - -_insert = lib.dexInsert -_insert.restype = ctypes.POINTER(HsContext) -_insert.argtypes = [ctypes.POINTER(HsContext), ctypes.c_char_p, ctypes.POINTER(HsAtom)] - -_eval = lib.dexEval -_eval.restype = ctypes.POINTER(HsContext) -_eval.argtypes = [ctypes.POINTER(HsContext), ctypes.c_char_p] - -_evalExpr = lib.dexEvalExpr -_evalExpr.restype = ctypes.POINTER(HsAtom) -_evalExpr.argtypes = [ctypes.POINTER(HsContext), ctypes.c_char_p] - -_lookup = lib.dexLookup -_lookup.restype = ctypes.POINTER(HsAtom) -_lookup.argtypes = [ctypes.POINTER(HsContext), ctypes.c_char_p] - -_toCAtom = lib.dexToCAtom -_toCAtom.restype = ctypes.c_int -_toCAtom.argtypes = [ctypes.POINTER(HsAtom), ctypes.POINTER(CAtom)] - -_getError = lib.dexGetError -_getError.restype = ctypes.c_char_p -_getError.argtypes = [] +class HsJIT(ctypes.Structure): pass +class NativeFunctionObj(ctypes.Structure): pass + +HsAtomPtr = ctypes.POINTER(HsAtom) +HsContextPtr = ctypes.POINTER(HsContext) +HsJITPtr = ctypes.POINTER(HsJIT) +CAtomPtr = ctypes.POINTER(CAtom) +NativeFunction = ctypes.POINTER(NativeFunctionObj) + +def _dex_func(name, *signature): + argtypes, restype = signature[:-1], signature[-1] + f = getattr(lib, name) + f.restype = restype + f.argtypes = argtypes + return f + +_init = _dex_func('dexInit', None) +_fini = _dex_func('dexFini', None) +_getError = _dex_func('dexGetError', ctypes.c_char_p) + +_create_context = _dex_func('dexCreateContext', HsContextPtr) +_destroy_context = _dex_func('dexDestroyContext', HsContextPtr, None) + +_eval = _dex_func('dexEval', HsContextPtr, ctypes.c_char_p, HsContextPtr) +_insert = _dex_func('dexInsert', HsContextPtr, ctypes.c_char_p, HsAtomPtr, HsContextPtr) +_evalExpr = _dex_func('dexEvalExpr', HsContextPtr, ctypes.c_char_p, HsAtomPtr) +_lookup = _dex_func('dexLookup', HsContextPtr, ctypes.c_char_p, HsAtomPtr) + +_print = _dex_func('dexPrint', HsAtomPtr, ctypes.c_char_p) +_toCAtom = _dex_func('dexToCAtom', HsAtomPtr, CAtomPtr, ctypes.c_int) + +_createJIT = _dex_func('dexCreateJIT', HsJITPtr) +_destroyJIT = _dex_func('dexDestroyJIT', HsJITPtr, None) +_compile = _dex_func('dexCompile', HsJITPtr, HsContextPtr, HsAtomPtr, NativeFunction) +_unload = _dex_func('dexUnload', HsJITPtr, NativeFunction, None) _init() +_jit = _createJIT() _nofree = False @atexit.register def _teardown(): global _nofree + _destroyJIT(_jit) _fini() _nofree = True # Don't destruct any Haskell objects after the RTS has been shutdown @@ -94,7 +87,7 @@ def _as_cstr(x: str): return ctypes.c_char_p(x.encode('ascii')) def _from_cstr(cx): - return cx.value.decode('ascii') + return cx.decode('ascii') class Module: @@ -148,6 +141,7 @@ def __del__(self): pass def __repr__(self): + # TODO: Free! return _print(self).decode('ascii') def __int__(self): @@ -176,3 +170,19 @@ def __call__(self, *args): old_env, env = env, _insert(env, _as_cstr(f"python_arg{i}"), atom) _destroy_context(old_env) return eval(" ".join(f"python_arg{i}" for i in range(len(args) + 1)), module=self.module, _env=env) + + def compile(self): + func_ptr = _compile(_jit, self.module, self) + if not func_ptr: + raise RuntimeError("Failed to JIT-compile a Dex function") + return NativeFunction(func_ptr) + + +class NativeFunction: + def __init__(self, ptr): + self._as_parameter_ = ptr + self.ptr = ptr + + def __del__(self): + if _nofree: return + _unload(_jit, self) diff --git a/python/tests/api_test.py b/python/tests/api_test.py index 6282b4875..4b55d7b92 100644 --- a/python/tests/api_test.py +++ b/python/tests/api_test.py @@ -5,6 +5,8 @@ # https://developers.google.com/open-source/licenses/bsd import unittest +import ctypes +import numpy as np from textwrap import dedent # TODO: Write a proper setup.py instead of using this hack... @@ -43,3 +45,23 @@ def addOne (x: Float) : Float = x + 1.0 def test_scalar_conversions(): assert float(dex.eval("5.0")) == 5.0 assert int(dex.eval("5")) == 5 + +def test_jit(): + m = dex.eval(r"\x:Float. 1.0 / (1.0 + exp(-x))") + native_func = m.compile() + func_ptr = ctypes.cast(native_func.ptr, ctypes.c_void_p).value + signature = ctypes.CFUNCTYPE(ctypes.c_int64, ctypes.c_float, ctypes.POINTER(ctypes.c_float)) + func = signature(func_ptr) + + def dex_sigmoid(x): + res = ctypes.c_float() + has_error = func(x, ctypes.pointer(res)) + assert not has_error + return res.value + + one = np.float32(1.0) + def py_sigmoid(x): return one / (one + np.exp(-x)) + + for value in map(np.float32, (-1.0, -0.5, 0.0, 0.5, 1.0)): + np.testing.assert_allclose(dex_sigmoid(value), py_sigmoid(value), + rtol=1e-4, atol=1e-6) diff --git a/src/Dex/Foreign/API.hs b/src/Dex/Foreign/API.hs new file mode 100644 index 000000000..7a284eb26 --- /dev/null +++ b/src/Dex/Foreign/API.hs @@ -0,0 +1,41 @@ +-- Copyright 2020 Google LLC +-- +-- Use of this source code is governed by a BSD-style +-- license that can be found in the LICENSE file or at +-- https://developers.google.com/open-source/licenses/bsd + +module Dex.Foreign.API where + +import Foreign.Ptr +import Foreign.C + +import Syntax + +import Dex.Foreign.Context +import Dex.Foreign.Serialize +import Dex.Foreign.JIT + +-- Public API (commented out exports are defined in rts.c) + +-- Initialization and basic runtime +-- foreign export ccall "dexInit" _ :: IO () +-- foreign export ccall "dexFini" _ :: IO () +-- foreign export ccall "dexGetError" _ :: CString + +-- Context +foreign export ccall "dexCreateContext" dexCreateContext :: IO (Ptr Context) +foreign export ccall "dexDestroyContext" dexDestroyContext :: Ptr Context -> IO () +foreign export ccall "dexInsert" dexInsert :: Ptr Context -> CString -> Ptr Atom -> IO (Ptr Context) +foreign export ccall "dexEval" dexEval :: Ptr Context -> CString -> IO (Ptr Context) +foreign export ccall "dexEvalExpr" dexEvalExpr :: Ptr Context -> CString -> IO (Ptr Atom) +foreign export ccall "dexLookup" dexLookup :: Ptr Context -> CString -> IO (Ptr Atom) + +-- Serialization +foreign export ccall "dexPrint" dexPrint :: Ptr Atom -> IO CString +foreign export ccall "dexToCAtom" dexToCAtom :: Ptr Atom -> Ptr CAtom -> IO CInt + +-- JIT +foreign export ccall "dexCreateJIT" dexCreateJIT :: IO (Ptr JIT) +foreign export ccall "dexDestroyJIT" dexDestroyJIT :: Ptr JIT -> IO () +foreign export ccall "dexCompile" dexCompile :: Ptr JIT -> Ptr Context -> Ptr Atom -> IO (Ptr NativeFunction) +foreign export ccall "dexUnload" dexUnload :: Ptr JIT -> Ptr NativeFunction -> IO () diff --git a/src/foreign/API.hs b/src/Dex/Foreign/Context.hs similarity index 54% rename from src/foreign/API.hs rename to src/Dex/Foreign/Context.hs index 23e1da09b..6b0ab72fe 100644 --- a/src/foreign/API.hs +++ b/src/Dex/Foreign/Context.hs @@ -4,19 +4,21 @@ -- license that can be found in the LICENSE file or at -- https://developers.google.com/open-source/licenses/bsd -module API where +module Dex.Foreign.Context ( + Context (..), + setError, + dexCreateContext, dexDestroyContext, + dexInsert, dexLookup, + dexEval, dexEvalExpr, + ) where import Control.Monad.State.Strict import Foreign.Ptr import Foreign.StablePtr -import Foreign.Storable -import Foreign.Marshal.Alloc import Foreign.C.String -import Foreign.C.Types import Data.String -import Data.Word import Data.Int import Data.Functor import Data.Foldable @@ -26,22 +28,10 @@ import Syntax hiding (sizeOf) import Type import TopLevel import Parser (parseExpr, exprAsModule) -import Serialize (pprintVal) import Env hiding (Tag) import PPrint --- Public API (commented out exports are defined in rts.c) --- foreign export ccall "dexInit" _ :: IO () --- foreign export ccall "dexFini" _ :: IO () --- foreign export ccall "dexGetError" _ :: CString -foreign export ccall "dexCreateContext" dexCreateContext :: IO (Ptr Context) -foreign export ccall "dexDestroyContext" dexDestroyContext :: Ptr Context -> IO () -foreign export ccall "dexPrint" dexPrint :: Ptr Atom -> IO CString -foreign export ccall "dexInsert" dexInsert :: Ptr Context -> CString -> Ptr Atom -> IO (Ptr Context) -foreign export ccall "dexEval" dexEval :: Ptr Context -> CString -> IO (Ptr Context) -foreign export ccall "dexEvalExpr" dexEvalExpr :: Ptr Context -> CString -> IO (Ptr Atom) -foreign export ccall "dexLookup" dexLookup :: Ptr Context -> CString -> IO (Ptr Atom) -foreign export ccall "dexToCAtom" dexToCAtom :: Ptr Atom -> Ptr CAtom -> IO CInt +import Dex.Foreign.Util data Context = Context EvalConfig TopEnv @@ -72,9 +62,6 @@ dexCreateContext = do dexDestroyContext :: Ptr Context -> IO () dexDestroyContext = freeStablePtr . castPtrToStablePtr . castPtr -dexPrint :: Ptr Atom -> IO CString -dexPrint atomPtr = newCString =<< pprintVal =<< fromStablePtr atomPtr - dexEval :: Ptr Context -> CString -> IO (Ptr Context) dexEval ctxPtr sourcePtr = do Context evalConfig env <- fromStablePtr ctxPtr @@ -118,61 +105,3 @@ dexLookup ctxPtr namePtr = do Just _ -> setError "Looking up an expression" $> nullPtr Nothing -> setError "Unbound name" $> nullPtr -dexToCAtom :: Ptr Atom -> Ptr CAtom -> IO CInt -dexToCAtom atomPtr resultPtr = do - atom <- fromStablePtr atomPtr - case atom of - Con con -> case con of - Lit (VecLit _) -> notSerializable - Lit l -> poke resultPtr (CLit l) $> 1 - _ -> notSerializable - _ -> notSerializable - where - notSerializable = setError "Unserializable atom" $> 0 - -dexFreeCAtom :: Ptr CAtom -> IO () -dexFreeCAtom = free - -data CAtom = CLit LitVal | CRectArray (Ptr ()) [Int] [Int] - -instance Storable CAtom where - sizeOf _ = tag + val + val + val - where tag = 8; val = 8 - alignment _ = 8 - peek addr = do - tag <- val @Word64 0 - case tag of - 0 -> do - litTag <- val @Word64 1 - CLit <$> case litTag of - 0 -> Int64Lit <$> val 2 - 1 -> Int32Lit <$> val 2 - 2 -> Word8Lit <$> val 2 - 3 -> Float64Lit <$> val 2 - 4 -> Float32Lit <$> val 2 - _ -> error "Invalid tag" - _ -> error "Invalid tag" - where - val :: forall a. Storable a => Int -> IO a - val i = peekByteOff (castPtr addr) (i * 8) - poke addr catom = case catom of - CLit lit -> do - val @Word64 0 0 - case lit of - Int64Lit v -> val @Word64 1 0 >> val 2 v - Int32Lit v -> val @Word64 1 1 >> val 2 v - Word8Lit v -> val @Word64 1 2 >> val 2 v - Float64Lit v -> val @Word64 1 3 >> val 2 v - Float32Lit v -> val @Word64 1 4 >> val 2 v - VecLit _ -> error "Unsupported" - PtrLit _ _ -> error "Unsupported" - CRectArray _ _ _ -> error "Unsupported" - where - val :: forall a. Storable a => Int -> a -> IO () - val i v = pokeByteOff (castPtr addr) (i * 8) v - -fromStablePtr :: Ptr a -> IO a -fromStablePtr = deRefStablePtr . castPtrToStablePtr . castPtr - -toStablePtr :: a -> IO (Ptr a) -toStablePtr x = castPtr . castStablePtrToPtr <$> newStablePtr x diff --git a/src/Dex/Foreign/JIT.hs b/src/Dex/Foreign/JIT.hs new file mode 100644 index 000000000..c188dfaf3 --- /dev/null +++ b/src/Dex/Foreign/JIT.hs @@ -0,0 +1,80 @@ +-- Copyright 2020 Google LLC +-- +-- Use of this source code is governed by a BSD-style +-- license that can be found in the LICENSE file or at +-- https://developers.google.com/open-source/licenses/bsd + +{-# LANGUAGE RecordWildCards #-} + +module Dex.Foreign.JIT ( + JIT, NativeFunction, + dexCreateJIT, dexDestroyJIT, + dexCompile, dexUnload + ) where + +import Control.Monad.State.Strict + +import Foreign.Ptr + +import Data.IORef +import qualified Data.Map.Strict as M + +import LLVM.Target (TargetMachine) +import qualified LLVM.Relocation as R +import qualified LLVM.CodeModel as CM +import qualified LLVM.CodeGenOpt as CGO +import qualified LLVM.JIT +import qualified LLVM.Shims + +import Logging +import LLVMExec +import JIT +import Syntax hiding (sizeOf) +import TopLevel + +import Dex.Foreign.Util +import Dex.Foreign.Context + +data JIT = ForeignJIT { jit :: LLVM.JIT.JIT + , jitTargetMachine :: TargetMachine + , funcToModuleRef :: IORef (M.Map (Ptr NativeFunction) LLVM.JIT.NativeModule) + } + + +dexCreateJIT :: IO (Ptr JIT) +dexCreateJIT = do + jitTargetMachine <- LLVM.Shims.newHostTargetMachine R.PIC CM.Large CGO.Aggressive + jit <- LLVM.JIT.createJIT jitTargetMachine + funcToModuleRef <- newIORef mempty + toStablePtr ForeignJIT{..} + +dexDestroyJIT :: Ptr JIT -> IO () +dexDestroyJIT jitPtr = do + ForeignJIT{..} <- fromStablePtr jitPtr + funcToModule <- readIORef funcToModuleRef + forM_ (M.toList funcToModule) $ \(_, m) -> LLVM.JIT.unloadNativeModule m + LLVM.JIT.destroyJIT jit + LLVM.Shims.disposeTargetMachine jitTargetMachine + +data NativeFunction + +dexCompile :: Ptr JIT -> Ptr Context -> Ptr Atom -> IO (Ptr NativeFunction) +dexCompile jitPtr ctxPtr funcAtomPtr = do + ForeignJIT{..} <- fromStablePtr jitPtr + Context _ env <- fromStablePtr ctxPtr + funcAtom <- fromStablePtr funcAtomPtr + let impMod = prepareFunctionForExport env "userFunc" funcAtom + nativeModule <- execLogger Nothing $ \logger -> do + llvmAST <- impToLLVM logger impMod + LLVM.JIT.compileModule jit llvmAST + (standardCompilationPipeline logger ["userFunc"] jitTargetMachine) + funcPtr <- castFunPtrToPtr <$> LLVM.JIT.getFunctionPtr nativeModule "userFunc" + modifyIORef funcToModuleRef $ M.insert funcPtr nativeModule + return $ funcPtr + +dexUnload :: Ptr JIT -> Ptr NativeFunction -> IO () +dexUnload jitPtr funcPtr = do + ForeignJIT{..} <- fromStablePtr jitPtr + funcToModule <- readIORef funcToModuleRef + LLVM.JIT.unloadNativeModule $ funcToModule M.! funcPtr + modifyIORef funcToModuleRef $ M.delete funcPtr diff --git a/src/Dex/Foreign/Serialize.hs b/src/Dex/Foreign/Serialize.hs new file mode 100644 index 000000000..76560c8df --- /dev/null +++ b/src/Dex/Foreign/Serialize.hs @@ -0,0 +1,78 @@ +-- Copyright 2020 Google LLC +-- +-- Use of this source code is governed by a BSD-style +-- license that can be found in the LICENSE file or at +-- https://developers.google.com/open-source/licenses/bsd + +module Dex.Foreign.Serialize ( + CAtom, + dexPrint, dexToCAtom + ) where + +import Data.Int +import Data.Word +import Data.Functor + +import Foreign.C +import Foreign.Ptr +import Foreign.Storable + +import Syntax +import Serialize (pprintVal) + +import Dex.Foreign.Context (setError) +import Dex.Foreign.Util + +-- TODO: Free! +dexPrint :: Ptr Atom -> IO CString +dexPrint atomPtr = newCString =<< pprintVal =<< fromStablePtr atomPtr + +data CAtom = CLit LitVal | CRectArray (Ptr ()) [Int] [Int] + +instance Storable CAtom where + sizeOf _ = tag + val + val + val + where tag = 8; val = 8 + alignment _ = 8 + peek addr = do + tag <- val @Word64 0 + case tag of + 0 -> do + litTag <- val @Word64 1 + CLit <$> case litTag of + 0 -> Int64Lit <$> val 2 + 1 -> Int32Lit <$> val 2 + 2 -> Word8Lit <$> val 2 + 3 -> Float64Lit <$> val 2 + 4 -> Float32Lit <$> val 2 + _ -> error "Invalid tag" + _ -> error "Invalid tag" + where + val :: forall a. Storable a => Int -> IO a + val i = peekByteOff (castPtr addr) (i * 8) + poke addr catom = case catom of + CLit lit -> do + val @Word64 0 0 + case lit of + Int64Lit v -> val @Word64 1 0 >> val 2 v + Int32Lit v -> val @Word64 1 1 >> val 2 v + Word8Lit v -> val @Word64 1 2 >> val 2 v + Float64Lit v -> val @Word64 1 3 >> val 2 v + Float32Lit v -> val @Word64 1 4 >> val 2 v + VecLit _ -> error "Unsupported" + PtrLit _ _ -> error "Unsupported" + CRectArray _ _ _ -> error "Unsupported" + where + val :: forall a. Storable a => Int -> a -> IO () + val i v = pokeByteOff (castPtr addr) (i * 8) v + +dexToCAtom :: Ptr Atom -> Ptr CAtom -> IO CInt +dexToCAtom atomPtr resultPtr = do + atom <- fromStablePtr atomPtr + case atom of + Con con -> case con of + Lit (VecLit _) -> notSerializable + Lit l -> poke resultPtr (CLit l) $> 1 + _ -> notSerializable + _ -> notSerializable + where + notSerializable = setError "Unserializable atom" $> 0 diff --git a/src/Dex/Foreign/Util.hs b/src/Dex/Foreign/Util.hs new file mode 100644 index 000000000..aaa3ce8ec --- /dev/null +++ b/src/Dex/Foreign/Util.hs @@ -0,0 +1,16 @@ +-- Copyright 2020 Google LLC +-- +-- Use of this source code is governed by a BSD-style +-- license that can be found in the LICENSE file or at +-- https://developers.google.com/open-source/licenses/bsd + +module Dex.Foreign.Util (fromStablePtr, toStablePtr) where + +import Foreign.StablePtr +import Foreign.Ptr + +fromStablePtr :: Ptr a -> IO a +fromStablePtr = deRefStablePtr . castPtrToStablePtr . castPtr + +toStablePtr :: a -> IO (Ptr a) +toStablePtr x = castPtr . castStablePtrToPtr <$> newStablePtr x diff --git a/src/foreign/rts.c b/src/Dex/Foreign/rts.c similarity index 100% rename from src/foreign/rts.c rename to src/Dex/Foreign/rts.c diff --git a/src/lib/LLVM/JIT.hs b/src/lib/LLVM/JIT.hs index 2075902f5..4649152c4 100644 --- a/src/lib/LLVM/JIT.hs +++ b/src/lib/LLVM/JIT.hs @@ -29,10 +29,11 @@ import qualified LLVM.OrcJIT as OrcJIT import qualified LLVM.Target as T import qualified LLVM.Linking as Linking -import qualified LLVM.Module as Mod -import qualified LLVM.AST as L -import qualified LLVM.AST.Global as L +import qualified LLVM.AST +import qualified LLVM.AST.Global as LLVM.AST import qualified LLVM.AST.Constant as C +import qualified LLVM.Module as LLVM +import qualified LLVM.Context as LLVM import LLVM.Shims @@ -71,16 +72,22 @@ data NativeModule = NativeModule { moduleJIT :: JIT , moduleKey :: OrcJIT.ModuleKey , moduleDtors :: [FunPtr (IO ())] + , llvmModule :: LLVM.Module + , llvmContext :: LLVM.Context } --- XXX: This destroys the passed in module! +type CompilationPipeline = LLVM.Module -> IO () + -- TODO: This leaks resources if we fail halfway -compileModule :: JIT -> (L.Module, Mod.Module) -> IO NativeModule -compileModule moduleJIT@JIT{..} (ast, m) = do +compileModule :: JIT -> LLVM.AST.Module -> CompilationPipeline -> IO NativeModule +compileModule moduleJIT@JIT{..} ast compilationPipeline = do + llvmContext <- LLVM.createContext + llvmModule <- LLVM.createModuleFromAST llvmContext ast + compilationPipeline llvmModule moduleKey <- OrcJIT.allocateModuleKey execSession resolver <- newSymbolResolver execSession (makeResolver compileLayer) modifyIORef resolvers (M.insert moduleKey resolver) - OrcJIT.addModule compileLayer moduleKey m + OrcJIT.addModule compileLayer moduleKey llvmModule moduleDtors <- forM dtorNames $ \dtorName -> do dtorSymbol <- OrcJIT.mangleSymbol compileLayer (fromString dtorName) Right (OrcJIT.JITSymbol dtorAddr _) <- OrcJIT.findSymbol compileLayer dtorSymbol False @@ -109,15 +116,16 @@ compileModule moduleJIT@JIT{..} (ast, m) = do -- Unfortunately the JIT layers we use here don't handle the destructors properly, -- so we have to find and call them ourselves. dtorNames = do - let dtorStructs = flip foldMap (L.moduleDefinitions ast) $ \case - L.GlobalDefinition - L.GlobalVariable{name="llvm.global_dtors", - initializer=Just (C.Array _ elems), - ..} -> elems + let dtorStructs = flip foldMap (LLVM.AST.moduleDefinitions ast) $ \case + LLVM.AST.GlobalDefinition + LLVM.AST.GlobalVariable{ + name="llvm.global_dtors", + initializer=Just (C.Array _ elems), + ..} -> elems _ -> [] -- Sort in the order of decreasing priority! fmap snd $ sortBy (flip compare) $ flip fmap dtorStructs $ - \(C.Struct _ _ [C.Int _ n, C.GlobalReference _ (L.Name dname), _]) -> + \(C.Struct _ _ [C.Int _ n, C.GlobalReference _ (LLVM.AST.Name dname), _]) -> (n, C8BS.unpack $ SBS.fromShort dname) foreign import ccall "dynamic" @@ -133,9 +141,11 @@ unloadNativeModule NativeModule{..} = do modifyIORef resolvers (M.delete moduleKey) OrcJIT.removeModule compileLayer moduleKey OrcJIT.releaseModuleKey execSession moduleKey + LLVM.disposeModule llvmModule + LLVM.disposeContext llvmContext -withNativeModule :: JIT -> (L.Module, Mod.Module) -> (NativeModule -> IO a) -> IO a -withNativeModule jit m = bracket (compileModule jit m) unloadNativeModule +withNativeModule :: JIT -> LLVM.AST.Module -> CompilationPipeline -> (NativeModule -> IO a) -> IO a +withNativeModule jit m p = bracket (compileModule jit m p) unloadNativeModule getFunctionPtr :: NativeModule -> String -> IO (FunPtr a) getFunctionPtr NativeModule{..} funcName = do diff --git a/src/lib/LLVM/Shims.hs b/src/lib/LLVM/Shims.hs index 9cbd02119..860b5540a 100644 --- a/src/lib/LLVM/Shims.hs +++ b/src/lib/LLVM/Shims.hs @@ -7,7 +7,6 @@ module LLVM.Shims ( SymbolResolver (..), newSymbolResolver, disposeSymbolResolver, newTargetMachine, newHostTargetMachine, disposeTargetMachine, - newTargetOptions, disposeTargetOptions ) where import qualified Data.Map as M @@ -73,24 +72,15 @@ newTargetMachine (Target.Target targetFFI) triple cpu features targetOptFFI relocModelFFI codeModelFFI cgoLevelFFI where encodeFeature (Target.CPUFeature f, on) = (if on then "+" else "-") <> f -newHostTargetMachine :: R.Model -> CM.Model -> CGO.Level -> IO (Target.TargetMachine, Target.TargetOptions) +newHostTargetMachine :: R.Model -> CM.Model -> CGO.Level -> IO Target.TargetMachine newHostTargetMachine relocModel codeModel cgoLevel = do Target.initializeAllTargets triple <- Target.getProcessTargetTriple (target, _) <- Target.lookupTarget Nothing triple cpu <- Target.getHostCPUName features <- Target.getHostCPUFeatures - targetOptions <- newTargetOptions - tm <- newTargetMachine target triple cpu features targetOptions relocModel codeModel cgoLevel - return (tm, targetOptions) + Target.withTargetOptions $ \targetOptions -> + newTargetMachine target triple cpu features targetOptions relocModel codeModel cgoLevel disposeTargetMachine :: Target.TargetMachine -> IO () disposeTargetMachine (Target.TargetMachine tmFFI) = Target.FFI.disposeTargetMachine tmFFI - --- llvm-hs doesn't expose any way to manage target options in a non-bracketed way - -newTargetOptions :: IO Target.TargetOptions -newTargetOptions = Target.TargetOptions <$> Target.FFI.createTargetOptions - -disposeTargetOptions :: Target.TargetOptions -> IO () -disposeTargetOptions (Target.TargetOptions optsFFI) = Target.FFI.disposeTargetOptions optsFFI diff --git a/src/lib/LLVMExec.hs b/src/lib/LLVMExec.hs index e6dcc2095..43d2c27ae 100644 --- a/src/lib/LLVMExec.hs +++ b/src/lib/LLVMExec.hs @@ -10,6 +10,7 @@ module LLVMExec (LLVMKernel (..), ptxDataLayout, ptxTargetTriple, compileAndEval, compileAndBench, exportObjectFile, + standardCompilationPipeline, compileCUDAKernel, loadLitVal) where import qualified LLVM.Analysis as L @@ -27,7 +28,6 @@ import qualified LLVM.Internal.Module as Mod import qualified LLVM.PassManager as P import qualified LLVM.Transforms as P import qualified LLVM.Target as T -import qualified LLVM.Linking as Linking import LLVM.Context import Data.Time.Clock (getCurrentTime, diffUTCTime) import System.IO @@ -107,16 +107,21 @@ checkedCallFunPtr sync argsPtr resultPtr fPtr = do compileOneOff :: Logger [Output] -> L.Module -> String -> (DexExecutable -> IO a) -> IO a compileOneOff logger ast name f = do - withContext $ \c -> do - Mod.withModuleFromAST c ast $ \m -> do - withHostTargetMachine $ \tm -> do - linkDexrt c m - let exports = [name] - internalize exports m - optimizeModule c logger tm m - withJIT tm $ \jit -> do - withNativeModule jit (ast, m) $ \compiled -> - f =<< getFunctionPtr compiled name + withHostTargetMachine $ \tm -> + withJIT tm $ \jit -> + withNativeModule jit ast (standardCompilationPipeline logger [name] tm) $ \compiled -> + f =<< getFunctionPtr compiled name + +standardCompilationPipeline :: Logger [Output] -> [String] -> T.TargetMachine -> Mod.Module -> IO () +standardCompilationPipeline logger exports tm m = do + linkDexrt m + internalize exports m + showModule m >>= logPass JitPass + L.verify m + runDefaultPasses tm m + showModule m >>= logPass LLVMOpt + showAsm tm m >>= logPass AsmPass + where logPass passName s = logThis logger [PassInfo passName s] -- === object file export === @@ -125,23 +130,19 @@ compileOneOff logger ast name f = do exportObjectFile :: FilePath -> [(L.Module, [String])] -> IO () exportObjectFile objFile modules = do withContext $ \c -> do - void $ Linking.loadLibraryPermanently Nothing withHostTargetMachine $ \tm -> - withBrackets (fmap (toLLVM c tm) modules) $ \mods -> do + withBrackets (fmap (toLLVM c) modules) $ \mods -> do Mod.withModuleFromAST c L.defaultModule $ \exportMod -> do void $ foldM linkModules exportMod mods - linkDexrt c exportMod - internalize allExports exportMod + execLogger Nothing $ \logger -> + standardCompilationPipeline logger allExports tm exportMod Mod.writeObjectToFile tm (Mod.File objFile) exportMod where allExports = foldMap snd modules - toLLVM :: Context -> T.TargetMachine -> (L.Module, [String]) -> (Mod.Module -> IO a) -> IO a - toLLVM c tm (ast, exports) cont = do - Mod.withModuleFromAST c ast $ \m -> do - internalize exports m - execLogger Nothing $ \logger -> optimizeModule c logger tm m - cont m + toLLVM :: Context -> (L.Module, [String]) -> (Mod.Module -> IO a) -> IO a + toLLVM c (ast, exports) cont = do + Mod.withModuleFromAST c ast $ \m -> internalize exports m >> cont m linkModules a b = a <$ Mod.linkModules a b @@ -154,15 +155,6 @@ exportObjectFile objFile modules = do -- === LLVM passes === -optimizeModule :: Context -> Logger [Output] -> T.TargetMachine -> Mod.Module -> IO () -optimizeModule ctx logger tm m = do - showModule m >>= logPass JitPass - L.verify m - runDefaultPasses tm m - showModule m >>= logPass LLVMOpt - showAsm ctx tm m >>= logPass AsmPass - where logPass passName s = logThis logger [PassInfo passName s] - runDefaultPasses :: T.TargetMachine -> Mod.Module -> IO () runDefaultPasses t m = do P.withPassManager defaultPasses $ \pm -> void $ P.runPassManager pm m @@ -187,6 +179,7 @@ runPasses passes mt m = do internalize :: [String] -> Mod.Module -> IO () internalize names m = runPasses [P.InternalizeFunctions names, P.GlobalDeadCodeElimination] Nothing m + -- === supported target machines === -- XXX: We need to use the large code model for macOS, because the libC functions @@ -222,8 +215,9 @@ withGPUTargetMachine computeCapability next = do showModule :: Mod.Module -> IO String showModule m = unpack <$> Mod.moduleLLVMAssembly m -showAsm :: Context -> T.TargetMachine -> Mod.Module -> IO String -showAsm ctx t m' = do +showAsm :: T.TargetMachine -> Mod.Module -> IO String +showAsm t m' = do + ctx <- Mod.moduleContext m' -- Uncomment this to dump assembly to a file that can be linked to a C benchmark suite: -- withModuleClone ctx m' $ \m -> Mod.writeObjectToFile t (Mod.File "asm.o") m withModuleClone ctx m' $ \m -> unpack <$> Mod.moduleTargetAssembly t m @@ -272,6 +266,7 @@ ptrArray p = map (\i -> p `plusPtr` (i * cellSize)) [0..] -- === dex runtime === +{-# NOINLINE dexrtAST #-} dexrtAST :: L.Module dexrtAST = unsafePerformIO $ do withContext $ \ctx -> do @@ -289,8 +284,9 @@ dexrtAST = unsafePerformIO $ do _ -> L.GlobalDefinition $ f { L.functionAttributes = [] } stripDef def = def -linkDexrt :: Context -> Mod.Module -> IO () -linkDexrt ctx m = do +linkDexrt :: Mod.Module -> IO () +linkDexrt m = do + ctx <- Mod.moduleContext m dataLayout <- Mod.getDataLayout =<< Mod.readModule m targetTriple <- Mod.getTargetTriple =<< Mod.readModule m let dexrtTargetAST = dexrtAST { L.moduleDataLayout = dataLayout @@ -310,10 +306,8 @@ compileCUDAKernel logger (LLVMKernel ast) = do withContext $ \ctx -> Mod.withModuleFromAST ctx ast $ \m -> do withGPUTargetMachine (pack arch) $ \tm -> do - linkLibdevice ctx m - linkDexrt ctx m - internalize ["kernel"] m - optimizeModule ctx logger tm m + linkLibdevice m + standardCompilationPipeline logger ["kernel"] tm m ptx <- Mod.moduleTargetAssembly tm m usePTXAS <- maybe False (=="1") <$> lookupEnv "DEX_USE_PTXAS" if usePTXAS @@ -348,8 +342,9 @@ libdevice = unsafePerformIO $ do return $ m { L.moduleDataLayout = Just ptxDataLayout , L.moduleTargetTriple = Just ptxTargetTriple } -linkLibdevice :: Context -> Mod.Module -> IO () -linkLibdevice ctx m = +linkLibdevice :: Mod.Module -> IO () +linkLibdevice m = do + ctx <- Mod.moduleContext m Mod.withModuleFromAST ctx zeroNVVMReflect $ \reflectm -> Mod.withModuleFromAST ctx libdevice $ \ldm -> do Mod.linkModules m ldm diff --git a/src/lib/TopLevel.hs b/src/lib/TopLevel.hs index b2ee837bb..e1957c523 100644 --- a/src/lib/TopLevel.hs +++ b/src/lib/TopLevel.hs @@ -7,7 +7,7 @@ {-# LANGUAGE FlexibleContexts #-} module TopLevel (evalSourceBlock, evalDecl, evalSource, evalFile, - exportFunctions, EvalConfig (..)) where + exportFunctions, prepareFunctionForExport, EvalConfig (..)) where import Control.Monad.State.Strict import Control.Monad.Reader @@ -257,32 +257,25 @@ runCArg :: CArgEnv -> CArgM a -> Embed (a, [IBinder], CArgEnv) runCArg initEnv m = repack <$> runCatT (runWriterT m) initEnv where repack ((ans, cargs), env) = (ans, cargs, env) -exportFunctions :: FilePath -> [(String, Atom)] -> TopEnv -> EvalConfig -> IO () -exportFunctions objPath funcs env opts = do - let names = fmap fst funcs - unless (length (nub names) == length names) $ liftEitherIO $ - throw CompilerErr "Duplicate export names" - modules <- forM funcs $ \(nameStr, func) -> do - -- Create a module that simulates an application of arguments to the function - let ((dest, cargs), (_, decls)) = flip runEmbed (freeVars func) $ do - (args, cargArgs, cargEnv) <- runCArg mempty $ createArgs $ getType func - resultAtom <- naryApp func args - (resultDest, cdestArgs, _) <- runCArg cargEnv $ createDest mempty $ getType resultAtom - void $ emitTo outputName PlainLet $ Atom resultAtom - return (resultDest, cargArgs <> cdestArgs) - - let coreModule = Module Core decls mempty - let defunctionalized = simplifyModule env coreModule - let Module _ optDecls optBindings = optimizeModule defunctionalized - let (_, LetBound PlainLet outputExpr) = optBindings ! outputName - let block = Block optDecls outputExpr - - let backend = backendName opts - let name = Name TopFunctionName (fromString nameStr) 0 - let (_, impModule, _) = toImpModule env backend CEntryFun name cargs (Just dest) block - llvmAST <- execLogger Nothing $ flip impToLLVM impModule - return (llvmAST, [nameStr]) - exportObjectFile objPath modules +prepareFunctionForExport :: TopEnv -> String -> Atom -> ImpModule +prepareFunctionForExport env nameStr func = do + -- Create a module that simulates an application of arguments to the function + let ((dest, cargs), (_, decls)) = flip runEmbed (freeVars func) $ do + (args, cargArgs, cargEnv) <- runCArg mempty $ createArgs $ getType func + resultAtom <- naryApp func args + (resultDest, cdestArgs, _) <- runCArg cargEnv $ createDest mempty $ getType resultAtom + void $ emitTo outputName PlainLet $ Atom resultAtom + return (resultDest, cargArgs <> cdestArgs) + + let coreModule = Module Core decls mempty + let defunctionalized = simplifyModule env coreModule + let Module _ optDecls optBindings = optimizeModule defunctionalized + let (_, LetBound PlainLet outputExpr) = optBindings ! outputName + let block = Block optDecls outputExpr + + let name = Name TopFunctionName (fromString nameStr) 0 + let (_, impModule, _) = toImpModule env LLVM CEntryFun name cargs (Just dest) block + impModule where outputName = GlobalName "_ans_" @@ -338,6 +331,18 @@ exportFunctions objPath funcs env opts = do tell [Bind $ name :> bt] return $ Var $ name :> BaseTy bt +exportFunctions :: FilePath -> [(String, Atom)] -> TopEnv -> EvalConfig -> IO () +exportFunctions objPath funcs env opts = do + unless (backendName opts == LLVM) $ liftEitherIO $ + throw CompilerErr "Export only supported with the LLVM CPU backend" + let names = fmap fst funcs + unless (length (nub names) == length names) $ liftEitherIO $ + throw CompilerErr "Duplicate export names" + modules <- forM funcs $ \(name, funcAtom) -> do + let impModule = prepareFunctionForExport env name funcAtom + (,[name]) <$> execLogger Nothing (flip impToLLVM impModule) + exportObjectFile objPath modules + abstractPtrLiterals :: Block -> ([IBinder], [LitVal], Block) abstractPtrLiterals block = flip evalState mempty $ do block' <- traverseLiterals block $ \val -> case val of diff --git a/stack-macos.yaml b/stack-macos.yaml index fbc7107e1..c14681f7f 100644 --- a/stack-macos.yaml +++ b/stack-macos.yaml @@ -10,8 +10,11 @@ packages: - . extra-deps: - - llvm-hs-9.0.1 - - llvm-hs-pure-9.0.0 + - github: apaszke/llvm-hs + commit: a9a74be1a7c15f3d21b2fffd35a425002ae7736f + subdirs: + - llvm-hs + - llvm-hs-pure - megaparsec-8.0.0 - prettyprinter-1.6.2 - store-0.7.8@sha256:0b604101fd5053b6d7d56a4ef4c2addf97f4e08fe8cd06b87ef86f958afef3ae,8001 diff --git a/stack.yaml b/stack.yaml index 1d5bae6ae..445dd9ffd 100644 --- a/stack.yaml +++ b/stack.yaml @@ -1,4 +1,4 @@ -# Copyright 2019 Google LLC +# Copyright 2020 Google LLC # # Use of this source code is governed by a BSD-style # license that can be found in the LICENSE file or at @@ -10,8 +10,11 @@ packages: - . extra-deps: - - llvm-hs-9.0.1 - - llvm-hs-pure-9.0.0 + - github: apaszke/llvm-hs + commit: a9a74be1a7c15f3d21b2fffd35a425002ae7736f + subdirs: + - llvm-hs + - llvm-hs-pure - megaparsec-8.0.0 - prettyprinter-1.6.2 - store-0.7.8@sha256:0b604101fd5053b6d7d56a4ef4c2addf97f4e08fe8cd06b87ef86f958afef3ae,8001 From a1982fb61b1a79bc44e37432cdb25d8f738d19ed Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Mon, 21 Dec 2020 13:23:01 -0500 Subject: [PATCH 014/105] Use the IO effect to expose malloc/free/load/store. The memory discipline is C-style: the user is responsible for freeing what they allocate. I added some bracketed functions like `withAlloc` to make this easier. This lets us retire `%getPtr` which was always a bit dodgy. --- lib/diagram.dx | 3 +- lib/io.dx | 35 ++++++++++------------ lib/png.dx | 6 ++-- lib/prelude.dx | 68 ++++++++++++++++++++++++++---------------- src/lib/Autodiff.hs | 22 +++++++++++--- src/lib/Embed.hs | 10 ++++++- src/lib/Imp.hs | 38 ++++++++++------------- src/lib/Interpreter.hs | 4 ++- src/lib/Simplify.hs | 5 ++-- src/lib/Syntax.hs | 24 +++++++++++---- src/lib/Type.hs | 18 ++++++++--- 11 files changed, 143 insertions(+), 90 deletions(-) diff --git a/lib/diagram.dx b/lib/diagram.dx index abcec481c..1d15998f1 100644 --- a/lib/diagram.dx +++ b/lib/diagram.dx @@ -14,8 +14,7 @@ HtmlColor : Type = (Word8 & Word8 & Word8) def showHex (x:Int32) : String = unsafeIO \(). (n, ptr) = %ffi showHex (Int32 & CharPtr) x - AsList n $ for i:(Fin n). - %ptrLoad (%ptrOffset ptr (ordinal i)) + stringFromCharPtr n ptr -- TODO: we should add overloaded string literals so we don't need this def str (n:Int) ?-> (s:(Fin n=>Char)) : String = AsList _ s diff --git a/lib/io.dx b/lib/io.dx index 78391bd90..d8502e730 100644 --- a/lib/io.dx +++ b/lib/io.dx @@ -11,17 +11,17 @@ data StreamMode = data Stream mode:StreamMode = MkStream CharPtr -- TODO: check the string contains no nulls -def asCString (s:String) : CString = - (AsList _ s') = s <> (AsList _ "\NUL") - MkCString %getPtr s' +def withCString (s:String) (action: CString -> {State World} a) : {State World} a = + (AsList n s') = s <> (AsList _ "\NUL") + withTabPtr s' \ptr. action $ MkCString ptr def fopen (path:String) (mode:StreamMode) : {State World} (Stream mode) = modeStr = AsList _ case mode of ReadMode -> "r" WriteMode -> "w" - (MkCString path') = asCString path - (MkCString mode') = asCString modeStr - MkStream $ %ffi fopen CharPtr path' mode' + withCString path \(MkCString pathPtr). + withCString modeStr \(MkCString modePtr). + MkStream $ %ffi fopen CharPtr pathPtr modePtr def fclose (mode:StreamMode) ?-> (stream:Stream mode) : {State World} Unit = (MkStream stream') = stream @@ -31,23 +31,21 @@ def fclose (mode:StreamMode) ?-> (stream:Stream mode) : {State World} Unit = def fwrite (stream:Stream WriteMode) (s:String) : {State World} Unit = (MkStream stream') = stream (AsList n s') = s - ptr = %getPtr s' - ans = %ffi fwrite Int64 ptr (IToI64 1) (IToI64 n) stream' + withTabPtr s' \ptr. + %ffi fwrite Int64 ptr (IToI64 1) (IToI64 n) stream' () def fread (stream:Stream ReadMode) : {State World} String = (MkStream stream') = stream - -- TODO: do the malloc and pointer reads/writes in the {IO World} effect -- TODO: allow reading longer files! n = 4096 - buffer = for i:(Fin n). '\NUL' - ptr = %getPtr buffer - numRead = I64ToI $ %ffi fread Int64 ptr (IToI64 1) (IToI64 n) stream' - AsList numRead (for i. %ptrLoad (%ptrOffset ptr (ordinal i))) + withAlloc n \ptr. + numRead = I64ToI $ %ffi fread Int64 ptr (IToI64 1) (IToI64 n) stream' + stringFromCharPtr numRead ptr def deleteFile (f:FilePath) : {State World} Unit = - (MkCString f') = asCString f - %ffi remove Int64 f' + withCString f \(MkCString ptr). + %ffi remove Int64 ptr () def withFile (f:FilePath) (mode:StreamMode) @@ -67,10 +65,9 @@ def readFile (f:FilePath) : {State World} String = def writeTemp (s:String) : {State World} FilePath = -- TODO: Make this properly atomic. It can fail if another process creates a -- file with same name after we ask for the name and before we create it. - template = "/tmp/dex-XXXXXX\NUL" - ptr = %getPtr template - %ffi mktemp CharPtr ptr - AsList 15 for i. %ptrLoad (%ptrOffset ptr (ordinal i)) + withCString (AsList _ "/tmp/dex-XXXXXX") \(MkCString ptr). + %ffi mktemp CharPtr ptr + stringFromCharPtr 15 ptr def withTempFile (action: FilePath -> {State World} a) : {State World} a = tmpFile = writeTemp (AsList _ []) diff --git a/lib/png.dx b/lib/png.dx index f4898cd6c..d131f2bf1 100644 --- a/lib/png.dx +++ b/lib/png.dx @@ -95,9 +95,9 @@ Html : Type = List Char def makePNG (img:n=>m=>(Fin 3)=>Word8) : List Byte = unsafeIO \(). (AsList _ imgFlat) = toList for (i,j,k). img.i.j.k - (n, ptr) = (%ffi encodePNG (Int & CharPtr) (%getPtr imgFlat) - (size m) (size n)) - AsList n $ for i. %ptrLoad (%ptrOffset ptr (ordinal i)) + withTabPtr imgFlat \ptr. + (n, ptr) = %ffi encodePNG (Int & CharPtr) ptr (size m) (size n) + AsList n $ tabFromCharPtr ptr def pngToHtml (png:List Byte) : List Char = (toList " : Eq (Fin n) = MkEq \x y. ordinal x == ordinal y def finOrd (n:Int) ?-> : Ord (Fin n) = MkOrd finEq (\x y. ordinal x > ordinal y) (\x y. ordinal x < ordinal y) +'## Raw pointer operations + +def Ptr (ty:Type) : Type = %makePtrType ty + +CharPtr : Type = %CharPtr + +-- TODO: generalize these to other pointer types (Storable type class?) +def malloc (n:Int) : {State World} CharPtr = %charAlloc n +def free (ptr:CharPtr) : {State World} Unit = %charFree ptr +def ptrStore (ptr:CharPtr) (x:Char) : {State World} Unit = %ptrStore ptr x +def ptrLoad (ptr:CharPtr) : {State World} Char = %ptrLoad ptr +def ptrOffset (ptr:CharPtr) (i:Int) : CharPtr = %ptrOffset ptr i + +-- TODO: generalize these brackets to allow other effects + +def withAlloc (n:Int) (action: CharPtr -> {State World} a) : {State World} a = + ptr = malloc n + result = action ptr + free ptr + result + +def withTabPtr (xs:n=>Char) (action : CharPtr -> {State World} a) : {State World} a = + withAlloc (size n) \ptr. + for i. ptrStore (ptrOffset ptr (ordinal i)) xs.i + action ptr + +def tabFromCharPtr (ptr:CharPtr) : {State World} n=>Char = + for i. ptrLoad $ ptrOffset ptr (ordinal i) + 'Misc pi : Float = 3.141592653589793 @@ -864,7 +893,8 @@ def sliceFields (iso: Iso ({|} | a) (b | c)) (tab: a=>v) : b=>v = String : Type = List Char -CharPtr : Type = %CharPtr +def stringFromCharPtr (n:Int) (ptr:CharPtr) : {State World} String= + AsList n $ tabFromCharPtr ptr -- TODO. This is ASCII code point. It really should be Int32 for Unicode codepoint def codepoint (c:Char) : Int = W8ToI c @@ -875,35 +905,31 @@ interface Show a:Type where instance showInt32 : Show Int32 where show = \x: Int32. unsafeIO \(). (n, ptr) = %ffi showInt32 (Int32 & CharPtr) x - AsList n $ for i:(Fin n). - %ptrLoad (%ptrOffset ptr (ordinal i)) + stringFromCharPtr n ptr instance showInt64 : Show Int64 where show = \x: Int64. unsafeIO \(). (n, ptr) = %ffi showInt64 (Int32 & CharPtr) x - AsList n $ for i:(Fin n). - %ptrLoad (%ptrOffset ptr (ordinal i)) + stringFromCharPtr n ptr instance showFloat32 : Show Float32 where show = \x: Float32.unsafeIO \(). (n, ptr) = %ffi showFloat32 (Int32 & CharPtr) x - AsList n $ for i:(Fin n). - %ptrLoad (%ptrOffset ptr (ordinal i)) + stringFromCharPtr n ptr instance showFloat64 : Show Float64 where show = \x: Float64.unsafeIO \(). (n, ptr) = %ffi showFloat64 (Int32 & CharPtr) x - AsList n $ for i:(Fin n). - %ptrLoad (%ptrOffset ptr (ordinal i)) + stringFromCharPtr n ptr -def writeStdErr (s:String) : {State World} Unit = - (AsList n cs) = s - %ffi writeToStdErr Int n (%getPtr cs) - () +-- def writeStdErr (s:String) : {State World} Unit = +-- (AsList n cs) = s +-- %ffi writeToStdErr Int n (%getPtr cs) +-- () -def throwMsg (s:String) : a = unsafeIO \(). - writeStdErr s - %throwError a +-- def throwMsg (s:String) : a = unsafeIO \(). +-- writeStdErr s +-- %throwError a -- pipe-like reverse function application def (|>) (x:a) (f: a -> b) : b = f x @@ -1104,16 +1130,6 @@ def (>>) (x:Byte) (y:Int) : Byte = %shr x (IToW8 y) def (.|.) (x:Byte) (y:Byte) : Byte = %or x y def (.&.) (x:Byte) (y:Byte) : Byte = %and x y -'## Raw pointer operations - -def Ptr (ty:Type) : Type = %makePtrType ty - -def tabToPtr (n:Int) ?-> (xs:(Fin n)=>Float) : Ptr Float = - %getPtr xs - -def ptrToTab (n:Int) (ptr:Ptr Float) : Fin n => Float = - for i:(Fin n). %ptrLoad (%ptrOffset ptr (ordinal i)) - '## Misc -- TODO: could make an `unsafeCastIndex` and this could avoid the runtime copy diff --git a/src/lib/Autodiff.hs b/src/lib/Autodiff.hs index e1b9b4578..59cdf502d 100644 --- a/src/lib/Autodiff.hs +++ b/src/lib/Autodiff.hs @@ -133,11 +133,14 @@ linearizeOp op = case op of FstRef ref -> (FstRef <$> la ref ) `bindLin` emitOp SndRef ref -> (SndRef <$> la ref ) `bindLin` emitOp Select p t f -> (Select p <$> la t <*> la f ) `bindLin` emitOp - PtrLoad _ -> emitWithZero -- XXX: This assumes that pointers are always constants + -- XXX: This assumes that pointers are always constants + PtrLoad _ -> emitWithZero + PtrStore _ _ -> emitDiscrete PtrOffset _ _ -> emitDiscrete + IOAlloc _ _ -> emitDiscrete + IOFree _ -> emitDiscrete TabCon ty xs -> (TabCon ty <$> traverse la xs) `bindLin` emitOp Inject _ -> emitDiscrete - GetPtr _ -> emitDiscrete MakePtrType _ -> emitDiscrete SliceOffset _ _ -> emitDiscrete SliceCurry _ _ -> emitDiscrete @@ -259,6 +262,14 @@ linearizeHof env hof = case hof of RunWriter lam -> linearizeEff Nothing lam True (const RunWriter) (emitRunWriter "r") Writer RunReader val lam -> linearizeEff (Just val) lam False RunReader (emitRunReader "r") Reader RunState val lam -> linearizeEff (Just val) lam True RunState (emitRunState "r") State + RunIO ~(Lam (Abs _ (arrow, body))) -> LinA $ do + arrow' <- substEmbed env arrow + -- TODO: consider the possibility of other effects here besides IO + lam <- buildLam (Ignore UnitTy) arrow' $ \_ -> + tangentFunAsLambda $ linearizeBlock env body + result <- emit $ Hof $ RunIO lam + (ans, linLam) <- fromPair result + return (ans, applyLinToTangents linLam) -- TODO: Consider providing an upper bound for the number of while iterations as a hint. -- In the current form the best we can do is try to use some dynamically growing lists, -- but that won't work on the GPU. @@ -588,7 +599,6 @@ transposeOp op ct = case op of else transposeAtom y =<< mul ct =<< substNonlin x ScalarBinOp FDiv x y -> transposeAtom x =<< div' ct =<< substNonlin y ScalarBinOp _ _ _ -> notLinear - GetPtr _ -> notLinear MakePtrType _ -> notLinear PrimEffect refArg m -> do refArg' <- substTranspose linRefSubst refArg @@ -616,8 +626,11 @@ transposeOp op ct = case op of RecordSplit _ _ -> notImplemented VariantLift _ _ -> notImplemented VariantSplit _ _ -> notImplemented + PtrStore _ _ -> notLinear PtrLoad _ -> notLinear - PtrOffset _ _ -> notLinear + PtrOffset _ _ -> notLinear + IOAlloc _ _ -> notLinear + IOFree _ -> notLinear Inject _ -> notLinear SliceOffset _ _ -> notLinear SliceCurry _ _ -> notLinear @@ -676,6 +689,7 @@ transposeHof hof ct = case hof of localLinRegion h $ localLinRefSubst (b@>ref) $ transposeBlock body ctBody return UnitVal transposeAtom s cts + RunIO _ -> error "Not implemented" Tile _ _ _ -> notImplemented While _ _ -> notImplemented Linearize _ -> error "Unexpected linearization" diff --git a/src/lib/Embed.hs b/src/lib/Embed.hs index f75e89c67..91e2af207 100644 --- a/src/lib/Embed.hs +++ b/src/lib/Embed.hs @@ -26,7 +26,8 @@ module Embed (emit, emitTo, emitAnn, emitOp, buildDepEffLam, buildLamAux, buildP singletonTypeVal, scopedDecls, embedScoped, extendScope, checkEmbed, embedExtend, unpackConsList, emitRunWriter, emitRunState, emitRunReader, tabGet, SubstEmbedT, SubstEmbed, runSubstEmbedT, - traverseAtom, ptrOffset, ptrLoad, evalBlockE, substTraversalDef, + traverseAtom, ptrOffset, ptrLoad, unsafePtrLoad, + evalBlockE, substTraversalDef, TraversalDef, traverseDecls, traverseDecl, traverseBlock, traverseExpr, clampPositive, buildNAbs, buildNAbsAux, buildNestedLam, zeroAt, transformModuleAsBlock, dropSub, appReduceTraversalDef, @@ -328,6 +329,13 @@ appTryReduce f x = case f of ptrOffset :: MonadEmbed m => Atom -> Atom -> m Atom ptrOffset x i = emitOp $ PtrOffset x i +unsafePtrLoad :: MonadEmbed m => Atom -> m Atom +unsafePtrLoad x = emit $ Hof $ RunIO $ Lam $ Abs (Ignore UnitTy) $ + (PlainArrow justIOEff, Block Empty (Op (PtrLoad x))) + +justIOEff :: EffectRow +justIOEff = EffectRow [(State, theWorld)] Nothing + ptrLoad :: MonadEmbed m => Atom -> m Atom ptrLoad x = emitOp $ PtrLoad x diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index 8337bbbd5..5a09e6af9 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -83,7 +83,7 @@ toImpModule env backend cc entryName argBinders maybeDest block = do for (requiredFunctions env block) $ \(v, f) -> runImpM initCtx inVarScope $ toImpStandalone v f runImpM initCtx inVarScope $ do - (reconAtom, impBlock) <- scopedBlock $ translateTopLevel (maybeDest, block) + (reconAtom, impBlock) <- scopedBlock $ translateTopLevel env (maybeDest, block) otherFunctions <- toList <$> looks envFunctions let ty = IFunType cc (map binderAnn argBinders) (impBlockType impBlock) let mainFunction = ImpFunction (entryName:>ty) argBinders impBlock @@ -111,14 +111,14 @@ requiredFunctions scope expr = -- We don't emit any results when a destination is provided, since they are already -- going to be available through the dest. -translateTopLevel :: WithDest Block -> ImpM (AtomRecon, [IExpr]) -translateTopLevel (maybeDest, block) = do +translateTopLevel :: TopEnv -> WithDest Block -> ImpM (AtomRecon, [IExpr]) +translateTopLevel topEnv (maybeDest, block) = do outDest <- case maybeDest of Nothing -> makeAllocDest Unmanaged $ getType block Just dest -> return dest handleErrors $ void $ translateBlock mempty (Just outDest, block) resultAtom <- destToAtom outDest - let vsOut = envAsVars $ freeVars resultAtom + let vsOut = envAsVars $ freeVars resultAtom `envDiff` topEnv let reconAtom = Abs (toNest $ [Bind (v:>ty) | (v:>(ty, _)) <- vsOut]) resultAtom let resultIExprs = case maybeDest of Nothing -> [IVar (v:>fromScalarType ty) | (v:>(ty, _)) <- vsOut] @@ -272,14 +272,19 @@ toImpOp (maybeDest, op) = case op of IndexRef refDest i -> returnVal =<< destGet refDest i FstRef ~(Con (ConRef (PairCon ref _ ))) -> returnVal ref SndRef ~(Con (ConRef (PairCon _ ref))) -> returnVal ref + IOAlloc ty n -> do + ptr <- emitAlloc (AllocatedPtr, Heap CPU, ty) (fromScalarAtom n) + returnVal $ toScalarAtom ptr + IOFree ptr -> do + emitStatement $ Free $ fromScalarAtom ptr + return UnitVal PtrOffset arr off -> do buf <- impOffset (fromScalarAtom arr) (fromScalarAtom off) returnVal $ toScalarAtom buf PtrLoad arr -> returnVal . toScalarAtom =<< loadAnywhere (fromScalarAtom arr) - GetPtr tab -> do - (dest, ptr) <- makeAllocDestForPtr (getType tab) - copyAtom dest tab - returnVal ptr + PtrStore ptr x -> do + store (fromScalarAtom ptr) (fromScalarAtom x) + return UnitVal SliceOffset ~(Con (IndexSliceVal n _ tileOffset)) idx -> do i' <- indexToInt idx i <- iaddI (fromScalarAtom tileOffset) i' @@ -706,14 +711,14 @@ copyDataConArgs bindings args = loadDest :: MonadEmbed m => Dest -> m Atom loadDest (BoxedRef b ptrPtr _ body) = do - ptr <- ptrLoad ptrPtr + ptr <- unsafePtrLoad ptrPtr body' <- substEmbed (b@>ptr) body loadDest body' loadDest (DataConRef def params bs) = do DataCon def params 0 <$> loadDataConArgs bs loadDest (Con dest) = do case dest of - BaseTypeRef ptr -> ptrLoad ptr + BaseTypeRef ptr -> unsafePtrLoad ptr TabRef (TabVal b body) -> buildLam b TabArrow $ \i -> do body' <- substEmbed (b@>i) body result <- emitBlock body' @@ -793,17 +798,6 @@ makeAllocDestWithPtrs allocTy ty = do dest' <- impSubst env dest return (dest', ptrs) --- TODO: deallocation! -makeAllocDestForPtr :: Type -> ImpM (Dest, Atom) -makeAllocDestForPtr ty = do - (ptrSizes, dest) <- fromEmbed $ makeDest (LLVM, CPU, Unmanaged) ty - case ptrSizes of - [(Bind (ptr:>PtrTy ptrTy), size)] -> do - ptr' <- emitAlloc ptrTy $ fromScalarAtom size - dest' <- impSubst (ptr @> toScalarAtom ptr') dest - return (dest', toScalarAtom ptr') - _ -> error $ "expected a single pointer" - splitDest :: WithDest Block -> ([WithDest Decl], WithDest Expr, [(Dest, Atom)]) splitDest (maybeDest, (Block decls ans)) = do case (maybeDest, ans) of @@ -913,7 +907,7 @@ toScalarAtom ie = case ie of ILit l -> Con $ Lit l IVar (v:>b) -> Var (v:>BaseTy b) -fromScalarType :: Type -> IType +fromScalarType :: HasCallStack => Type -> IType fromScalarType (BaseTy b) = b fromScalarType ty = error $ "Not a scalar type: " ++ pprint ty diff --git a/src/lib/Interpreter.hs b/src/lib/Interpreter.hs index f9f43fa18..2c0e3edc0 100644 --- a/src/lib/Interpreter.hs +++ b/src/lib/Interpreter.hs @@ -66,7 +66,9 @@ evalExpr env expr = case expr of evalBlock env $ applyNaryAbs (alts !! i) (xss !! i) _ -> error $ "Not implemented: SumAsProd with tag " ++ pprint expr _ -> error $ "Unexpected scrutinee: " ++ pprint e - _ -> error $ "Not implemented: " ++ pprint expr + Hof hof -> case hof of + RunIO ~(Lam (Abs _ (_, body))) -> evalBlock env body + _ -> error $ "Not implemented: " ++ pprint expr evalOp :: Op -> InterpM Atom evalOp expr = case expr of diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index f9175dc8c..1442f81b6 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -476,8 +476,9 @@ simplifyHof hof = case hof of ans' <- applyRecon recon ans return $ PairVal ans' sOut RunIO lam -> do - ~(lam', Nothing) <- simplifyLam lam - emit $ Hof $ RunIO lam' + ~(lam', recon) <- simplifyLam lam + ans <- emit $ Hof $ RunIO lam' + applyRecon recon ans where applyRecon Nothing x = return x applyRecon (Just f) x = f x diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index b8b75945c..c29cf0480 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -321,12 +321,15 @@ data PrimOp e = | SndRef e | FFICall String e [e] | Inject e - | PtrOffset e e - | PtrLoad e - | GetPtr e - | MakePtrType e | SliceOffset e e -- Index slice first, inner index second | SliceCurry e e -- Index slice first, curried index second + -- Low-level memory operations + | MakePtrType e + | IOAlloc BaseType e + | IOFree e + | PtrOffset e e + | PtrLoad e + | PtrStore e e -- SIMD operations | VectorBinOp BinOp e e | VectorPack [e] -- List should have exactly vectorWidth elements @@ -560,9 +563,16 @@ data BaseType = Scalar ScalarBaseType data Device = CPU | GPU deriving (Show, Eq, Ord, Generic) data AddressSpace = Stack | Heap Device deriving (Show, Eq, Ord, Generic) -data PtrOrigin = DerivedPtr | AllocatedPtr deriving (Show, Eq, Ord, Generic) +data PtrOrigin = DerivedPtr | AllocatedPtr deriving (Show, Ord, Generic) type PtrType = (PtrOrigin, AddressSpace, BaseType) +instance Eq PtrOrigin where + -- XXX: this is a hack. We expose pointer operations to the surface language + -- but we don't yet expose the derived/allocated distinction, and they get + -- mixed up when we use ops like ptrOffset. + _ == _ = True + + sizeOf :: BaseType -> Int sizeOf t = case t of Scalar Int64Type -> 8 @@ -1539,9 +1549,11 @@ builtinNames = M.fromList , ("cast", OpExpr $ CastOp () ()) , ("sliceOffset", OpExpr $ SliceOffset () ()) , ("sliceCurry", OpExpr $ SliceCurry () ()) + , ("charAlloc", OpExpr $ IOAlloc (Scalar Word8Type) ()) + , ("charFree" , OpExpr $ IOFree ()) , ("ptrOffset", OpExpr $ PtrOffset () ()) , ("ptrLoad" , OpExpr $ PtrLoad ()) - , ("getPtr" , OpExpr $ GetPtr () ) + , ("ptrStore" , OpExpr $ PtrStore () ()) , ("makePtrType", OpExpr $ MakePtrType ()) , ("CharPtr" , ptrTy Word8Type) , ("dataConTag", OpExpr $ DataConTag ()) diff --git a/src/lib/Type.hs b/src/lib/Type.hs index 11b0de085..43d20d6a3 100644 --- a/src/lib/Type.hs +++ b/src/lib/Type.hs @@ -704,16 +704,26 @@ typeCheckOp op = case op of SndRef ref -> do RefTy h (PairTy _ b) <- typeCheck ref return $ RefTy h b + IOAlloc t n -> do + n |: IdxRepTy + return $ PtrTy (AllocatedPtr, Heap CPU, t) + IOFree ptr -> do + PtrTy _ <- typeCheck ptr + declareEff (State, Just theWorld) + return UnitTy PtrOffset arr off -> do PtrTy (_, a, b) <- typeCheck arr off |: IdxRepTy return $ PtrTy (DerivedPtr, a, b) PtrLoad ptr -> do - PtrTy (_, _, t) <- typeCheck ptr + PtrTy (_, _, t) <- typeCheck ptr + declareEff (State, Just theWorld) return $ BaseTy t - GetPtr tab -> do - TabTy _ (BaseTy a) <- typeCheck tab - return $ BaseTy $ PtrType (AllocatedPtr, Heap CPU, a) + PtrStore ptr val -> do + PtrTy (_, _, t) <- typeCheck ptr + val |: BaseTy t + declareEff (State, Just theWorld) + return $ UnitTy MakePtrType ty -> ty|:TyKind >> return TyKind SliceOffset s i -> do TC (IndexSlice n l) <- typeCheck s From e9847eb08da7d8fca91acdf7509b05024e0cc549 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Mon, 21 Dec 2020 16:07:32 -0500 Subject: [PATCH 015/105] Generalize load/store/alloc/free to `Ptr a` for `Storable a`. --- lib/diagram.dx | 2 +- lib/io.dx | 11 ++++--- lib/png.dx | 5 +-- lib/prelude.dx | 81 ++++++++++++++++++++++++++++++++++++----------- src/lib/Imp.hs | 12 +++---- src/lib/JIT.hs | 1 + src/lib/Syntax.hs | 3 +- src/lib/Type.hs | 5 ++- tests/io-tests.dx | 8 +++++ 9 files changed, 93 insertions(+), 35 deletions(-) diff --git a/lib/diagram.dx b/lib/diagram.dx index 1d15998f1..c3889bbb7 100644 --- a/lib/diagram.dx +++ b/lib/diagram.dx @@ -14,7 +14,7 @@ HtmlColor : Type = (Word8 & Word8 & Word8) def showHex (x:Int32) : String = unsafeIO \(). (n, ptr) = %ffi showHex (Int32 & CharPtr) x - stringFromCharPtr n ptr + stringFromCharPtr n (MkPtr ptr) -- TODO: we should add overloaded string literals so we don't need this def str (n:Int) ?-> (s:(Fin n=>Char)) : String = AsList _ s diff --git a/lib/io.dx b/lib/io.dx index d8502e730..b742d82f5 100644 --- a/lib/io.dx +++ b/lib/io.dx @@ -13,7 +13,7 @@ data Stream mode:StreamMode = MkStream CharPtr -- TODO: check the string contains no nulls def withCString (s:String) (action: CString -> {State World} a) : {State World} a = (AsList n s') = s <> (AsList _ "\NUL") - withTabPtr s' \ptr. action $ MkCString ptr + withTabPtr s' \(MkPtr ptr). action $ MkCString ptr def fopen (path:String) (mode:StreamMode) : {State World} (Stream mode) = modeStr = AsList _ case mode of @@ -31,7 +31,7 @@ def fclose (mode:StreamMode) ?-> (stream:Stream mode) : {State World} Unit = def fwrite (stream:Stream WriteMode) (s:String) : {State World} Unit = (MkStream stream') = stream (AsList n s') = s - withTabPtr s' \ptr. + withTabPtr s' \(MkPtr ptr). %ffi fwrite Int64 ptr (IToI64 1) (IToI64 n) stream' () @@ -39,8 +39,9 @@ def fread (stream:Stream ReadMode) : {State World} String = (MkStream stream') = stream -- TODO: allow reading longer files! n = 4096 - withAlloc n \ptr. - numRead = I64ToI $ %ffi fread Int64 ptr (IToI64 1) (IToI64 n) stream' + withAlloc n \ptr:(Ptr Char). + (MkPtr rawPtr) = ptr + numRead = I64ToI $ %ffi fread Int64 rawPtr (IToI64 1) (IToI64 n) stream' stringFromCharPtr numRead ptr def deleteFile (f:FilePath) : {State World} Unit = @@ -67,7 +68,7 @@ def writeTemp (s:String) : {State World} FilePath = -- file with same name after we ask for the name and before we create it. withCString (AsList _ "/tmp/dex-XXXXXX") \(MkCString ptr). %ffi mktemp CharPtr ptr - stringFromCharPtr 15 ptr + stringFromCharPtr 15 (MkPtr ptr) def withTempFile (action: FilePath -> {State World} a) : {State World} a = tmpFile = writeTemp (AsList _ []) diff --git a/lib/png.dx b/lib/png.dx index d131f2bf1..75033b535 100644 --- a/lib/png.dx +++ b/lib/png.dx @@ -96,8 +96,9 @@ Html : Type = List Char def makePNG (img:n=>m=>(Fin 3)=>Word8) : List Byte = unsafeIO \(). (AsList _ imgFlat) = toList for (i,j,k). img.i.j.k withTabPtr imgFlat \ptr. - (n, ptr) = %ffi encodePNG (Int & CharPtr) ptr (size m) (size n) - AsList n $ tabFromCharPtr ptr + (MkPtr rawPtr) = ptr + (n, ptr') = %ffi encodePNG (Int & CharPtr) rawPtr (size m) (size n) + AsList n $ tabFromPtr $ MkPtr ptr' def pngToHtml (png:List Byte) : List Char = (toList " : Ord (Fin n) = '## Raw pointer operations -def Ptr (ty:Type) : Type = %makePtrType ty +Int32Ptr : Type = %Int32Ptr +Word8Ptr : Type = %Word8Ptr -CharPtr : Type = %CharPtr +CharPtr = Word8Ptr --- TODO: generalize these to other pointer types (Storable type class?) -def malloc (n:Int) : {State World} CharPtr = %charAlloc n -def free (ptr:CharPtr) : {State World} Unit = %charFree ptr -def ptrStore (ptr:CharPtr) (x:Char) : {State World} Unit = %ptrStore ptr x -def ptrLoad (ptr:CharPtr) : {State World} Char = %ptrLoad ptr -def ptrOffset (ptr:CharPtr) (i:Int) : CharPtr = %ptrOffset ptr i +data Ptr a:Type = MkPtr Word8Ptr + +-- Is there a better way to select the right instance for `storageSize`?? +data TypeVehicle a:Type = MkTypeVehicle + +interface Storable a:Type where + store : Ptr a -> a -> {State World} Unit + load : Ptr a -> {State World} a + storageSize : TypeVehicle a -> Int + +-- TODO: there's a bug preventing us inlining these definitions into the instance +def charStore ((MkPtr ptr): Ptr Word8) (x:Word8) : {State World} Unit = %ptrStore ptr x +def charLoad ((MkPtr ptr): Ptr Word8) : {State World} Word8 = %ptrLoad ptr + +instance charStorable : Storable Word8 where + store = charStore + load = charLoad + storageSize = const 1 + +-- TODO: there's a bug preventing us inlining these definitions into the instance +def int32Store ((MkPtr ptr): Ptr Int32) (x:Int32) : {State World} Unit = + %ptrStore (internalCast %Int32Ptr ptr) x +def int32Load ((MkPtr ptr): Ptr Int32) : {State World} Int32 = + %ptrLoad (internalCast %Int32Ptr ptr) + +instance int32Storable : Storable Int32 where + store = int32Store + load = int32Load + storageSize = const 4 + +-- TODO: Storable instances for other types + +def malloc (_:Storable a) ?=> (n:Int) : {State World} (Ptr a) = + typeVehicle : TypeVehicle a = MkTypeVehicle + numBytes = storageSize typeVehicle * n + MkPtr $ %charAlloc numBytes + +def free (ptr:Ptr a) : {State World} Unit = + (MkPtr ptr') = ptr + %charFree ptr' + +def (+>>) (_:Storable a) ?=> (ptr:Ptr a) (i:Int) : Ptr a = + typeVehicle : TypeVehicle a = MkTypeVehicle + (MkPtr ptr') = ptr + i' = i * storageSize typeVehicle + MkPtr $ %ptrOffset ptr' i' -- TODO: generalize these brackets to allow other effects -def withAlloc (n:Int) (action: CharPtr -> {State World} a) : {State World} a = +def withAlloc (_:Storable a) ?=> + (n:Int) (action: Ptr a -> {State World} b) : {State World} b = ptr = malloc n result = action ptr free ptr result -def withTabPtr (xs:n=>Char) (action : CharPtr -> {State World} a) : {State World} a = +def withTabPtr (_:Storable a) ?=> + (xs:n=>a) (action : Ptr a -> {State World} b) : {State World} b = withAlloc (size n) \ptr. - for i. ptrStore (ptrOffset ptr (ordinal i)) xs.i + for i. store (ptr +>> ordinal i) xs.i action ptr -def tabFromCharPtr (ptr:CharPtr) : {State World} n=>Char = - for i. ptrLoad $ ptrOffset ptr (ordinal i) +def tabFromPtr (_:Storable a) ?=> (ptr:Ptr a) : {State World} n=>a = + for i. load $ ptr +>> ordinal i 'Misc @@ -893,8 +936,8 @@ def sliceFields (iso: Iso ({|} | a) (b | c)) (tab: a=>v) : b=>v = String : Type = List Char -def stringFromCharPtr (n:Int) (ptr:CharPtr) : {State World} String= - AsList n $ tabFromCharPtr ptr +def stringFromCharPtr (n:Int) (ptr:Ptr Char) : {State World} String= + AsList n $ tabFromPtr ptr -- TODO. This is ASCII code point. It really should be Int32 for Unicode codepoint def codepoint (c:Char) : Int = W8ToI c @@ -905,22 +948,22 @@ interface Show a:Type where instance showInt32 : Show Int32 where show = \x: Int32. unsafeIO \(). (n, ptr) = %ffi showInt32 (Int32 & CharPtr) x - stringFromCharPtr n ptr + stringFromCharPtr n $ MkPtr ptr instance showInt64 : Show Int64 where show = \x: Int64. unsafeIO \(). (n, ptr) = %ffi showInt64 (Int32 & CharPtr) x - stringFromCharPtr n ptr + stringFromCharPtr n $ MkPtr ptr instance showFloat32 : Show Float32 where show = \x: Float32.unsafeIO \(). (n, ptr) = %ffi showFloat32 (Int32 & CharPtr) x - stringFromCharPtr n ptr + stringFromCharPtr n $ MkPtr ptr instance showFloat64 : Show Float64 where show = \x: Float64.unsafeIO \(). (n, ptr) = %ffi showFloat64 (Int32 & CharPtr) x - stringFromCharPtr n ptr + stringFromCharPtr n $ MkPtr ptr -- def writeStdErr (s:String) : {State World} Unit = -- (AsList n cs) = s diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index 5a09e6af9..e2c9cbd0e 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -1211,12 +1211,12 @@ instrTypeChecked instr = case instr of mapM_ checkIExpr args IPrimOp op -> (:[]) <$> checkImpOp op ICastOp dt x -> (:[]) <$> do - case getIType x of - Scalar _ -> return () - _ -> throw CompilerErr $ "Invalid cast source type: " ++ pprint dt - case dt of - Scalar _ -> return () - _ -> throw CompilerErr $ "Invalid cast destination type: " ++ pprint dt + let st = getIType x + case (dt, st) of + (PtrType _, PtrType _) -> return () + (Scalar _, Scalar _) -> return () + _ -> throw CompilerErr $ + "Can't cast " ++ pprint st ++ " to " ++ pprint dt return dt Alloc a ty _ -> (:[]) <$> do when (a /= Stack) assertHost diff --git a/src/lib/JIT.hs b/src/lib/JIT.hs index a7776b1da..0a624bc7e 100644 --- a/src/lib/JIT.hs +++ b/src/lib/JIT.hs @@ -301,6 +301,7 @@ compileInstr instr = case instr of GT -> emitInstr dt $ L.FPTrunc x dt [] (L.FloatingPointType _, L.IntegerType _) -> emitInstr dt $ L.FPToSI x dt [] (L.IntegerType _, L.FloatingPointType _) -> emitInstr dt $ L.SIToFP x dt [] + (L.PointerType _ _, L.PointerType eltTy _) -> castLPtr eltTy x _ -> error $ "Unsupported cast" ICall f@(fname:> IFunType cc argTys resultTys) args -> do -- TODO: consider having a separate calling convention specification rather diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index c29cf0480..9ad7b9c99 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -1532,6 +1532,8 @@ builtinNames = M.fromList , ("Int64" , TCExpr $ BaseType $ Scalar Int64Type) , ("Int32" , TCExpr $ BaseType $ Scalar Int32Type) , ("Word8" , TCExpr $ BaseType $ Scalar Word8Type) + , ("Int32Ptr", ptrTy Int32Type) + , ("Word8Ptr", ptrTy Word8Type) , ("IntRange", TCExpr $ IntRange () ()) , ("Ref" , TCExpr $ RefType (Just ()) ()) , ("PairType", TCExpr $ PairType () ()) @@ -1555,7 +1557,6 @@ builtinNames = M.fromList , ("ptrLoad" , OpExpr $ PtrLoad ()) , ("ptrStore" , OpExpr $ PtrStore () ()) , ("makePtrType", OpExpr $ MakePtrType ()) - , ("CharPtr" , ptrTy Word8Type) , ("dataConTag", OpExpr $ DataConTag ()) , ("toEnum" , OpExpr $ ToEnum () ()) ] diff --git a/src/lib/Type.hs b/src/lib/Type.hs index 43d20d6a3..6237074b6 100644 --- a/src/lib/Type.hs +++ b/src/lib/Type.hs @@ -640,7 +640,9 @@ checkFloatBaseType allowVector t = case t of "floating-point type, but found: " ++ pprint t checkValidCast :: Type -> Type -> TypeM () -checkValidCast sourceTy destTy = checkScalarType sourceTy >> checkScalarType destTy +checkValidCast (BaseTy (PtrType _)) (BaseTy (PtrType _)) = return () +checkValidCast sourceTy destTy = + checkScalarType sourceTy >> checkScalarType destTy where checkScalarType ty = case ty of BaseTy (Scalar Int64Type ) -> return () @@ -648,6 +650,7 @@ checkValidCast sourceTy destTy = checkScalarType sourceTy >> checkScalarType des BaseTy (Scalar Word8Type ) -> return () BaseTy (Scalar Float64Type) -> return () BaseTy (Scalar Float32Type) -> return () + _ -> throw TypeErr $ "Can't cast " ++ pprint sourceTy ++ " to " ++ pprint destTy typeCheckOp :: Op -> TypeM Type diff --git a/tests/io-tests.dx b/tests/io-tests.dx index db26bae5a..6d3921ccd 100644 --- a/tests/io-tests.dx +++ b/tests/io-tests.dx @@ -10,3 +10,11 @@ include "io.dx" > (AsList 27 "lorem ipsum > dolor sit amet > ") + + +:p unsafeIO \(). + withAlloc 4 \ptr:(Ptr Int). + for i:(Fin 4). store (ptr +>> ordinal i) (ordinal i) + result : Fin 4 => Int = tabFromPtr ptr + result +> [0, 1, 2, 3] From 067ba2f8306a94d4046ba79ddb53a2604b0d11ae Mon Sep 17 00:00:00 2001 From: David Duvenaud Date: Mon, 21 Dec 2020 17:06:13 -0500 Subject: [PATCH 016/105] Fix atan2 bug and add more tests. --- lib/prelude.dx | 2 +- tests/trig-tests.dx | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/lib/prelude.dx b/lib/prelude.dx index 2b05011bc..e90b538e7 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -956,7 +956,7 @@ def atan2 (y:Float) (x:Float) : Float = (min_abs_x_y, max_abs_x_y) = min_and_max abs_x abs_y a = atan_inner (min_abs_x_y / max_abs_x_y) a = select (abs_x <= abs_y) ((pi / 2.0) -a) a - a = select (x < 0.0) pi a + a = select (x < 0.0) (pi - a) a t = select (x < 0.0) pi 0.0 a = select (y == 0.0) t a t = select (x < 0.0) (3.0 * pi / 4.0) (pi / 4.0) diff --git a/tests/trig-tests.dx b/tests/trig-tests.dx index c8bd9dcf9..abd6776ae 100644 --- a/tests/trig-tests.dx +++ b/tests/trig-tests.dx @@ -20,6 +20,14 @@ > True :p atan2 (-sin (-0.44)) (cos (-0.44)) ~~ (0.44) > True +:p atan2 (-1.0) (-1.0) ~~ (-3.0/4.0*pi) +> True + +-- Test all the way around the circle. +angles = linspace (Fin 11) (-pi + 0.001) (pi) +:p all for i:(Fin 11). + angles.i ~~ atan2 (sin angles.i) (cos angles.i) +> True :p (atan2 infinity 1.0) ~~ ( pi / 2.0) > True From fc825c119d7a5cc45c1abd0263cfe2234390161e Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 21 Dec 2020 14:59:41 +0000 Subject: [PATCH 017/105] Automate wrapping of JITed function pointers in Python callables This patch extends the export functionality with the ability to generate a description of how the user-facing arguments and results map to the exported native function. Then, I created a mini-language that can faithfully encode this description, and is exposed in the foreign API via the `dexGetFunctionSignature` function. On the Python side, I have implemented a parser for that language, which turns it into a few `NativeType` objects that describe how the arguments are to be (de)serialized when interfacing with the native function. This completely automates the manual labor of creating ctype wrappers and allows used to call the compiled function as if it was any other Python callable. Note that this interface also supports zero-copy bidirectional conversion of NumPy arrays and Dex tables with `Fin` index sets. For example: ```py import dex import numpy as np from textwrap import dedent sigmoid = dex.eval(r"\x:Float. 1.0 / (1.0 + exp(-x))").compile() print(sigmoid(-1.0), sigmoid(0.0), sigmoid(1.0)) transpose = dex.Module(dedent(""" def myTranspose (n: Int) ?-> (m: Int) ?-> (x : (Fin n)=>(Fin m)=>Float) : (Fin m)=>(Fin n)=>Float = for i j. x.j.i """)).myTranspose.compile() example = np.arange(25, dtype=np.float32).reshape((5, 5)) print(transpose(example)) # NB: Implicit arguments get inferred from shapes print(example.T) ``` --- dex.cabal | 2 +- python/dex/__init__.py | 140 +++++---------------- python/dex/api.py | 98 +++++++++++++++ python/dex/native_function.py | 221 +++++++++++++++++++++++++++++++++ python/tests/api_test.py | 20 --- python/tests/jit_test.py | 79 ++++++++++++ src/Dex/Foreign/API.hs | 2 + src/Dex/Foreign/JIT.hs | 65 +++++++--- src/Dex/Foreign/Serialize.hs | 1 - src/Dex/Foreign/Util.hs | 12 +- src/dex.hs | 5 +- src/lib/Export.hs | 226 ++++++++++++++++++++++++++++++++++ src/lib/LLVM/JIT.hs | 2 +- src/lib/Syntax.hs | 5 +- src/lib/TopLevel.hs | 103 +--------------- 15 files changed, 728 insertions(+), 253 deletions(-) create mode 100644 python/dex/api.py create mode 100644 python/dex/native_function.py create mode 100644 python/tests/jit_test.py create mode 100644 src/lib/Export.hs diff --git a/dex.cabal b/dex.cabal index 8ecdef23b..ebfc3d3ee 100644 --- a/dex.cabal +++ b/dex.cabal @@ -32,7 +32,7 @@ library exposed-modules: Env, Syntax, Type, Inference, JIT, LLVMExec, Parser, Util, Imp, Imp.Embed, Imp.Optimize, PPrint, Algebra, Parallelize, Optimize, Serialize - Actor, Cat, Flops, Embed, + Actor, Cat, Flops, Embed, Export, RenderHtml, LiveOutput, Simplify, TopLevel, Autodiff, Interpreter, Logging, PipeRPC, CUDA, LLVM.JIT, LLVM.Shims diff --git a/python/dex/__init__.py b/python/dex/__init__.py index e60ffaea6..84da18d93 100644 --- a/python/dex/__init__.py +++ b/python/dex/__init__.py @@ -5,117 +5,43 @@ # https://developers.google.com/open-source/licenses/bsd import itertools as it +import sys import ctypes -import pathlib -import atexit -from enum import Enum -from typing import List - -__all__ = ['execute'] - -here = pathlib.Path(__file__).parent.absolute() - -lib = ctypes.cdll.LoadLibrary(here / 'libDex.so') - -def tagged_union(name: str, members: List[type]): - named_members = [(f"t{i}", member) for i, member in enumerate(members)] - payload = type(name + "Payload", (ctypes.Union,), {"_fields_": named_members}) - union = type(name, (ctypes.Structure,), { - "_fields_": [("tag", ctypes.c_uint64), ("payload", payload)], - "value": property(lambda self: getattr(self.payload, f"t{self.tag}")), - }) - return union - -CLit = tagged_union("Lit", [ctypes.c_int64, ctypes.c_int32, ctypes.c_int8, ctypes.c_double, ctypes.c_float]) -class CRectArray(ctypes.Structure): - _fields_ = [("data", ctypes.c_void_p), - ("shape_ptr", ctypes.POINTER(ctypes.c_int64)), - ("strides_ptr", ctypes.POINTER(ctypes.c_int64))] -CAtom = tagged_union("CAtom", [CLit, CRectArray]) -assert ctypes.sizeof(CAtom) == 4 * 8 - -class HsAtom(ctypes.Structure): pass -class HsContext(ctypes.Structure): pass -class HsJIT(ctypes.Structure): pass -class NativeFunctionObj(ctypes.Structure): pass - -HsAtomPtr = ctypes.POINTER(HsAtom) -HsContextPtr = ctypes.POINTER(HsContext) -HsJITPtr = ctypes.POINTER(HsJIT) -CAtomPtr = ctypes.POINTER(CAtom) -NativeFunction = ctypes.POINTER(NativeFunctionObj) - -def _dex_func(name, *signature): - argtypes, restype = signature[:-1], signature[-1] - f = getattr(lib, name) - f.restype = restype - f.argtypes = argtypes - return f - -_init = _dex_func('dexInit', None) -_fini = _dex_func('dexFini', None) -_getError = _dex_func('dexGetError', ctypes.c_char_p) - -_create_context = _dex_func('dexCreateContext', HsContextPtr) -_destroy_context = _dex_func('dexDestroyContext', HsContextPtr, None) - -_eval = _dex_func('dexEval', HsContextPtr, ctypes.c_char_p, HsContextPtr) -_insert = _dex_func('dexInsert', HsContextPtr, ctypes.c_char_p, HsAtomPtr, HsContextPtr) -_evalExpr = _dex_func('dexEvalExpr', HsContextPtr, ctypes.c_char_p, HsAtomPtr) -_lookup = _dex_func('dexLookup', HsContextPtr, ctypes.c_char_p, HsAtomPtr) - -_print = _dex_func('dexPrint', HsAtomPtr, ctypes.c_char_p) -_toCAtom = _dex_func('dexToCAtom', HsAtomPtr, CAtomPtr, ctypes.c_int) - -_createJIT = _dex_func('dexCreateJIT', HsJITPtr) -_destroyJIT = _dex_func('dexDestroyJIT', HsJITPtr, None) -_compile = _dex_func('dexCompile', HsJITPtr, HsContextPtr, HsAtomPtr, NativeFunction) -_unload = _dex_func('dexUnload', HsJITPtr, NativeFunction, None) - -_init() -_jit = _createJIT() -_nofree = False -@atexit.register -def _teardown(): - global _nofree - _destroyJIT(_jit) - _fini() - _nofree = True # Don't destruct any Haskell objects after the RTS has been shutdown - - -def _as_cstr(x: str): - return ctypes.c_char_p(x.encode('ascii')) - -def _from_cstr(cx): - return cx.decode('ascii') +from typing import Any, List, Union +from . import api +from .native_function import NativeFunction +__all__ = [ + 'Module', + 'eval', +] class Module: __slots__ = ('_as_parameter_',) def __init__(self, source): - self._as_parameter_ = _eval(prelude, _as_cstr(source)) + self._as_parameter_ = api.eval(prelude, api.as_cstr(source)) if not self._as_parameter_: - raise RuntimeError(_from_cstr(_getError())) + api.raise_from_dex() def __del__(self): - if _nofree: + if api.nofree: return - _destroy_context(self) + api.destroyContext(self) def __getattr__(self, name): - result = _lookup(self, _as_cstr(name)) + result = api.lookup(self, api.as_cstr(name)) if not result: - raise RuntimeError(_from_cstr(_getError())) + api.raise_from_dex() return Atom(result, self) class Prelude(Module): __slots__ = () def __init__(self): - self._as_parameter_ = _create_context() + self._as_parameter_ = api.createContext() if not self._as_parameter_: - raise RuntimeError("Failed to initialize prelude!") + api.raise_from_dex() prelude = Prelude() @@ -123,9 +49,9 @@ def __init__(self): def eval(expr: str, module=prelude, _env=None): if _env is None: _env = module - result = _evalExpr(_env, _as_cstr(expr)) + result = api.evalExpr(_env, api.as_cstr(expr)) if not result: - raise RuntimeError(_from_cstr(_getError())) + api.raise_from_dex() return Atom(result, module) @@ -142,7 +68,7 @@ def __del__(self): def __repr__(self): # TODO: Free! - return _print(self).decode('ascii') + return api.from_cstr(api.print(self)) def __int__(self): return int(self._as_scalar()) @@ -151,12 +77,12 @@ def __float__(self): return float(self._as_scalar()) def _as_scalar(self): - result = CAtom() - success = _toCAtom(self, ctypes.pointer(result)) + result = api.CAtom() + success = api.toCAtom(self, ctypes.pointer(result)) if not success: - raise RuntimeError(_from_cstr(_getError())) + api.raise_from_dex() value = result.value - if not isinstance(value, CLit): + if not isinstance(value, api.CLit): raise TypeError("Atom is not a scalar value") return value.value @@ -167,22 +93,12 @@ def __call__(self, *args): # NB: Atoms can contain arbitrary references if atom.module is not prelude and atom.module is not self.module: raise RuntimeError("Mixing atoms coming from different Dex modules is not supported yet!") - old_env, env = env, _insert(env, _as_cstr(f"python_arg{i}"), atom) - _destroy_context(old_env) + old_env, env = env, api.insert(env, api.as_cstr(f"python_arg{i}"), atom) + api.destroyContext(old_env) return eval(" ".join(f"python_arg{i}" for i in range(len(args) + 1)), module=self.module, _env=env) def compile(self): - func_ptr = _compile(_jit, self.module, self) + func_ptr = api.compile(api.jit, self.module, self) if not func_ptr: - raise RuntimeError("Failed to JIT-compile a Dex function") - return NativeFunction(func_ptr) - - -class NativeFunction: - def __init__(self, ptr): - self._as_parameter_ = ptr - self.ptr = ptr - - def __del__(self): - if _nofree: return - _unload(_jit, self) + api.raise_from_dex() + return NativeFunction(api.jit, func_ptr) diff --git a/python/dex/api.py b/python/dex/api.py new file mode 100644 index 000000000..fcd881697 --- /dev/null +++ b/python/dex/api.py @@ -0,0 +1,98 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +import ctypes +import pathlib +import atexit +from typing import List + +here = pathlib.Path(__file__).parent.absolute() + +lib = ctypes.cdll.LoadLibrary(here / 'libDex.so') + +def tagged_union(name: str, members: List[type]): + named_members = [(f"t{i}", member) for i, member in enumerate(members)] + payload = type(name + "Payload", (ctypes.Union,), {"_fields_": named_members}) + union = type(name, (ctypes.Structure,), { + "_fields_": [("tag", ctypes.c_uint64), ("payload", payload)], + "value": property(lambda self: getattr(self.payload, f"t{self.tag}")), + }) + return union + +CLit = tagged_union("Lit", [ctypes.c_int64, ctypes.c_int32, ctypes.c_int8, ctypes.c_double, ctypes.c_float]) +class CRectArray(ctypes.Structure): + _fields_ = [("data", ctypes.c_void_p), + ("shape_ptr", ctypes.POINTER(ctypes.c_int64)), + ("strides_ptr", ctypes.POINTER(ctypes.c_int64))] +CAtom = tagged_union("CAtom", [CLit, CRectArray]) +assert ctypes.sizeof(CAtom) == 4 * 8 + +class HsAtom(ctypes.Structure): pass +class HsContext(ctypes.Structure): pass +class HsJIT(ctypes.Structure): pass +class NativeFunctionObj(ctypes.Structure): pass +class NativeFunctionSignature(ctypes.Structure): + _fields_ = [("arg", ctypes.c_char_p), + ("res", ctypes.c_char_p), + ("ccall", ctypes.c_char_p)] + + +HsAtomPtr = ctypes.POINTER(HsAtom) +HsContextPtr = ctypes.POINTER(HsContext) +HsJITPtr = ctypes.POINTER(HsJIT) +CAtomPtr = ctypes.POINTER(CAtom) +NativeFunctionSignaturePtr = ctypes.POINTER(NativeFunctionSignature) +NativeFunction = ctypes.POINTER(NativeFunctionObj) + +def dex_func(name, *signature): + argtypes, restype = signature[:-1], signature[-1] + f = getattr(lib, name) + f.restype = restype + f.argtypes = argtypes + return f + +init = dex_func('dexInit', None) +fini = dex_func('dexFini', None) +getError = dex_func('dexGetError', ctypes.c_char_p) + +createContext = dex_func('dexCreateContext', HsContextPtr) +destroyContext = dex_func('dexDestroyContext', HsContextPtr, None) + +eval = dex_func('dexEval', HsContextPtr, ctypes.c_char_p, HsContextPtr) +insert = dex_func('dexInsert', HsContextPtr, ctypes.c_char_p, HsAtomPtr, HsContextPtr) +evalExpr = dex_func('dexEvalExpr', HsContextPtr, ctypes.c_char_p, HsAtomPtr) +lookup = dex_func('dexLookup', HsContextPtr, ctypes.c_char_p, HsAtomPtr) + +print = dex_func('dexPrint', HsAtomPtr, ctypes.c_char_p) +toCAtom = dex_func('dexToCAtom', HsAtomPtr, CAtomPtr, ctypes.c_int) + +createJIT = dex_func('dexCreateJIT', HsJITPtr) +destroyJIT = dex_func('dexDestroyJIT', HsJITPtr, None) +compile = dex_func('dexCompile', HsJITPtr, HsContextPtr, HsAtomPtr, NativeFunction) +unload = dex_func('dexUnload', HsJITPtr, NativeFunction, None) + +getFunctionSignature = dex_func('dexGetFunctionSignature', HsJITPtr, NativeFunction, NativeFunctionSignaturePtr) +freeFunctionSignature = dex_func('dexFreeFunctionSignature', NativeFunctionSignaturePtr) + +init() +jit = createJIT() +nofree = False +@atexit.register +def _teardown(): + global nofree + destroyJIT(jit) + fini() + nofree = True # Don't destruct any Haskell objects after the RTS has been shutdown + + +def as_cstr(x: str): + return ctypes.c_char_p(x.encode('ascii')) + +def from_cstr(cx): + return cx.decode('ascii') + +def raise_from_dex(): + raise RuntimeError(from_cstr(getError())) diff --git a/python/dex/native_function.py b/python/dex/native_function.py new file mode 100644 index 000000000..2277d6f8c --- /dev/null +++ b/python/dex/native_function.py @@ -0,0 +1,221 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +import sys +import ctypes +import string +import numpy as np +from typing import Any, List, Union, Callable, Dict +from dataclasses import dataclass +from . import api + +ScalarCType = Union[ + ctypes.c_int64, ctypes.c_int32, + ctypes.c_uint8, + ctypes.c_double, ctypes.c_float +] +IdxRepTy = ctypes.c_int32 + +@dataclass(frozen=True) +class ScalarType: + ctype: Any + from_ctype: Callable + + @property + def arg_ctype(self): return self.ctype + + @property + def ref_ctype(self): return ctypes.POINTER(self.ctype) + + def to_ctype(self, value, name_cvalue): + return self.ctype(value) + + def create(self, name_cvalue): + instance = self.ctype() + return ctypes.pointer(instance), lambda: self.from_ctype(instance) + + +@dataclass(frozen=True) +class RectContArrayType: + ctype: ScalarType + shape: List[Union[str, int]] + + @property + def arg_ctype(self): + return ctypes.POINTER(self.ctype) + + @property + def ref_ctype(self): + return ctypes.POINTER(self.ctype) + + def unsafe_array_ptr(self, array): + ptr, _ = array.__array_interface__['data'] + return ctypes.cast(ctypes.c_void_p(ptr), ctypes.POINTER(self.ctype)) + + def to_ctype(self, array, name_cvalue): + if not isinstance(array, np.ndarray): + raise TypeError("Expected a NumPy ndarray for an array argument") + if array.ndim != len(self.shape): + raise ValueError(f"Expected a {len(self.shape)}D array, got {array.ndim}D") + expected_dtype = np.dtype(self.ctype) + if array.dtype != expected_dtype: + raise ValueError(f"Expected a {expected_dtype} array, got {array.dtype}") + expected_shape = tuple( + size if isinstance(size, int) else name_cvalue.setdefault(size, IdxRepTy(real_size)).value + for size, real_size in zip(self.shape, array.shape)) + if expected_shape != array.shape: + raise ValueError(f"Shape mismatch: expected {expected_shape}, but got {array.shape}") + if not array.flags['C_CONTIGUOUS']: + raise ValueError("Only contiguous arrays supported as arguments at the moment") + return self.unsafe_array_ptr(array) + + def create(self, name_cvalue): + shape = [size if isinstance(size, int) else name_cvalue[size].value + for size in self.shape] + result = np.empty(shape, dtype=self.ctype) + return self.unsafe_array_ptr(result), lambda: result + +NativeType = Union[ScalarType, RectContArrayType] + + +@dataclass(frozen=True) +class Binder: + name: str + type: NativeType + implicit: bool + + +class NativeFunction: + def __init__(self, jit, ptr): + self._as_parameter_ = ptr + self._jit = jit + sig_ptr = api.getFunctionSignature(jit, ptr) + if not sig_ptr: + raise RuntimeError("Failed to retrieve the function signature") + try: + signature = sig_ptr.contents + self.argument_signature = _SignatureParser(signature.arg).parse() + self.explicit_argument_signature = [arg for arg in self.argument_signature if not arg.implicit] + self.result_signature = _SignatureParser(signature.res).parse() + self.ccall_signature = [sys.intern(arg.decode('ascii')) for arg in signature.ccall.split(b',')] + finally: + api.freeFunctionSignature(sig_ptr) + + func_type = ctypes.CFUNCTYPE( + ctypes.c_int64, + *(arg.type.arg_ctype for arg in self.argument_signature), + *(res.type.ref_ctype for res in self.result_signature)) + self.callable = func_type(ctypes.cast(ptr, ctypes.c_void_p).value) + + def __del__(self): + if api.nofree: return + if hasattr(self, '_as_parameter_'): + api.unload(self._jit, self) + + def __call__(self, *args): + name_to_cval = {} + result_thunks = [] + assert len(self.explicit_argument_signature) == len(args) + for arg, binder in zip(args, self.explicit_argument_signature): + name_to_cval[binder.name] = binder.type.to_ctype(arg, name_to_cval) + for binder in self.result_signature: + value, result_thunk = binder.type.create(name_to_cval) + name_to_cval[binder.name] = value + result_thunks.append(result_thunk) + self.callable(*(name_to_cval[name] for name in self.ccall_signature)) + results = tuple(thunk() for thunk in result_thunks) + if len(results) == 1: + return results[0] + else: + return results + + +class _SignatureParser: + __slots__ = ('text', 'offset') + + def __init__(self, text): + self.text = text + + def consume(self, char: str): + assert self.text[self.offset] == ord(char) + self.offset += 1 + + def maybe_consume(self, char: str) -> bool: + if self.offset < len(self.text) and self.text[self.offset] == ord(char): + self.offset += 1 + return True + return False + + digit_codes = set(string.digits.encode('ascii')) + name_codes = set(string.ascii_letters.encode('ascii')) | digit_codes + + def parse_name(self) -> str: + end = self.offset + name_codes = self.name_codes + text = self.text + while text[end] in name_codes: + end += 1 + result = sys.intern(self.text[self.offset:end].decode('ascii')) + self.offset = end + return result + + scalar_types: Dict[bytes, ScalarType] = { + b'i64': ScalarType(ctypes.c_int64, np.int64), + b'i32': ScalarType(ctypes.c_int32, np.int32), + b'u8': ScalarType(ctypes.c_uint8, np.uint8), + b'f64': ScalarType(ctypes.c_double, np.float64), + b'f32': ScalarType(ctypes.c_float, np.float32), + } + + def parse_type(self) -> NativeType: + for name, scalar_type in self.scalar_types.items(): + if self.text.startswith(name, self.offset): + break + else: + raise RuntimeError(f"Invalid type specification: {sig[self.offset:self.offset+3].decode('ascii')}") + self.offset += len(name) + if self.maybe_consume('['): + if self.maybe_consume('?'): + raise RuntimeError("Only rectangular array types supported") + shape = [] + while True: + shape.append(self.parse_dim()) + if self.maybe_consume(']'): + break + else: + self.consume(',') + return RectContArrayType(scalar_type.ctype, shape) + else: + return scalar_type + + def parse_dim(self): + if self.text[self.offset] in self.digit_codes: + return self.parse_number() + else: + return self.parse_name() + + def parse_number(self) -> int: + end = self.offset + while self.text[end] in self.digit_codes: + end += 1 + result = int(self.text[self.offset:end].decode('ascii')) + self.offset = end + return result + + def parse(self): + self.offset = 0 + binders = [] + while True: + implicit = self.maybe_consume('?') + name = self.parse_name() + self.consume(':') + ty = self.parse_type() + binders.append(Binder(name, ty, implicit)) + if self.offset == len(self.text): + break + else: + self.consume(',') + return binders diff --git a/python/tests/api_test.py b/python/tests/api_test.py index 4b55d7b92..5b5606967 100644 --- a/python/tests/api_test.py +++ b/python/tests/api_test.py @@ -45,23 +45,3 @@ def addOne (x: Float) : Float = x + 1.0 def test_scalar_conversions(): assert float(dex.eval("5.0")) == 5.0 assert int(dex.eval("5")) == 5 - -def test_jit(): - m = dex.eval(r"\x:Float. 1.0 / (1.0 + exp(-x))") - native_func = m.compile() - func_ptr = ctypes.cast(native_func.ptr, ctypes.c_void_p).value - signature = ctypes.CFUNCTYPE(ctypes.c_int64, ctypes.c_float, ctypes.POINTER(ctypes.c_float)) - func = signature(func_ptr) - - def dex_sigmoid(x): - res = ctypes.c_float() - has_error = func(x, ctypes.pointer(res)) - assert not has_error - return res.value - - one = np.float32(1.0) - def py_sigmoid(x): return one / (one + np.exp(-x)) - - for value in map(np.float32, (-1.0, -0.5, 0.0, 0.5, 1.0)): - np.testing.assert_allclose(dex_sigmoid(value), py_sigmoid(value), - rtol=1e-4, atol=1e-6) diff --git a/python/tests/jit_test.py b/python/tests/jit_test.py new file mode 100644 index 000000000..a0594cd34 --- /dev/null +++ b/python/tests/jit_test.py @@ -0,0 +1,79 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +import unittest +import ctypes +import numpy as np +import itertools as it +from textwrap import dedent + +# TODO: Write a proper setup.py instead of using this hack... +from pathlib import Path +import sys +sys.path.append(str(Path(__file__).parent.parent)) + +import dex + +example_floats = list(map(np.float32, (-1.0, -0.5, 0.0, 0.5, 1.0))) +example_ints = [-10, -5, 0, 5, 10] + +def check_atom(dex_atom, reference, args_iter): + compiled = dex_atom.compile() + ran_any_iter = False + for args in args_iter: + ran_any_iter = True + print(args) + np.testing.assert_allclose(compiled(*args), reference(*args), + rtol=1e-4, atol=1e-6) + assert ran_any_iter, "Empty argument iterator!" + +def expr_test(dex_source, reference, args_iter): + def test(): + return check_atom(dex.eval(dex_source), reference, args_iter) + return test + +test_sigmoid = expr_test(r"\x:Float. 1.0 / (1.0 + exp(-x))", + lambda x: np.float32(1.0) / (np.float32(1.0) + np.exp(-x)), + ((x,) for x in example_floats)) + +test_multi_arg = expr_test(r"\x:Float y:Float. atan2 x y", + np.arctan2, + ((x + 0.01, y) for x, y in it.product(example_floats, repeat=2) + if (x, y) != (0.0, 0.0))) + +test_int_arg = expr_test(r"\x:Int64 y:Int. I64ToI x + y", + lambda x, y: x + y, + it.product(example_ints, example_ints)) + +test_array_scalar = expr_test(r"\x:((Fin 10)=>Float). sum x", + np.sum, + [(np.arange(10, dtype=np.float32),)]) + +test_scalar_array = expr_test(r"\x:Int. for i:(Fin 10). x + ordinal i", + lambda x: x + np.arange(10, dtype=np.int32), + [(i,) for i in range(5)]) + +test_array_array = expr_test(r"\x:((Fin 10)=>Float). for i. exp x.i", + np.exp, + [(np.arange(10, dtype=np.float32),)]) + +def test_polymorphic_array_1d(): + m = dex.Module(dedent(""" + def addTwo (n: Int) ?-> (x: (Fin n)=>Float) : (Fin n)=>Float = for i. x.i + 2.0 + """)) + check_atom(m.addTwo, lambda x: x + 2, + [(np.arange(l, dtype=np.float32),) for l in (2, 5, 10)]) + +def test_polymorphic_array_2d(): + m = dex.Module(dedent(""" + def myTranspose (n: Int) ?-> (m: Int) ?-> + (x : (Fin n)=>(Fin m)=>Float) : (Fin m)=>(Fin n)=>Float = + for i j. x.j.i + """)) + check_atom(m.myTranspose, lambda x: x.T, + [(np.arange(a*b, dtype=np.float32).reshape((a, b)),) + for a, b in it.product((2, 5, 10), repeat=2)]) + diff --git a/src/Dex/Foreign/API.hs b/src/Dex/Foreign/API.hs index 7a284eb26..f6c8349c7 100644 --- a/src/Dex/Foreign/API.hs +++ b/src/Dex/Foreign/API.hs @@ -39,3 +39,5 @@ foreign export ccall "dexCreateJIT" dexCreateJIT :: IO (Ptr JIT) foreign export ccall "dexDestroyJIT" dexDestroyJIT :: Ptr JIT -> IO () foreign export ccall "dexCompile" dexCompile :: Ptr JIT -> Ptr Context -> Ptr Atom -> IO (Ptr NativeFunction) foreign export ccall "dexUnload" dexUnload :: Ptr JIT -> Ptr NativeFunction -> IO () +foreign export ccall "dexGetFunctionSignature" dexGetFunctionSignature :: Ptr JIT -> Ptr NativeFunction -> IO (Ptr ExportedSignature) +foreign export ccall "dexFreeFunctionSignature" dexFreeFunctionSignature :: Ptr ExportedSignature -> IO () diff --git a/src/Dex/Foreign/JIT.hs b/src/Dex/Foreign/JIT.hs index c188dfaf3..d40a4b4a0 100644 --- a/src/Dex/Foreign/JIT.hs +++ b/src/Dex/Foreign/JIT.hs @@ -5,18 +5,24 @@ -- https://developers.google.com/open-source/licenses/bsd {-# LANGUAGE RecordWildCards #-} +{-# OPTIONS_GHC -Wno-orphans #-} module Dex.Foreign.JIT ( - JIT, NativeFunction, + JIT, NativeFunction, ExportedSignature, dexCreateJIT, dexDestroyJIT, + dexGetFunctionSignature, dexFreeFunctionSignature, dexCompile, dexUnload ) where import Control.Monad.State.Strict import Foreign.Ptr +import Foreign.C.String +import Foreign.Storable +import Foreign.Marshal.Alloc import Data.IORef +import Data.Functor import qualified Data.Map.Strict as M import LLVM.Target (TargetMachine) @@ -30,51 +36,80 @@ import Logging import LLVMExec import JIT import Syntax hiding (sizeOf) -import TopLevel +import Export import Dex.Foreign.Util import Dex.Foreign.Context +data NativeFunction = + NativeFunction { nativeModule :: LLVM.JIT.NativeModule + , nativeSignature :: ExportedSignature } +type NativeFunctionAddr = Ptr NativeFunction + data JIT = ForeignJIT { jit :: LLVM.JIT.JIT , jitTargetMachine :: TargetMachine - , funcToModuleRef :: IORef (M.Map (Ptr NativeFunction) LLVM.JIT.NativeModule) + , addrTableRef :: IORef (M.Map NativeFunctionAddr NativeFunction) } +instance Storable ExportedSignature where + sizeOf _ = 3 * sizeOf (undefined :: Ptr ()) + alignment _ = alignment (undefined :: Ptr ()) + peek _ = error "peek not implemented for ExportedSignature" + poke addr sig = do + let strAddr = castPtr @ExportedSignature @CString addr + let (arg, res, ccall) = exportedSignatureDesc sig + pokeElemOff strAddr 0 =<< newCString arg + pokeElemOff strAddr 1 =<< newCString res + pokeElemOff strAddr 2 =<< newCString ccall dexCreateJIT :: IO (Ptr JIT) dexCreateJIT = do jitTargetMachine <- LLVM.Shims.newHostTargetMachine R.PIC CM.Large CGO.Aggressive jit <- LLVM.JIT.createJIT jitTargetMachine - funcToModuleRef <- newIORef mempty + addrTableRef <- newIORef mempty toStablePtr ForeignJIT{..} dexDestroyJIT :: Ptr JIT -> IO () dexDestroyJIT jitPtr = do ForeignJIT{..} <- fromStablePtr jitPtr - funcToModule <- readIORef funcToModuleRef - forM_ (M.toList funcToModule) $ \(_, m) -> LLVM.JIT.unloadNativeModule m + addrTable <- readIORef addrTableRef + forM_ (M.toList addrTable) $ \(_, m) -> LLVM.JIT.unloadNativeModule $ nativeModule m LLVM.JIT.destroyJIT jit LLVM.Shims.disposeTargetMachine jitTargetMachine -data NativeFunction - -dexCompile :: Ptr JIT -> Ptr Context -> Ptr Atom -> IO (Ptr NativeFunction) +dexCompile :: Ptr JIT -> Ptr Context -> Ptr Atom -> IO NativeFunctionAddr dexCompile jitPtr ctxPtr funcAtomPtr = do ForeignJIT{..} <- fromStablePtr jitPtr Context _ env <- fromStablePtr ctxPtr funcAtom <- fromStablePtr funcAtomPtr - let impMod = prepareFunctionForExport env "userFunc" funcAtom + let (impMod, nativeSignature) = prepareFunctionForExport env "userFunc" funcAtom nativeModule <- execLogger Nothing $ \logger -> do llvmAST <- impToLLVM logger impMod LLVM.JIT.compileModule jit llvmAST (standardCompilationPipeline logger ["userFunc"] jitTargetMachine) funcPtr <- castFunPtrToPtr <$> LLVM.JIT.getFunctionPtr nativeModule "userFunc" - modifyIORef funcToModuleRef $ M.insert funcPtr nativeModule + modifyIORef addrTableRef $ M.insert funcPtr NativeFunction{..} return $ funcPtr -dexUnload :: Ptr JIT -> Ptr NativeFunction -> IO () +dexGetFunctionSignature :: Ptr JIT -> NativeFunctionAddr -> IO (Ptr ExportedSignature) +dexGetFunctionSignature jitPtr funcPtr = do + ForeignJIT{..} <- fromStablePtr jitPtr + addrTable <- readIORef addrTableRef + case M.lookup funcPtr addrTable of + Nothing -> setError "Invalid function address" $> nullPtr + Just NativeFunction{..} -> putOnHeap nativeSignature + +dexFreeFunctionSignature :: Ptr ExportedSignature -> IO () +dexFreeFunctionSignature sigPtr = do + let strPtr = castPtr @ExportedSignature @CString sigPtr + free =<< peekElemOff strPtr 0 + free =<< peekElemOff strPtr 1 + free =<< peekElemOff strPtr 2 + free sigPtr + +dexUnload :: Ptr JIT -> NativeFunctionAddr -> IO () dexUnload jitPtr funcPtr = do ForeignJIT{..} <- fromStablePtr jitPtr - funcToModule <- readIORef funcToModuleRef - LLVM.JIT.unloadNativeModule $ funcToModule M.! funcPtr - modifyIORef funcToModuleRef $ M.delete funcPtr + addrTable <- readIORef addrTableRef + LLVM.JIT.unloadNativeModule $ nativeModule $ addrTable M.! funcPtr + modifyIORef addrTableRef $ M.delete funcPtr diff --git a/src/Dex/Foreign/Serialize.hs b/src/Dex/Foreign/Serialize.hs index 76560c8df..8d882ee49 100644 --- a/src/Dex/Foreign/Serialize.hs +++ b/src/Dex/Foreign/Serialize.hs @@ -9,7 +9,6 @@ module Dex.Foreign.Serialize ( dexPrint, dexToCAtom ) where -import Data.Int import Data.Word import Data.Functor diff --git a/src/Dex/Foreign/Util.hs b/src/Dex/Foreign/Util.hs index aaa3ce8ec..de156e983 100644 --- a/src/Dex/Foreign/Util.hs +++ b/src/Dex/Foreign/Util.hs @@ -4,13 +4,21 @@ -- license that can be found in the LICENSE file or at -- https://developers.google.com/open-source/licenses/bsd -module Dex.Foreign.Util (fromStablePtr, toStablePtr) where +module Dex.Foreign.Util (fromStablePtr, toStablePtr, putOnHeap) where -import Foreign.StablePtr import Foreign.Ptr +import Foreign.StablePtr +import Foreign.Storable +import Foreign.Marshal.Alloc fromStablePtr :: Ptr a -> IO a fromStablePtr = deRefStablePtr . castPtrToStablePtr . castPtr toStablePtr :: a -> IO (Ptr a) toStablePtr x = castPtr . castStablePtrToPtr <$> newStablePtr x + +putOnHeap :: Storable a => a -> IO (Ptr a) +putOnHeap x = do + ptr <- malloc + poke ptr x + return ptr diff --git a/src/dex.hs b/src/dex.hs index 9c36c9ae8..7c0fb8011 100644 --- a/src/dex.hs +++ b/src/dex.hs @@ -25,6 +25,7 @@ import Resources import TopLevel import Parser hiding (Parser) import LiveOutput +import Export data ErrorHandling = HaltOnErr | ContinueOnErr data DocFmt = ResultOnly | TextDoc | HTMLDoc | JSONDoc @@ -59,7 +60,9 @@ runMode evalMode preludeFile opts = do let errors = foldMap (\case (Result _ (Left err)) -> [err]; _ -> []) results putStr $ foldMap (nonEmptyNewline . pprint) errors let exportedFuns = foldMap (\case (ExportedFun name f) -> [(name, f)]; _ -> []) outputs - exportFunctions objPath exportedFuns env opts + unless (backendName opts == LLVM) $ liftEitherIO $ + throw CompilerErr "Export only supported with the LLVM CPU backend" + exportFunctions objPath exportedFuns env evalPrelude :: EvalConfig -> Maybe FilePath -> IO TopEnv evalPrelude opts fname = flip execStateT mempty $ do diff --git a/src/lib/Export.hs b/src/lib/Export.hs new file mode 100644 index 000000000..d5c9e9472 --- /dev/null +++ b/src/lib/Export.hs @@ -0,0 +1,226 @@ +-- Copyright 2020 Google LLC +-- +-- Use of this source code is governed by a BSD-style +-- license that can be found in the LICENSE file or at +-- https://developers.google.com/open-source/licenses/bsd + +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE RecordWildCards #-} + +module Export ( + exportFunctions, prepareFunctionForExport, exportedSignatureDesc, + ExportedSignature (..), ExportArrayType (..), ExportArg (..), ExportResult (..), + ) where + +import Control.Monad.State.Strict +import Control.Monad.Writer hiding (pass) +import qualified Data.Text as T +import Data.String +import Data.Foldable +import Data.List (nub, intercalate) + +import Algebra +import Syntax +import Embed +import Cat +import Env +import Type +import Simplify +import Imp +import JIT +import Logging +import LLVMExec +import PPrint +import Optimize + +exportFunctions :: FilePath -> [(String, Atom)] -> TopEnv -> IO () +exportFunctions objPath funcs env = do + let names = fmap fst funcs + unless (length (nub names) == length names) $ liftEitherIO $ + throw CompilerErr "Duplicate export names" + modules <- forM funcs $ \(name, funcAtom) -> do + let (impModule, _) = prepareFunctionForExport env name funcAtom + (,[name]) <$> execLogger Nothing (flip impToLLVM impModule) + exportObjectFile objPath modules + + +type CArgList = [IBinder] -- ^ List of arguments to the C call +data CArgEnv = CArgEnv { -- | Maps scalar atom binders to their CArgs. All atoms are Vars. + cargScalarScope :: Env Atom + -- | Tracks the CArg names used so far (globally scoped, unlike Embed) + , cargScope :: Env () } +type CArgM = WriterT CArgList (CatT CArgEnv Embed) + +instance Semigroup CArgEnv where + (CArgEnv a1 a2) <> (CArgEnv b1 b2) = CArgEnv (a1 <> b1) (a2 <> b2) + +instance Monoid CArgEnv where + mempty = CArgEnv mempty mempty + +runCArg :: CArgEnv -> CArgM a -> Embed (a, [IBinder], CArgEnv) +runCArg initEnv m = repack <$> runCatT (runWriterT m) initEnv + where repack ((ans, cargs), env) = (ans, cargs, env) + +prepareFunctionForExport :: TopEnv -> String -> Atom -> (ImpModule, ExportedSignature) +prepareFunctionForExport env nameStr func = do + -- Create a module that simulates an application of arguments to the function + -- TODO: Assert that the type of func is closed? + let ((dest, cargs, apiDesc), (_, decls)) = flip runEmbed (freeVars func) $ do + (args, cargArgs, cargEnv) <- runCArg mempty $ createArgs $ getType func + let (atomArgs, exportedArgSig) = unzip args + resultAtom <- naryApp func atomArgs + void $ emitTo outputName PlainLet $ Atom resultAtom + ((resultDest, exportedResSig), cdestArgs, _) <- runCArg cargEnv $ createDest mempty $ getType resultAtom + let cargs' = cargArgs <> cdestArgs + let exportedCCallSig = fmap (\(Bind (v:>_)) -> v) cargs' + return (resultDest, cargs', ExportedSignature{..}) + + let coreModule = Module Core decls mempty + let defunctionalized = simplifyModule env coreModule + let Module _ optDecls optBindings = optimizeModule defunctionalized + let (_, LetBound PlainLet outputExpr) = optBindings ! outputName + let block = Block optDecls outputExpr + + let name = Name TopFunctionName (fromString nameStr) 0 + let (_, impModule, _) = toImpModule env LLVM CEntryFun name cargs (Just dest) block + (impModule, apiDesc) + where + outputName = GlobalName "_ans_" + + createArgs :: Type -> CArgM [(Atom, ExportArg)] + createArgs ty = case ty of + PiTy b arrow result | arrow /= TabArrow -> do + argSubst <- looks cargScalarScope + let visibility = case arrow of + PlainArrow Pure -> ExplicitArg + PlainArrow _ -> error $ "Effectful functions cannot be exported" + ImplicitArrow -> ImplicitArg + _ -> error $ "Unexpected type for an exported function: " ++ pprint ty + (:) <$> createArg visibility (subst (argSubst, mempty) b) <*> createArgs result + _ -> return [] + + createArg :: ArgVisibility -> Binder -> CArgM (Atom, ExportArg) + createArg vis b = case ty of + BaseTy bt@(Scalar sbt) -> do + ~v@(Var (name:>_)) <- newCVar bt + extend $ mempty { cargScalarScope = b @> (Var $ name :> BaseTy bt) } + return (v, ExportScalarArg vis name sbt) + TabTy _ _ -> createTabArg vis mempty ty + _ -> error $ "Unsupported arg type: " ++ pprint ty + where ty = binderType b + + createTabArg :: ArgVisibility -> IndexStructure -> Type -> CArgM (Atom, ExportArg) + createTabArg vis idx ty = case ty of + BaseTy bt@(Scalar sbt) -> do + ~v@(Var (name:>_)) <- newCVar (ptrTy bt) + destAtom <- ptrLoad =<< applyIdxs v idx + funcArgScope <- looks cargScope + let exportArg = ExportArrayArg vis name $ case getRectShape funcArgScope idx of + Just rectShape -> RectContArrayPtr sbt rectShape + Nothing -> GeneralArrayPtr sbt + return (destAtom, exportArg) + TabTy b elemTy -> do + buildLamAux b (const $ return TabArrow) $ \(Var i) -> do + elemTy' <- substEmbed (b@>Var i) elemTy + createTabArg vis (idx <> Nest (Bind i) Empty) elemTy' + _ -> unsupported + where unsupported = error $ "Unsupported table type suffix: " ++ pprint ty + + createDest :: IndexStructure -> Type -> CArgM (Atom, ExportResult) + createDest idx ty = case ty of + BaseTy bt@(Scalar sbt) -> do + ~v@(Var (name:>_)) <- newCVar (ptrTy bt) + dest <- Con . BaseTypeRef <$> applyIdxs v idx + funcArgScope <- looks cargScope + let exportResult = case idx of + Empty -> ExportScalarResultPtr name sbt + _ -> ExportArrayResult name $ case getRectShape funcArgScope idx of + Just rectShape -> RectContArrayPtr sbt rectShape + Nothing -> GeneralArrayPtr sbt + return (dest, exportResult) + TabTy b elemTy -> do + (destTab, exportResult) <- buildLamAux b (const $ return TabArrow) $ \(Var i) -> do + elemTy' <- substEmbed (b@>Var i) elemTy + createDest (idx <> Nest (Bind i) Empty) elemTy' + return (Con $ TabRef destTab, exportResult) + _ -> unsupported + where unsupported = error $ "Unsupported result type: " ++ pprint ty + + -- TODO: I guess that the address space depends on the backend? + -- TODO: Have an ExternalPtr tag? + ptrTy ty = PtrType (DerivedPtr, Heap CPU, ty) + + getRectShape :: Env () -> IndexStructure -> Maybe [Either Name Int] + getRectShape scope idx = traverse (dimShape . binderType) $ toList idx + where + dimShape dimTy = case dimTy of + Fin (IdxRepVal n) -> Just $ Right $ fromIntegral n + Fin (Var v) | v `isin` scope -> Just $ Left $ varName v + _ -> Nothing + + newCVar :: BaseType -> CArgM Atom + newCVar bt = do + name <- genFresh (Name CArgName "arg" 0) <$> looks cargScope + extend $ mempty { cargScope = name @> () } + tell [Bind $ name :> bt] + return $ Var $ name :> BaseTy bt + +-- === Exported function signature === + +data ExportArrayType = GeneralArrayPtr ScalarBaseType + | RectContArrayPtr ScalarBaseType [Either Name Int] +data ArgVisibility = ImplicitArg | ExplicitArg +data ExportArg = ExportArrayArg ArgVisibility Name ExportArrayType + | ExportScalarArg ArgVisibility Name ScalarBaseType +data ExportResult = ExportArrayResult Name ExportArrayType + | ExportScalarResultPtr Name ScalarBaseType +data ExportedSignature = + ExportedSignature { exportedArgSig :: [ExportArg] + , exportedResSig :: ExportResult + , exportedCCallSig :: [Name] + } + +-- Serialization + +exportedSignatureDesc :: ExportedSignature -> (String, String, String) +exportedSignatureDesc ExportedSignature{..} = + ( intercalate "," $ fmap show exportedArgSig + , show exportedResSig + , intercalate "," $ fmap showCArgName exportedCCallSig + ) + +showExportSBT :: ScalarBaseType -> String +showExportSBT sbt = case sbt of + Int64Type -> "i64" + Int32Type -> "i32" + Word8Type -> "u8" + Float64Type -> "f64" + Float32Type -> "f32" + +showCArgName :: Name -> String +showCArgName ~name@(Name namespace tag idx) = case namespace of + CArgName -> T.unpack tag <> show idx + _ -> error $ "Expected a CArgName namespace: " ++ show name + +instance Show ExportArrayType where + show arr = case arr of + GeneralArrayPtr sbt -> showExportSBT sbt <> "[?]" + RectContArrayPtr sbt shape -> showExportSBT sbt <> showShape shape + where + showShape shape = "[" <> (intercalate "," $ fmap showDim shape) <> "]" + showDim size = case size of + Left name -> showCArgName name + Right lit -> show lit + +instance Show ExportArg where + show arg = case arg of + ExportArrayArg vis name ty -> showVis vis <> showCArgName name <> ":" <> show ty + ExportScalarArg vis name sbt -> showVis vis <> showCArgName name <> ":" <> showExportSBT sbt + where + showVis ImplicitArg = "?" + showVis ExplicitArg = "" + +instance Show ExportResult where + show res = case res of + ExportArrayResult name ty -> showCArgName name <> ":" <> show ty + ExportScalarResultPtr name sbt -> showCArgName name <> ":" <> showExportSBT sbt diff --git a/src/lib/LLVM/JIT.hs b/src/lib/LLVM/JIT.hs index 4649152c4..e10228a4c 100644 --- a/src/lib/LLVM/JIT.hs +++ b/src/lib/LLVM/JIT.hs @@ -151,5 +151,5 @@ getFunctionPtr :: NativeModule -> String -> IO (FunPtr a) getFunctionPtr NativeModule{..} funcName = do let JIT{..} = moduleJIT symbol <- OrcJIT.mangleSymbol compileLayer $ fromString funcName - Right (OrcJIT.JITSymbol funcAddr _) <- OrcJIT.findSymbol compileLayer symbol False + Right (OrcJIT.JITSymbol funcAddr _) <- OrcJIT.findSymbolIn compileLayer moduleKey symbol False return $ castPtrToFunPtr $ wordPtrToPtr funcAddr diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index 04052e457..e4c17b142 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -53,7 +53,7 @@ module Syntax ( pattern IdxRepTy, pattern IdxRepVal, pattern IIdxRepVal, pattern IIdxRepTy, pattern TagRepTy, pattern TagRepVal, pattern Word8Ty, pattern IntLitExpr, pattern FloatLitExpr, - pattern UnitTy, pattern PairTy, pattern FunTy, + pattern UnitTy, pattern PairTy, pattern FunTy, pattern PiTy, pattern FixedIntRange, pattern Fin, pattern RefTy, pattern RawRefTy, pattern BaseTy, pattern PtrTy, pattern UnitVal, pattern PairVal, pattern PureArrow, @@ -1428,6 +1428,9 @@ fromConsList xs = case xs of pattern FunTy :: Binder -> EffectRow -> Type -> Type pattern FunTy b eff bodyTy = Pi (Abs b (PlainArrow eff, bodyTy)) +pattern PiTy :: Binder -> Arrow -> Type -> Type +pattern PiTy b arr bodyTy = Pi (Abs b (arr, bodyTy)) + pattern BinaryFunTy :: Binder -> Binder -> EffectRow -> Type -> Type pattern BinaryFunTy b1 b2 eff bodyTy = FunTy b1 Pure (FunTy b2 eff bodyTy) diff --git a/src/lib/TopLevel.hs b/src/lib/TopLevel.hs index e1957c523..1e05bfe1d 100644 --- a/src/lib/TopLevel.hs +++ b/src/lib/TopLevel.hs @@ -1,25 +1,23 @@ --- Copyright 2019 Google LLC +-- Copyright 2020 Google LLC -- -- Use of this source code is governed by a BSD-style -- license that can be found in the LICENSE file or at -- https://developers.google.com/open-source/licenses/bsd {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE RecordWildCards #-} -module TopLevel (evalSourceBlock, evalDecl, evalSource, evalFile, - exportFunctions, prepareFunctionForExport, EvalConfig (..)) where +module TopLevel (evalSourceBlock, evalDecl, evalSource, evalFile, EvalConfig (..)) where import Control.Monad.State.Strict import Control.Monad.Reader -import Control.Monad.Writer hiding (pass) import Control.Monad.Except hiding (Except) import Data.Text.Prettyprint.Doc import Data.String -import Data.List (partition, nub) +import Data.List (partition) import Data.Time.Clock (getCurrentTime, diffUTCTime) import qualified Data.Map.Strict as M -import Algebra import Syntax import Embed import Cat @@ -250,99 +248,6 @@ logTop x = do logger <- asks logService logThis logger [x] -type CArgM = WriterT [IBinder] (CatT CArgEnv Embed) -type CArgEnv = (Env IBinder, Env ()) - -runCArg :: CArgEnv -> CArgM a -> Embed (a, [IBinder], CArgEnv) -runCArg initEnv m = repack <$> runCatT (runWriterT m) initEnv - where repack ((ans, cargs), env) = (ans, cargs, env) - -prepareFunctionForExport :: TopEnv -> String -> Atom -> ImpModule -prepareFunctionForExport env nameStr func = do - -- Create a module that simulates an application of arguments to the function - let ((dest, cargs), (_, decls)) = flip runEmbed (freeVars func) $ do - (args, cargArgs, cargEnv) <- runCArg mempty $ createArgs $ getType func - resultAtom <- naryApp func args - (resultDest, cdestArgs, _) <- runCArg cargEnv $ createDest mempty $ getType resultAtom - void $ emitTo outputName PlainLet $ Atom resultAtom - return (resultDest, cargArgs <> cdestArgs) - - let coreModule = Module Core decls mempty - let defunctionalized = simplifyModule env coreModule - let Module _ optDecls optBindings = optimizeModule defunctionalized - let (_, LetBound PlainLet outputExpr) = optBindings ! outputName - let block = Block optDecls outputExpr - - let name = Name TopFunctionName (fromString nameStr) 0 - let (_, impModule, _) = toImpModule env LLVM CEntryFun name cargs (Just dest) block - impModule - where - outputName = GlobalName "_ans_" - - createArgs :: Type -> CArgM [Atom] - createArgs ty = case ty of - FunTy b Pure result -> do - argSubst <- fmap (\(Bind (n:>bt)) -> Var $ n :> BaseTy bt) <$> looks fst - arg <- createArg $ subst (argSubst, mempty) $ b - (arg:) <$> createArgs result - FunTy _ _ _ -> error $ "Unexpected type for an exported function: " ++ pprint ty - _ -> return [] - - createArg :: Binder -> CArgM Atom - createArg b = case ty of - BaseTy bt@(Scalar _) -> do - ~v@(Var (name:>_)) <- newCVar bt - extend $ asFst $ b @> (Bind $ name :> bt) - return v - TabTy _ _ -> createTabArg mempty ty - _ -> error $ "Unsupported arg type: " ++ pprint ty - where ty = binderType b - - createTabArg :: IndexStructure -> Type -> CArgM Atom - createTabArg idx ty = case ty of - BaseTy bt@(Scalar _) -> do - ptrLoad =<< flip applyIdxs idx =<< newCVar (ptrTy bt) - TabTy b elemTy -> do - buildLam b TabArrow $ \(Var i) -> do - elemTy' <- substEmbed (b@>Var i) elemTy - createTabArg (idx <> Nest (Bind i) Empty) elemTy' - _ -> unsupported - where unsupported = error "Unsupported table type" - - createDest :: IndexStructure -> Type -> CArgM Atom - createDest idx ty = case ty of - BaseTy bt@(Scalar _) -> do - liftM (Con . BaseTypeRef) $ flip applyIdxs idx =<< newCVar (ptrTy bt) - TabTy b elemTy -> do - liftM (Con . TabRef) $ buildLam b TabArrow $ \(Var i) -> do - elemTy' <- substEmbed (b@>Var i) elemTy - createDest (idx <> Nest (Bind i) Empty) elemTy' - _ -> unsupported - where unsupported = error "Unsupported table type" - - -- TODO: I guess that the address space depends on the backend? - -- TODO: Have an ExternalPtr tag? - ptrTy ty = PtrType (DerivedPtr, Heap CPU, ty) - - newCVar :: BaseType -> CArgM Atom - newCVar bt = do - name <- genFresh (Name CArgName "arg" 0) <$> looks snd - extend $ asSnd $ name @> () - tell [Bind $ name :> bt] - return $ Var $ name :> BaseTy bt - -exportFunctions :: FilePath -> [(String, Atom)] -> TopEnv -> EvalConfig -> IO () -exportFunctions objPath funcs env opts = do - unless (backendName opts == LLVM) $ liftEitherIO $ - throw CompilerErr "Export only supported with the LLVM CPU backend" - let names = fmap fst funcs - unless (length (nub names) == length names) $ liftEitherIO $ - throw CompilerErr "Duplicate export names" - modules <- forM funcs $ \(name, funcAtom) -> do - let impModule = prepareFunctionForExport env name funcAtom - (,[name]) <$> execLogger Nothing (flip impToLLVM impModule) - exportObjectFile objPath modules - abstractPtrLiterals :: Block -> ([IBinder], [LitVal], Block) abstractPtrLiterals block = flip evalState mempty $ do block' <- traverseLiterals block $ \val -> case val of From 2212cc7cc1fee5944ea64bd35cbbbcf0d8485f74 Mon Sep 17 00:00:00 2001 From: David Duvenaud Date: Tue, 22 Dec 2020 12:44:39 -0500 Subject: [PATCH 018/105] Updated fluidsim example to have a gradient check, and be more general. --- examples/fluidsim.dx | 114 ++++++++++++++++++++----------------------- 1 file changed, 54 insertions(+), 60 deletions(-) diff --git a/examples/fluidsim.dx b/examples/fluidsim.dx index e684e1ce2..39bd393a6 100644 --- a/examples/fluidsim.dx +++ b/examples/fluidsim.dx @@ -4,55 +4,46 @@ Fluid simulation code based on include "plot.dx" -def zeroedges (dict:VSpace a) ?=> (n:Type) ?-> (m:Type) ?-> (x: n=>m=>a) : n=>m=>a = - -- Todo: update in place without starting with a copy. - snd $ withState x \buf. - for i j. - edge = i==0 || j==0 || i==ordinal n || i==ordinal m - select edge (buf!(i@n)!(j@m) := zero) () - def wrapidx (n:Type) -> (i:Int) : n = -- Index wrapping around at ends. asidx $ mod i $ size n -def incwrap (n:Type) ?-> (i:n) : n = - -- Increments index, wrapping around at ends. +def incwrap (i:n) : n = -- Increment index, wrapping around at ends. asidx $ mod ((ordinal i) + 1) $ size n -def decwrap (n:Type) ?-> (i:n) : n = - -- Decrements index, wrapping around at ends. +def decwrap (i:n) : n = -- Decrement index, wrapping around at ends. asidx $ mod ((ordinal i) - 1) $ size n -def finite_difference_neighbours (n:Type) ?-> (x:n=>Float) : n=>Float = +def finite_difference_neighbours (_:Add a) ?=> (x:n=>a) : n=>a = for i. x.(incwrap i) - x.(decwrap i) -def add_neighbours (n:Type) ?-> (x:n=>Float) : n=>Float = +def add_neighbours (_:Add a) ?=> (x:n=>a) : n=>a = for i. x.(incwrap i) + x.(decwrap i) -def apply_along_axis1 (f : b=>Float -> b=>Float) (x : b=>c=>Float) : b=>c=>Float = +def apply_along_axis1 (f:b=>a -> b=>a) (x:b=>c=>a) : b=>c=>a = transpose for j. f for i. x.i.j -def apply_along_axis2 (f : c=>Float -> c=>Float) (x : b=>c=>Float) : b=>c=>Float = +def apply_along_axis2 (f:c=>a -> c=>a) (x:b=>c=>a) : b=>c=>a = for i. f x.i -def fdx (x : n=>m=>Float) : (n=>m=>Float) = +def fdx (_:Add a) ?=> (x:n=>m=>a) : (n=>m=>a) = apply_along_axis1 finite_difference_neighbours x -def fdy (x : n=>m=>Float) : (n=>m=>Float) = +def fdy (_:Add a) ?=> (x:n=>m=>a) : (n=>m=>a) = apply_along_axis2 finite_difference_neighbours x -def divergence (vx : n=>m=>Float) (vy : n=>m=>Float) : (n=>m=>Float) = +def divergence (_:Add a) ?=> (vx:n=>m=>a) (vy:n=>m=>a) : (n=>m=>a) = fdx vx + fdy vy -def add_neighbours_2d (x : n=>m=>Float) : (n=>m=>Float) = +def add_neighbours_2d (_:Add a) ?=> (x:n=>m=>a) : (n=>m=>a) = ax1 = apply_along_axis1 add_neighbours x ax2 = apply_along_axis2 add_neighbours x ax1 + ax2 -def project (v: n=>m=>(Fin 2)=>Float) : n=>m=>(Fin 2)=>Float = +def project (_:VSpace a) ?=> (v: n=>m=>(Fin 2)=>a) : n=>m=>(Fin 2)=>a = -- Project the velocity field to be approximately mass-conserving, -- using a few iterations of Gauss-Seidel. - h = 0.01 -- todo: work out units + h = 1.0 / IToF (size n) -- unpack into two scalar fields vx = for i j. v.i.j.(fromOrdinal _ 0) @@ -60,43 +51,27 @@ def project (v: n=>m=>(Fin 2)=>Float) : n=>m=>(Fin 2)=>Float = div = -0.5 .* h .* (divergence vx vy) - p_init = for i. for j. 0.0 - p = snd $ withState p_init \state. + p = snd $ withState zero \state. for i:(Fin 10). - p = get state - state := (1.0 / 4.0) .* (div + add_neighbours_2d p) + state := (1.0 / 4.0) .* (div + add_neighbours_2d (get state)) vx = vx - (0.5 / h) .* fdx(p) vy = vy - (0.5 / h) .* fdy(p) - for i j. [vx.i.j, vy.i.j] -- pack back into a vector field - - -- zeroedges v -- BUG: Crashes with "Not implemented Int" + for i j. [vx.i.j, vy.i.j] -- pack back into a table. -def bilinear_interp (dict:VSpace a) ?=> (right_weight:Float) --o (bottom_weight:Float) --o - (topleft: a) --o (bottomleft: a) --o (topright: a) --o (bottomright: a) --o : a = +def bilinear_interp (_:VSpace a) ?=> (right_weight:Float) (bottom_weight:Float) + (topleft: a) (bottomleft: a) (topright: a) (bottomright: a) : a = left = (1.0 - right_weight) .* ((1.0 - bottom_weight) .* topleft + bottom_weight .* bottomleft) right = right_weight .* ((1.0 - bottom_weight) .* topright + bottom_weight .* bottomright) left + right - -N = Fin 100 -M = Fin 100 - --- BUG: Changing the order of implicit arguments causes an error further down. --- i.e. it doesn't work to start the next line with --- (n:Type) ?-> (m:Type) ?-> (dict:VSpace a) ?=> -def advect (dict:VSpace a) ?=> (n:Type) ?-> (m:Type) ?-> (f: n=>m=>a) (v: n=>m=>(Fin 2)=>Float) : n=>m=>a = +def advect (_:VSpace a) ?=> (f: n=>m=>a) (v: n=>m=>(Fin 2)=>Float) : n=>m=>a = -- Move field f according to x and y velocities (u and v) -- using an implicit Euler integrator. - -- Create table of cell locations. - -- BUG: using n and m below causes a crash, so I hardcoded it for now. - numrows = 100.0 -- IToF $ ordinal n - numcols = 100.0 -- IToF $ ordinal m - - cell_xs = linspace n 0.0 numrows - cell_ys = linspace m 0.0 numcols + cell_xs = linspace n 0.0 $ IToF (size n) + cell_ys = linspace m 0.0 $ IToF (size m) for i j. -- Location of source of flow for this cell. No meshgrid! @@ -108,22 +83,19 @@ def advect (dict:VSpace a) ?=> (n:Type) ?-> (m:Type) ?-> (f: n=>m=>a) (v: n=>m=> source_row = floor center_ys -- Relative weight of right-hand and bottom cells. - -- TODO: clipping shouldn't be necessary here, find out why it is. - right_weight = clip (0.0, 1.0) $ center_xs - source_col - bottom_weight = clip (0.0, 1.0) $ center_ys - source_row + right_weight = center_xs - source_col + bottom_weight = center_ys - source_row -- Cast back to indices, wrapping around edges. - source_col_int = FToI source_col - source_row_int = FToI source_row - l = wrapidx n source_col_int - r = wrapidx n (source_col_int + 1) - t = wrapidx m source_row_int - b = wrapidx m (source_row_int + 1) + l = wrapidx n (FToI source_col) + r = wrapidx n ((FToI source_col) + 1) + t = wrapidx m (FToI source_row) + b = wrapidx m ((FToI source_row) + 1) -- A convex weighting of the 4 surrounding cells. bilinear_interp right_weight bottom_weight f.l.t f.l.b f.r.t f.r.b -def fluidsim (dict: VSpace a) ?=> (num_steps: Int) (color_init: n=>m=>a) +def fluidsim (_: VSpace a) ?=> (num_steps: Int) (color_init: n=>m=>a) (v: n=>m=>(Fin 2)=>Float) : n=>m=>a = (color_final, v) = snd $ withState (color_init, v) \state. for i:(Fin num_steps). @@ -136,18 +108,40 @@ def fluidsim (dict: VSpace a) ?=> (num_steps: Int) (color_init: n=>m=>a) '### Demo +N = Fin 50 +M = Fin 50 + -- Create random velocity field. def ixkey3 (k:Key) (i:n) (j:m) (k2:o) : Key = hash (hash (hash k (ordinal i)) (ordinal j)) (ordinal k2) -v = for i:N j:M k:(Fin 2). 3.0 * (randn $ ixkey3 (newKey 0) i j k) +init_velocity = for i:N j:M k:(Fin 2). + 3.0 * (randn $ ixkey3 (newKey 0) i j k) -- Create diagonally-striped color pattern. init_color = for i:N j:M. - BToF $ (sin $ (IToF $ (ordinal j) + (ordinal i)) / 8.0) > 0.0 + r = BToF $ (sin $ (IToF $ (ordinal j) + (ordinal i)) / 8.0) > 0.0 + b = BToF $ (sin $ (IToF $ (ordinal j) - (ordinal i)) / 6.0) > 0.0 + g = BToF $ (sin $ (IToF $ (ordinal j) + (ordinal i)) / 4.0) > 0.0 + [r, g, b] -- Run fluid sim and plot it. -num_steps = 50 -final_color = fluidsim num_steps init_color v +num_steps = 5 +final_color = fluidsim num_steps init_color init_velocity + +:html imshow final_color +> + + + +'### Gradient test + +target = transpose init_color + +def objective (v:N=>M=>(Fin 2)=>Float) : Float = + final_color = fluidsim num_steps init_color v + sum for (i, j, c). sq (final_color.i.j.c - target.i.j.c) + +init_vel_grad = grad objective zero -:html matshow final_color +:html imshow for i j. [0.0, init_vel_grad.i.j.(0@_), init_vel_grad.i.j.(1@_)] > From a61f1fae821aa3874564083771c09a2952a2e5fc Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 23 Dec 2020 12:24:02 +0000 Subject: [PATCH 019/105] Clean up the Cabal file Remove unused dependencies, group the ones we actually need in categories. Also, remove some files that have bitrotted. --- dex.cabal | 41 +- misc/dex.tex | 56 --- misc/py/dex.py | 34 -- misc/py/dex_binary_object.py | 106 ------ misc/py/foo.dx | 6 - misc/py/generate-dex-data.py | 26 -- misc/py/jax_call.py | 490 ------------------------ misc/py/mnist_to_dxbo.py | 32 -- misc/py/test-dex2jax.py | 14 - src/lib/JAX.hs | 711 ----------------------------------- src/lib/PipeRPC.hs | 60 --- 11 files changed, 17 insertions(+), 1559 deletions(-) delete mode 100644 misc/dex.tex delete mode 100644 misc/py/dex.py delete mode 100644 misc/py/dex_binary_object.py delete mode 100644 misc/py/foo.dx delete mode 100644 misc/py/generate-dex-data.py delete mode 100644 misc/py/jax_call.py delete mode 100644 misc/py/mnist_to_dxbo.py delete mode 100644 misc/py/test-dex2jax.py delete mode 100644 src/lib/JAX.hs delete mode 100644 src/lib/PipeRPC.hs diff --git a/dex.cabal b/dex.cabal index ebfc3d3ee..114014720 100644 --- a/dex.cabal +++ b/dex.cabal @@ -10,6 +10,7 @@ name: dex version: 0.1.0.0 author: Dougal Maclaurin maintainer: dougalm@google.com +license-file: LICENSE build-type: Simple flag cuda @@ -34,15 +35,20 @@ library PPrint, Algebra, Parallelize, Optimize, Serialize Actor, Cat, Flops, Embed, Export, RenderHtml, LiveOutput, Simplify, TopLevel, - Autodiff, Interpreter, Logging, PipeRPC, CUDA, + Autodiff, Interpreter, Logging, CUDA, LLVM.JIT, LLVM.Shims - build-depends: base, containers, mtl, binary, bytestring, - time, tf-random, llvm-hs-pure ==9.*, llvm-hs ==9.*, - aeson, megaparsec >=8.0, warp, wai, filepath, - parser-combinators, http-types, prettyprinter, text, - blaze-html, cmark, diagrams-lib, ansi-terminal, - transformers, directory, mmap, unix, - process, primitive, store, dex-resources, temporary, + build-depends: base, containers, mtl, bytestring, time, + llvm-hs-pure, llvm-hs, + -- Parsing + megaparsec, parser-combinators, + -- Text output + prettyprinter, text, + -- Portable system utilities + filepath, directory, ansi-terminal, process, temporary, + -- Serialization + store, + -- Notebook support + warp, wai, blaze-html, aeson, http-types, cmark, binary if !os(darwin) exposed-modules: Resources hs-source-dirs: src/resources @@ -86,7 +92,9 @@ foreign-library Dex type: native-shared other-modules: Dex.Foreign.API, Dex.Foreign.Util, Dex.Foreign.JIT , Dex.Foreign.Context, Dex.Foreign.Serialize - build-depends: base, dex, dex-resources, mtl, llvm-hs, containers + build-depends: base, mtl, containers, llvm-hs, dex, dex-resources + if os(darwin) + build-depends: dex-resources hs-source-dirs: src/ c-sources: src/Dex/Foreign/rts.c cc-options: -std=c11 -fPIC @@ -97,18 +105,3 @@ foreign-library Dex ghc-options: -O3 else ghc-options: -O0 - -Test-Suite test-dex - type: exitcode-stdio-1.0 - main-is: PropTests.hs - build-depends: dex, base, prettyprinter, containers, - hedgehog, microlens-platform, mtl - other-modules: GenExpr, TestPass - default-language: Haskell2010 - hs-source-dirs: tests - ghc-options: cbits/libdex.so - -Wall - if flag(optimized) - ghc-options: -O3 - else - ghc-options: -O0 diff --git a/misc/dex.tex b/misc/dex.tex deleted file mode 100644 index 839328fbf..000000000 --- a/misc/dex.tex +++ /dev/null @@ -1,56 +0,0 @@ -\documentclass[12pt]{article} -\usepackage{amsmath} -\usepackage{geometry} -\geometry{legalpaper, landscape, margin=0in} - -\newcommand{\annot}[1]{\texttt{::} #1} -\newcommand{\ttt}[1]{~\texttt{#1}~} - -\begin{document} - -\vspace{-0.5cm} - -\begin{huge} -\begin{align*} -\text{Terms } \quad t ::&=~ -l \mid x - && \text{Literal / variable} \\ -&\mid \ttt{let} x \annot{\tau} = t \ttt{in} t - && \text{Let expression} \\ -&\mid \ttt{lam} x \annot{\tau} \ttt {.} t - \mid t ~ t - && \text{Lambda abstraction / application} \\ -&\mid \ttt{tlam} a \, . ~ t - \mid t ~ \tau - && \text{Type-lambda abstraction / application} \\ -&\mid \ttt{for} i \annot{\iota} \ttt {.} t - \mid t.t - && \text{Index comprehension / indexing} \\ -&\mid \ttt{pack} t, \iota \ttt{::} \exists n. ~ \tau - && \text{Existential packing} \\ -&\mid \ttt{let} x, n = \ttt{unpack} t \ttt{in} t - && \text{Existential unpacking} -\\ \\ -\text{Types} \quad \tau, \iota ::&= - \ttt{Int} | \ttt{ Real} \mid \ttt{Bool} \mid a - && \text{Base types and type variable} \\ -&\mid \tau \ttt{->} \tau && \text{Arrow type} \\ -&\mid \forall a. ~ \tau - && \text{Universal quantification} \\ -&\mid \iota \ttt{=>} \tau && \text{Table type} \\ -&\mid \ttt{\{<}l\ttt{\}} && \text{Index set literal} \\ -&\mid \exists n. ~ \tau - && \text{Existential quantification} -\end{align*} -% -\begin{align*} -\text{Term variables} \quad x, i \qquad -\text{Type variables} \quad a, n \qquad -\text{Literals} \quad l -\end{align*} -\end{huge} - -\end{document} - -%% to convert to slides-friendly png: -%% convert -alpha remove -density 300 -quality 85 dex.pdf -transparent white dex.png diff --git a/misc/py/dex.py b/misc/py/dex.py deleted file mode 100644 index c90cb21f9..000000000 --- a/misc/py/dex.py +++ /dev/null @@ -1,34 +0,0 @@ -from __future__ import print_function -# Copyright 2019 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd - -import ctypes -import json - -libname = "./dex2jax.so" - -lib = ctypes.cdll.LoadLibrary(libname) -lib.hs_init(0, 0) # TODO should call lib.hs_exit() when done - -def setup_f(fname): - f = getattr(lib, fname) - f.argtypes = [ctypes.c_char_p] - f.restype = ctypes.c_char_p - return lambda x: json.loads(f(json.dumps(x))) - -loadSource, = map(setup_f, ["loadSource"]) - -class DexModule(object): - def __init__(self, functions): - for fname, definition in functions: - self.__dict__[fname] = definition - -def load(fname): - with open(fname) as f: - s = f.read() - top_level_functions = loadSource(s) - print(top_level_functions) - return DexModule(top_level_functions) diff --git a/misc/py/dex_binary_object.py b/misc/py/dex_binary_object.py deleted file mode 100644 index 9fee60af8..000000000 --- a/misc/py/dex_binary_object.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright 2019 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd - -import itertools as it -from collections import namedtuple -import numpy as np - -TabType = namedtuple('TabType', ['index_set', 'element_type']) - -preheader_length = 81 -preheader_start = "-- dex-object-file-v0.0.1 num-header-bytes " - -def dump(obj, f): - ty = get_dex_ty(obj) - buffers = flatten_to_buffers(obj) - ty_str = "type: {}\n".format(pprint_ty(ty)) - sizes_str = "bufferSizes: [{}]\n".format(", ".join([str(get_buffer_size(x)) - for x in buffers])) - header_size = preheader_length + len(ty_str) + len(sizes_str) - pre_header_str = make_preheader(header_size) - header = pre_header_str + ty_str + sizes_str - assert header_size == len(header) - f.write(header) - f.flush() - for b in buffers: - buf_bytes = b.tobytes() - assert len(buf_bytes) == get_buffer_size(b), \ - "{} {} != {}".format(b, len(buf_bytes), get_buffer_size(b)) - f.buffer.write(buf_bytes) - f.flush() - -def get_dex_ty(obj): - if isinstance(obj, tuple): - return tuple(get_dex_ty(x) for x in obj) - elif isinstance(obj, np.ndarray): - base_ty = dtype_to_dex_ty(obj.dtype) - return make_tab_type(base_ty, obj.shape) - elif isinstance(obj, float): - return float - elif isinstance(obj, bool): - return bool - elif isinstance(obj, int): - return int - else: - raise Exception("No corresponding Dex type for {}".format(type(obj))) - -def flatten_to_buffers(obj): - if isinstance(obj, tuple): - return tuple(it.chain(*(flatten_to_buffers(x) for x in obj))) - elif isinstance(obj, np.ndarray): - flat_array = obj.ravel() - if obj.dtype == np.bool: - return [np.asarray(flat_array, dtype=np.int64)] - else: - return [flat_array] - elif isinstance(obj, float): - return [np.array(obj, dtype=np.float64)] - elif isinstance(obj, bool): - return [np.array(obj, dtype=np.int64)] - elif isinstance(obj, int): - return [np.array(obj, dtype=np.int64)] - else: - raise Exception("No corresponding Dex type for {}".format(type(obj))) - -def dtype_to_dex_ty(dtype): - if dtype == np.float64: - return float - elif dtype == np.int64: - return int - elif dtype == np.bool: - return bool - else: - raise Exception("Unrecognized dtype: " + str(dtype)) - -def make_tab_type(base_ty, shape): - shape = tuple(shape) - if shape == (): - return base_ty - else: - (n, *rest) = shape - return TabType(n, make_tab_type(base_ty, rest)) - -def get_buffer_size(array): - return array.size * 8 - -def pprint_ty(ty): - if isinstance(ty, TabType): - return "{}=>{}".format(str(ty.index_set), pprint_ty(ty.element_type)) - elif isinstance(ty, tuple): - return "({})".format(", ".join(map(pprint_ty, ty))) - if ty is int: - return "Int" - elif ty is float: - return "Real" - elif ty is bool: - return "Bool" - else: - raise Exception("Can't print type: {}".format(ty)) - -def make_preheader(n): - preheader_prefix = preheader_start + str(n) + " " - padding = '-' * (preheader_length - len(preheader_prefix) - 1) + "\n" - return preheader_prefix + padding diff --git a/misc/py/foo.dx b/misc/py/foo.dx deleted file mode 100644 index 3dc99813e..000000000 --- a/misc/py/foo.dx +++ /dev/null @@ -1,6 +0,0 @@ - - -addFloats :: Float -> Float -> Float -addFloats x y = x + y - - diff --git a/misc/py/generate-dex-data.py b/misc/py/generate-dex-data.py deleted file mode 100644 index 7981aefcd..000000000 --- a/misc/py/generate-dex-data.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2019 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd - -from collections import namedtuple -import numpy as np -import dex_binary_object as dbo - -data = (1.2, - 12, - (), - True, - False, - (-2, np.array([1.0, 2.0, 3.0])), - np.array([[10.0, 20.0, 30.0], [0.1, 0.2, 0.3]]) , - np.array([[10.0, 20.0, 30.0], [0.1, 0.2, 0.3]]).T, - 1.3, - np.array(0.123), - np.array([[[1]]]), - np.array([6,5,4,3]), - np.array([True, False, True])) - -with open("test-scratch/pydata.dxbo", "w") as f: - dbo.dump(data, f) diff --git a/misc/py/jax_call.py b/misc/py/jax_call.py deleted file mode 100644 index d92690076..000000000 --- a/misc/py/jax_call.py +++ /dev/null @@ -1,490 +0,0 @@ -# Copyright 2019 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd - -import collections -import json -import pprint -import sys -import pprint as pp -import traceback -import numpy as np -import jax.numpy as jnp -from jax import jit, make_jaxpr, xla_computation -from jax import random -from jax import lax - -scary_map = map - -def map(f, *args): - return list(scary_map(f, *args)) - -class JaxFunction(object): - def __init__(self, binders, decls, results): - for b in binders: assert isinstance(b, Var) - for b, op in decls: - assert isinstance(b, Var) - assert isinstance(op, Operation) - for r in results: assert isinstance(r, Atom) - self.binders = binders - self.decls = decls - self.results = results - - def ser(self): - assert False - - @staticmethod - def des(obj): - binders_ser, (decls_ser, results_ser) = obj - binders = map(Var.des, binders_ser) - results = map(Atom.des, results_ser) - decls = [(Var.des(b), Operation.des(op)) for (b, op) in decls_ser] - return JaxFunction(binders, decls, results) - -class Name(object): - def __init__(self, namespace, root, i): - assert isinstance(i, int) - assert isinstance(namespace, str) - assert isinstance(root, str) - self._name = (namespace, root, i) - - @staticmethod - def des(obj): - namespace, root, i = obj - return Name(namespace, root, i) - - def ser(self): - return {"tag":"Name", "contents": list(self._name)} - - def __repr__(self): return str(self) - def __str__(self): - (_, root, i) = self._name - if i == 0: - return root - else: - return root + str(i) - - def __eq__(self, other): - assert isinstance(other, Name) - return self._name == other._name - - def __hash__(self): - return hash(self._name) - -class IdxVar(object): - def __init__(self, name, size): - assert isinstance(name, Name) - assert isinstance(size, int) - self.name = name - self.size = size - - def __repr__(self): return str(self) - def __str__(self): - return str(self.name) + ":" + str(self.size) - - def __eq__(self, other): - assert isinstance(other, IdxVar) - return self.name == other.name - - def __hash__(self): - return hash(self.name) - - @staticmethod - def des(obj): - name, idxSize = obj - assert name["tag"] == "Name" - return IdxVar(Name.des(name["contents"]), idxSize) - -class Var(object): - def __init__(self, name, ty): - assert isinstance(ty, Ty) - assert isinstance(name, Name) - self.name = name - self.ty = ty - - def __repr__(self): return str(self) - def __str__(self): - return str(self.name) + ":" + str(self.ty) - - def __eq__(self, other): - assert isinstance(other, Var) - return self.name == other.name - - def __hash__(self): - return hash(self.name) - - def ser(self): - return [self.name.ser(), self.ty.ser()] - - @staticmethod - def des(obj): - name, (shape, basetype) = obj - assert name["tag"] == "Name" - return Var(Name.des(name["contents"]), Ty(shape, basetype)) - -class Atom(object): - def __init__(self, case, data): - self.case = case - if case == "Var": - assert isinstance(data, Var) - self.var = data - elif case == "Lit": - assert isinstance(data, arrayish_types), type(data) - self.val = data - else: - assert False - - def __repr__(self): return str(self) - def __str__(self): - if self.case == "Var": - return str(self.var) - elif self.case == "Lit": - return str(self.val) - else: - assert False - - @property - def ty(self): - if self.case == "Var": - return self.var.ty - elif self.case == "Lit": - x = self.val - return array_ty(x) - else: - assert False - - @staticmethod - def des(obj): - if obj["tag"] == "JVar": - val = obj["contents"] - return Atom("Var", Var.des(val)) - elif obj["tag"] == "JLit": - shape, vec = obj["contents"] - val = np.array(vec["contents"], dtype=vec_dtype(vec)).reshape(shape) - return Atom("Lit", val) - -class IndexedAtom(object): - def __init__(self, atom, idxs): - assert isinstance(atom, Atom) - for i in idxs: assert isinstance(i, IdxVar) - self.atom = atom - self.idxs = idxs - - @property - def ty(self): - atom_ty = self.atom.ty - return Ty(atom_ty.shape[:len(self.idxs)], atom_ty.basetype) - - @staticmethod - def des(obj): - atom, idxs = obj - return IndexedAtom(Atom.des(atom), map(IdxVar.des, idxs)) - - def __repr__(self): return str(self) - def __str__(self): - return str(self.atom) + "".join("." + str(i) for i in self.idxs) - -class Ty(object): - def __init__(self, shape, basetype): - for n in shape: assert isinstance(n, int) - assert basetype in ["IntType", "BoolType", "RealType"] - self.basetype = basetype - self.shape = tuple(shape) - - def ser(self): - return [self.shape, self.basetype] - - def __eq__(self, other): - assert isinstance(other, Ty) - return self.basetype == other.basetype and self.shape == other.shape - - @staticmethod - def des(obj): - assert False - - def __repr__(self): return str(self) - def __str__(self): - return self.basetype + str(self.shape) - -MapIdx = "MapIdx" -SumIdx = "SumIdx" -class Operation(object): - def __init__(self, binders, op_name, size_args, args): - for (i, flavor) in binders: - assert isinstance(i, IdxVar) - assert flavor in (MapIdx, SumIdx) - - assert isinstance(op_name, str) - for size in size_args: assert isinstance(size, int) - for arg in args: assert isinstance(arg, IndexedAtom) - self.binders = binders - self.op_name = op_name - self.size_args = size_args - self.args = args - - @property - def all_idxs(self): - return [i for i, _ in self.binders] - - def ser(self): - assert False - - @staticmethod - def des(obj): - binders_ser, op_and_args_ser = obj - binders = [(IdxVar.des(i), fl) for i, fl in binders_ser] - op_name, size_args, args = des_op_and_args(op_and_args_ser) - return Operation(binders, op_name, size_args, args) - - def __repr__(self): return str(self) - def __str__(self): - return "for {} . {} {}".format( - self.binders, self.op_name, tuple(self.args)) - -def array_ty(x): - return Ty(x.shape, dtype_basetype(x.dtype)) - -def ser_array(arr): - assert isinstance(arr, arrayish_types) - return ser_flat_vec(arr.ravel()) - -def ser_flat_vec(vec): - if vec.dtype in [np.int32, np.int64]: - return {"tag":"IntVec", "contents": map(int, vec)} - if vec.dtype in [np.float32, np.float64]: - return {"tag":"DoubleVec", "contents": map(float, vec)} - else: - assert False - -def des_op_and_args(obj): - tag = obj["tag"] - if tag == "JScalarBinOp": - binop_name, x_ser, y_ser = obj["contents"] - x = IndexedAtom.des(x_ser) - y = IndexedAtom.des(y_ser) - return binop_name["tag"], [], [x, y] - if tag == "JScalarUnOp": - unop_name, x_ser = obj["contents"] - x = IndexedAtom.des(x_ser) - return unop_name, [], [x] - elif tag == "JIota": - size = obj["contents"] - assert isinstance(size, int) - return "Iota", [size], [] - elif tag == "JId": - x_ser = obj["contents"] - x = IndexedAtom.des(x_ser) - return "Id", [], [x] - elif tag == "JGet": - x_ser, y_ser = obj["contents"] - x = IndexedAtom.des(x_ser) - y = IndexedAtom.des(y_ser) - return "Get", [], [x, y] - elif tag == "JThreeFry2x32": - x_ser, y_ser = obj["contents"] - x = IndexedAtom.des(x_ser) - y = IndexedAtom.des(y_ser) - return "ThreeFry2x32", [], [x, y] - else: - raise Exception("Not implemented: " + str(tag)) - -global_env = {} - -def eval_op(op): - if op.op_name in ("FMul", "IMul"): - ans = eval_einsum(op) - return Atom("Lit", ans) - else: - broadcast_ans = eval_for(op) - sum_axes = tuple(i for (i, (_, fl)) in enumerate(op.binders) if fl == SumIdx) - if sum_axes == (): - return Atom("Lit", broadcast_ans) - else: - summed_ans = np.sum(broadcast_ans, axis=sum_axes) - return Atom("Lit", summed_ans) - -def eval_einsum(op): - assert op.op_name in ("FMul", "IMul") - x, y = op.args - x_axes = [str(i.name) for i in x.idxs] - y_axes = [str(i.name) for i in y.idxs] - out_axes = [str(i.name) for i, f in op.binders if f != SumIdx] - return jnp.einsum(x.atom.val, x_axes, y.atom.val, y_axes, out_axes) - -def eval_for(op): - if op.op_name in ("IAdd", "IMul", "FAdd", "FMul", "FDiv"): - x, y = op.args - x_bc = broadcast_dims(op.all_idxs, x.idxs, x.atom.val) - y_bc = broadcast_dims(op.all_idxs, y.idxs, y.atom.val) - if op.op_name in ("IAdd", "FAdd"): - return jnp.add(x_bc, y_bc) - elif op.op_name in ("IMul", "FMul"): - return jnp.multiply(x_bc, y_bc) - if op.op_name in ("FDiv",): - return jnp.divide(x_bc, y_bc) - else: - raise Exception("Not implemented: " + str(op.op_name)) - elif op.op_name == "Iota": - n, = op.size_args - val = jnp.arange(n) - val_bc = broadcast_dims(op.all_idxs, [], val) - return val_bc - elif op.op_name == "Id": - x, = op.args - x_bc = broadcast_dims(op.all_idxs, x.idxs, x.atom.val) - return x_bc - elif op.op_name == "Get": - x, idx = op.args - out_shape = [i.size for i in op.all_idxs] - x_idxs_used = get_stack_idxs_used(op.all_idxs, x.idxs) - leading_idx_arrays = [] - for i, idx_used in enumerate(x_idxs_used): - if idx_used: - leading_idx_arrays.append(nth_iota(out_shape, i)) - else: - pass - payload_idx_array = broadcast_dims(op.all_idxs, idx.idxs, idx.atom.val) - out = x.atom.val[tuple(leading_idx_arrays) + (payload_idx_array,)] - return out - elif op.op_name == "IntToReal": - x, = op.args - real_val = jnp.array(x.atom.val, dtype="float32") - x_bc = broadcast_dims(op.all_idxs, x.idxs, real_val) - return x_bc - elif op.op_name in ("FNeg", "INeg"): - x, = op.args - x_bc = broadcast_dims(op.all_idxs, x.idxs, jnp.negative(x.atom.val)) - return x_bc - elif op.op_name == "ThreeFry2x32": - convert_64_to_32s = lambda x: np.array([x]).view(np.uint32) - convert_32s_to_64 = lambda x: np.int64(np.array(x).view(np.int64).item()) - x, y = op.args - key, count = convert_64_to_32s(x.atom.val), convert_64_to_32s(y.atom.val) - result = convert_32s_to_64(random.threefry_2x32(key, count)) - x_bc = broadcast_dims(op.all_idxs, x.idxs, result) - return x_bc - else: - raise Exception("Unrecognized op: {}".format(op.op_name)) - -def broadcast_dims(for_idxs, idxs, x): - shape = [i.size for i in for_idxs] - idxs_used = get_stack_idxs_used(for_idxs, idxs) - bcast_dims = [i for i, b in enumerate(idxs_used) if b] - return lax.broadcast_in_dim(x, shape, bcast_dims) - -def broadcast_with(x, final_shape, idxs_used): - rem_shape = list(x.shape[sum(idxs_used):]) - reshape_shape = [size if use else 1 for (size, use) in zip(final_shape, idxs_used)] - x_singletons = jnp.reshape(x, reshape_shape + rem_shape) - return jnp.broadcast_to(x_singletons, final_shape + rem_shape) - -def nth_iota(shape, i): - size = shape[i] - iota = jnp.arange(size) - idxs_used = [Discard for _ in shape] - idxs_used[i] = Use - return broadcast_with(iota, shape, idxs_used) - -Use = True -Discard = False -def get_stack_idxs_used(for_idxs, idxs): - stack_vars = [] - cur_idxs = list(idxs) - for i in for_idxs: - if cur_idxs and i == cur_idxs[0]: - stack_vars.append(Use) - cur_idxs = cur_idxs[1:] - else: - stack_vars.append(Discard) - return stack_vars - -arrayish_types = (jnp.ndarray, np.ndarray, np.int64, np.float64, np.float32) - -def subst_op(env, op): - args = [IndexedAtom(subst_atom(env, x.atom), x.idxs) for x in op.args] - return Operation(op.binders, op.op_name, op.size_args, args) - -def subst_atom(env, x): - assert isinstance(x, Atom) - if x.case == "Var": - return env[x.var] - elif x.case == "Lit": - return x - else: - assert False - -def dtype_basetype(x): - if x in [np.int32, np.int64]: - return "IntType" - elif x in [np.float32, np.float64]: - return "RealType" - else: - assert False, x - -def vec_dtype(vec): - if vec["tag"] == "IntVec": - return np.int64 - elif vec["tag"] == "DoubleVec": - return np.float64 - else: - assert False - -def atom_as_var(x): - assert isinstance(x, Atom) - i = len(global_env) - name = Name("ArrayName", "arr", i) - v = Var(name, x.ty) - assert v not in global_env - global_env[v] = x - return v - -def eval_function_application(top_arg): - def run(): - f = JaxFunction.des(top_arg[0]) - args = [Atom("Var", Var.des(x)) for x in top_arg[1]] - env = global_env.copy() - args_subst = [subst_atom(env, arg) for arg in args] - for v, arg in zip(f.binders, args_subst): - env[v] = arg - for (v, op) in f.decls: - ans = eval_op(subst_op(env, op)) - if not (v.ty == ans.ty): - print(op) - raise Exception("Unexpected type. Expected {}, got {}".format(v.ty, ans.ty)) - env[v] = ans - return [subst_atom(env, r).val for r in f.results] - outs = run() - irdump = str(make_jaxpr(run)()) - return [atom_as_var(Atom("Lit", out)).ser() for out in outs], irdump - -def check_type(ty, val): - assert isinstance(ty, Ty) - -def retrieve_arrays(arrs): - vs = map(Var.des, arrs) - return [ser_array(global_env[v].val) for v in vs] - -def just_print_it(obj): - print(obj) - return () - -def run_server(functions): - readChan, writeChan = sys.argv[1:] - with open(writeChan, "w") as w: - for line in open(readChan): - (f_idx, arg) = json.loads(line) - try: - f = functions[f_idx] - ans = {"Right" : f(arg)} - except Exception as e: - traceback.print_exc() - ans = {"Left": traceback.format_exc()} - w.write(json.dumps(ans) + "\n") - w.flush() - -if __name__ == "__main__": - run_server([eval_function_application, - retrieve_arrays, - just_print_it]) diff --git a/misc/py/mnist_to_dxbo.py b/misc/py/mnist_to_dxbo.py deleted file mode 100644 index cf3b3e339..000000000 --- a/misc/py/mnist_to_dxbo.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2019 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd - -import sys -import numpy as np -import dex_binary_object as dbo -sys.path.append("../jax") -from examples import datasets - -def oneHotToInt(xs): - xsInt = np.sum(xs * np.arange(10)[None,:], axis=1).astype(np.int64) - print(xsInt.shape) - assert np.max(xsInt) == 9 - return xsInt - -data = tuple(x.astype(np.float64) for x in datasets.mnist()) -train_images, train_labels, test_images, test_labels = data - -train_images_unflat = train_images.reshape((60000, 28, 28)) -test_images_unflat = test_images.reshape( (10000, 28, 28)) - -train_labels_int = oneHotToInt(train_labels) -test_labels_int = oneHotToInt(test_labels) - -data_out = (train_images_unflat, train_labels_int, - test_images_unflat, test_labels_int) - -with open("scratch/mnist.dxbo", "w") as f: - dbo.dump(data_out, f) diff --git a/misc/py/test-dex2jax.py b/misc/py/test-dex2jax.py deleted file mode 100644 index 553642e93..000000000 --- a/misc/py/test-dex2jax.py +++ /dev/null @@ -1,14 +0,0 @@ -from __future__ import print_function -# Copyright 2019 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd - -import dex - -foo = dex.load("foo.dx") - -# print foo.addReals(1.0, 2.0) - -print(foo.f) diff --git a/src/lib/JAX.hs b/src/lib/JAX.hs deleted file mode 100644 index 7810b7220..000000000 --- a/src/lib/JAX.hs +++ /dev/null @@ -1,711 +0,0 @@ --- Copyright 2019 Google LLC --- --- Use of this source code is governed by a BSD-style --- license that can be found in the LICENSE file or at --- https://developers.google.com/open-source/licenses/bsd - -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE PatternSynonyms #-} -{-# OPTIONS_GHC -w #-} -- XXX: Disable once fixed -{-# OPTIONS_GHC -Wno-orphans #-} - -module JAX (JAtom (..), JVar, typeToJType, jTypeToType, - JExpr, JaxFunction, toJaxFunction, simplifyJaxFunction, - dceJaxFunction) where - -import Control.Applicative -import Control.Monad.Except hiding (Except) -import Control.Monad.Reader -import Control.Monad.Writer -import Control.Monad.State.Strict -import Data.Aeson hiding (Array) -import Data.Maybe -import Data.Text.Prettyprint.Doc -import GHC.Generics -import GHC.Stack - -import Env -import Syntax -import PPrint -import Type -import Cat -import Array - --- === JAXish IR === - -type AxisSize = Int -type JVar = VarP JType -type IdxVar = VarP AxisSize -data IdxFlavor = MapIdx | SumIdx deriving (Generic, Show, Eq) - -data JDecl = JLet JVar JFor deriving (Generic, Show, Eq) -data JExpr = JExpr [JDecl] [JAtom] deriving (Generic, Show, Eq) -data JAtom = JLit [Int] Array | JVar JVar deriving (Generic, Show, Eq) -data IdxAtom = IdxAtom JAtom [IdxVar] deriving (Generic, Show, Eq) -data JType = JType [AxisSize] ScalarBaseType deriving (Generic, Show, Eq) -data JaxFunction = JaxFunction [JVar] JExpr deriving (Generic, Show, Eq) - -type JOp = JOpP IdxAtom -data JOpP e = JId e - | JIota AxisSize - | JGet e e - | JScalarBinOp BinOp e e - | JThreeFry2x32 e e - | JScalarUnOp UnOp e - deriving (Generic, Functor, Foldable, Traversable, Show, Eq) - -data TmpAtom = TmpLeaf IdxAtom - | TmpRefName Var - | TmpCon (PrimCon TmpAtom) - deriving (Generic, Show, Eq) - -data JFor = JFor [(IdxVar, IdxFlavor)] (JOpP IdxAtom) - deriving (Generic, Show, Eq) - -type JScope = Env () -- TODO: put bindings here too - --- === lowering from Expr === - -type JEmbedEnv = (JScope, [JDecl]) -type JSubstEnv = Env TmpAtom -type EffectState = Env (Int, TmpAtom) -type IdxEnv = [IdxVar] -- for i j. --> [i, j] -type JaxM = ReaderT IdxEnv (StateT EffectState (Cat JEmbedEnv)) - -runJaxM :: JaxM a -> a -runJaxM m = fst $ flip runCat mempty $ - flip evalStateT mempty $ flip runReaderT mempty m - -toJaxFunction :: ([Var], Block) -> JaxFunction -toJaxFunction (vs, block) = runJaxM $ do - vs' <- mapM freshVar vs - let env = newEnv vs $ map varToTmpAtom vs - (result, (_, decls)) <- scoped $ do - result <- toJaxBlock env block - return $ flattenAtom result - let jvs = map (fmap typeToJType) vs' - return $ JaxFunction jvs $ JExpr decls result - -varToTmpAtom :: Var -> TmpAtom -varToTmpAtom v = TmpLeaf $ IdxAtom (JVar $ fmap typeToJType v) [] - -flattenAtom :: TmpAtom -> [JAtom] -flattenAtom atom = - execWriter $ traverseArrayLeaves atom $ \(IdxAtom x []) -> do - tell [x] - return $ IdxAtom x [] - -toJaxBlock :: JSubstEnv -> Block -> JaxM TmpAtom -toJaxBlock env (Block [] result) = toJaxExpr env result -toJaxBlock env (Block (decl:decls) result) = do - env' <- toJaxDecl env decl - toJaxBlock (env <> env') body - where body = Block decls result - -toJaxDecl :: JSubstEnv -> Decl -> JaxM JSubstEnv -toJaxDecl env (Let _ v rhs) = do - ans <- toJaxExpr env rhs - return $ v @> ans - -toJaxAtom :: JSubstEnv -> Atom -> TmpAtom -toJaxAtom env atom = case atom of - Var v@(_:>RefTy _ _) -> TmpRefName v - Var v -> fromMaybe (error "lookup failed") $ envLookup env v - Con (Lit x) -> tmpAtomScalarLit x - Con con -> TmpCon $ fmap (toJaxAtom env) con - _ -> error $ "Not implemented: " ++ pprint atom - -toJaxExpr :: JSubstEnv -> Expr -> JaxM TmpAtom -toJaxExpr env expr = case expr of - -- For _ (LamExpr i@(_ :> FixedIntRange 0 n) body) -> do - -- idxEnv <- ask - -- -- TODO: scope this to avoid burning through names - -- i' <-freshIdxVar n - -- iotaVar <- emitJFor $ JFor [] $ JIota n - -- let iotaAtom = iotaVarAsIdx (FixedIntRange 0 n) $ IdxAtom (JVar iotaVar) [i'] - -- let env' = env <> i @> iotaAtom - -- ans <- extendR [i'] $ toJaxBlock env' body - -- liftM (TmpCon . AFor (varAnn i)) $ traverseArrayLeaves ans $ \x -> do - -- ansVar <- emitJFor $ JFor (map (,MapIdx) (idxEnv ++ [i'])) $ JId x - -- return $ IdxAtom (JVar ansVar) idxEnv - -- TabGet xs i -> do - -- let (TmpCon (AFor _ tab)) = toJaxAtom env xs - -- let i' = toJaxAtom env i - -- traverseArrayLeaves tab $ \x -> emitOp $ JGet x $ fromScalarAtom i' - Op op -> toJaxOp $ fmap (toJaxAtom env) op - -toJaxOp :: PrimOp TmpAtom -> JaxM TmpAtom -toJaxOp op = case op of - ScalarBinOp op x y -> liftM toScalarAtom $ - emitOp $ JScalarBinOp op (fromScalarAtom x) (fromScalarAtom y) - IndexAsInt x -> liftM toScalarAtom $ - emitOp $ JId (fromScalarAtom x) - ScalarUnOp op x -> liftM toScalarAtom $ - emitOp $ JScalarUnOp op (fromScalarAtom x) - PrimEffect (TmpRefName refVar) m -> do - case m of - MTell x -> do - (depth, curAccum) <- gets (! refVar) - xSum <- sumPoly depth x - newAccum <- local (take depth) $ addPoly curAccum xSum - modify (<> refVar @> (depth, newAccum)) - return $ TmpCon $ UnitCon - _ -> error $ "Not implemented: " ++ show op - -- RecGet x i -> do - -- case x of - -- TmpCon (RecCon r) -> return $ recGet r i - -- val -> error $ "Expected a record, got: " ++ show val - FFICall s _ args | s == "threefry2x32" -> liftM toScalarAtom $ - emitOp $ JThreeFry2x32 (fromScalarAtom x) (fromScalarAtom y) - where x:y:[] = args - _ -> error $ "Not implemented: " ++ show op - --- toJaxHof :: PrimHof TmpAtom (LamExpr, JSubstEnv) -> JaxM TmpAtom --- toJaxHof hof = case hof of --- RunWriter (LamExpr refVar _ body, env) -> do --- idxEnvDepth <- asks length --- let (RefTy wTy) = varAnn refVar --- wInit <- zerosAt wTy --- modify (<> refVar @> (idxEnvDepth, wInit)) --- aResult <- toJaxBlock env body --- wFinal <- gets $ snd . (! refVar) --- modify $ envDelete (varName refVar) --- return $ TmpCon $ RecCon $ Tup [aResult, wFinal] --- _ -> error $ "Not implemented: " ++ show hof - -iotaVarAsIdx :: Type -> IdxAtom -> TmpAtom -iotaVarAsIdx = undefined --- iotaVarAsIdx n x = TmpCon $ AsIdx n $ toScalarAtom x - -fromScalarAtom :: HasCallStack => TmpAtom -> IdxAtom -fromScalarAtom atom = case atom of - TmpCon (Coerce _ x) -> fromScalarAtom x - --TmpCon (AGet (TmpLeaf x)) -> x - _ -> error $ "Not a scalar atom: " ++ show atom - -toScalarAtom :: IdxAtom -> TmpAtom -toScalarAtom x = undefined --TmpCon $ AGet $ TmpLeaf x - -traverseArrayLeaves :: HasCallStack => Monad m => TmpAtom -> (IdxAtom -> m IdxAtom) -> m TmpAtom -traverseArrayLeaves atom f = case atom of - TmpCon con -> liftM TmpCon $ case con of - --AFor n body -> liftM (AFor n) $ traverseArrayLeaves body f - --AGet (TmpLeaf x) -> liftM (AGet . TmpLeaf) $ f x - _ -> error $ "Not implemented: " ++ show atom - TmpLeaf x -> liftM TmpLeaf $ f x - TmpRefName _ -> error "Unexpected reference name" - -typeToJType :: Type -> JType -typeToJType ty = case ty of - TC (JArrayType dims b) -> JType dims b - _ -> error $ "Not a jax type: " ++ pprint ty - -jTypeToType :: JType -> Type -jTypeToType ty = case ty of - JType shape b -> TC $ JArrayType shape b - -emitOp :: JOpP IdxAtom -> JaxM IdxAtom -emitOp op = do - idxEnv <- ask - v <- emitJFor $ JFor (map (,MapIdx) idxEnv) op - return $ IdxAtom (JVar v) idxEnv - -zerosAt :: Type -> JaxM TmpAtom -zerosAt ty = case ty of - BaseTy (Scalar FloatType) -> return $ tmpAtomScalarLit $ FloatLit 0.0 - _ -> error "Not implemented" - -addPoly :: TmpAtom -> TmpAtom -> JaxM TmpAtom -addPoly x y = case getType x of - BaseTy (Scalar FloatType) -> liftM toScalarAtom $ - emitOp $ JScalarBinOp FAdd (fromScalarAtom x) (fromScalarAtom y) - ty -> error $ "Not implemented: " ++ pprint ty - -sumPoly :: Int -> TmpAtom -> JaxM TmpAtom -sumPoly depth atom = do - idxEnv <- ask - let (forIdxs, sumIdxs) = splitAt depth idxEnv - let idxBinders = zip forIdxs (repeat MapIdx) - <> zip sumIdxs (repeat SumIdx) - traverseArrayLeaves atom $ \x -> do - v <- emitJFor $ JFor idxBinders $ JId x - return $ IdxAtom (JVar v) forIdxs - -tmpAtomScalarLit :: LitVal -> TmpAtom -tmpAtomScalarLit x = toScalarAtom $ IdxAtom (JLit [] $ arrayFromScalar x) [] - -instance HasType TmpAtom where - typeCheck atom = case atom of - TmpLeaf idxAtom -> return $ jTypeToType $ getJType idxAtom - TmpRefName _ -> undefined - TmpCon con -> undefined - --- === Simplification pass on JAX IR === - -type BindingEnv = Env (VarUsage, JFor) -type SimpEnv = (Env JAtom, BindingEnv) -type SimpM = Cat JEmbedEnv - -pattern JForId :: JAtom -> JFor -pattern JForId x = JFor [] (JId (IdxAtom x [])) - -simplifyJaxFunction :: JaxFunction -> JaxFunction -simplifyJaxFunction (JaxFunction vs expr) = fst $ flip runCat mempty $ do - vs' <- mapM freshVar vs - let env = (newEnv vs (map JVar vs'), mempty) - (result', (_, decls')) <- scoped $ simplifyJaxExpr env expr - return $ JaxFunction vs' $ JExpr decls' result' - -simplifyJaxExpr :: SimpEnv -> JExpr -> SimpM [JAtom] -simplifyJaxExpr env expr@(JExpr decls results) = do - let usageEnv = collectUsage expr - (_, env') <- flip runCatT env $ mapM (simplifyJaxDecl usageEnv) decls - let (substEnv, _) = env <> env' - return $ fmap (jSubst substEnv) results - -simplifyJaxDecl :: UsageEnv -> JDecl -> CatT SimpEnv SimpM () -simplifyJaxDecl usageEnv (JLet v jfor) = do - (substEnv, bindingEnv) <- look - let usage = lookupUse usageEnv v - let jfor' = simpFix (simplifyJFor bindingEnv) $ jSubst substEnv jfor - case jfor' of - JForId x -> extend $ asFst (v @> x) - _ -> do - vOut <- lift $ emitJFor jfor' - extend $ (v @> JVar vOut, vOut @> (usage, jfor')) - -simplifyJFor :: BindingEnv -> JFor -> Maybe JFor -simplifyJFor env jfor@(JFor idxs op) = - liftM (JFor idxs) (mapParallel (inlineFromId env) op) - <|> inlineGetIota env jfor - <|> inlineIntoId env jfor - <|> liftM (JFor idxs) (algebraicSimp op) - <|> checkProgress etaReduce jfor - -inlineGetIota :: BindingEnv -> JFor -> Maybe JFor -inlineGetIota env (JFor idxBinders op) = do - let idxEnv = map fst idxBinders - JGet (IdxAtom x xIdxs) (IdxAtom (JVar v) idxs) <- return op - (_, varDef) <- envLookup env v - (JFor [] (JIota _), [i]) <- return $ betaReduce varDef idxs - let idxs' = xIdxs ++ [i] - -- TODO: have a more direct way to check index ordering condition - case checkIdxEnv idxs' idxEnv of - Left _ -> Nothing - Right () -> return $ JFor idxBinders $ JId $ IdxAtom x idxs' - -inlineIntoId :: BindingEnv -> JFor -> Maybe JFor -inlineIntoId env (JFor idxs op) = do - JId (IdxAtom (JVar v) appIdxs) <- return op - (UsedOnce, jfor) <- envLookup env v - let idxScope = foldMap ((@>()) . fst) idxs - let jforFresh = refreshIdxVars idxScope jfor - (jfor', []) <- return $ betaReduce jforFresh appIdxs - let (JFor idxs' op') = refreshIdxVars idxScope jfor' - return $ JFor (idxs <> idxs') op' - -inlineFromId :: BindingEnv -> IdxAtom -> Maybe IdxAtom -inlineFromId env idxAtom = do - IdxAtom (JVar v) idxs <- return idxAtom - (_, jfor) <- envLookup env v - (JFor [] (JId (IdxAtom x idxs')), idxs'') <- return $ betaReduce jfor idxs - return $ IdxAtom x (idxs' <> idxs'') - -algebraicSimp :: JOp -> Maybe JOp -algebraicSimp op = case op of - JScalarBinOp FAdd x y - | fromScalarLit x == Just (FloatLit 0) -> Just $ JId y - | fromScalarLit y == Just (FloatLit 0) -> Just $ JId x - _ -> Nothing - -fromScalarLit :: IdxAtom -> Maybe LitVal -fromScalarLit (IdxAtom (JLit [] x) []) = scalarFromArray x -fromScalarLit _ = Nothing - --- === variable usage pass === - -data VarUsage = Unused | UsedOnce | ArbitraryUse deriving (Show, Eq) - -type UsageEnv = MonMap Name VarUsage - -collectUsage :: JExpr -> UsageEnv -collectUsage (JExpr decls result) = snd $ flip runCat mempty $ do - extend $ useFreeVars ArbitraryUse result - forM_ (reverse decls) $ \(JLet v jfor) -> do - use <- looks $ flip lookupUse v - case use of - Unused -> return () - _ -> extend $ useFreeVars UsedOnce jfor - -lookupUse :: UsageEnv -> VarP ann -> VarUsage -lookupUse env (v:>_) = monMapLookup env v - -useFreeVars :: HasJVars a => VarUsage -> a -> UsageEnv -useFreeVars use x = foldMap (useVar use) $ envNames $ freeJVars x - -useVar :: VarUsage -> Name -> UsageEnv -useVar use v = monMapSingle v use - -instance Semigroup VarUsage where - Unused <> use = use - use <> Unused = use - _ <> _ = ArbitraryUse - -instance Monoid VarUsage where - mempty = Unused - -dceJaxFunction :: JaxFunction -> JaxFunction -dceJaxFunction (JaxFunction vs expr@(JExpr decls result)) = - JaxFunction vs (JExpr decls' result) - where - decls' = filter (\(JLet v _) -> lookupUse useEnv v /= Unused) decls - useEnv = collectUsage expr - --- === JAX IR builder === - -emitJFor :: MonadCat JEmbedEnv m => JFor -> m JVar -emitJFor jfor = do - v <- freshVar ("v":> getJType jfor) - extend $ (v @> (), [JLet v jfor]) - return v - -freshVar :: MonadCat JEmbedEnv m => VarP ann -> m (VarP ann) -freshVar v = do - scope <- looks fst - let v' = rename v scope - extend $ asFst (v' @> ()) - return v' - -freshIdxVar :: MonadCat JEmbedEnv m => AxisSize -> m IdxVar -freshIdxVar n = do - scope <- looks fst - let nameChoices = [Name JaxIdx name 0 | name <- ["i", "j", "k"]] - let v = renameChoice nameChoices scope :> n - extend $ asFst (v @> ()) - return v - --- === JAXy IR Types === - -type IdxTyEnv = [IdxVar] -type JTypeEnv = (Env JType, IdxEnv) - -instance Checkable JaxFunction where - checkValid (JaxFunction vs body) = do - let argTys = map varAnn vs - void $ checkJExprType (newEnv vs argTys, []) body - -checkJExprType :: JTypeEnv -> JExpr -> Except [JType] -checkJExprType initEnv (JExpr decls results) = - liftM fst $ flip runCatT initEnv $ do - forM_ decls $ \(JLet v@(_:>reqTy) jfor) -> do - env <- look - ty <- checkJType env jfor - assertEq reqTy ty "Annotation" - extend (v @> ty, []) - env <- look - forM results $ checkJType env - -class HasJType a where - getJType :: a -> JType - checkJType :: MonadError Err m => JTypeEnv -> a -> m JType - -instance HasJType JFor where - getJType (JFor idxs op) = JType (shape ++ shape') b - where - shape = [n | (_:>n, MapIdx) <- idxs] - (JType shape' b) = getJType op - - checkJType env jfor@(JFor idxs op) = - addContext ("\nChecking: " ++ pprint jfor) $ do - let idxBinders = map fst idxs - checkBinders env idxBinders - let env' = env <> (mempty, idxBinders) - let shape = [n | (_:>n, MapIdx) <- idxs] - (JType shape' b) <- checkJType env' op - return (JType (shape ++ shape') b) - -assertNoMutualShadows :: (MonadError Err m, Pretty b, Traversable f) - => f (VarP b) -> m () -assertNoMutualShadows bs = - void $ flip runCatT mempty $ forM bs $ \b -> do - env <- look - checkNoShadow env b - extend (b@>()) - -checkBinders :: (MonadError Err m, Pretty ann) => JTypeEnv -> [VarP ann] -> m () -checkBinders env bs = do - mapM_ (checkNoShadow (fst env)) bs - assertNoMutualShadows bs - -instance HasJType IdxAtom where - getJType (IdxAtom x idxs) = JType (drop (length idxs) shape) b - where JType shape b = getJType x - - checkJType (env, idxEnv) (IdxAtom x idxs) = do - JType shape b <- checkJType (env, []) x - throwIf (length idxs > length shape) CompilerErr $ - "Too many indices: " ++ pprint idxs - forM_ (zip idxs shape) $ \((_:>nAnn), nArr) -> - assertEq nArr nAnn "Index size doesn't match array shape" - checkIdxEnv idxs idxEnv - return $ JType (drop (length idxs) shape) b - -checkIdxEnv :: MonadError Err m => [IdxVar] -> IdxTyEnv -> m () -checkIdxEnv [] _ = return () -checkIdxEnv (i:_) [] = throw CompilerErr $ "Index not in env " ++ pprint i -checkIdxEnv (i:idxs) (i':idxEnv) - | varName i == varName i' = do - assertEq i' i "Index size doesn't match index env" - checkIdxEnv idxs idxEnv - | otherwise = checkIdxEnv (i:idxs) idxEnv - -instance HasJType JAtom where - getJType atom = case atom of - JVar (_:> ty) -> ty - JLit shape arr -> JType shape b - where (_, Scalar b) = arrayType arr - - checkJType (env,_) atom = case atom of - JVar v@(_:> ty) -> do - case envLookup env v of - Just reqTy -> do - assertEq reqTy ty "JVar" - return ty - _ -> throw CompilerErr $ "Lookup failed: " ++ pprint v - JLit shape arr -> return $ JType shape b - where (_, Scalar b) = arrayType arr - -instance (Pretty a, HasJType a) => HasJType (JOpP a) where - getJType op = ignoreExcept $ addContext ("Getting type of: " ++ pprint op) $ - traverseJOpType $ fmap getJType op - checkJType env op = do - op' <- traverse (checkJType env) op - traverseJOpType op' - -traverseJOpType :: MonadError Err m => JOpP JType -> m JType -traverseJOpType jop = case jop of - JScalarBinOp op xTy' yTy' -> do - assertEq (JType [] xTy) xTy' "Arg type mismatch" - assertEq (JType [] yTy) yTy' "Arg type mismatch" - return $ JType [] outTy - where (xTy, yTy, outTy) = binOpType op - JScalarUnOp op xTy' -> do - assertEq (JType [] xTy) xTy' "Arg type mismatch" - return $ JType [] outTy - where (xTy, outTy) = unOpType op - JThreeFry2x32 xTy yTy -> do - assertEq (JType [] IntType) xTy "Arg type mismatch" - assertEq (JType [] IntType) yTy "Arg type mismatch" - return $ JType [] IntType - JId ty -> return $ ty - JIota n -> return $ JType [n] IntType - JGet (JType (_:shape) b) idxTy -> do - assertEq (JType [] IntType) idxTy "Arg type mismatch" - return $ JType shape b - JGet (JType [] _) _ -> error "Attempting to index zero-dim array" - --- === free vars and substitutions === - -class HasJVars a where - freeJVars :: a -> Env () - jSubst :: Env JAtom -> a -> a - -instance HasJVars JFor where - freeJVars (JFor _ op) = freeJVars op - jSubst env (JFor idxs op) = JFor idxs $ jSubst env op - -instance HasJVars JAtom where - freeJVars x = case x of - JLit _ _ -> mempty - JVar v -> v @> () - jSubst env x = case x of - JLit _ _ -> x - JVar v -> env ! v - -instance HasJVars IdxAtom where - freeJVars (IdxAtom x _) = freeJVars x - jSubst env (IdxAtom x idxs) = IdxAtom (jSubst env x) idxs - -instance (Traversable f, HasJVars a) => HasJVars (f a) where - freeJVars xs = foldMap freeJVars xs - jSubst env op = fmap (jSubst env) op - -etaReduce :: JFor -> JFor -etaReduce (JFor [] op) = JFor [] op -etaReduce (JFor (b:bs) op) = do - let (JFor bs' op') = etaReduce (JFor bs op) - fromMaybe (JFor (b:bs') op') $ do - (i, MapIdx) <- return b - [] <- return bs' - JId (IdxAtom x idxs) <- return op' - (idxs', i') <- unsnoc idxs - unless (i == i') Nothing - return $ JFor bs' $ JId $ IdxAtom x idxs' - -betaReduce :: JFor -> [IdxVar] -> (JFor, [IdxVar]) -betaReduce jfor idxs = do - let freeVs = foldMap (@>()) idxs - let jfor' = refreshIdxVars freeVs jfor - betaReduceRec jfor' idxs - -betaReduceRec :: JFor -> [IdxVar] -> (JFor, [IdxVar]) -betaReduceRec jfor [] = (jfor, []) -betaReduceRec jfor idxs = do - let Just (rest, i) = unsnoc idxs - let (jfor', idxs') = betaReduceRec jfor rest - fromMaybe (jfor', idxs' ++ [i]) $ do - [] <- return idxs' - JFor ((b,MapIdx):bs) op <- return jfor' - return (JFor bs $ substOp (b @> i) op, []) - -refreshIdxVars :: JScope -> JFor -> JFor -refreshIdxVars scope (JFor binders op) = do - let (idxs, flavors) = unzip binders - let idxs' = fst $ renames idxs () scope - JFor (zip idxs' flavors) $ substOp (newEnv idxs idxs') op - --- TODO: extend `HasJVars` to handle index var substitution too -substOp :: Env IdxVar -> JOp -> JOp -substOp env op = flip fmap op $ \(IdxAtom x atomIdxs) -> - IdxAtom x $ map trySubst atomIdxs - where trySubst v = fromMaybe v (envLookup env v) - --- TODO: make a right-appending list we can actually pattern-match on -unsnoc :: [a] -> Maybe ([a], a) -unsnoc xs = case reverse xs of - [] -> Nothing - x:rest -> Just (reverse rest, x) - --- === simplification combinators === - --- Simplifiers must only produce `Just` if some progress was made. --- (e.g. avoid `mySimp x = trySimp x <|> pure x`) - -simpFix :: Eq a => (a -> Maybe a) -> a -> a -simpFix f x = case f x of - Nothing -> x - Just x' -> simpFix f x' - --- TODO: more efficient implementation without using Eq -mapParallel :: (Eq a, Eq (f a), Functor f) => (a -> Maybe a) -> f a -> Maybe (f a) -mapParallel f = checkProgress (fmap (\x -> fromMaybe x (f x))) - -checkProgress :: Eq a => (a -> a) -> a -> Maybe a -checkProgress f x | x' == x = Nothing - | otherwise = Just x' - where x' = f x - --- === instances === - -instance Pretty JaxFunction where - pretty (JaxFunction vs body) = "lambda" <+> pretty vs <> hardline <> pretty body - -instance Pretty JExpr where - pretty (JExpr decls results) = - foldMap (\d -> pretty d <> hardline) decls <> "results:" <+> pretty results - -instance Pretty IdxAtom where - pretty (IdxAtom x idxs) = pretty x <> foldMap (\(i:>_) -> "." <> pretty i) idxs - -instance Pretty JAtom where - pretty (JLit _ x) = pretty $ scalarFromArray x - pretty (JVar (v:>_)) = pretty v - -instance Pretty JDecl where - pretty (JLet v rhs) = pretty v <+> "=" <+> pretty rhs - -instance Pretty a => Pretty (JOpP a) where - pretty op = prettyOpName op <+> foldMap (\x -> parens (pretty x) <> " ") op - -instance Pretty JType where - pretty (JType s b) = pretty b <> pretty s - -instance Pretty JFor where - pretty (JFor [] op) = pretty op - pretty jfor@(JFor ((_,flavor):_) _) = - pretty s <+> prettyJForCtx flavor jfor - where - s :: String - s = case flavor of MapIdx -> "for" - SumIdx -> "sum" -instance Pretty TmpAtom where - pretty _ = "" - -prettyJForCtx :: IdxFlavor -> JFor -> Doc ann -prettyJForCtx flavor jfor@(JFor idxs op) = case idxs of - [] -> " . " <> pretty op - (i, flavor'):rest - | flavor == flavor' -> pretty (varName i) <+> - prettyJForCtx flavor (JFor rest op) - | otherwise -> pretty jfor - -prettyOpName :: JOpP a -> Doc ann -prettyOpName jop = case jop of - JScalarBinOp op _ _ -> pretty $ show op - JScalarUnOp op _ -> pretty $ show op - JThreeFry2x32 _ _ -> "threefry2x32" - JIota n -> "iota@" <> pretty n - JGet _ _ -> "get" - JId _ -> "id" - -instance ToJSON JDecl -instance FromJSON JDecl - -instance ToJSON JaxFunction -instance FromJSON JaxFunction - -instance ToJSON JExpr -instance FromJSON JExpr - -instance ToJSON JFor -instance FromJSON JFor - -instance ToJSON JAtom -instance FromJSON JAtom - -instance ToJSON IdxAtom -instance FromJSON IdxAtom - -instance ToJSON IdxFlavor -instance FromJSON IdxFlavor - -instance (ToJSON ann) => ToJSON (VarP ann) -instance (FromJSON ann) => FromJSON (VarP ann) - -instance (ToJSON e) => ToJSON (JOpP e) -instance (FromJSON e) => FromJSON (JOpP e) - -instance ToJSON JType -instance FromJSON JType - -instance ToJSON Name -instance FromJSON Name - -instance ToJSON NameSpace -instance FromJSON NameSpace - -instance ToJSON BinOp -instance FromJSON BinOp - -instance ToJSON UnOp -instance FromJSON UnOp - -instance ToJSON CmpOp -instance FromJSON CmpOp - -instance ToJSON LitVal -instance FromJSON LitVal - -instance ToJSON BaseType -instance FromJSON BaseType - -instance ToJSON ScalarBaseType -instance FromJSON ScalarBaseType - -instance ToJSON Array -instance FromJSON Array - -instance ToJSON Vec -instance FromJSON Vec diff --git a/src/lib/PipeRPC.hs b/src/lib/PipeRPC.hs deleted file mode 100644 index 4369160f2..000000000 --- a/src/lib/PipeRPC.hs +++ /dev/null @@ -1,60 +0,0 @@ --- Copyright 2019 Google LLC --- --- Use of this source code is governed by a BSD-style --- license that can be found in the LICENSE file or at --- https://developers.google.com/open-source/licenses/bsd - -module PipeRPC (PipeServer, startPipeServer, callPipeServer, psPop) where - -import Control.Concurrent.MVar -import Control.Monad -import Control.Monad.IO.Class -import Data.Aeson -import Data.ByteString.Lazy.Char8 (pack, unpack) -import GHC.IO.Handle.FD -import System.IO -import System.Process - -data PipeServer f = PipeServer { _psLock :: MVar () - , _psSendHandle :: Handle - , _psReceiveHandle :: Handle - , psFunctionIndex :: Int} - -startPipeServer :: MonadIO m => FilePath -> [String] -> m (PipeServer f) -startPipeServer cmd args = liftIO $ do - ((clientRead, _), (_, serverWrite)) <- createPipeWithNames - ((_, serverRead), (clientWrite, _)) <- createPipeWithNames - void $ createProcess $ proc cmd $ args ++ [serverRead, serverWrite] - lock <- newMVar () - return $ PipeServer lock clientWrite clientRead 0 - -psPop :: PipeServer (head, tail) -> PipeServer tail -psPop server = server { psFunctionIndex = 1 + psFunctionIndex server } - -callPipeServer :: (MonadIO m, ToJSON a, FromJSON b) - => PipeServer (a -> b, tail) -> a -> m b -callPipeServer (PipeServer lock sendHandle receiveHandle fIdx) arg = liftIO $ do - void $ takeMVar lock - let request = unpack $ encode (fIdx, arg) - hPutStrLn sendHandle request - response <- hGetLine receiveHandle - putMVar lock () - case eitherDecode (pack response) of - Right x -> case x of - Right x' -> return x' - Left s -> error $ "Error thrown by server:\n" ++ s - Left s -> error $ s ++ "\nDecoding error. Full response:\n" ++ response - -createPipeWithNames :: IO ((Handle, String), (Handle, String)) -createPipeWithNames = do - (r, w) <- createPipe - hSetBuffering r LineBuffering - hSetBuffering w LineBuffering - rName <- unixHandleName r - wName <- unixHandleName w - return ((r,rName), (w, wName)) - -unixHandleName :: Handle -> IO String -unixHandleName h = do - fd <- handleToFd h - return $ "/dev/fd/" ++ show fd From 75c4cb6891034136b6440ab3221c09746efdb5b2 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 23 Dec 2020 17:05:06 +0000 Subject: [PATCH 020/105] Improve the benchmark harness Use a stable timer, and take a single measurement for a whole bunch of runs, to amortize the measurement cost (~1us). Thanks to this, the lowest numbers we can possibly benchmark are ~150ns on my machine (this is for running a `1 + 1` program). This is still suspiciously high, but I think it might be connected to the fact that the program we emit actually does contain a call to `posix_memalign`, and so it might just be the cost of memory allocation (especially that we never free the result!). --- dex.cabal | 2 +- src/lib/LLVMExec.hs | 56 ++++++++++++++++++++++----------------------- src/lib/PPrint.hs | 4 ++-- src/lib/Syntax.hs | 2 +- src/lib/TopLevel.hs | 34 +++++++++++++-------------- src/lib/Util.hs | 9 ++++++++ 6 files changed, 58 insertions(+), 49 deletions(-) diff --git a/dex.cabal b/dex.cabal index 114014720..b1390b84c 100644 --- a/dex.cabal +++ b/dex.cabal @@ -37,7 +37,7 @@ library RenderHtml, LiveOutput, Simplify, TopLevel, Autodiff, Interpreter, Logging, CUDA, LLVM.JIT, LLVM.Shims - build-depends: base, containers, mtl, bytestring, time, + build-depends: base, containers, mtl, bytestring, llvm-hs-pure, llvm-hs, -- Parsing megaparsec, parser-combinators, diff --git a/src/lib/LLVMExec.hs b/src/lib/LLVMExec.hs index 43d2c27ae..71fa92aa7 100644 --- a/src/lib/LLVMExec.hs +++ b/src/lib/LLVMExec.hs @@ -29,7 +29,6 @@ import qualified LLVM.PassManager as P import qualified LLVM.Transforms as P import qualified LLVM.Target as T import LLVM.Context -import Data.Time.Clock (getCurrentTime, diffUTCTime) import System.IO import System.IO.Unsafe import System.IO.Temp @@ -54,6 +53,7 @@ import Syntax import Resources import CUDA (synchronizeCUDA) import LLVM.JIT +import Util (measureSeconds) -- === One-shot evaluation === @@ -62,48 +62,48 @@ compileAndEval logger ast fname args resultTypes = do allocaBytes (length args * cellSize) $ \argsPtr -> allocaBytes (length resultTypes * cellSize) $ \resultPtr -> do storeLitVals argsPtr args - evalTime <- compileOneOff logger ast fname $ checkedCallFunPtr False argsPtr resultPtr + evalTime <- compileOneOff logger ast fname $ checkedCallFunPtr argsPtr resultPtr logThis logger [EvalTime evalTime Nothing] loadLitVals resultPtr resultTypes -compileAndBench :: Logger [Output] -> L.Module -> String -> [LitVal] -> [BaseType] -> IO [LitVal] -compileAndBench logger ast fname args resultTypes = do +compileAndBench :: Bool -> Logger [Output] -> L.Module -> String -> [LitVal] -> [BaseType] -> IO [LitVal] +compileAndBench shouldSyncCUDA logger ast fname args resultTypes = do allocaBytes (length args * cellSize) $ \argsPtr -> allocaBytes (length resultTypes * cellSize) $ \resultPtr -> do storeLitVals argsPtr args compileOneOff logger ast fname $ \fPtr -> do - -- First warmup iteration, which we also use to get the results - void $ checkedCallFunPtr True argsPtr resultPtr fPtr - results <- loadLitVals resultPtr resultTypes - let run = do - time <- checkedCallFunPtr True argsPtr resultPtr fPtr - _benchResults <- loadLitVals resultPtr resultTypes - -- TODO: Free results! - return time - exampleDuration <- run - let timeBudget = 2 -- seconds - let benchRuns = (ceiling $ timeBudget / exampleDuration) :: Int - times <- forM [1..benchRuns] $ const run - let avgTime = sum times / (fromIntegral benchRuns) - logThis logger [EvalTime avgTime (Just benchRuns)] + ((avgTime, benchRuns, results), totalTime) <- measureSeconds $ do + -- First warmup iteration, which we also use to get the results + void $ checkedCallFunPtr argsPtr resultPtr fPtr + results <- loadLitVals resultPtr resultTypes + let run = do + exitCode <- callFunPtr fPtr argsPtr resultPtr + unless (exitCode == 0) $ throwIO $ Err RuntimeErr Nothing "" + -- TODO: Free results! + exampleDuration <- snd <$> measureSeconds run + let timeBudget = 2 -- seconds + let benchRuns = (ceiling $ timeBudget / exampleDuration) :: Int + totalTime <- liftM snd $ measureSeconds $ do + forM_ [1..benchRuns] $ const run + when shouldSyncCUDA $ synchronizeCUDA + let avgTime = totalTime / (fromIntegral benchRuns) + return (avgTime, benchRuns, results) + logThis logger [EvalTime avgTime (Just (benchRuns, totalTime))] return results -foreign import ccall "dynamic" +foreign import ccall unsafe "dynamic" callFunPtr :: DexExecutable -> Ptr () -> Ptr () -> IO DexExitCode type DexExecutable = FunPtr (Ptr () -> Ptr () -> IO DexExitCode) type DexExitCode = Int -checkedCallFunPtr :: Bool -> Ptr () -> Ptr () -> DexExecutable -> IO Double -checkedCallFunPtr sync argsPtr resultPtr fPtr = do - t1 <- getCurrentTime - exitCode <- callFunPtr fPtr argsPtr resultPtr - when sync $ synchronizeCUDA - t2 <- getCurrentTime +checkedCallFunPtr :: Ptr () -> Ptr () -> DexExecutable -> IO Double +checkedCallFunPtr argsPtr resultPtr fPtr = do + (exitCode, duration) <- measureSeconds $ do + exitCode <- callFunPtr fPtr argsPtr resultPtr + return exitCode unless (exitCode == 0) $ throwIO $ Err RuntimeErr Nothing "" - return $ t2 `secondsSince` t1 - where - secondsSince end start = realToFrac $ end `diffUTCTime` start + return duration compileOneOff :: Logger [Output] -> L.Module -> String -> (DexExecutable -> IO a) -> IO a compileOneOff logger ast name f = do diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index 6022fd00d..5f2bb073c 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -521,8 +521,8 @@ instance Pretty Output where benchName <> hardline <> "Compile time: " <> prettyDuration compileTime <> hardline <> "Run time: " <> prettyDuration runTime <+> - (case stats of Just runs -> "\t" <> parens ("based on" <+> p runs <+> "runs") - Nothing -> "") + (case stats of Just (runs, _) -> "\t" <> parens ("based on" <+> p runs <+> "runs") + Nothing -> "") where benchName = case name of "" -> "" _ -> "\n" <> p name pretty (PassInfo name s) = "===" <+> p name <+> "===" <> hardline <> p s diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index e4c17b142..1da6b7889 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -614,7 +614,7 @@ type LitProg = [(SourceBlock, Result)] type SrcCtx = Maybe SrcPos data Result = Result [Output] (Except ()) deriving (Show, Eq) -type BenchStats = Int -- number of runs +type BenchStats = (Int, Double) -- number of runs, total benchmarking time data Output = TextOut String | HtmlOut String | PassInfo PassName String diff --git a/src/lib/TopLevel.hs b/src/lib/TopLevel.hs index 1e05bfe1d..a36c18b86 100644 --- a/src/lib/TopLevel.hs +++ b/src/lib/TopLevel.hs @@ -14,8 +14,8 @@ import Control.Monad.Reader import Control.Monad.Except hiding (Except) import Data.Text.Prettyprint.Doc import Data.String +import Data.Maybe import Data.List (partition) -import Data.Time.Clock (getCurrentTime, diffUTCTime) import qualified Data.Map.Strict as M import Syntax @@ -34,7 +34,7 @@ import Logging import LLVMExec import PPrint import Parser -import Util (highlightRegion) +import Util (highlightRegion, measureSeconds) import Optimize import Parallelize @@ -130,16 +130,18 @@ processLogs logLevel logs = case logLevel of where (compileTime, runTime, benchStats) = timesFromLogs logs timesFromLogs :: [Output] -> (Double, Double, Maybe BenchStats) -timesFromLogs logs = (totalTime - evalTime, evalTime, benchStats) +timesFromLogs logs = (totalTime - totalEvalTime, singleEvalTime, benchStats) where - (evalTime, benchStats) = case [(t, stats) | EvalTime t stats <- logs] of - [] -> (0.0, Nothing) - [(t, stats)] -> (t, stats) - _ -> error "Expect at most one result" + (totalEvalTime, singleEvalTime, benchStats) = + case [(t, stats) | EvalTime t stats <- logs] of + [] -> (0.0 , 0.0, Nothing) + [(t, stats)] -> (total, t , stats) + where total = fromMaybe t $ fmap snd stats + _ -> error "Expect at most one result" totalTime = case [tTotal | TotalTime tTotal <- logs] of - [] -> 0.0 - [t] -> t - _ -> error "Expect at most one result" + [] -> 0.0 + [t] -> t + _ -> error "Expect at most one result" isLogInfo :: Output -> Bool isLogInfo out = case out of @@ -199,8 +201,8 @@ evalBackend env block = do let (ptrBinders, ptrVals, block') = abstractPtrLiterals block let funcName = "entryFun" let mainName = Name TopFunctionName (fromString funcName) 0 - let cc = case backend of LLVMCUDA -> EntryFun CUDARequired - _ -> EntryFun CUDANotRequired + let (cc, needsSync) = case backend of LLVMCUDA -> (EntryFun CUDARequired , True ) + _ -> (EntryFun CUDANotRequired, False) let (mainFunc, impModuleUnoptimized, reconAtom) = toImpModule env backend cc mainName ptrBinders Nothing block' -- TODO: toImpModule might generate invalid Imp code, because GPU allocations @@ -213,17 +215,15 @@ evalBackend env block = do checkPass ImpPass impModule llvmAST <- liftIO $ impToLLVM logger impModule let IFunType _ _ resultTypes = impFunType $ mainFunc - let llvmEvaluate = if bench then compileAndBench else compileAndEval + let llvmEvaluate = if bench then compileAndBench needsSync else compileAndEval resultVals <- liftM (map (Con . Lit)) $ liftIO $ llvmEvaluate logger llvmAST funcName ptrVals resultTypes return $ applyNaryAbs reconAtom resultVals withCompileTime :: TopPassM a -> TopPassM a withCompileTime m = do - t1 <- liftIO $ getCurrentTime - ans <- m - t2 <- liftIO $ getCurrentTime - logTop $ TotalTime $ realToFrac $ t2 `diffUTCTime` t1 + (ans, t) <- measureSeconds $ m + logTop $ TotalTime t return ans checkPass :: (Pretty a, Checkable a) => PassName -> a -> TopPassM () diff --git a/src/lib/Util.hs b/src/lib/Util.hs index b85fee9fd..eb405a1c0 100644 --- a/src/lib/Util.hs +++ b/src/lib/Util.hs @@ -12,6 +12,7 @@ module Util (IsBool (..), group, ungroup, pad, padLeft, delIdx, replaceIdx, scanM, composeN, mapMaybe, uncons, repeated, transitiveClosure, showErr, listDiff, splitMap, enumerate, restructure, onSnd, onFst, highlightRegion, findReplace, swapAt, uncurry3, + measureSeconds, bindM2, foldMapM, lookupWithIdx, (...), zipWithT, for) where import Data.Functor.Identity (Identity(..)) @@ -21,6 +22,7 @@ import Prelude import qualified Data.Set as Set import qualified Data.Map.Strict as M import Control.Monad.State.Strict +import System.CPUTime import Cat @@ -232,3 +234,10 @@ transitiveClosure getParents seeds = unless (x `Set.member` visited) $ do extend $ Set.singleton x mapM_ go $ getParents x + +measureSeconds :: MonadIO m => m a -> m (a, Double) +measureSeconds m = do + t1 <- liftIO $ getCPUTime + ans <- m + t2 <- liftIO $ getCPUTime + return (ans, (fromIntegral $ t2 - t1) / 1e12) From d83d756c62bb32c3ac962e2c6b91f6f18eee21e8 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Wed, 23 Dec 2020 15:04:02 -0500 Subject: [PATCH 021/105] Make a "virtual stdout" stream that we can write to using the IO effect. We introduce a special global variable, `OUT_STREAM_PTR`, which at runtime resolves to a C stdio `FILE*` pointer. The runtime creates a fresh pipe every time it runs an LLVM program. It puts the write end in `OUT_STREAM_PTR`, and captures what comes out the other end in a separate thread while the LLVM program runs. The Haskell thread that reads the output needs to be able to concurrently with the LLVM program in case the pipe fills up. To ensure this can happen, we use the `-threaded` GHC option. --- dex.cabal | 3 +- lib/io.dx | 7 +++ src/lib/Imp.hs | 1 + src/lib/JIT.hs | 35 +++++++++++++-- src/lib/LLVMExec.hs | 105 +++++++++++++++++++++++++++++++------------- src/lib/Syntax.hs | 15 +++++-- src/lib/dexrt.cpp | 4 ++ tests/io-tests.dx | 23 ++++++++++ 8 files changed, 155 insertions(+), 38 deletions(-) diff --git a/dex.cabal b/dex.cabal index 8fe9d6835..e8e5f798a 100644 --- a/dex.cabal +++ b/dex.cabal @@ -51,7 +51,7 @@ library build-depends: dex-resources default-language: Haskell2010 hs-source-dirs: src/lib - ghc-options: -Wall -fPIC + ghc-options: -Wall -fPIC -threaded cxx-sources: src/lib/dexrt.cpp cxx-options: -std=c++11 -fPIC default-extensions: CPP, DeriveTraversable, TypeApplications, OverloadedStrings, @@ -77,6 +77,7 @@ executable dex default-language: Haskell2010 hs-source-dirs: src default-extensions: CPP, LambdaCase + ghc-options: -threaded if flag(optimized) ghc-options: -O3 else diff --git a/lib/io.dx b/lib/io.dx index b742d82f5..414e65303 100644 --- a/lib/io.dx +++ b/lib/io.dx @@ -33,6 +33,7 @@ def fwrite (stream:Stream WriteMode) (s:String) : {State World} Unit = (AsList n s') = s withTabPtr s' \(MkPtr ptr). %ffi fwrite Int64 ptr (IToI64 1) (IToI64 n) stream' + %ffi fflush Int64 stream' () def fread (stream:Stream ReadMode) : {State World} String = @@ -75,3 +76,9 @@ def withTempFile (action: FilePath -> {State World} a) : {State World} a = result = action tmpFile deleteFile tmpFile result + +def getOutputStream (_:Unit) : {State World} Stream WriteMode = + MkStream $ %ptrLoad OUT_STREAM_PTR + +def print (s:String) : {State World} Unit = + fwrite (getOutputStream ()) (s <> AsList _ "\n") diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index e2c9cbd0e..bcff11c93 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -1164,6 +1164,7 @@ instance Checkable ImpFunction where checkValid f@(ImpFunction (_:> IFunType cc _ _) bs block) = addContext ctx $ do let scope = foldMap (binderAsEnv . fmap (const ())) bs let env = foldMap (binderAsEnv ) bs + <> fmap (fromScalarType . fst) initTopEnv void $ flip runReaderT (env, deviceFromCallingConvention cc) $ flip runStateT scope $ checkBlock block where ctx = "Checking:\n" ++ pprint f diff --git a/src/lib/JIT.hs b/src/lib/JIT.hs index 0a624bc7e..ffd1db3a6 100644 --- a/src/lib/JIT.hs +++ b/src/lib/JIT.hs @@ -90,8 +90,10 @@ compileFunction logger fun@(ImpFunction f bs body) = case cc of extraSpecs <- gets funSpecs return ([L.GlobalDefinition mainFun], extraSpecs) EntryFun requiresCUDA -> return $ runCompile CPU $ do + (streamFDParam , streamFDOperand ) <- freshParamOpPair attrs $ i32 (argPtrParam , argPtrOperand ) <- freshParamOpPair attrs $ hostPtrTy i64 (resultPtrParam, resultPtrOperand) <- freshParamOpPair attrs $ hostPtrTy i64 + initializeOutputStream streamFDOperand argOperands <- forM (zip [0..] argTys) $ \(i, ty) -> gep argPtrOperand (i64Lit i) >>= castLPtr (scalarTy ty) >>= load when (toBool requiresCUDA) ensureHasCUDAContext @@ -99,9 +101,9 @@ compileFunction logger fun@(ImpFunction f bs body) = case cc of forM_ (zip [0..] results) $ \(i, x) -> gep resultPtrOperand (i64Lit i) >>= castLPtr (L.typeOf x) >>= flip store x mainFun <- makeFunction (asLLVMName name) - [argPtrParam, resultPtrParam] (Just $ i64Lit 0) + [streamFDParam, argPtrParam, resultPtrParam] (Just $ i64Lit 0) extraSpecs <- gets funSpecs - return ([L.GlobalDefinition mainFun], extraSpecs) + return ([L.GlobalDefinition mainFun, outputStreamPtrDef], extraSpecs) where attrs = [L.NoAlias, L.NoCapture, L.NonNull] CUDAKernelLaunch -> do (CUDAKernel kernelText) <- compileCUDAKernel logger $ impKernelToLLVMGPU fun @@ -858,12 +860,36 @@ cpuBinaryIntrinsic op x y = case L.typeOf x of floatIntrinsic ty name = ExternFunSpec (L.mkName name) ty [] [] [ty, ty] callFloatIntrinsic ty name = emitExternCall (floatIntrinsic ty name) [x, y] +-- === Output stream === + +outputStreamPtrLName :: L.Name +outputStreamPtrLName = asLLVMName outputStreamPtrName + +outputStreamPtrDef :: L.Definition +outputStreamPtrDef = L.GlobalDefinition $ L.globalVariableDefaults + { L.name = outputStreamPtrLName + , L.type' = hostVoidp + , L.linkage = L.Private + , L.initializer = Just $ C.Null hostVoidp } + +outputStreamPtr :: Operand +outputStreamPtr = L.ConstantOperand $ C.GlobalReference + (hostPtrTy hostVoidp) outputStreamPtrLName + +initializeOutputStream :: Operand -> Compile () +initializeOutputStream streamFD = do + streamPtr <- emitExternCall fdopenFun [streamFD] + store outputStreamPtr streamPtr + +outputStreamEnv :: OperandEnv +outputStreamEnv = outputStreamPtrName @> outputStreamPtr + -- === Compile monad utilities === runCompile :: Device -> Compile a -> a runCompile dev m = evalState (runReaderT m env) initState where - env = CompileEnv mempty dev + env = CompileEnv outputStreamEnv dev initState = CompileState [] [] [] "start_block" mempty mempty mempty extendOperands :: OperandEnv -> Compile a -> Compile a @@ -945,6 +971,9 @@ mallocFun = ExternFunSpec "malloc_dex" (hostPtrTy i8) [L.NoAlias] [] [i64] freeFun :: ExternFunSpec freeFun = ExternFunSpec "free_dex" L.VoidType [] [] [hostPtrTy i8] +fdopenFun :: ExternFunSpec +fdopenFun = ExternFunSpec "fdopen_w" (hostPtrTy i8) [L.NoAlias] [] [i32] + boolTy :: L.Type boolTy = i8 diff --git a/src/lib/LLVMExec.hs b/src/lib/LLVMExec.hs index e6dcc2095..fc01acd6c 100644 --- a/src/lib/LLVMExec.hs +++ b/src/lib/LLVMExec.hs @@ -29,7 +29,10 @@ import qualified LLVM.Transforms as P import qualified LLVM.Target as T import qualified LLVM.Linking as Linking import LLVM.Context +import Data.Int import Data.Time.Clock (getCurrentTime, diffUTCTime) +import GHC.IO.FD +import GHC.IO.Handle.FD import System.IO import System.IO.Unsafe import System.IO.Temp @@ -40,14 +43,17 @@ import System.Exit import Foreign.Marshal.Alloc import Foreign.Ptr +import Foreign.C.Types (CInt (..)) import Foreign.Storable hiding (alignment) import Control.Monad +import Control.Concurrent import Control.Exception hiding (throw) import Data.ByteString.Short (ShortByteString) import Data.ByteString.Char8 (unpack, pack) import qualified Data.ByteString.Char8 as B import qualified Data.Map as M import qualified Data.Set as S +import qualified Control.Exception as E import Logging import Syntax @@ -59,45 +65,84 @@ import LLVM.JIT compileAndEval :: Logger [Output] -> L.Module -> String -> [LitVal] -> [BaseType] -> IO [LitVal] compileAndEval logger ast fname args resultTypes = do - allocaBytes (length args * cellSize) $ \argsPtr -> - allocaBytes (length resultTypes * cellSize) $ \resultPtr -> do - storeLitVals argsPtr args - evalTime <- compileOneOff logger ast fname $ checkedCallFunPtr False argsPtr resultPtr - logThis logger [EvalTime evalTime Nothing] - loadLitVals resultPtr resultTypes + withPipeToLogger logger $ \fd -> + allocaBytes (length args * cellSize) $ \argsPtr -> + allocaBytes (length resultTypes * cellSize) $ \resultPtr -> do + storeLitVals argsPtr args + evalTime <- compileOneOff logger ast fname $ + checkedCallFunPtr False fd argsPtr resultPtr + logThis logger [EvalTime evalTime Nothing] + loadLitVals resultPtr resultTypes compileAndBench :: Logger [Output] -> L.Module -> String -> [LitVal] -> [BaseType] -> IO [LitVal] compileAndBench logger ast fname args resultTypes = do - allocaBytes (length args * cellSize) $ \argsPtr -> - allocaBytes (length resultTypes * cellSize) $ \resultPtr -> do - storeLitVals argsPtr args - compileOneOff logger ast fname $ \fPtr -> do - -- First warmup iteration, which we also use to get the results - void $ checkedCallFunPtr True argsPtr resultPtr fPtr - results <- loadLitVals resultPtr resultTypes - let run = do - time <- checkedCallFunPtr True argsPtr resultPtr fPtr - _benchResults <- loadLitVals resultPtr resultTypes - -- TODO: Free results! - return time - exampleDuration <- run - let timeBudget = 2 -- seconds - let benchRuns = (ceiling $ timeBudget / exampleDuration) :: Int - times <- forM [1..benchRuns] $ const run - let avgTime = sum times / (fromIntegral benchRuns) - logThis logger [EvalTime avgTime (Just benchRuns)] - return results + withPipeToLogger logger $ \fd -> + allocaBytes (length args * cellSize) $ \argsPtr -> + allocaBytes (length resultTypes * cellSize) $ \resultPtr -> do + storeLitVals argsPtr args + compileOneOff logger ast fname $ \fPtr -> do + -- First warmup iteration, which we also use to get the results + void $ checkedCallFunPtr True fd argsPtr resultPtr fPtr + results <- loadLitVals resultPtr resultTypes + let run = do + time <- checkedCallFunPtr True fd argsPtr resultPtr fPtr + _benchResults <- loadLitVals resultPtr resultTypes + -- TODO: Free results! + return time + exampleDuration <- run + let timeBudget = 2 -- seconds + let benchRuns = (ceiling $ timeBudget / exampleDuration) :: Int + times <- forM [1..benchRuns] $ const run + let avgTime = sum times / (fromIntegral benchRuns) + logThis logger [EvalTime avgTime (Just benchRuns)] + return results foreign import ccall "dynamic" - callFunPtr :: DexExecutable -> Ptr () -> Ptr () -> IO DexExitCode + callFunPtr :: DexExecutable -> Int32 -> Ptr () -> Ptr () -> IO DexExitCode -type DexExecutable = FunPtr (Ptr () -> Ptr () -> IO DexExitCode) +type DexExecutable = FunPtr (Int32 -> Ptr () -> Ptr () -> IO DexExitCode) type DexExitCode = Int -checkedCallFunPtr :: Bool -> Ptr () -> Ptr () -> DexExecutable -> IO Double -checkedCallFunPtr sync argsPtr resultPtr fPtr = do +withPipeToLogger :: Logger [Output] -> (FD -> IO a) -> IO a +withPipeToLogger logger writeAction = do + snd <$> withPipe + (\h -> readStream h $ \s -> logThis logger [TextOut s]) + (\h -> handleToFd h >>= writeAction) + +withPipe :: (Handle -> IO a) -> (Handle -> IO b) -> IO (a, b) +withPipe readAction writeAction = do + (readHandle, writeHandle) <- createPipe + readResult <- forkWithResult $ readAction readHandle + writeResult <- forkWithResult $ writeAction writeHandle + y <- writeResult <* hClose writeHandle + x <- readResult <* hClose readHandle + return (x, y) + +forkWithResult :: IO a -> IO (IO a) +forkWithResult action = do + resultMVar <- newEmptyMVar + void $ forkIO $ catch (do result <- action + putMVar resultMVar $ Right result) + (\e -> putMVar resultMVar $ Left (e::SomeException)) + return $ do + result <- takeMVar resultMVar + case result of + Left e -> E.throw e + Right result' -> return result' + +readStream :: Handle -> (String -> IO ()) -> IO () +readStream h action = go + where + go :: IO () + go = do + eof <- hIsEOF h + unless eof $ hGetLine h >>= action >> go + +checkedCallFunPtr :: Bool -> FD -> Ptr () -> Ptr () -> DexExecutable -> IO Double +checkedCallFunPtr sync fd argsPtr resultPtr fPtr = do + let (CInt fd') = fdFD fd t1 <- getCurrentTime - exitCode <- callFunPtr fPtr argsPtr resultPtr + exitCode <- callFunPtr fPtr fd' argsPtr resultPtr when sync $ synchronizeCUDA t2 <- getCurrentTime unless (exitCode == 0) $ throwIO $ Err RuntimeErr Nothing "" diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index 9ad7b9c99..3a193967a 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -46,7 +46,7 @@ module Syntax ( subst, deShadow, scopelessSubst, absArgType, applyAbs, makeAbs, applyNaryAbs, applyDataDefParams, freshSkolemVar, IndexStructure, mkConsList, mkConsListTy, fromConsList, fromConsListTy, extendEffRow, - getProjection, theWorld, initTopEnv, + getProjection, theWorld, outputStreamPtrName, initTopEnv, varType, binderType, isTabTy, LogLevel (..), IRVariant (..), applyIntBinOp, applyIntCmpOp, applyFloatBinOp, applyFloatUnOp, getIntLit, getFloatLit, sizeOf, vectorWidth, @@ -75,7 +75,7 @@ import qualified Data.List.NonEmpty as NE import qualified Data.Set as S import Data.Store (Store) import Data.Tuple (swap) -import Data.Foldable (toList) +import Data.Foldable (toList, fold) import Data.Int import Data.Word import Foreign.Ptr @@ -452,9 +452,16 @@ instance Eq EffectRow where theWorld :: Name theWorld = GlobalName "World" +outputStreamPtrName :: Name +outputStreamPtrName = GlobalName "OUT_STREAM_PTR" + initTopEnv :: TopEnv -initTopEnv = - (theWorld:>TyKind) @> (TyKind, LamBound ImplicitArrow) +initTopEnv = fold [v @> (ty, LamBound ImplicitArrow) | (v, ty) <- + [ (theWorld , TyKind) + , (outputStreamPtrName , BaseTy $ hostPtrTy $ hostPtrTy $ Scalar Word8Type)]] + +hostPtrTy :: BaseType -> BaseType +hostPtrTy ty = PtrType (AllocatedPtr, Heap CPU, ty) -- === top-level constructs === diff --git a/src/lib/dexrt.cpp b/src/lib/dexrt.cpp index d99e189df..77ef1a626 100644 --- a/src/lib/dexrt.cpp +++ b/src/lib/dexrt.cpp @@ -38,6 +38,10 @@ void free_dex(char* ptr) { free(ptr); } +void* fdopen_w(int fd) { + return fdopen(fd, "w"); +} + uint32_t rotate_left(uint32_t x, uint32_t d) { return (x << d) | (x >> (32 - d)); } diff --git a/tests/io-tests.dx b/tests/io-tests.dx index 6d3921ccd..ebd786161 100644 --- a/tests/io-tests.dx +++ b/tests/io-tests.dx @@ -18,3 +18,26 @@ include "io.dx" result : Fin 4 => Int = tabFromPtr ptr result > [0, 1, 2, 3] + +unsafeIO \(). + print $ AsList _ "testing log" +> testing log +> () + +unsafeIO \(). + for i':(Fin 10). + i = ordinal i' + if rem i 2 == 0 + then print $ show i <> AsList _ " is even" + else print $ show i <> AsList _ " is odd" +> 0 is even +> 1 is odd +> 2 is even +> 3 is odd +> 4 is even +> 5 is odd +> 6 is even +> 7 is odd +> 8 is even +> 9 is odd +> [(), (), (), (), (), (), (), (), (), ()] From 459da30a279c31281f8dc7071c5ad0dc0079b1bd Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Wed, 23 Dec 2020 17:37:22 -0500 Subject: [PATCH 022/105] Parse string literals as lists instead of tables. We can still create fixed-sized tables of characters using `['a', 'b', 'c']`. We should do the same on the pretty-printing side. --- lib/diagram.dx | 81 +++++++++++++++++++--------------------- lib/io.dx | 10 ++--- lib/png.dx | 8 ++-- lib/prelude.dx | 4 +- src/lib/Parser.hs | 7 ++-- src/lib/Syntax.hs | 4 ++ tests/eval-tests.dx | 2 +- tests/io-tests.dx | 13 +++---- tests/serialize-tests.dx | 4 +- 9 files changed, 67 insertions(+), 66 deletions(-) diff --git a/lib/diagram.dx b/lib/diagram.dx index c3889bbb7..8663769af 100644 --- a/lib/diagram.dx +++ b/lib/diagram.dx @@ -16,9 +16,6 @@ def showHex (x:Int32) : String = unsafeIO \(). (n, ptr) = %ffi showHex (Int32 & CharPtr) x stringFromCharPtr n (MkPtr ptr) --- TODO: we should add overloaded string literals so we don't need this -def str (n:Int) ?-> (s:(Fin n=>Char)) : String = AsList _ s - black : HtmlColor = (IToW8 0, IToW8 0, IToW8 0) white : HtmlColor = (IToW8 255, IToW8 255, IToW8 255) red : HtmlColor = (IToW8 255, IToW8 0, IToW8 0) @@ -55,9 +52,9 @@ def applyTransformation (transformGeom: Geom -> Geom) (d:Diagram) : Diagram = (MkDiagram (AsList _ objs)) = d - (MkDiagram $ AsList _ for i. + MkDiagram $ toList for i. (attr, p, geom) = objs.i - (attr, transformPoint p, transformGeom geom)) + (attr, transformPoint p, transformGeom geom) flipY : Diagram -> Diagram = applyTransformation (\(x,y). (x, -y)) \geom. case geom of @@ -77,7 +74,7 @@ def moveXY ((offX, offY) : Point) : (Diagram -> Diagram) = applyTransformation (\(x,y). (x + offX, y + offY) ) id def singletonDefault (geom:Geom) : Diagram = - MkDiagram $ AsList _ [(defaultGeomStyle, (0.0, 0.0), geom)] + MkDiagram $ toList [(defaultGeomStyle, (0.0, 0.0), geom)] def pointDiagram : Diagram = singletonDefault PointGeom def circle (r:Float) : Diagram = singletonDefault $ Circle r @@ -86,7 +83,7 @@ def line (p:Point) : Diagram = singletonDefault $ Line p def updateGeom (update: GeomStyle -> GeomStyle) (d:Diagram) : Diagram = (MkDiagram (AsList _ objs)) = d - MkDiagram $ AsList _ for i. + MkDiagram $ toList for i. (attr , geoms) = objs.i (update attr, geoms) @@ -105,67 +102,67 @@ def strCatUncurried ((xs,ys):(String & String)) : String = xs <> ys def (<.>) (xs:String) (ys:String) : String = strCatUncurried (xs, ys) -def quote (s:String) : String = str "\"" <.> s <.> str "\"" +def quote (s:String) : String = "\"" <.> s <.> "\"" @noinline def strSpaceCatUncurried ((s1,s2):(String & String)) : String = - s1 <.> str " " <.> s2 + s1 <.> " " <.> s2 def (<+>) (s1:String) (s2:String) : String = strSpaceCatUncurried (s1, s2) -def selfClosingBrackets (s:String) : String = str "<" <.> s <.> str "/>" +def selfClosingBrackets (s:String) : String = "<" <.> s <.> "/>" def tagBrackets (tag:String) (s:String) : String = - str "<" <.> tag <.> str ">" <.> s <.> str " tag <.> str ">" + "<" <.> tag <.> ">" <.> s <.> " tag <.> ">" @noinline def tagBracketsAttrUncurried ((tag, attr, s):(String & String & String)) : String = - str "<" <.> tag <+> attr <.> str ">" <.> s <.> str " tag <.> str ">" + "<" <.> tag <+> attr <.> ">" <.> s <.> " tag <.> ">" def tagBracketsAttr (tag:String) (attr:String) (s:String) : String = tagBracketsAttrUncurried (tag, attr, s) def makeAttr (attr:String) (val:String) : String = - attr <.> str "=" <.> quote val + attr <.> "=" <.> quote val -def htmlColorStr (cs:HtmlColor) : String = +def htmlColor(cs:HtmlColor) : String = (r, g, b) = cs - toList "#" <> (showHex $ W8ToI r) <> (showHex $ W8ToI g) <> (showHex $ W8ToI b) + "#" <> (showHex $ W8ToI r) <> (showHex $ W8ToI g) <> (showHex $ W8ToI b) -def optionalHtmlColorStr (c: Maybe HtmlColor) : String = +def optionalHtmlColor(c: Maybe HtmlColor) : String = case c of - Nothing -> str "none" - Just c' -> htmlColorStr c' + Nothing -> "none" + Just c' -> htmlColor c' @noinline def attrString (attr:GeomStyle) : String = - ( -- makeAttr (str "stroke") (optionalHtmlColorStr $ getAt #strokeColor attr) - makeAttr (str "fill") (optionalHtmlColorStr $ getAt #fillColor attr) - <+> makeAttr (str "stroke-width") (show $ getAt #strokeWidth attr)) + ( -- makeAttr "stroke" (optionalHtmlColor$ getAt #strokeColor attr) + makeAttr "fill" (optionalHtmlColor$ getAt #fillColor attr) + <+> makeAttr "stroke-width" (show $ getAt #strokeWidth attr)) def renderGeom (attr:GeomStyle) ((x,y):Point) (geom:Geom) : String = case geom of PointGeom -> pointAttr = setAt #fillColor (getAt #strokeColor attr) attr - tagBracketsAttr (str "g") (attrString pointAttr) $ selfClosingBrackets $ - (str "circle" <+> - str "cx=" <.> quote (show x) <.> - str "cy=" <.> quote (show y) <.> - str "r=\"1\"") + tagBracketsAttr "g" (attrString pointAttr) $ selfClosingBrackets $ + ("circle" <+> + "cx=" <.> quote (show x) <.> + "cy=" <.> quote (show y) <.> + "r=\"1\"") Circle r -> - tagBracketsAttr (str "g") (attrString attr) $ selfClosingBrackets $ - (str "circle" <+> - str "cx=" <.> quote (show x) <.> - str "cy=" <.> quote (show y) <.> - str "r=" <.> quote (show r)) + tagBracketsAttr "g" (attrString attr) $ selfClosingBrackets $ + ("circle" <+> + "cx=" <.> quote (show x) <.> + "cy=" <.> quote (show y) <.> + "r=" <.> quote (show r)) Rectangle w h -> - tagBracketsAttr (str "g") (attrString attr) $ selfClosingBrackets $ - (str "rect" <+> - str "width=" <.> quote (show w) <.> - str "height=" <.> quote (show h) <.> - str "x=" <.> quote (show (x - (w/2.0))) <.> - str "y=" <.> quote (show (y - (h/2.0)))) + tagBracketsAttr "g" (attrString attr) $ selfClosingBrackets $ + ("rect" <+> + "width=" <.> quote (show w) <.> + "height=" <.> quote (show h) <.> + "x=" <.> quote (show (x - (w/2.0))) <.> + "y=" <.> quote (show (y - (h/2.0)))) BoundingBox : Type = (Point & Point) @@ -175,13 +172,13 @@ def renderSVG (d:Diagram) (bounds:BoundingBox) : String = scaleFactor = imgWidth / (xmax - xmin) imgHeight = (ymax - ymin) * scaleFactor (MkDiagram (AsList _ objs)) = d |> flipY |> scale scaleFactor - viewBoxStr = makeAttr (str "viewBox") $ + viewBoxStr = makeAttr "viewBox" $ (show (xmin * scaleFactor) <+> show (-(ymax * scaleFactor)) <+> show imgWidth <+> show imgHeight) - svgAttrStr = ( makeAttr (str "width" ) (show imgWidth) - <+> makeAttr (str "height") (show imgHeight) - <+> viewBoxStr) - tagBracketsAttr (str "svg") svgAttrStr $ + svgAttr= ( makeAttr "width" (show imgWidth) + <+> makeAttr "height" (show imgHeight) + <+> viewBoxStr) + tagBracketsAttr "svg" svgAttr$ concat for i. (attr, pos, geom) = objs.i renderGeom attr pos geom diff --git a/lib/io.dx b/lib/io.dx index 414e65303..d2a4e4c33 100644 --- a/lib/io.dx +++ b/lib/io.dx @@ -12,11 +12,11 @@ data Stream mode:StreamMode = MkStream CharPtr -- TODO: check the string contains no nulls def withCString (s:String) (action: CString -> {State World} a) : {State World} a = - (AsList n s') = s <> (AsList _ "\NUL") + (AsList n s') = s <> "\NUL" withTabPtr s' \(MkPtr ptr). action $ MkCString ptr def fopen (path:String) (mode:StreamMode) : {State World} (Stream mode) = - modeStr = AsList _ case mode of + modeStr = case mode of ReadMode -> "r" WriteMode -> "w" withCString path \(MkCString pathPtr). @@ -67,12 +67,12 @@ def readFile (f:FilePath) : {State World} String = def writeTemp (s:String) : {State World} FilePath = -- TODO: Make this properly atomic. It can fail if another process creates a -- file with same name after we ask for the name and before we create it. - withCString (AsList _ "/tmp/dex-XXXXXX") \(MkCString ptr). + withCString "/tmp/dex-XXXXXX" \(MkCString ptr). %ffi mktemp CharPtr ptr stringFromCharPtr 15 (MkPtr ptr) def withTempFile (action: FilePath -> {State World} a) : {State World} a = - tmpFile = writeTemp (AsList _ []) + tmpFile = writeTemp "" result = action tmpFile deleteFile tmpFile result @@ -81,4 +81,4 @@ def getOutputStream (_:Unit) : {State World} Stream WriteMode = MkStream $ %ptrLoad OUT_STREAM_PTR def print (s:String) : {State World} Unit = - fwrite (getOutputStream ()) (s <> AsList _ "\n") + fwrite (getOutputStream ()) (s <> "\n") diff --git a/lib/png.dx b/lib/png.dx index 75033b535..f9d0a7303 100644 --- a/lib/png.dx +++ b/lib/png.dx @@ -98,12 +98,12 @@ def makePNG (img:n=>m=>(Fin 3)=>Word8) : List Byte = unsafeIO \(). withTabPtr imgFlat \ptr. (MkPtr rawPtr) = ptr (n, ptr') = %ffi encodePNG (Int & CharPtr) rawPtr (size m) (size n) - AsList n $ tabFromPtr $ MkPtr ptr' + toList $ tabFromPtr (Fin n) $ MkPtr ptr' def pngToHtml (png:List Byte) : List Char = - (toList " base64Encode png - <> toList "\">") + (" base64Encode png + <> "\">") '## API entry point diff --git a/lib/prelude.dx b/lib/prelude.dx index d4ac95828..0cba8f14b 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -478,7 +478,7 @@ def withTabPtr (_:Storable a) ?=> for i. store (ptr +>> ordinal i) xs.i action ptr -def tabFromPtr (_:Storable a) ?=> (ptr:Ptr a) : {State World} n=>a = +def tabFromPtr (_:Storable a) ?=> (n:Type) -> (ptr:Ptr a) : {State World} n=>a = for i. load $ ptr +>> ordinal i 'Misc @@ -937,7 +937,7 @@ def sliceFields (iso: Iso ({|} | a) (b | c)) (tab: a=>v) : b=>v = String : Type = List Char def stringFromCharPtr (n:Int) (ptr:Ptr Char) : {State World} String= - AsList n $ tabFromPtr ptr + AsList n $ tabFromPtr _ ptr -- TODO. This is ASCII code point. It really should be Int32 for Unicode codepoint def codepoint (c:Char) : Int = W8ToI c diff --git a/src/lib/Parser.hs b/src/lib/Parser.hs index 001df0cee..2a7eac7b3 100644 --- a/src/lib/Parser.hs +++ b/src/lib/Parser.hs @@ -217,8 +217,9 @@ uType = expr uString :: Lexer UExpr uString = do (s, pos) <- withPos $ strLit - let cs = map (WithSrc (Just pos) . charExpr) s - return $ WithSrc (Just pos) $ UTabCon cs + let addSrc = WithSrc (Just pos) + let cs = map (addSrc . charExpr) s + return $ mkApp (addSrc "toList") $ addSrc $ UTabCon cs uLit :: Parser UExpr uLit = withSrc $ uLitParser @@ -924,7 +925,7 @@ mkSymName s = mkName $ "(" <> s <> ")" prefixNegOp :: Operator Parser UExpr prefixNegOp = Prefix $ label "negation" $ do ((), pos) <- withPos $ sym "-" - let f = WithSrc (Just pos) $ UVar $ mkName "neg" :> () + let f = WithSrc (Just pos) "neg" return $ \case -- Special case: negate literals directly WithSrc litpos (IntLitExpr i) diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index c07dc5488..e1d544d88 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -78,6 +78,7 @@ import Data.Tuple (swap) import Data.Foldable (toList, fold) import Data.Int import Data.Word +import Data.String (IsString, fromString) import Foreign.Ptr import GHC.Generics @@ -270,6 +271,9 @@ data WithSrc a = WithSrc SrcCtx a srcPos :: WithSrc a -> SrcCtx srcPos (WithSrc pos _) = pos +instance IsString UExpr' where + fromString s = UVar $ Name SourceName (fromString s) 0 :> () + -- === primitive constructors and operators === data PrimExpr e = diff --git a/tests/eval-tests.dx b/tests/eval-tests.dx index 4dd02dfe8..141881219 100644 --- a/tests/eval-tests.dx +++ b/tests/eval-tests.dx @@ -769,7 +769,7 @@ str = ['x', 'y'] s1 = "hello world" :p s1 -> "hello world" +> (AsList 11 "hello world") :p codepoint 'a' > 97 diff --git a/tests/io-tests.dx b/tests/io-tests.dx index ebd786161..c6de5efa3 100644 --- a/tests/io-tests.dx +++ b/tests/io-tests.dx @@ -4,8 +4,8 @@ include "io.dx" :p unsafeIO \(). withTempFile \fname. withFile fname WriteMode \stream. - fwrite stream $ AsList _ "lorem ipsum\n" - fwrite stream $ AsList _ "dolor sit amet\n" + fwrite stream "lorem ipsum\n" + fwrite stream "dolor sit amet\n" readFile fname > (AsList 27 "lorem ipsum > dolor sit amet @@ -15,12 +15,11 @@ include "io.dx" :p unsafeIO \(). withAlloc 4 \ptr:(Ptr Int). for i:(Fin 4). store (ptr +>> ordinal i) (ordinal i) - result : Fin 4 => Int = tabFromPtr ptr - result + tabFromPtr (Fin 4) ptr > [0, 1, 2, 3] unsafeIO \(). - print $ AsList _ "testing log" + print "testing log" > testing log > () @@ -28,8 +27,8 @@ unsafeIO \(). for i':(Fin 10). i = ordinal i' if rem i 2 == 0 - then print $ show i <> AsList _ " is even" - else print $ show i <> AsList _ " is odd" + then print $ show i <> " is even" + else print $ show i <> " is odd" > 0 is even > 1 is odd > 2 is even diff --git a/tests/serialize-tests.dx b/tests/serialize-tests.dx index ccbb34b6c..d35c66705 100644 --- a/tests/serialize-tests.dx +++ b/tests/serialize-tests.dx @@ -19,7 +19,7 @@ :p () > () -x = "ab" +x = ['a', 'b'] :p for (i,j). [x.i, x.j] > ["aa", "ab", "ba", "bb"]@(Fin 2 & Fin 2) @@ -29,7 +29,7 @@ x = "ab" > {a = 1, b = 2} :p {a="1234", b=[1, 2, 3]} -> {a = "1234", b = [1, 2, 3]} +> {a = (AsList 4 "1234"), b = [1, 2, 3]} :p [{| a=1 |}, {| b=2.0 |}] : (Fin 2) => {a:Int | b:Float} > [{| a = 1 |}, {| b = 2.0 |}] From cf8b2609db44fd54ca51234c80b7decf7a54a38f Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Wed, 23 Dec 2020 21:41:43 -0500 Subject: [PATCH 023/105] Add strings to runtime error messages. It looks like a bigger change than it is because I had to manually re-toposort the prelude. --- lib/io.dx | 84 ---------- lib/prelude.dx | 392 +++++++++++++++++++++++++++----------------- src/lib/JIT.hs | 3 +- src/lib/LLVMExec.hs | 25 +-- tests/adt-tests.dx | 3 +- tests/eval-tests.dx | 3 +- tests/io-tests.dx | 3 - 7 files changed, 257 insertions(+), 256 deletions(-) delete mode 100644 lib/io.dx diff --git a/lib/io.dx b/lib/io.dx deleted file mode 100644 index d2a4e4c33..000000000 --- a/lib/io.dx +++ /dev/null @@ -1,84 +0,0 @@ - -'File system operations - -FilePath : Type = String -data CString = MkCString CharPtr - -data StreamMode = - ReadMode - WriteMode - -data Stream mode:StreamMode = MkStream CharPtr - --- TODO: check the string contains no nulls -def withCString (s:String) (action: CString -> {State World} a) : {State World} a = - (AsList n s') = s <> "\NUL" - withTabPtr s' \(MkPtr ptr). action $ MkCString ptr - -def fopen (path:String) (mode:StreamMode) : {State World} (Stream mode) = - modeStr = case mode of - ReadMode -> "r" - WriteMode -> "w" - withCString path \(MkCString pathPtr). - withCString modeStr \(MkCString modePtr). - MkStream $ %ffi fopen CharPtr pathPtr modePtr - -def fclose (mode:StreamMode) ?-> (stream:Stream mode) : {State World} Unit = - (MkStream stream') = stream - %ffi fclose Int64 stream' - () - -def fwrite (stream:Stream WriteMode) (s:String) : {State World} Unit = - (MkStream stream') = stream - (AsList n s') = s - withTabPtr s' \(MkPtr ptr). - %ffi fwrite Int64 ptr (IToI64 1) (IToI64 n) stream' - %ffi fflush Int64 stream' - () - -def fread (stream:Stream ReadMode) : {State World} String = - (MkStream stream') = stream - -- TODO: allow reading longer files! - n = 4096 - withAlloc n \ptr:(Ptr Char). - (MkPtr rawPtr) = ptr - numRead = I64ToI $ %ffi fread Int64 rawPtr (IToI64 1) (IToI64 n) stream' - stringFromCharPtr numRead ptr - -def deleteFile (f:FilePath) : {State World} Unit = - withCString f \(MkCString ptr). - %ffi remove Int64 ptr - () - -def withFile (f:FilePath) (mode:StreamMode) - (action: Stream mode -> {State World} a) - : {State World} a = - stream = fopen f mode - result = action stream - fclose stream - result - -def writeFile (f:FilePath) (s:String) : {State World} Unit = - withFile f WriteMode \stream. fwrite stream s - -def readFile (f:FilePath) : {State World} String = - withFile f ReadMode \stream. fread stream - -def writeTemp (s:String) : {State World} FilePath = - -- TODO: Make this properly atomic. It can fail if another process creates a - -- file with same name after we ask for the name and before we create it. - withCString "/tmp/dex-XXXXXX" \(MkCString ptr). - %ffi mktemp CharPtr ptr - stringFromCharPtr 15 (MkPtr ptr) - -def withTempFile (action: FilePath -> {State World} a) : {State World} a = - tmpFile = writeTemp "" - result = action tmpFile - deleteFile tmpFile - result - -def getOutputStream (_:Unit) : {State World} Stream WriteMode = - MkStream $ %ptrLoad OUT_STREAM_PTR - -def print (s:String) : {State World} Unit = - fwrite (getOutputStream ()) (s <> "\n") diff --git a/lib/prelude.dx b/lib/prelude.dx index 0cba8f14b..3295887b8 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -211,10 +211,7 @@ def select (p:Bool) (x:a) (y:a) : a = case p of False -> y def BToI (x:Bool) : Int = W8ToI $ BToW8 x - def BToF (x:Bool) : Float = IToF (BToI x) -def todo (a:Type) ?-> : a = %throwError a -def throw (a:Type) ?-> : a = %throwError a 'Effects @@ -391,14 +388,6 @@ def Fin (n:Int) : Type = Range 0 n def ordinal (i:a) : Int = %toOrdinal i def size (n:Type) : Int = %idxSetSize n def unsafeFromOrdinal (n : Type) (i : Int) : n = %unsafeFromOrdinal n i - -def fromOrdinal (n:Type) (i:Int) : n = - case (0 <= i) && (i < size n) of - True -> unsafeFromOrdinal _ i - False -> throw - -def asidx (n:Type) ?-> (i:Int) : n = fromOrdinal n i -def (@) (i:Int) (n:Type) : n = fromOrdinal n i def iota (n:Type) : n=>Int = for i. ordinal i -- TODO: we want Eq and Ord for all index sets, not just `Fin n` @@ -495,9 +484,6 @@ def sq (d:Mul a) ?=> (x:a) : a = x * x def abs (_:Add a) ?=> (_:Ord a) ?=> (x:a) : a = select (x > zero) x (zero - x) def mod (x:Int) (y:Int) : Int = rem (y + rem x y) y -def slice (xs:n=>a) (start:Int) (m:Type) : m=>a = - for i. xs.(fromOrdinal _ (ordinal i + start)) - def reindex (ixr: b -> a) (tab: a=>v) : b=>v = for i. tab.(ixr i) def scan (init:a) (body:n->a->(a&b)) : (a & n=>b) = @@ -575,10 +561,6 @@ def randn (k:Key) : Float = u2 = rand k2 sqrt ((-2.0) * log u1) * cos (2.0 * pi * u2) -def randIdx (n:Type) ?-> (k:Key) : n = - unif = rand k - fromOrdinal n $ FToI $ floor $ unif * IToF (size n) - -- TODO: Make this better... def randInt (k:Key) : Int = (I64ToI k) `mod` 2147483647 @@ -594,47 +576,6 @@ def cumSum (xs: n=>Float) : n=>Float = total := newTotal newTotal -interface Arbitrary a:Type where - arb : Key -> a - -instance float32Arb : Arbitrary Float32 where - arb = randn - -instance in32Arb : Arbitrary Int32 where - arb = \key. FToI $ randn key * 5.0 - -instance tabArb : Arbitrary a ?=> Arbitrary (n=>a) where - arb = \key. for i. arb $ ixkey key i - -instance finArb : n:Int ?-> Arbitrary (Fin n) where - arb = randIdx - -'min / max etc - -def minBy (_:Ord o) ?=> (f:a->o) (x:a) (y:a) : a = select (f x < f y) x y -def maxBy (_:Ord o) ?=> (f:a->o) (x:a) (y:a) : a = select (f x > f y) x y - -def min (_:Ord o) ?=> (x1: o) -> (x2: o) : o = minBy id x1 x2 -def max (_:Ord o) ?=> (x1: o) -> (x2: o) : o = maxBy id x1 x2 - -def minimumBy (_:Ord o) ?=> (f:a->o) (xs:n=>a) : a = - reduce xs.(0@_) (minBy f) xs -def maximumBy (_:Ord o) ?=> (f:a->o) (xs:n=>a) : a = - reduce xs.(0@_) (maxBy f) xs - -def minimum (_:Ord o) ?=> (xs:n=>o) : o = minimumBy id xs -def maximum (_:Ord o) ?=> (xs:n=>o) : o = maximumBy id xs - -def argmin (_:Ord o) ?=> (xs:n=>o) : n = - zeroth = (0@_, xs.(0@_)) - compare = \(idx1, x1) (idx2, x2). - select (x1 < x2) (idx1, x1) (idx2, x2) - zipped = for i. (i, xs.i) - fst $ reduce zeroth compare zipped - -def clip (_:Ord a) ?=> ((low,high):(a&a)) (x:a) : a = - min high $ max low x - 'Automatic differentiation -- TODO: add vector space constraints @@ -693,55 +634,6 @@ def checkDerivBase (f:Float->Float) (x:Float) : Bool = def checkDeriv (f:Float->Float) (x:Float) : Bool = checkDerivBase f x && checkDerivBase (deriv f) x -'Control flow - -def while - (eff:Effects) ?-> - (cond: Unit -> {|eff} Bool) - (body: Unit -> {|eff} Unit) - : {|eff} Unit = - cond' : Unit -> {|eff} Word8 = \_. BToW8 $ cond () - %while cond' body - -data IterResult a:Type b:Type = - Continue a - Done b - --- A little iteration combinator --- TODO: allow effects (bug #267) -def iter (init:a) (body: Int -> a -> IterResult a b) : b = - result = snd $ withState Nothing \resultRef. - withState init \carryRef. - withState 0 \i. - while (\(). isNothing (get resultRef)) \(). - case body (get i) (get carryRef) of - Continue carry -> - i := get i + 1 - carryRef := carry - Done result -> - resultRef := Just result - case result of - Just ans -> ans - Nothing -> todo -- should be unreachable - --- returns the highest index `i` such that `xs.i <= x` -def searchSorted (_:Ord a) ?=> (xs:n=>a) (x:a) : Maybe n = - case size n == 0 of - True -> Nothing - False -> case x < xs.(fromOrdinal _ 0) of - True -> Nothing - False -> - iter (0, size n) \_ (low, high). - numLeft = high - low - case numLeft == 1 of - True -> Done $ Just $ fromOrdinal _ low - False -> - centerIx = low + idiv (high - low) 2 - case x < xs.(fromOrdinal _ centerIx) of - True -> Continue (low, centerIx) - False -> Continue (centerIx, high) - - 'Vector support -- TODO: Reenable vector suport once fixed-width types are supported. @@ -806,27 +698,6 @@ def tile1 (n : Type) ?-> (l : Type) ?-> (m : Type) ?-> -- arr.(t +> UNSAFEFromOrdinal idx 2) -- arr.(t +> UNSAFEFromOrdinal idx 3)) -'Numerical utilities - -def logsumexp (x: n=>Float) : Float = - m = maximum x - m + (log $ sum for i. exp (x.i - m)) - -def logsoftmax (x: n=>Float) : n=>Float = - lse = logsumexp x - for i. x.i - lse - -def softmax (x: n=>Float) : n=>Float = - m = maximum x - e = for i. exp (x.i - m) - s = sum e - for i. e.i / s - -def evalpoly (_:VSpace v) ?=> (coefficients:n=>v) (x:Float) : v = - -- Evaluate a polynomial at x. Same as Numpy's polyval. - fold zero \i c. coefficients.i + x .* c - - 'Monoid interface Monoid a:Type where @@ -840,6 +711,13 @@ interface Monoid a:Type where data List a:Type = AsList n:Int foo:(Fin n => a) +def unsafeCastTable (m:Type) (xs:n=>a) : m=>a = + for i. xs.(unsafeFromOrdinal _ (ordinal i)) + +def toList (n:Type) ?-> (xs:n=>a) : List a = + n' = size n + AsList _ $ unsafeCastTable (Fin n') xs + instance monoidList : Monoid (List a) where mempty = AsList _ [] mcombine = \x y. @@ -965,15 +843,6 @@ instance showFloat64 : Show Float64 where (n, ptr) = %ffi showFloat64 (Int32 & CharPtr) x stringFromCharPtr n $ MkPtr ptr --- def writeStdErr (s:String) : {State World} Unit = --- (AsList n cs) = s --- %ffi writeToStdErr Int n (%getPtr cs) --- () - --- def throwMsg (s:String) : a = unsafeIO \(). --- writeStdErr s --- %throwError a - -- pipe-like reverse function application def (|>) (x:a) (f: a -> b) : b = f x @@ -1004,6 +873,218 @@ def isnan (x:Float) : Bool = not (x >= x && x <= x) -- Todo: use IEEE-754R 5.11: Floating Point Comparison Relation cmpUnordered. def either_is_nan (x:Float) (y:Float) : Bool = (isnan x) || (isnan y) +'File system operations + +FilePath : Type = String +data CString = MkCString CharPtr + +data StreamMode = + ReadMode + WriteMode + +data Stream mode:StreamMode = MkStream CharPtr + +-- TODO: check the string contains no nulls +def withCString (s:String) (action: CString -> {State World} a) : {State World} a = + (AsList n s') = s <> "\NUL" + withTabPtr s' \(MkPtr ptr). action $ MkCString ptr + +def fopen (path:String) (mode:StreamMode) : {State World} (Stream mode) = + modeStr = case mode of + ReadMode -> "r" + WriteMode -> "w" + withCString path \(MkCString pathPtr). + withCString modeStr \(MkCString modePtr). + MkStream $ %ffi fopen CharPtr pathPtr modePtr + +def fclose (mode:StreamMode) ?-> (stream:Stream mode) : {State World} Unit = + (MkStream stream') = stream + %ffi fclose Int64 stream' + () + +def fwrite (stream:Stream WriteMode) (s:String) : {State World} Unit = + (MkStream stream') = stream + (AsList n s') = s + withTabPtr s' \(MkPtr ptr). + %ffi fwrite Int64 ptr (IToI64 1) (IToI64 n) stream' + %ffi fflush Int64 stream' + () + +def fread (stream:Stream ReadMode) : {State World} String = + (MkStream stream') = stream + -- TODO: allow reading longer files! + n = 4096 + withAlloc n \ptr:(Ptr Char). + (MkPtr rawPtr) = ptr + numRead = I64ToI $ %ffi fread Int64 rawPtr (IToI64 1) (IToI64 n) stream' + stringFromCharPtr numRead ptr + +def deleteFile (f:FilePath) : {State World} Unit = + withCString f \(MkCString ptr). + %ffi remove Int64 ptr + () + +def withFile (f:FilePath) (mode:StreamMode) + (action: Stream mode -> {State World} a) + : {State World} a = + stream = fopen f mode + result = action stream + fclose stream + result + +def writeFile (f:FilePath) (s:String) : {State World} Unit = + withFile f WriteMode \stream. fwrite stream s + +def readFile (f:FilePath) : {State World} String = + withFile f ReadMode \stream. fread stream + +def writeTemp (s:String) : {State World} FilePath = + -- TODO: Make this properly atomic. It can fail if another process creates a + -- file with same name after we ask for the name and before we create it. + withCString "/tmp/dex-XXXXXX" \(MkCString ptr). + %ffi mktemp CharPtr ptr + stringFromCharPtr 15 (MkPtr ptr) + +def withTempFile (action: FilePath -> {State World} a) : {State World} a = + tmpFile = writeTemp "" + result = action tmpFile + deleteFile tmpFile + result + +def getOutputStream (_:Unit) : {State World} Stream WriteMode = + MkStream $ %ptrLoad OUT_STREAM_PTR + +def print (s:String) : {State World} Unit = + fwrite (getOutputStream ()) (s <> "\n") + +'Partial functions + +def error (s:String) : a = unsafeIO \(). + print s + %throwError a + +def todo (a:Type) ?-> : a = error "TODO: implement it!" + +def fromOrdinal (n:Type) (i:Int) : n = + case (0 <= i) && (i < size n) of + True -> unsafeFromOrdinal _ i + False -> error $ + "Ordinal index out of range:" <> show i <> " >= " <> show (size n) + +-- TODO: could make an `unsafeCastIndex` and this could avoid the runtime copy +-- TODO: safe (runtime-checked) and unsafe versions +def castTable (m:Type) (xs:n=>a) : m=>a = + case size m == size n of + True -> unsafeCastTable _ xs + False -> error $ + "Table size mismatch in cast: " <> show (size m) <> " vs " <> show (size n) + +def asidx (n:Type) ?-> (i:Int) : n = fromOrdinal n i +def (@) (i:Int) (n:Type) : n = fromOrdinal n i + +def slice (xs:n=>a) (start:Int) (m:Type) : m=>a = + for i. xs.(fromOrdinal _ (ordinal i + start)) + +def tail (n:Type) ?-> (xs:n=>a) (start:Int) : List a = + numElts = size n - start + toList $ slice xs start (Fin numElts) + +def randIdx (n:Type) ?-> (k:Key) : n = + unif = rand k + fromOrdinal n $ FToI $ floor $ unif * IToF (size n) + +'Type class for generating example values + +interface Arbitrary a:Type where + arb : Key -> a + +instance float32Arb : Arbitrary Float32 where + arb = randn + +instance in32Arb : Arbitrary Int32 where + arb = \key. FToI $ randn key * 5.0 + +instance tabArb : Arbitrary a ?=> Arbitrary (n=>a) where + arb = \key. for i. arb $ ixkey key i + +instance finArb : n:Int ?-> Arbitrary (Fin n) where + arb = randIdx + +'Control flow + +def while + (eff:Effects) ?-> + (cond: Unit -> {|eff} Bool) + (body: Unit -> {|eff} Unit) + : {|eff} Unit = + cond' : Unit -> {|eff} Word8 = \_. BToW8 $ cond () + %while cond' body + +data IterResult a:Type b:Type = + Continue a + Done b + +-- A little iteration combinator +-- TODO: allow effects (bug #267) +def iter (init:a) (body: Int -> a -> IterResult a b) : b = + result = snd $ withState Nothing \resultRef. + withState init \carryRef. + withState 0 \i. + while (\(). isNothing (get resultRef)) \(). + case body (get i) (get carryRef) of + Continue carry -> + i := get i + 1 + carryRef := carry + Done result -> + resultRef := Just result + case result of + Just ans -> ans + Nothing -> error "should be unreachable" + +-- returns the highest index `i` such that `xs.i <= x` +def searchSorted (_:Ord a) ?=> (xs:n=>a) (x:a) : Maybe n = + case size n == 0 of + True -> Nothing + False -> case x < xs.(fromOrdinal _ 0) of + True -> Nothing + False -> + iter (0, size n) \_ (low, high). + numLeft = high - low + case numLeft == 1 of + True -> Done $ Just $ fromOrdinal _ low + False -> + centerIx = low + idiv (high - low) 2 + case x < xs.(fromOrdinal _ centerIx) of + True -> Continue (low, centerIx) + False -> Continue (centerIx, high) + + + +'min / max etc + +def minBy (_:Ord o) ?=> (f:a->o) (x:a) (y:a) : a = select (f x < f y) x y +def maxBy (_:Ord o) ?=> (f:a->o) (x:a) (y:a) : a = select (f x > f y) x y + +def min (_:Ord o) ?=> (x1: o) -> (x2: o) : o = minBy id x1 x2 +def max (_:Ord o) ?=> (x1: o) -> (x2: o) : o = maxBy id x1 x2 + +def minimumBy (_:Ord o) ?=> (f:a->o) (xs:n=>a) : a = + reduce xs.(0@_) (minBy f) xs +def maximumBy (_:Ord o) ?=> (f:a->o) (xs:n=>a) : a = + reduce xs.(0@_) (maxBy f) xs + +def minimum (_:Ord o) ?=> (xs:n=>o) : o = minimumBy id xs +def maximum (_:Ord o) ?=> (xs:n=>o) : o = maximumBy id xs + +def argmin (_:Ord o) ?=> (xs:n=>o) : n = + zeroth = (0@_, xs.(0@_)) + compare = \(idx1, x1) (idx2, x2). + select (x1 < x2) (idx1, x1) (idx2, x2) + zipped = for i. (i, xs.i) + fst $ reduce zeroth compare zipped + +def clip (_:Ord a) ?=> ((low,high):(a&a)) (x:a) : a = + min high $ max low x '## Trigonometric functions. @@ -1175,25 +1256,10 @@ def (.&.) (x:Byte) (y:Byte) : Byte = %and x y '## Misc --- TODO: could make an `unsafeCastIndex` and this could avoid the runtime copy --- TODO: safe (runtime-checked) and unsafe versions -def castTable (m:Type) (xs:n=>a) : m=>a = - case size m == size n of - True -> for i. xs.(unsafeFromOrdinal _ (ordinal i)) - False -> throw - def reverse (x:n=>a) : n=>a = s = size n for i. x.((s - 1 - ordinal i)@_) -def toList (n:Type) ?-> (xs:n=>a) : List a = - n' = size n - AsList _ $ castTable (Fin n') xs - -def tail (n:Type) ?-> (xs:n=>a) (start:Int) : List a = - numElts = size n - start - toList $ slice xs start (Fin numElts) - def padTo (n:Type) ?-> (m:Type) (x:a) (xs:n=>a) : (m=>a) = n' = size n for i. @@ -1266,3 +1332,23 @@ def categorical (logprobs: n=>Float) (key: Key) : n = def categoricalBatch (logprobs: n=>Float) (key: Key) : m=>n = cdf = cdfForCategorical logprobs for i. categoricalFromCDF cdf $ ixkey key i + +'Numerical utilities + +def logsumexp (x: n=>Float) : Float = + m = maximum x + m + (log $ sum for i. exp (x.i - m)) + +def logsoftmax (x: n=>Float) : n=>Float = + lse = logsumexp x + for i. x.i - lse + +def softmax (x: n=>Float) : n=>Float = + m = maximum x + e = for i. exp (x.i - m) + s = sum e + for i. e.i / s + +def evalpoly (_:VSpace v) ?=> (coefficients:n=>v) (x:Float) : v = + -- Evaluate a polynomial at x. Same as Numpy's polyval. + fold zero \i c. coefficients.i + x .* c diff --git a/src/lib/JIT.hs b/src/lib/JIT.hs index ffd1db3a6..b4bf3e77e 100644 --- a/src/lib/JIT.hs +++ b/src/lib/JIT.hs @@ -312,8 +312,7 @@ compileInstr instr = case instr of let resultTys' = map scalarTy resultTys case cc of FFIFun -> do - let [resultTy] = resultTys' - ans <- emitInstr resultTy $ externCall (makeFunSpec f) args' + ans <- emitExternCall (makeFunSpec f) args' return [ans] FFIMultiResultFun -> do resultPtr <- makeMultiResultAlloc resultTys' diff --git a/src/lib/LLVMExec.hs b/src/lib/LLVMExec.hs index bb15a1ba2..b2957cf5b 100644 --- a/src/lib/LLVMExec.hs +++ b/src/lib/LLVMExec.hs @@ -111,9 +111,12 @@ compileAndBench shouldSyncCUDA logger ast fname args resultTypes = do withPipeToLogger :: Logger [Output] -> (FD -> IO a) -> IO a withPipeToLogger logger writeAction = do - snd <$> withPipe + result <- snd <$> withPipe (\h -> readStream h $ \s -> logThis logger [TextOut s]) (\h -> handleToFd h >>= writeAction) + case result of + Left e -> E.throw e + Right ans -> return ans checkedCallFunPtr :: FD -> Ptr () -> Ptr () -> DexExecutable -> IO Double checkedCallFunPtr fd argsPtr resultPtr fPtr = do @@ -411,26 +414,24 @@ ptxDataLayout = (L.defaultDataLayout L.LittleEndian) -- ==== unix pipe utilities === -withPipe :: (Handle -> IO a) -> (Handle -> IO b) -> IO (a, b) +type IOExcept a = Either SomeException a + +withPipe :: (Handle -> IO a) -> (Handle -> IO b) -> IO (IOExcept a, IOExcept b) withPipe readAction writeAction = do (readHandle, writeHandle) <- createPipe - readResult <- forkWithResult $ readAction readHandle - writeResult <- forkWithResult $ writeAction writeHandle - y <- writeResult <* hClose writeHandle - x <- readResult <* hClose readHandle + waitForReader <- forkWithResult $ readAction readHandle + waitForWriter <- forkWithResult $ writeAction writeHandle + y <- waitForWriter `finally` hClose writeHandle + x <- waitForReader `finally` hClose readHandle return (x, y) -forkWithResult :: IO a -> IO (IO a) +forkWithResult :: IO a -> IO (IO (IOExcept a)) forkWithResult action = do resultMVar <- newEmptyMVar void $ forkIO $ catch (do result <- action putMVar resultMVar $ Right result) (\e -> putMVar resultMVar $ Left (e::SomeException)) - return $ do - result <- takeMVar resultMVar - case result of - Left e -> E.throw e - Right result' -> return result' + return $ takeMVar resultMVar readStream :: Handle -> (String -> IO ()) -> IO () readStream h action = go diff --git a/tests/adt-tests.dx b/tests/adt-tests.dx index 364b73d25..e8fb17566 100644 --- a/tests/adt-tests.dx +++ b/tests/adt-tests.dx @@ -97,7 +97,8 @@ myTab = [MyLeft 1, MyRight 3.5, MyLeft 123, MyLeft 456] :p for i. case myTab.i of MyLeft val -> val - MyRight _ -> todo + MyRight _ -> error "nope" +> nope > Runtime error :p diff --git a/tests/eval-tests.dx b/tests/eval-tests.dx index 141881219..57fa06f56 100644 --- a/tests/eval-tests.dx +++ b/tests/eval-tests.dx @@ -815,9 +815,10 @@ triLit def fromLeftFloat (x:(Float | Int)) : Float = case x of Left x' -> x' - Right _ -> throw + Right _ -> error "this is an error" :p fromLeftFloat $ Right 1 +> this is an error > Runtime error :p fromLeftFloat $ Left 1.2 diff --git a/tests/io-tests.dx b/tests/io-tests.dx index c6de5efa3..53a985e60 100644 --- a/tests/io-tests.dx +++ b/tests/io-tests.dx @@ -1,6 +1,4 @@ -include "io.dx" - :p unsafeIO \(). withTempFile \fname. withFile fname WriteMode \stream. @@ -11,7 +9,6 @@ include "io.dx" > dolor sit amet > ") - :p unsafeIO \(). withAlloc 4 \ptr:(Ptr Int). for i:(Fin 4). store (ptr +>> ordinal i) (ordinal i) From e74e30659719de2cb10660ebe721ec6f0aed7561 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Thu, 24 Dec 2020 07:53:57 -0500 Subject: [PATCH 024/105] Fix `stack build` on macOS. (#378) Disable `-Wnonportable-include-path`. I am not sure what earlier change surfaced this macOS build error or why macOS CI does not encounter it. --- dex.cabal | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dex.cabal b/dex.cabal index b1390b84c..eb8744f54 100644 --- a/dex.cabal +++ b/dex.cabal @@ -57,7 +57,7 @@ library build-depends: dex-resources default-language: Haskell2010 hs-source-dirs: src/lib - ghc-options: -Wall -fPIC + ghc-options: -Wall -fPIC -optP-Wno-nonportable-include-path cxx-sources: src/lib/dexrt.cpp cxx-options: -std=c++11 -fPIC default-extensions: CPP, DeriveTraversable, TypeApplications, OverloadedStrings, From 56b128e9454249feb81e02e57f3a4d5921504a0b Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Thu, 24 Dec 2020 14:56:33 -0500 Subject: [PATCH 025/105] Add an IO-based dynamically expanding buffer and use it to implement `fread`. --- lib/prelude.dx | 128 ++++++++++++++++++++++++++++++++++++++++------ src/lib/Syntax.hs | 8 +-- src/lib/Type.hs | 5 ++ tests/io-tests.dx | 32 ++++++++++++ 4 files changed, 154 insertions(+), 19 deletions(-) diff --git a/lib/prelude.dx b/lib/prelude.dx index 3295887b8..650146d33 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -409,6 +409,7 @@ data Ptr a:Type = MkPtr Word8Ptr -- Is there a better way to select the right instance for `storageSize`?? data TypeVehicle a:Type = MkTypeVehicle +def typeVehicle (a:Type) : TypeVehicle a = MkTypeVehicle interface Storable a:Type where store : Ptr a -> a -> {State World} Unit @@ -435,11 +436,48 @@ instance int32Storable : Storable Int32 where load = int32Load storageSize = const 4 +def unpackPairPtr (_:Storable a) ?=> (_:Storable b) ?=> + (pairPtr: Ptr (a & b)) : (Ptr a & Ptr b) = + (MkPtr rawPtrX) = pairPtr + rawPtrY = %ptrOffset rawPtrX (storageSize (typeVehicle a)) + (MkPtr rawPtrX, MkPtr rawPtrY) + +def pairStore (_:Storable a) ?=> (_:Storable b) ?=> + (pairPtr:Ptr (a & b)) ((x, y):(a & b)) : {State World} Unit = + (xPtr, yPtr) = unpackPairPtr pairPtr + store xPtr x + store yPtr y + +def pairLoad (_:Storable a) ?=> (_:Storable b) ?=> + (pairPtr:Ptr (a & b)) : {State World} (a & b) = + (xPtr, yPtr) = unpackPairPtr pairPtr + (load xPtr, load yPtr) + +def pairStorageSize (_:Storable a) ?=> (_:Storable b) ?=> + (_:TypeVehicle (a & b)) : Int = + storageSize (typeVehicle a) + storageSize (typeVehicle b) + +instance pairStorable : Storable a ?=> Storable b ?=> Storable (a & b) where + store = pairStore + load = pairLoad + storageSize = pairStorageSize + +def ptrPtrStore ((MkPtr ptr): Ptr (Ptr a)) (x:(Ptr a)) : {State World} Unit = + (MkPtr x') = x + %ptrStore (internalCast %PtrPtr ptr) x' + +def ptrPtrLoad ((MkPtr ptr): Ptr (Ptr a)) : {State World} (Ptr a) = + MkPtr $ %ptrLoad (internalCast %PtrPtr ptr) + +instance ptrStorable : Storable (Ptr a) where + store = ptrPtrStore + load = ptrPtrLoad + storageSize = const 8 -- TODO: something more portable? + -- TODO: Storable instances for other types def malloc (_:Storable a) ?=> (n:Int) : {State World} (Ptr a) = - typeVehicle : TypeVehicle a = MkTypeVehicle - numBytes = storageSize typeVehicle * n + numBytes = storageSize (typeVehicle a) * n MkPtr $ %charAlloc numBytes def free (ptr:Ptr a) : {State World} Unit = @@ -447,13 +485,21 @@ def free (ptr:Ptr a) : {State World} Unit = %charFree ptr' def (+>>) (_:Storable a) ?=> (ptr:Ptr a) (i:Int) : Ptr a = - typeVehicle : TypeVehicle a = MkTypeVehicle (MkPtr ptr') = ptr - i' = i * storageSize typeVehicle + i' = i * storageSize (typeVehicle a) MkPtr $ %ptrOffset ptr' i' -- TODO: generalize these brackets to allow other effects +-- TODO: consider making a Storable instance for tables instead +def storeTab (_:Storable a) ?=> (ptr: Ptr a) (tab:n=>a) : {State World} Unit = + for_ i. store (ptr +>> ordinal i) tab.i + +def memcpy (_:Storable a) ?=> (dest:Ptr a) (src:Ptr a) (n:Int) : {State World} Unit = + for_ i:(Fin n). + i' = ordinal i + store (dest +>> i') (load $ src +>> i') + def withAlloc (_:Storable a) ?=> (n:Int) (action: Ptr a -> {State World} b) : {State World} b = ptr = malloc n @@ -810,6 +856,54 @@ splitV : Iso a ({|} | a) = def sliceFields (iso: Iso ({|} | a) (b | c)) (tab: a=>v) : b=>v = reindex (buildWith $ splitV &>> iso) tab +'Dynamic buffer + +-- TODO: should we be able to use `Ref World Int` instead of `Ptr Int`? +-- TODO: would be nice to be able to use records here +data DynBuffer a:Type = MkDynBuffer (Ptr (Int & Int & Ptr a)) -- size, max size, buf ptr + +def withDynamicBuffer (_:Storable a) ?=> + (action: DynBuffer a -> {State World} b) : {State World} b = + initMaxSize = 256 + withAlloc 1 \dbPtr. + bufPtr = malloc initMaxSize + store dbPtr (0, initMaxSize, bufPtr) + result = action $ MkDynBuffer dbPtr + (_, _, bufPtr') = load dbPtr + free bufPtr' + result + +def maybeIncreaseBufferSize (_:Storable a) ?=> + (buf: DynBuffer a) (sizeDelta:Int) : {State World} Unit = + (MkDynBuffer dbPtr) = buf + (size, maxSize, bufPtr) = load dbPtr + newSize = sizeDelta + size + if newSize > maxSize + then + -- TODO: maybe this should use integer arithmetic? + newMaxSize = FToI $ pow 2.0 (ceil $ log2 $ IToF newSize) + newBufPtr = malloc newMaxSize + memcpy newBufPtr bufPtr size + store dbPtr (size, newMaxSize, newBufPtr) + else () + +def extendDynBuffer (_:Storable a) ?=> + (buf: DynBuffer a) (new:List a) : {State World} Unit = + (AsList n xs) = new + maybeIncreaseBufferSize buf n + (MkDynBuffer dbPtr) = buf + (size, maxSize, bufPtr) = load dbPtr + newSize = n + size + storeTab (bufPtr +>> size) xs + store dbPtr (newSize, maxSize, bufPtr) + +def loadDynBuffer (_:Storable a) ?=> + (buf: DynBuffer a) : {State World} (List a) = + (MkDynBuffer dbPtr) = buf + (size, _, bufPtr) = load dbPtr + AsList size $ tabFromPtr _ bufPtr + + '## Strings and Characters String : Type = List Char @@ -910,14 +1004,26 @@ def fwrite (stream:Stream WriteMode) (s:String) : {State World} Unit = %ffi fflush Int64 stream' () +def while + (eff:Effects) ?-> + (cond: Unit -> {|eff} Bool) + (body: Unit -> {|eff} Unit) + : {|eff} Unit = + cond' : Unit -> {|eff} Word8 = \_. BToW8 $ cond () + %while cond' body + def fread (stream:Stream ReadMode) : {State World} String = (MkStream stream') = stream -- TODO: allow reading longer files! n = 4096 withAlloc n \ptr:(Ptr Char). - (MkPtr rawPtr) = ptr - numRead = I64ToI $ %ffi fread Int64 rawPtr (IToI64 1) (IToI64 n) stream' - stringFromCharPtr numRead ptr + withDynamicBuffer \buf. + while (\(). + (MkPtr rawPtr) = ptr + numRead = I64ToI $ %ffi fread Int64 rawPtr (IToI64 1) (IToI64 n) stream' + extendDynBuffer buf $ stringFromCharPtr numRead ptr + numRead == n) (\(). ()) + loadDynBuffer buf def deleteFile (f:FilePath) : {State World} Unit = withCString f \(MkCString ptr). @@ -1012,14 +1118,6 @@ instance finArb : n:Int ?-> Arbitrary (Fin n) where 'Control flow -def while - (eff:Effects) ?-> - (cond: Unit -> {|eff} Bool) - (body: Unit -> {|eff} Unit) - : {|eff} Unit = - cond' : Unit -> {|eff} Word8 = \_. BToW8 $ cond () - %while cond' body - data IterResult a:Type b:Type = Continue a Done b diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index e1d544d88..62c2f0067 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -1546,8 +1546,9 @@ builtinNames = M.fromList , ("Int64" , TCExpr $ BaseType $ Scalar Int64Type) , ("Int32" , TCExpr $ BaseType $ Scalar Int32Type) , ("Word8" , TCExpr $ BaseType $ Scalar Word8Type) - , ("Int32Ptr", ptrTy Int32Type) - , ("Word8Ptr", ptrTy Word8Type) + , ("Int32Ptr", TCExpr $ BaseType $ ptrTy $ Scalar Int32Type) + , ("Word8Ptr", TCExpr $ BaseType $ ptrTy $ Scalar Word8Type) + , ("PtrPtr" , TCExpr $ BaseType $ ptrTy $ ptrTy $ Scalar Word8Type) , ("IntRange", TCExpr $ IntRange () ()) , ("Ref" , TCExpr $ RefType (Just ()) ()) , ("PairType", TCExpr $ PairType () ()) @@ -1578,8 +1579,7 @@ builtinNames = M.fromList vbinOp op = OpExpr $ VectorBinOp op () () binOp op = OpExpr $ ScalarBinOp op () () unOp op = OpExpr $ ScalarUnOp op () - ptrTy ty = TCExpr $ BaseType $ PtrType $ - (AllocatedPtr, Heap CPU, Scalar ty) + ptrTy ty = PtrType (AllocatedPtr, Heap CPU, ty) instance Store a => Store (PrimOp a) instance Store a => Store (PrimCon a) diff --git a/src/lib/Type.hs b/src/lib/Type.hs index 6237074b6..e093ace0a 100644 --- a/src/lib/Type.hs +++ b/src/lib/Type.hs @@ -276,6 +276,10 @@ exprEffs expr = case expr of MAsk -> S.singleton (Reader, h) MTell _ -> S.singleton (Writer, h) where RefTy (Var (h:>_)) _ = getType ref + IOAlloc _ _ -> S.singleton (State, theWorld) + IOFree _ -> S.singleton (State, theWorld) + PtrLoad _ -> S.singleton (State, theWorld) + PtrStore _ _ -> S.singleton (State, theWorld) FFICall _ _ _ -> S.singleton (State, theWorld) _ -> NoEffects Hof hof -> case hof of @@ -709,6 +713,7 @@ typeCheckOp op = case op of return $ RefTy h b IOAlloc t n -> do n |: IdxRepTy + declareEff (State, Just theWorld) return $ PtrTy (AllocatedPtr, Heap CPU, t) IOFree ptr -> do PtrTy _ <- typeCheck ptr diff --git a/tests/io-tests.dx b/tests/io-tests.dx index 53a985e60..55fd74ed8 100644 --- a/tests/io-tests.dx +++ b/tests/io-tests.dx @@ -37,3 +37,35 @@ unsafeIO \(). > 8 is even > 9 is odd > [(), (), (), (), (), (), (), (), (), ()] + +:p storageSize (typeVehicle Int) +> 4 + +:p unsafeIO \(). + withAlloc 1 \ptr:(Ptr (Int & Int)). + store ptr (4, 3) + load ptr +> (4, 3) + +:p unsafeIO \(). + withAlloc 1 \ptr:(Ptr Int). + store ptr 3 + load ptr +> 3 + +:p unsafeIO \(). + withDynamicBuffer \buf. + extendDynBuffer buf $ toList for i:(Fin 1000). ordinal i + extendDynBuffer buf $ toList for i:(Fin 1000). ordinal i + (AsList _ xs) = loadDynBuffer buf + sum xs +> 999000 + +:p unsafeIO \(). + s = for i:(Fin 10000). IToW8 $ FToI $ 128.0 * rand (ixkey (newKey 0) i) + withTempFile \fname. + withFile fname WriteMode \stream. + fwrite stream $ AsList _ s + (AsList _ s') = readFile fname + sum (for i. W8ToI s.i) == sum (for i. W8ToI s'.i) +> True From 20433345aea440924aea441ee92f9c8ce90057e1 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Thu, 24 Dec 2020 16:43:39 -0500 Subject: [PATCH 026/105] Add animations by shelling out to imagemagick to turn PNGs into a GIF. --- examples/fluidsim.dx | 21 ++++++++++----------- lib/plot.dx | 2 +- lib/png.dx | 30 ++++++++++++++++++++++++------ lib/prelude.dx | 13 +++++++++++++ 4 files changed, 48 insertions(+), 18 deletions(-) diff --git a/examples/fluidsim.dx b/examples/fluidsim.dx index 39bd393a6..52f745057 100644 --- a/examples/fluidsim.dx +++ b/examples/fluidsim.dx @@ -96,15 +96,15 @@ def advect (_:VSpace a) ?=> (f: n=>m=>a) (v: n=>m=>(Fin 2)=>Float) : n=>m=>a = bilinear_interp right_weight bottom_weight f.l.t f.l.b f.r.t f.r.b def fluidsim (_: VSpace a) ?=> (num_steps: Int) (color_init: n=>m=>a) - (v: n=>m=>(Fin 2)=>Float) : n=>m=>a = - (color_final, v) = snd $ withState (color_init, v) \state. + (v: n=>m=>(Fin 2)=>Float) : (Fin num_steps)=>n=>m=>a = + fst $ withState (color_init, v) \state. for i:(Fin num_steps). (color, v) = get state v = advect v v -- Move velocities v = project v -- Project to be volume-preserving - color = advect color v -- Move color - state := (color, v) - color_final + color' = advect color v -- Move color + state := (color', v) + color '### Demo @@ -126,19 +126,18 @@ init_color = for i:N j:M. -- Run fluid sim and plot it. num_steps = 5 -final_color = fluidsim num_steps init_color init_velocity - -:html imshow final_color +:html imseqshow $ fluidsim num_steps init_color init_velocity > - - '### Gradient test target = transpose init_color +-- This is partial +def last (xs:n=>a) : a = xs.((size n - 1)@_) + def objective (v:N=>M=>(Fin 2)=>Float) : Float = - final_color = fluidsim num_steps init_color v + final_color = last $ fluidsim num_steps init_color v sum for (i, j, c). sq (final_color.i.j.c - target.i.j.c) init_vel_grad = grad objective zero diff --git a/lib/plot.dx b/lib/plot.dx index 396bd7498..0212ad537 100644 --- a/lib/plot.dx +++ b/lib/plot.dx @@ -122,6 +122,6 @@ def yPlot (ys:n=>Float) : Plot n Float Float Unit = def matshow (img:n=>m=>Float) : Html = low = minimum $ for (i,j). img.i.j high = maximum $ for (i,j). img.i.j - pngToHtml $ makePNG for i j. + imgToHtml $ makePNG for i j. x = floatTo8Bit $ (img.i.j - low) / (high - low) [x, x, x] diff --git a/lib/png.dx b/lib/png.dx index f9d0a7303..261cf8c8a 100644 --- a/lib/png.dx +++ b/lib/png.dx @@ -91,24 +91,42 @@ def base64Decode (s:String) : Maybe String = '## PNG FFI -Html : Type = List Char +Html : Type = String +Png : Type = String +Gif : Type = String -def makePNG (img:n=>m=>(Fin 3)=>Word8) : List Byte = unsafeIO \(). +def makePNG (img:n=>m=>(Fin 3)=>Word8) : Png = unsafeIO \(). (AsList _ imgFlat) = toList for (i,j,k). img.i.j.k withTabPtr imgFlat \ptr. (MkPtr rawPtr) = ptr (n, ptr') = %ffi encodePNG (Int & CharPtr) rawPtr (size m) (size n) toList $ tabFromPtr (Fin n) $ MkPtr ptr' -def pngToHtml (png:List Byte) : List Char = +def pngsToGif (delay:Int) (pngs:t=>Png) : Gif = unsafeIO \(). + withTempFiles \pngFiles. + for i. writeFile pngFiles.i pngs.i + withTempFile \gifFile. + shellOut $ + "convert" <> " -delay " <> show delay <> " " <> + concat (for i. "png:" <> pngFiles.i <> " ") <> + "gif:" <> gifFile + readFile gifFile + +def imgToHtml (png:String) : Html = (" base64Encode png <> "\">") -'## API entry point - def floatTo8Bit (x:Float) : Word8 = IToW8 $ FToI $ 255.0 * clip (0.0, 1.0) x +def imgToPng (img:n=>m=>(Fin 3)=>Float) : Png = + makePNG for i j k. floatTo8Bit img.i.j.k + +'## API entry point + def imshow (img:n=>m=>(Fin 3)=>Float) : Html = - pngToHtml $ makePNG for i j k. floatTo8Bit img.i.j.k + imgToHtml $ imgToPng img + +def imseqshow (imgs:t=>n=>m=>(Fin 3)=>Float) : Html = + imgToHtml $ pngsToGif 50 $ map imgToPng imgs diff --git a/lib/prelude.dx b/lib/prelude.dx index 650146d33..7b05143c3 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -1057,12 +1057,25 @@ def withTempFile (action: FilePath -> {State World} a) : {State World} a = deleteFile tmpFile result +def withTempFiles (action: (n=>FilePath) -> {State World} a) : {State World} a = + tmpFiles = for i. writeTemp "" + result = action tmpFiles + for i. deleteFile tmpFiles.i + result + def getOutputStream (_:Unit) : {State World} Stream WriteMode = MkStream $ %ptrLoad OUT_STREAM_PTR def print (s:String) : {State World} Unit = fwrite (getOutputStream ()) (s <> "\n") +def shellOut (command:String) : {State World} String = + modeStr = "r" + withCString command \(MkCString commandPtr). + withCString modeStr \(MkCString modePtr). + pipe = MkStream %ffi popen CharPtr commandPtr modePtr + fread pipe + 'Partial functions def error (s:String) : a = unsafeIO \(). From 9674910b34daa27e9b7bd3de3db172e967966b24 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 23 Dec 2020 23:22:42 +0000 Subject: [PATCH 027/105] Make show identity on strings --- lib/prelude.dx | 3 +++ tests/show-tests.dx | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/lib/prelude.dx b/lib/prelude.dx index e90b538e7..46b05fe53 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -869,6 +869,9 @@ def codepoint (c:Char) : Int = W8ToI c interface Show a:Type where show : a -> String +instance showString : Show String where + show = id + instance showInt32 : Show Int32 where show = \x: Int32. (n, ptr) = %ffi showInt32 (Int32 & CharPtr) x diff --git a/tests/show-tests.dx b/tests/show-tests.dx index 0f3e26171..da1cbaa6e 100644 --- a/tests/show-tests.dx +++ b/tests/show-tests.dx @@ -1,4 +1,8 @@ '# `Show` instances +-- String + +:p show (AsList _ "abc") +> (AsList 3 "abc") -- Int32 From 269c26cc6369606e4be45f2793314d45393cc679 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 24 Dec 2020 01:15:16 +0000 Subject: [PATCH 028/105] Refactor diagram.dx remove excess whitespace remove viewBoxDims variable get ride of viewboxdims --- lib/diagram.dx | 60 +++++++++++++++++++++++++------------------------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/lib/diagram.dx b/lib/diagram.dx index e19fd325b..1598b8ca2 100644 --- a/lib/diagram.dx +++ b/lib/diagram.dx @@ -112,8 +112,8 @@ def quote (s:String) : String = str "\"" <.> s <.> str "\"" def strSpaceCatUncurried ((s1,s2):(String & String)) : String = s1 <.> str " " <.> s2 -def (<+>) (s1:String) (s2:String) : String = - strSpaceCatUncurried (s1, s2) +def (<+>) (_:Show a) ?=> (_:Show b) ?=> (s1:a) (s2:b) : String = + strSpaceCatUncurried ((show s1), (show s2)) def selfClosingBrackets (s:String) : String = str "<" <.> s <.> str "/>" @@ -127,8 +127,8 @@ def tagBracketsAttrUncurried ((tag, attr, s):(String & String & String)) : Strin def tagBracketsAttr (tag:String) (attr:String) (s:String) : String = tagBracketsAttrUncurried (tag, attr, s) -def makeAttr (attr:String) (val:String) : String = - attr <.> str "=" <.> quote val +def (<=>) (_:Show b) ?=> (attr:String) (val:b) : String = + attr <.> str "=" <.> quote (show val) def htmlColorStr (cs:HtmlColor) : String = (r, g, b) = cs @@ -141,32 +141,33 @@ def optionalHtmlColorStr (c: Maybe HtmlColor) : String = @noinline def attrString (attr:GeomStyle) : String = - ( -- makeAttr (str "stroke") (optionalHtmlColorStr $ getAt #strokeColor attr) - makeAttr (str "fill") (optionalHtmlColorStr $ getAt #fillColor attr) - <+> makeAttr (str "stroke-width") (show $ getAt #strokeWidth attr)) + ( -- (str "stroke") <=> (optionalHtmlColorStr $ getAt #strokeColor attr) + (str "fill") <=> (optionalHtmlColorStr $ getAt #fillColor attr) + <+> (str "stroke-width") <=> (getAt #strokeWidth attr)) def renderGeom (attr:GeomStyle) ((x,y):Point) (geom:Geom) : String = + groupEle = \attr. tagBracketsAttr (str "g") (attrString attr) case geom of PointGeom -> - pointAttr = setAt #fillColor (getAt #strokeColor attr) attr - tagBracketsAttr (str "g") (attrString pointAttr) $ selfClosingBrackets $ - (str "circle" <+> - str "cx=" <.> quote (show x) <.> - str "cy=" <.> quote (show y) <.> - str "r=\"1\"") + pointAttr = setAt #fillColor (getAt #strokeColor attr) attr + groupEle pointAttr $ selfClosingBrackets $ + (str "circle" <+> + str "cx" <=> x <.> + str "cy" <=> y <.> + str "r=\"1\"") Circle r -> - tagBracketsAttr (str "g") (attrString attr) $ selfClosingBrackets $ - (str "circle" <+> - str "cx=" <.> quote (show x) <.> - str "cy=" <.> quote (show y) <.> - str "r=" <.> quote (show r)) + groupEle attr $ selfClosingBrackets $ + (str "circle" <+> + str "cx" <=> x <.> + str "cy" <=> y <.> + str "r" <=> r) Rectangle w h -> - tagBracketsAttr (str "g") (attrString attr) $ selfClosingBrackets $ - (str "rect" <+> - str "width=" <.> quote (show w) <.> - str "height=" <.> quote (show h) <.> - str "x=" <.> quote (show (x - (w/2.0))) <.> - str "y=" <.> quote (show (y - (h/2.0)))) + groupEle attr $ selfClosingBrackets $ + (str "rect" <+> + str "width" <=> w <.> + str "height" <=> h <.> + str "x" <=> (x - (w/2.0)) <.> + str "y" <=> (y - (h/2.0))) BoundingBox : Type = (Point & Point) @@ -175,13 +176,12 @@ def renderSVG (d:Diagram) (bounds:BoundingBox) : String = imgWidth = 400.0 scaleFactor = imgWidth / (xmax - xmin) imgHeight = (ymax - ymin) * scaleFactor + imgXMin = xmin * scaleFactor + imgYMin = -ymax * scaleFactor (MkDiagram (AsList _ objs)) = d |> flipY |> scale scaleFactor - viewBoxStr = makeAttr (str "viewBox") $ - (show (xmin * scaleFactor) <+> show (-(ymax * scaleFactor)) <+> - show imgWidth <+> show imgHeight) - svgAttrStr = ( makeAttr (str "width" ) (show imgWidth) - <+> makeAttr (str "height") (show imgHeight) - <+> viewBoxStr) + svgAttrStr = ( str "width" <=> imgWidth + <+> str "height" <=> imgHeight + <+> str "viewBox" <=> (imgXMin <+> imgYMin <+> imgWidth <+> imgHeight)) tagBracketsAttr (str "svg") svgAttrStr $ concat for i. (attr, pos, geom) = objs.i From bdf1df132cca4353c3073b45259e2c4c859514d7 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Sat, 26 Dec 2020 14:08:52 -0500 Subject: [PATCH 029/105] Make changes suggested in review. --- lib/diagram.dx | 2 +- lib/png.dx | 2 +- lib/prelude.dx | 54 ++++++++++++++++++++---------------------- src/lib/Autodiff.hs | 2 -- src/lib/Embed.hs | 6 ++--- src/lib/Export.hs | 2 +- src/lib/Imp.hs | 39 +++++++++++++++--------------- src/lib/Interpreter.hs | 10 ++++---- src/lib/JIT.hs | 8 +++---- src/lib/PPrint.hs | 4 ---- src/lib/Syntax.hs | 23 +++++------------- src/lib/Type.hs | 21 ++++++---------- src/lib/dexrt.cpp | 8 ------- 13 files changed, 74 insertions(+), 107 deletions(-) diff --git a/lib/diagram.dx b/lib/diagram.dx index 8663769af..f87be1e85 100644 --- a/lib/diagram.dx +++ b/lib/diagram.dx @@ -13,7 +13,7 @@ data Geom = HtmlColor : Type = (Word8 & Word8 & Word8) def showHex (x:Int32) : String = unsafeIO \(). - (n, ptr) = %ffi showHex (Int32 & CharPtr) x + (n, ptr) = %ffi showHex (Int32 & RawPtr) x stringFromCharPtr n (MkPtr ptr) black : HtmlColor = (IToW8 0, IToW8 0, IToW8 0) diff --git a/lib/png.dx b/lib/png.dx index 261cf8c8a..131f7c609 100644 --- a/lib/png.dx +++ b/lib/png.dx @@ -99,7 +99,7 @@ def makePNG (img:n=>m=>(Fin 3)=>Word8) : Png = unsafeIO \(). (AsList _ imgFlat) = toList for (i,j,k). img.i.j.k withTabPtr imgFlat \ptr. (MkPtr rawPtr) = ptr - (n, ptr') = %ffi encodePNG (Int & CharPtr) rawPtr (size m) (size n) + (n, ptr') = %ffi encodePNG (Int & RawPtr) rawPtr (size m) (size n) toList $ tabFromPtr (Fin n) $ MkPtr ptr' def pngsToGif (delay:Int) (pngs:t=>Png) : Gif = unsafeIO \(). diff --git a/lib/prelude.dx b/lib/prelude.dx index 7b05143c3..7a0f62274 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -400,12 +400,9 @@ def finOrd (n:Int) ?-> : Ord (Fin n) = '## Raw pointer operations -Int32Ptr : Type = %Int32Ptr -Word8Ptr : Type = %Word8Ptr +RawPtr : Type = %Word8Ptr -CharPtr = Word8Ptr - -data Ptr a:Type = MkPtr Word8Ptr +data Ptr a:Type = MkPtr RawPtr -- Is there a better way to select the right instance for `storageSize`?? data TypeVehicle a:Type = MkTypeVehicle @@ -416,13 +413,15 @@ interface Storable a:Type where load : Ptr a -> {State World} a storageSize : TypeVehicle a -> Int --- TODO: there's a bug preventing us inlining these definitions into the instance -def charStore ((MkPtr ptr): Ptr Word8) (x:Word8) : {State World} Unit = %ptrStore ptr x -def charLoad ((MkPtr ptr): Ptr Word8) : {State World} Word8 = %ptrLoad ptr +-- TODO: we can't inline these into the instance definitions until we change +-- type inference to push types down into record constructors or allow `def` in +-- instance definitions. +def word8Store ((MkPtr ptr): Ptr Word8) (x:Word8) : {State World} Unit = %ptrStore ptr x +def word8Load ((MkPtr ptr): Ptr Word8) : {State World} Word8 = %ptrLoad ptr -instance charStorable : Storable Word8 where - store = charStore - load = charLoad +instance word8Storable : Storable Word8 where + store = word8Store + load = word8Load storageSize = const 1 -- TODO: there's a bug preventing us inlining these definitions into the instance @@ -478,11 +477,11 @@ instance ptrStorable : Storable (Ptr a) where def malloc (_:Storable a) ?=> (n:Int) : {State World} (Ptr a) = numBytes = storageSize (typeVehicle a) * n - MkPtr $ %charAlloc numBytes + MkPtr $ %alloc numBytes def free (ptr:Ptr a) : {State World} Unit = (MkPtr ptr') = ptr - %charFree ptr' + %free ptr' def (+>>) (_:Storable a) ?=> (ptr:Ptr a) (i:Int) : Ptr a = (MkPtr ptr') = ptr @@ -908,7 +907,7 @@ def loadDynBuffer (_:Storable a) ?=> String : Type = List Char -def stringFromCharPtr (n:Int) (ptr:Ptr Char) : {State World} String= +def stringFromCharPtr (n:Int) (ptr:Ptr Char) : {State World} String = AsList n $ tabFromPtr _ ptr -- TODO. This is ASCII code point. It really should be Int32 for Unicode codepoint @@ -919,22 +918,22 @@ interface Show a:Type where instance showInt32 : Show Int32 where show = \x: Int32. unsafeIO \(). - (n, ptr) = %ffi showInt32 (Int32 & CharPtr) x + (n, ptr) = %ffi showInt32 (Int32 & RawPtr) x stringFromCharPtr n $ MkPtr ptr instance showInt64 : Show Int64 where show = \x: Int64. unsafeIO \(). - (n, ptr) = %ffi showInt64 (Int32 & CharPtr) x + (n, ptr) = %ffi showInt64 (Int32 & RawPtr) x stringFromCharPtr n $ MkPtr ptr instance showFloat32 : Show Float32 where show = \x: Float32.unsafeIO \(). - (n, ptr) = %ffi showFloat32 (Int32 & CharPtr) x + (n, ptr) = %ffi showFloat32 (Int32 & RawPtr) x stringFromCharPtr n $ MkPtr ptr instance showFloat64 : Show Float64 where show = \x: Float64.unsafeIO \(). - (n, ptr) = %ffi showFloat64 (Int32 & CharPtr) x + (n, ptr) = %ffi showFloat64 (Int32 & RawPtr) x stringFromCharPtr n $ MkPtr ptr -- pipe-like reverse function application @@ -970,13 +969,13 @@ def either_is_nan (x:Float) (y:Float) : Bool = (isnan x) || (isnan y) 'File system operations FilePath : Type = String -data CString = MkCString CharPtr +data CString = MkCString RawPtr data StreamMode = ReadMode WriteMode -data Stream mode:StreamMode = MkStream CharPtr +data Stream mode:StreamMode = MkStream RawPtr -- TODO: check the string contains no nulls def withCString (s:String) (action: CString -> {State World} a) : {State World} a = @@ -989,7 +988,7 @@ def fopen (path:String) (mode:StreamMode) : {State World} (Stream mode) = WriteMode -> "w" withCString path \(MkCString pathPtr). withCString modeStr \(MkCString modePtr). - MkStream $ %ffi fopen CharPtr pathPtr modePtr + MkStream $ %ffi fopen RawPtr pathPtr modePtr def fclose (mode:StreamMode) ?-> (stream:Stream mode) : {State World} Unit = (MkStream stream') = stream @@ -1044,21 +1043,20 @@ def writeFile (f:FilePath) (s:String) : {State World} Unit = def readFile (f:FilePath) : {State World} String = withFile f ReadMode \stream. fread stream -def writeTemp (s:String) : {State World} FilePath = - -- TODO: Make this properly atomic. It can fail if another process creates a - -- file with same name after we ask for the name and before we create it. +def newTempFile (_:Unit) : {State World} FilePath = withCString "/tmp/dex-XXXXXX" \(MkCString ptr). - %ffi mktemp CharPtr ptr + fd = %ffi mkstemp Int32 ptr + %ffi close Int32 fd stringFromCharPtr 15 (MkPtr ptr) def withTempFile (action: FilePath -> {State World} a) : {State World} a = - tmpFile = writeTemp "" + tmpFile = newTempFile () result = action tmpFile deleteFile tmpFile result def withTempFiles (action: (n=>FilePath) -> {State World} a) : {State World} a = - tmpFiles = for i. writeTemp "" + tmpFiles = for i. newTempFile () result = action tmpFiles for i. deleteFile tmpFiles.i result @@ -1073,7 +1071,7 @@ def shellOut (command:String) : {State World} String = modeStr = "r" withCString command \(MkCString commandPtr). withCString modeStr \(MkCString modePtr). - pipe = MkStream %ffi popen CharPtr commandPtr modePtr + pipe = MkStream %ffi popen RawPtr commandPtr modePtr fread pipe 'Partial functions diff --git a/src/lib/Autodiff.hs b/src/lib/Autodiff.hs index 59cdf502d..27eb99967 100644 --- a/src/lib/Autodiff.hs +++ b/src/lib/Autodiff.hs @@ -141,7 +141,6 @@ linearizeOp op = case op of IOFree _ -> emitDiscrete TabCon ty xs -> (TabCon ty <$> traverse la xs) `bindLin` emitOp Inject _ -> emitDiscrete - MakePtrType _ -> emitDiscrete SliceOffset _ _ -> emitDiscrete SliceCurry _ _ -> emitDiscrete VectorBinOp _ _ _ -> notImplemented @@ -599,7 +598,6 @@ transposeOp op ct = case op of else transposeAtom y =<< mul ct =<< substNonlin x ScalarBinOp FDiv x y -> transposeAtom x =<< div' ct =<< substNonlin y ScalarBinOp _ _ _ -> notLinear - MakePtrType _ -> notLinear PrimEffect refArg m -> do refArg' <- substTranspose linRefSubst refArg let emitEff = emitOp . PrimEffect refArg' diff --git a/src/lib/Embed.hs b/src/lib/Embed.hs index 91e2af207..92e43df78 100644 --- a/src/lib/Embed.hs +++ b/src/lib/Embed.hs @@ -332,9 +332,9 @@ ptrOffset x i = emitOp $ PtrOffset x i unsafePtrLoad :: MonadEmbed m => Atom -> m Atom unsafePtrLoad x = emit $ Hof $ RunIO $ Lam $ Abs (Ignore UnitTy) $ (PlainArrow justIOEff, Block Empty (Op (PtrLoad x))) - -justIOEff :: EffectRow -justIOEff = EffectRow [(State, theWorld)] Nothing + where + justIOEff :: EffectRow + justIOEff = EffectRow [(State, theWorld)] Nothing ptrLoad :: MonadEmbed m => Atom -> m Atom ptrLoad x = emitOp $ PtrLoad x diff --git a/src/lib/Export.hs b/src/lib/Export.hs index d5c9e9472..0db91a501 100644 --- a/src/lib/Export.hs +++ b/src/lib/Export.hs @@ -148,7 +148,7 @@ prepareFunctionForExport env nameStr func = do -- TODO: I guess that the address space depends on the backend? -- TODO: Have an ExternalPtr tag? - ptrTy ty = PtrType (DerivedPtr, Heap CPU, ty) + ptrTy ty = PtrType (Heap CPU, ty) getRectShape :: Env () -> IndexStructure -> Maybe [Either Name Int] getRectShape scope idx = traverse (dimShape . binderType) $ toList idx diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index bcff11c93..d2b655818 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -101,7 +101,8 @@ requiredFunctions scope expr = flip foldMap (transitiveClosure getParents immediateParents) $ \fname -> case scope ! fname of (_, LetBound _ (Atom f)) -> [(fname, f)] - _ -> [] + (_, LamBound _) -> [] + _ -> error "Shouldn't have other free variables left" where getParents :: Name -> [Name] getParents fname = envNames $ freeVars $ scope ! fname @@ -273,7 +274,7 @@ toImpOp (maybeDest, op) = case op of FstRef ~(Con (ConRef (PairCon ref _ ))) -> returnVal ref SndRef ~(Con (ConRef (PairCon _ ref))) -> returnVal ref IOAlloc ty n -> do - ptr <- emitAlloc (AllocatedPtr, Heap CPU, ty) (fromScalarAtom n) + ptr <- emitAlloc (Heap CPU, ty) (fromScalarAtom n) returnVal $ toScalarAtom ptr IOFree ptr -> do emitStatement $ Free $ fromScalarAtom ptr @@ -513,7 +514,7 @@ toImpHof env (maybeDest, hof) = do copyAtom sDest =<< impSubst env s void $ translateBlock (env <> ref @> sDest) (Just aDest, body) PairVal <$> destToAtom aDest <*> destToAtom sDest - RunIO ~(Lam (Abs _ (_, body))) -> do + RunIO ~(Lam (Abs _ (_, body))) -> translateBlock env (maybeDest, body) Linearize _ -> error "Unexpected Linearize" Transpose _ -> error "Unexpected Transpose" @@ -649,7 +650,7 @@ makeBaseTypePtr ty = do -- where they could cast shadows let addrSpace = chooseAddrSpace allocInfo numel let ptrName = genFresh (Name AllocPtrName "ptr" 0) (scope <> ptrScope) - let ptrTy = PtrTy (AllocatedPtr, addrSpace, ty) + let ptrTy = PtrTy (addrSpace, ty) extend (ptrName @> (ptrTy, numel)) let ptr = Var (ptrName :> ptrTy) applyIdxs ptr idxs @@ -792,7 +793,7 @@ makeAllocDestWithPtrs allocTy ty = do (env, ptrs) <- flip foldMapM ptrsSizes $ \(Bind (ptr:>PtrTy ptrTy), size) -> do ptr' <- emitAlloc ptrTy $ fromScalarAtom size case ptrTy of - (_, Heap _, _) | allocTy == Managed -> extendAlloc ptr' + (Heap _, _) | allocTy == Managed -> extendAlloc ptr' _ -> return () return (ptr @> toScalarAtom ptr', [ptr']) dest' <- impSubst env dest @@ -880,7 +881,7 @@ isSmall numel = case numel of allocateBuffer :: AddressSpace -> Bool -> BaseType -> IExpr -> ImpM IExpr allocateBuffer addrSpace mustFree b numel = do - buffer <- emitAlloc (AllocatedPtr, addrSpace, b) numel + buffer <- emitAlloc (addrSpace, b) numel when mustFree $ extendAlloc buffer return buffer @@ -976,7 +977,7 @@ addToAtom dest src = case (dest, src) of loadAnywhere :: IExpr -> ImpM IExpr loadAnywhere ptr = do curDev <- asks curDevice - let (PtrType (_, addrSpace, ty)) = getIType ptr + let (PtrType (addrSpace, ty)) = getIType ptr case addrSpace of Heap ptrDev | ptrDev /= curDev -> do localPtr <- allocateStackSingleton ty @@ -987,7 +988,7 @@ loadAnywhere ptr = do storeAnywhere :: IExpr -> IExpr -> ImpM () storeAnywhere ptr val = do curDev <- asks curDevice - let (PtrType (_, addrSpace, ty)) = getIType ptr + let (PtrType (addrSpace, ty)) = getIType ptr case addrSpace of Heap ptrDev | ptrDev /= curDev -> do localPtr <- allocateStackSingleton ty @@ -1114,7 +1115,7 @@ extendAlloc :: IExpr -> ImpM () extendAlloc v = extend $ mempty { envPtrsToFree = [v] } emitAlloc :: HasCallStack => PtrType -> IExpr -> ImpM IExpr -emitAlloc (_, addr, ty) n = emitInstr $ Alloc addr ty n +emitAlloc (addr, ty) n = emitInstr $ Alloc addr ty n scopedErrBlock :: ImpM () -> ImpM ImpBlock scopedErrBlock body = liftM snd $ scopedBlock $ handleErrors body $> ((),[]) @@ -1221,14 +1222,14 @@ instrTypeChecked instr = case instr of return dt Alloc a ty _ -> (:[]) <$> do when (a /= Stack) assertHost - return $ PtrType (AllocatedPtr, a, ty) + return $ PtrType (a, ty) MemCopy dest src numel -> [] <$ do - PtrType (_, _, destTy) <- checkIExpr dest - PtrType (_, _, srcTy) <- checkIExpr src + PtrType (_, destTy) <- checkIExpr dest + PtrType (_, srcTy) <- checkIExpr src assertEq destTy srcTy "pointer type mismatch" checkInt numel Store dest val -> [] <$ do - PtrType (_, addr, ty) <- checkIExpr dest + PtrType (addr, ty) <- checkIExpr dest checkAddrAccessible addr valTy <- checkIExpr val assertEq ty valTy "Type mismatch in store" @@ -1289,12 +1290,12 @@ checkImpOp op = do checkIntBaseType False $ BaseTy ibt return $ Scalar ty PtrLoad ref -> do - PtrType (_, addr, ty) <- return ref + PtrType (addr, ty) <- return ref checkAddrAccessible addr return ty PtrOffset ref _ -> do -- TODO: check offset too - PtrType (_, addr, ty) <- return ref - return $ PtrType (DerivedPtr, addr, ty) + PtrType (addr, ty) <- return ref + return $ PtrType (addr, ty) _ -> error $ "Not allowed in Imp IR: " ++ pprint op where checkEq :: (Pretty a, Show a, Eq a) => a -> a -> ImpCheckM () @@ -1322,7 +1323,7 @@ impInstrTypes :: ImpInstr -> [IType] impInstrTypes instr = case instr of IPrimOp op -> [impOpType op] ICastOp t _ -> [t] - Alloc a ty _ -> [PtrType (AllocatedPtr, a, ty)] + Alloc a ty _ -> [PtrType (a, ty)] Store _ _ -> [] Free _ -> [] IThrowError -> [] @@ -1358,8 +1359,8 @@ impOpType pop = case pop of Select _ x _ -> getIType x VectorPack xs -> Vector ty where Scalar ty = getIType $ head xs VectorIndex x _ -> Scalar ty where Vector ty = getIType x - PtrLoad ref -> ty where PtrType (_, _, ty) = getIType ref - PtrOffset ref _ -> PtrType (DerivedPtr, addr, ty) where PtrType (_, addr, ty) = getIType ref + PtrLoad ref -> ty where PtrType (_, ty) = getIType ref + PtrOffset ref _ -> PtrType (addr, ty) where PtrType (addr, ty) = getIType ref _ -> unreachable where unreachable = error $ "Not allowed in Imp IR: " ++ pprint pop diff --git a/src/lib/Interpreter.hs b/src/lib/Interpreter.hs index 2c0e3edc0..3f9911119 100644 --- a/src/lib/Interpreter.hs +++ b/src/lib/Interpreter.hs @@ -96,14 +96,14 @@ evalOp expr = case expr of "randunif" -> Float64Val $ c_unif x where [Int64Val x] = args "threefry2x32" -> Int64Val $ c_threefry x y where [Int64Val x, Int64Val y] = args _ -> error $ "FFI function not recognized: " ++ name - PtrOffset (Con (Lit (PtrLit (_, a, t) p))) (IdxRepVal i) -> - return $ Con $ Lit $ PtrLit (DerivedPtr, a, t) $ p `plusPtr` (sizeOf t * fromIntegral i) - PtrLoad (Con (Lit (PtrLit (_, Heap CPU, t) p))) -> Con . Lit <$> loadLitVal p t - PtrLoad (Con (Lit (PtrLit (_, Heap GPU, t) p))) -> + PtrOffset (Con (Lit (PtrLit (a, t) p))) (IdxRepVal i) -> + return $ Con $ Lit $ PtrLit (a, t) $ p `plusPtr` (sizeOf t * fromIntegral i) + PtrLoad (Con (Lit (PtrLit (Heap CPU, t) p))) -> Con . Lit <$> loadLitVal p t + PtrLoad (Con (Lit (PtrLit (Heap GPU, t) p))) -> allocaBytes (sizeOf t) $ \hostPtr -> do loadCUDAArray hostPtr p (sizeOf t) Con . Lit <$> loadLitVal hostPtr t - PtrLoad (Con (Lit (PtrLit (_, Stack, _) _))) -> + PtrLoad (Con (Lit (PtrLit (Stack, _) _))) -> error $ "Unexpected stack pointer in interpreter" ToOrdinal idxArg -> case idxArg of Con (IntRangeVal _ _ i) -> return i diff --git a/src/lib/JIT.hs b/src/lib/JIT.hs index b4bf3e77e..5261b519b 100644 --- a/src/lib/JIT.hs +++ b/src/lib/JIT.hs @@ -266,15 +266,15 @@ compileInstr instr = case instr of GPU -> cuMemAlloc elemTy numBytes where elemTy = scalarTy t Free ptr -> [] <$ do - let PtrType (_, addr, _) = getIType ptr + let PtrType (addr, _) = getIType ptr ptr' <- compileExpr ptr case addr of Heap CPU -> free ptr' Heap GPU -> cuMemFree ptr' Stack -> error "Shouldn't be freeing alloca" MemCopy dest src numel -> [] <$ do - let PtrType (_, destAddr, ty) = getIType dest - let PtrType (_, srcAddr , _ ) = getIType src + let PtrType (destAddr, ty) = getIType dest + let PtrType (srcAddr , _ ) = getIType src destDev <- deviceFromAddr destAddr srcDev <- deviceFromAddr srcAddr dest' <- compileExpr dest >>= castVoidPtr @@ -766,7 +766,7 @@ scalarTy b = case b of Float64Type -> fp64 Float32Type -> fp32 Vector sb -> L.VectorType (fromIntegral vectorWidth) $ scalarTy $ Scalar sb - PtrType (_, s, t) -> L.PointerType (scalarTy t) (lAddress s) + PtrType (s, t) -> L.PointerType (scalarTy t) (lAddress s) hostPtrTy :: L.Type -> L.Type hostPtrTy ty = L.PointerType ty $ L.AddrSpace 0 diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index 5f2bb073c..d37847ac1 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -118,10 +118,6 @@ instance PrettyPrec BaseType where Vector sb -> atPrec ArgPrec $ "<" <> p vectorWidth <+> "x" <+> p sb <> ">" PtrType ty -> atPrec AppPrec $ "Ptr" <+> p ty -instance Pretty PtrOrigin where - pretty AllocatedPtr = "a" - pretty DerivedPtr = "d" - instance Pretty AddressSpace where pretty Stack = "stack" pretty (Heap d) = p (show d) diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index 62c2f0067..b6d9dc01b 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -38,7 +38,7 @@ module Syntax ( addSrcContext, catchIOExcept, liftEitherIO, (-->), (--@), (==>), boundUVars, PassName (..), boundVars, renamingSubst, bindingsAsVars, freeVars, freeUVars, Subst, HasVars, BindsVars, Ptr, PtrType, - AddressSpace (..), PtrOrigin (..), showPrimName, strToPrimName, primNameToStr, + AddressSpace (..), showPrimName, strToPrimName, primNameToStr, monMapSingle, monMapLookup, Direction (..), Limit (..), UExpr, UExpr' (..), UType, UPatAnn, UPiPatAnn, UAnnBinder, UVar, UPat, UPat' (..), UModule (..), UDecl (..), UArrow, arrowEff, @@ -328,7 +328,6 @@ data PrimOp e = | SliceOffset e e -- Index slice first, inner index second | SliceCurry e e -- Index slice first, curried index second -- Low-level memory operations - | MakePtrType e | IOAlloc BaseType e | IOFree e | PtrOffset e e @@ -465,7 +464,7 @@ initTopEnv = fold [v @> (ty, LamBound ImplicitArrow) | (v, ty) <- , (outputStreamPtrName , BaseTy $ hostPtrTy $ hostPtrTy $ Scalar Word8Type)]] hostPtrTy :: BaseType -> BaseType -hostPtrTy ty = PtrType (AllocatedPtr, Heap CPU, ty) +hostPtrTy ty = PtrType (Heap CPU, ty) -- === top-level constructs === @@ -574,15 +573,7 @@ data BaseType = Scalar ScalarBaseType data Device = CPU | GPU deriving (Show, Eq, Ord, Generic) data AddressSpace = Stack | Heap Device deriving (Show, Eq, Ord, Generic) -data PtrOrigin = DerivedPtr | AllocatedPtr deriving (Show, Ord, Generic) -type PtrType = (PtrOrigin, AddressSpace, BaseType) - -instance Eq PtrOrigin where - -- XXX: this is a hack. We expose pointer operations to the surface language - -- but we don't yet expose the derived/allocated distinction, and they get - -- mixed up when we use ops like ptrOffset. - _ == _ = True - +type PtrType = (AddressSpace, BaseType) sizeOf :: BaseType -> Int sizeOf t = case t of @@ -1566,12 +1557,11 @@ builtinNames = M.fromList , ("cast", OpExpr $ CastOp () ()) , ("sliceOffset", OpExpr $ SliceOffset () ()) , ("sliceCurry", OpExpr $ SliceCurry () ()) - , ("charAlloc", OpExpr $ IOAlloc (Scalar Word8Type) ()) - , ("charFree" , OpExpr $ IOFree ()) + , ("alloc", OpExpr $ IOAlloc (Scalar Word8Type) ()) + , ("free" , OpExpr $ IOFree ()) , ("ptrOffset", OpExpr $ PtrOffset () ()) , ("ptrLoad" , OpExpr $ PtrLoad ()) , ("ptrStore" , OpExpr $ PtrStore () ()) - , ("makePtrType", OpExpr $ MakePtrType ()) , ("dataConTag", OpExpr $ DataConTag ()) , ("toEnum" , OpExpr $ ToEnum () ()) ] @@ -1579,7 +1569,7 @@ builtinNames = M.fromList vbinOp op = OpExpr $ VectorBinOp op () () binOp op = OpExpr $ ScalarBinOp op () () unOp op = OpExpr $ ScalarUnOp op () - ptrTy ty = PtrType (AllocatedPtr, Heap CPU, ty) + ptrTy ty = PtrType (Heap CPU, ty) instance Store a => Store (PrimOp a) instance Store a => Store (PrimCon a) @@ -1611,6 +1601,5 @@ instance Store LitVal instance Store ScalarBaseType instance Store BaseType instance Store AddressSpace -instance Store PtrOrigin instance Store Device instance Store DataConRefBinding diff --git a/src/lib/Type.hs b/src/lib/Type.hs index e093ace0a..e177709b6 100644 --- a/src/lib/Type.hs +++ b/src/lib/Type.hs @@ -158,7 +158,7 @@ instance HasType Atom where checkDataConRefBindings argBs' args return $ RawRefTy $ TypeCon def params BoxedRef b ptr numel body -> do - PtrTy (_, _, t) <- typeCheck ptr + PtrTy (_, t) <- typeCheck ptr checkEq (binderAnn b) (BaseTy t) numel |: IdxRepTy void $ typeCheck b @@ -591,7 +591,7 @@ typeCheckCon con = case con of IndexRangeVal t l h i -> i|:IdxRepTy >> return (TC $ IndexRange t l h) IndexSliceVal _ _ _ -> error "not implemented" BaseTypeRef p -> do - (PtrTy (_, _, b)) <- typeCheck p + (PtrTy (_, b)) <- typeCheck p return $ RawRefTy $ BaseTy b TabRef tabTy -> do TabTy b (RawRefTy a) <- typeCheck tabTy @@ -654,7 +654,6 @@ checkValidCast sourceTy destTy = BaseTy (Scalar Word8Type ) -> return () BaseTy (Scalar Float64Type) -> return () BaseTy (Scalar Float32Type) -> return () - _ -> throw TypeErr $ "Can't cast " ++ pprint sourceTy ++ " to " ++ pprint destTy typeCheckOp :: Op -> TypeM Type @@ -714,25 +713,24 @@ typeCheckOp op = case op of IOAlloc t n -> do n |: IdxRepTy declareEff (State, Just theWorld) - return $ PtrTy (AllocatedPtr, Heap CPU, t) + return $ PtrTy (Heap CPU, t) IOFree ptr -> do PtrTy _ <- typeCheck ptr declareEff (State, Just theWorld) return UnitTy PtrOffset arr off -> do - PtrTy (_, a, b) <- typeCheck arr + PtrTy (a, b) <- typeCheck arr off |: IdxRepTy - return $ PtrTy (DerivedPtr, a, b) + return $ PtrTy (a, b) PtrLoad ptr -> do - PtrTy (_, _, t) <- typeCheck ptr + PtrTy (_, t) <- typeCheck ptr declareEff (State, Just theWorld) return $ BaseTy t PtrStore ptr val -> do - PtrTy (_, _, t) <- typeCheck ptr + PtrTy (_, t) <- typeCheck ptr val |: BaseTy t declareEff (State, Just theWorld) return $ UnitTy - MakePtrType ty -> ty|:TyKind >> return TyKind SliceOffset s i -> do TC (IndexSlice n l) <- typeCheck s l' <- typeCheck i @@ -1069,9 +1067,4 @@ typeReduceExpr scope expr = case expr of typeReduceBlock scope $ subst (b@>x', scope) block TypeCon con xs -> Just $ TypeCon con $ xs ++ [x'] _ -> Nothing - Op (MakePtrType ty) -> do - let ty' = typeReduceAtom scope ty - case ty' of - BaseTy b -> return $ PtrTy (AllocatedPtr, Heap CPU, b) - _ -> Nothing _ -> Nothing diff --git a/src/lib/dexrt.cpp b/src/lib/dexrt.cpp index 77ef1a626..389455c40 100644 --- a/src/lib/dexrt.cpp +++ b/src/lib/dexrt.cpp @@ -188,14 +188,6 @@ void doubleVec(char **resultPtr, int32_t n, float* xs) { *result2Ptr = p2; } -int32_t writeToStdErr(int32_t numBytes, char* bytes) { - fwrite(bytes, 1, (size_t) numBytes, stderr); - fprintf(stderr, "\n"); - fflush(stderr); - return 0; -} - - void encodePNG(char **resultPtr, int8_t* pixels, int32_t width, int32_t height) { png_image img; memset(&img, 0, sizeof(img)); From 5eb53d7bd83e87c431ee0cf87d8698caa7b91e18 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sat, 26 Dec 2020 22:08:39 +0100 Subject: [PATCH 030/105] Update MCMC example --- examples/mcmc.dx | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/examples/mcmc.dx b/examples/mcmc.dx index a33dfc3ca..9607a3a7b 100644 --- a/examples/mcmc.dx +++ b/examples/mcmc.dx @@ -1,5 +1,6 @@ +'# Markov Chain Monte Carlo --- === General MCMC utilities === +'## General MCMC utilities include "plot.dx" @@ -24,8 +25,8 @@ def propose (proposal : a) (k : Key) : a = - acceptProb = exp (logDensity proposal) / exp (logDensity cur) - select (bern acceptProb k) proposal cur + accept = logDensity proposal > (logDensity cur + log (rand k)) + select accept proposal cur def meanAndCovariance (n:Type) ?-> (d:Type) ?-> (xs:n=>d=>Float) : (d=>Float & d=>d=>Float) = @@ -35,7 +36,7 @@ def meanAndCovariance (n:Type) ?-> (d:Type) ?-> (xs.j.i - xsMean.i ) ) / IToF (size n - 1) (xsMean, xsCov) --- === Metropolis-Hastings implementation === +'## Metropolis-Hastings implementation MHParams : Type = Float -- step size @@ -49,7 +50,7 @@ def mhStep proposal = x + stepSize .* randnVec k1 propose logProb x proposal k2 --- === HMC implementation === +'## HMC implementation HMCParams : Type = (Int & Float) -- leapfrog steps, step size @@ -79,28 +80,31 @@ def hmcStep proposal = leapfrogIntegrate params logProb (x, p) fst $ propose hamiltonian (x, p) proposal k2 --- === test it out === +'## Test it out + +'Generate samples from a multivariate normal distribution N([1.5, 2.5], [[1., 0.], [0., 0.05]]). def myLogProb (x:(Fin 2)=>Float) : LogProb = x' = x - [1.5, 2.5] neg $ 0.5 * inner x' [[1.,0.],[0.,20.]] x' -hmcParams = (10, 0.1) -mhParams = 0.1 -numSamples = 500 +numSamples = 10000 k0 = newKey 1 +mhParams = 0.1 mhSamples = runChain randnVec (mhStep mhParams myLogProb) numSamples k0 -hmcSamples = runChain randnVec (hmcStep hmcParams myLogProb) numSamples k0 :p meanAndCovariance mhSamples -> ([0.64555484, 2.4140575], [[0.38236195, 0.17941256], [0.17941256, 0.22895703]]) +> ([1.5165595, 2.493105], [[1.0373966, 1.1821209e-2], [1.1821209e-2, 5.3775612e-2]]) :html showPlot $ yPlot (for i. mhSamples.i.(0@_)) > +hmcParams = (10, 0.1) +hmcSamples = runChain randnVec (hmcStep hmcParams myLogProb) numSamples k0 + :p meanAndCovariance hmcSamples -> ([1.4468338, 2.4944723], [[1.065676, 2.047594e-2], [2.047594e-2, 5.288498e-2]]) +> ([1.50457, 2.5000212], [[0.9738671, 3.4229287e-3], [3.4229287e-3, 5.0585825e-2]]) :html showPlot $ yPlot (for i. hmcSamples.i.(0@_)) > From 70a645f15c85e8ae100d41f935c2087ec2d087f2 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sat, 26 Dec 2020 17:10:49 -0500 Subject: [PATCH 031/105] Fix `stack build` on macOS again. (#381) Disable `-Wnonportable-include-path`. I am not sure what earlier change surfaced this macOS build error or why macOS CI does not encounter it. --- dex.cabal | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dex.cabal b/dex.cabal index 16a3b64a7..3cf1683c6 100644 --- a/dex.cabal +++ b/dex.cabal @@ -99,7 +99,7 @@ foreign-library Dex hs-source-dirs: src/ c-sources: src/Dex/Foreign/rts.c cc-options: -std=c11 -fPIC - ghc-options: -Wall -fPIC + ghc-options: -Wall -fPIC -optP-Wno-nonportable-include-path default-language: Haskell2010 default-extensions: TypeApplications, ScopedTypeVariables, LambdaCase if flag(optimized) From 55df89781c9d47e14df6961f89f8f2a63e734faf Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Sat, 26 Dec 2020 17:48:59 -0500 Subject: [PATCH 032/105] Tweak MCMC example to fix quine test and make it faster to run. --- examples/mcmc.dx | 12 ++++++++---- lib/prelude.dx | 2 ++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/examples/mcmc.dx b/examples/mcmc.dx index 9607a3a7b..0df6c857d 100644 --- a/examples/mcmc.dx +++ b/examples/mcmc.dx @@ -95,16 +95,20 @@ mhParams = 0.1 mhSamples = runChain randnVec (mhStep mhParams myLogProb) numSamples k0 :p meanAndCovariance mhSamples -> ([1.5165595, 2.493105], [[1.0373966, 1.1821209e-2], [1.1821209e-2, 5.3775612e-2]]) +> ( [1.5165595, 2.493105] +> , [[1.0373967, 1.1820998e-2], [1.1820998e-2, 5.377563e-2]] ) -:html showPlot $ yPlot (for i. mhSamples.i.(0@_)) +:html showPlot $ yPlot $ + slice (map head mhSamples) 0 (Fin 1000) > hmcParams = (10, 0.1) hmcSamples = runChain randnVec (hmcStep hmcParams myLogProb) numSamples k0 :p meanAndCovariance hmcSamples -> ([1.50457, 2.5000212], [[0.9738671, 3.4229287e-3], [3.4229287e-3, 5.0585825e-2]]) +> ( [1.5045699, 2.5000212] +> , [[0.97386724, 3.422921e-3], [3.422921e-3, 5.058581e-2]] ) -:html showPlot $ yPlot (for i. hmcSamples.i.(0@_)) +:html showPlot $ yPlot $ + slice (map head hmcSamples) 0 (Fin 1000) > diff --git a/lib/prelude.dx b/lib/prelude.dx index b98b0a548..80d9dfea0 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -1105,6 +1105,8 @@ def (@) (i:Int) (n:Type) : n = fromOrdinal n i def slice (xs:n=>a) (start:Int) (m:Type) : m=>a = for i. xs.(fromOrdinal _ (ordinal i + start)) +def head (xs:n=>a) : a = xs.(0@_) + def tail (n:Type) ?-> (xs:n=>a) (start:Int) : List a = numElts = size n - start toList $ slice xs start (Fin numElts) From e15835ecf2e20991f4854aa245a86eefee130639 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sun, 27 Dec 2020 02:53:58 +0100 Subject: [PATCH 033/105] Extend rejection sampling example --- examples/rejection-sampler.dx | 82 ++++++++++++++++++++++++++++------- 1 file changed, 67 insertions(+), 15 deletions(-) diff --git a/examples/rejection-sampler.dx b/examples/rejection-sampler.dx index 503dcfb69..de60b7b9a 100644 --- a/examples/rejection-sampler.dx +++ b/examples/rejection-sampler.dx @@ -1,34 +1,86 @@ +'# Rejection sampler of a Binomial distribution + +'We implement rejection sampling from a Binomial distribution using a uniform proposal. def rejectionSample (try: Key -> Maybe a) (k:Key) : a = - ans = fst $ withState 0 \i. + fromJust $ fst $ withState 0 \i. snd $ withState Nothing \sample. while (\(). isNothing (get sample)) \(). i := get i + 1 sample := try $ hash k (get i) - case ans of Just sample -> sample Prob = Float LogProb = Float -def binomialSample (n:Int) (p:Prob) (k:Key) : Int = todo - +-- log probability density of a Binomial distribution def logBinomialProb (n:Int) (p:Prob) (counts:Int) : LogProb = pSuccess = log p * IToF counts pFailure = log1p (-p) * IToF (n - counts) normConst = (lbeta (1. + IToF counts) (1. + IToF n - IToF counts) + - log (1. + IToF n)) + log1p (IToF n)) pSuccess + pFailure - normConst -def binomialProb (n:Int) (p:Prob) (count:Int) : Prob = - exp $ logBinomialProb n p count - def trySampleBinomial (n:Int) (p:Prob) (k:Key) : Maybe Int = - (k1, k2) = splitKey k + [k1, k2] = splitKey k proposal = FToI $ floor $ rand k1 * IToF (n + 1) - acceptance = rand k2 < binomialProb n p proposal - case proposal < (n + 1) && acceptance of - True -> Just proposal - False -> Nothing + if proposal > n + then Nothing + else + acceptance = log (rand k2) < logBinomialProb n p proposal + if acceptance + then Just proposal + else Nothing + +'## Example + +'We test the implementation by sampling from a Binomial distribution with 10 trials and success probability 0.4. + +-- parameters +n = 10 +p = 0.4 +numSamples = 5000 +k0 = newKey 0 + +rejectionSamples = randVec numSamples (rejectionSample $ trySampleBinomial n p) k0 + +:p slice rejectionSamples 0 $ Fin 10 +> [4, 2, 5, 4, 6, 7, 3, 6, 4, 3] + +'The Binomial distribution has mean 4 and variance 2.4. + +def meanAndVariance (xs:n=>Float) : (Float&Float) = (mean xs, sq $ std xs) + +:p meanAndVariance $ map IToF rejectionSamples +> (3.9933999, 2.3585567) + +'## Alternative: Inversion sampling + +'Alternatively, we can use inversion sampling. + +def binomialSample (n:Int) (p:Prob) (k:Key) : Int = + m = n + 1 + logprobs = for i:(Fin m). logBinomialProb n p $ ordinal i + ordinal $ categorical logprobs k + +inversionSamples = randVec numSamples (binomialSample n p) k0 + +:p slice inversionSamples 0 $ Fin 10 +> [6, 7, 6, 5, 3, 2, 4, 4, 3, 4] + +:p meanAndVariance $ map IToF inversionSamples +> (3.9977999, 2.4097958) + +'The following variant is guaranteed to evaluate the CDF only once. + +def binomialBatch (n:Int) (p:Prob) (k:Key) : a => Int = + m = n + 1 + logprobs = for i:(Fin m). logBinomialProb n p $ ordinal i + map ordinal $ categoricalBatch logprobs k + +inversionBatchSamples = (binomialBatch n p k0) : Fin numSamples => Int + +:p slice inversionBatchSamples 0 $ Fin 10 +> [6, 7, 6, 5, 3, 2, 4, 4, 3, 4] -:p randVec 10 (rejectionSample (trySampleBinomial 10 0.5)) (newKey 0) -> [4, 2, 5, 4, 6, 7, 3, 6, 6, 3] +:p meanAndVariance $ map IToF inversionBatchSamples +> (3.9977999, 2.4097958) From 49c04114ac02c78c05add1007ab0a88ba9eb40ba Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Mon, 28 Dec 2020 09:09:21 -0500 Subject: [PATCH 034/105] Treat effect rows as sets of effects. Fixes #267. --- src/lib/Autodiff.hs | 11 +++--- src/lib/Embed.hs | 5 +-- src/lib/Inference.hs | 22 ++++++------ src/lib/Optimize.hs | 2 +- src/lib/PPrint.hs | 4 +-- src/lib/Parallelize.hs | 7 ++-- src/lib/Parser.hs | 3 +- src/lib/Syntax.hs | 54 ++++++++++++---------------- src/lib/Type.hs | 80 +++++++++++++++++++++--------------------- tests/monad-tests.dx | 7 ++++ 10 files changed, 97 insertions(+), 98 deletions(-) diff --git a/src/lib/Autodiff.hs b/src/lib/Autodiff.hs index 27eb99967..f47c88de1 100644 --- a/src/lib/Autodiff.hs +++ b/src/lib/Autodiff.hs @@ -429,8 +429,10 @@ tangentFunAsLambda m = do -- TODO: Only bind tangents for free variables? let activeVarBinders = map (Bind . fmap (tangentRefRegion regionMap)) $ envAsVars activeVars buildNestedLam PureArrow activeVarBinders $ \activeVarArgs -> - buildLam (Ignore UnitTy) (PlainArrow $ EffectRow effs' Nothing) $ \_ -> - runReaderT tanFun $ TangentEnv (newEnv (envNames activeVars) activeVarArgs) hVarNames (newEnv rematList $ fmap Var rematArgs) + buildLam (Ignore UnitTy) (PlainArrow $ EffectRow (S.fromList effs') Nothing) $ \_ -> + runReaderT tanFun $ TangentEnv + (newEnv (envNames activeVars) activeVarArgs) hVarNames + (newEnv rematList $ fmap Var rematArgs) case rematList of [] -> return tanLam _ -> deShadow tanLam <$> getScope @@ -753,11 +755,12 @@ freeLinVars x = do isLin :: HasVars a => a -> TransposeM Bool isLin x = not . null <$> freeLinVars x -isLinEff :: EffectSummary -> TransposeM Bool -isLinEff effs = do +isLinEff :: EffectRow -> TransposeM Bool +isLinEff (EffectRow effs Nothing) = do regions <- asks linRegions return $ not $ null $ effRegions `envIntersect` regions where effRegions = newEnv (S.map snd effs) (repeat ()) +isLinEff _ = error "Can't transpose polymorphic effects" emitCTToRef :: Maybe Atom -> Atom -> TransposeM () emitCTToRef mref ct = case mref of diff --git a/src/lib/Embed.hs b/src/lib/Embed.hs index 92e43df78..1dc88144e 100644 --- a/src/lib/Embed.hs +++ b/src/lib/Embed.hs @@ -331,10 +331,7 @@ ptrOffset x i = emitOp $ PtrOffset x i unsafePtrLoad :: MonadEmbed m => Atom -> m Atom unsafePtrLoad x = emit $ Hof $ RunIO $ Lam $ Abs (Ignore UnitTy) $ - (PlainArrow justIOEff, Block Empty (Op (PtrLoad x))) - where - justIOEff :: EffectRow - justIOEff = EffectRow [(State, theWorld)] Nothing + (PlainArrow (oneEffect ioEffect), Block Empty (Op (PtrLoad x))) ptrLoad :: MonadEmbed m => Atom -> m Atom ptrLoad x = emitOp $ PtrLoad x diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index 9912b3b3d..05178f075 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -19,6 +19,7 @@ import Data.Functor import qualified Data.List.NonEmpty as NE import qualified Data.Map.Strict as M import Data.String (fromString) +import qualified Data.Set as S import Data.Text.Prettyprint.Doc import Syntax @@ -394,7 +395,7 @@ checkULam (p, ann) body piTy = do checkUEff :: EffectRow -> UInferM EffectRow checkUEff (EffectRow effs t) = do - effs' <- forM effs $ \(effName, region) -> do + effs' <- liftM S.fromList $ forM (toList effs) $ \(effName, region) -> do (Var (v:>TyKind)) <- lookupSourceVar (region:>()) return (effName, v) t' <- forM t $ \tv -> lookupVarName EffKind tv @@ -794,7 +795,7 @@ freshType EffKind = Eff <$> freshEff freshType k = Var . (:>k) <$> freshInferenceName k freshEff :: (MonadError Err m, MonadCat SolverEnv m) => m EffectRow -freshEff = EffectRow [] . Just <$> freshInferenceName EffKind +freshEff = EffectRow mempty . Just <$> freshInferenceName EffKind constrainEq :: (MonadCat SolverEnv m, MonadError Err m) => Type -> Type -> m () @@ -877,19 +878,16 @@ unifyEff r1 r2 = do vs <- looks solverVars case (r1', r2') of _ | r1' == r2' -> return () - (r, EffectRow [] (Just v)) | v `isin` vs -> bindQ (v:>EffKind) (Eff r) - (EffectRow [] (Just v), r) | v `isin` vs -> bindQ (v:>EffKind) (Eff r) - (EffectRow effs1@(_:_) t1, EffectRow effs2@(_:_) t2) -> do - let extras1 = effs1 `setDiff` effs2 - let extras2 = effs2 `setDiff` effs1 + (r, EffectRow effs (Just v)) | S.null effs && v `isin` vs -> bindQ (v:>EffKind) (Eff r) + (EffectRow effs (Just v), r) | S.null effs && v `isin` vs -> bindQ (v:>EffKind) (Eff r) + (EffectRow effs1 t1, EffectRow effs2 t2) | not (S.null effs1 || S.null effs2) -> do + let extras1 = effs1 `S.difference` effs2 + let extras2 = effs2 `S.difference` effs1 newRow <- freshEff - unifyEff (EffectRow [] t1) (extendEffRow extras2 newRow) - unifyEff (extendEffRow extras1 newRow) (EffectRow [] t2) + unifyEff (EffectRow mempty t1) (extendEffRow extras2 newRow) + unifyEff (extendEffRow extras1 newRow) (EffectRow mempty t2) _ -> throw TypeErr "" -setDiff :: Eq a => [a] -> [a] -> [a] -setDiff xs ys = filter (`notElem` ys) xs - bindQ :: (MonadCat SolverEnv m, MonadError Err m) => Var -> Type -> m () bindQ v t | v `occursIn` t = throw TypeErr $ "Occurs check failure: " ++ pprint (v, t) | hasSkolems t = throw TypeErr "Can't unify with skolem vars" diff --git a/src/lib/Optimize.hs b/src/lib/Optimize.hs index dd88dac68..92158b51c 100644 --- a/src/lib/Optimize.hs +++ b/src/lib/Optimize.hs @@ -121,7 +121,7 @@ inlineTraverseExpr expr = case expr of -- optimization will waste a bunch of memory by keeping the large intermediates alive. LamVal ib block@(Block Empty (Atom _)) -> return $ Atom $ TabVal ib block -- Pure broadcasts - LamVal ib@(Ignore _) block | blockEffs block == NoEffects -> do + LamVal ib@(Ignore _) block | blockEffs block == Pure -> do result <- dropSub $ evalBlockE inlineTraversalDef block Atom <$> buildLam ib TabArrow (\_ -> return $ result) _ -> return $ Hof $ For d newBody diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index d37847ac1..4236ac959 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -659,9 +659,9 @@ spaced :: (Foldable f, Pretty a) => f a -> Doc ann spaced xs = hsep $ map p $ toList xs instance Pretty EffectRow where - pretty (EffectRow [] Nothing) = mempty + pretty Pure = mempty pretty (EffectRow effs tailVar) = - braces $ hsep (punctuate "," (fmap prettyEff effs)) <> tailStr + braces $ hsep (punctuate "," (fmap prettyEff (toList effs))) <> tailStr where prettyEff (effName, region) = p effName <+> p region tailStr = case tailVar of diff --git a/src/lib/Parallelize.hs b/src/lib/Parallelize.hs index 4595d1d4c..445591f7a 100644 --- a/src/lib/Parallelize.hs +++ b/src/lib/Parallelize.hs @@ -70,9 +70,10 @@ parallelTraverseExpr expr = case expr of -- TODO: functionEffs is an overapproximation of the effects that really appear inside refs <- gets activeAccs let allowedRegions = foldMap (\(varType -> RefTy (Var reg) _) -> reg @> ()) refs - bodyEffs <- substEmbedR $ functionEffs fbody - let onlyAllowedEffects = flip all bodyEffs $ \(eff, reg) -> eff == Writer && reg `isin` allowedRegions - case onlyAllowedEffects of + (EffectRow bodyEffs t) <- substEmbedR $ functionEffs fbody + let onlyAllowedEffects = flip all (toList bodyEffs) $ \(eff, reg) -> + eff == Writer && reg `isin` allowedRegions + case t == Nothing && onlyAllowedEffects of True -> do b' <- substEmbedR b liftM Atom $ runLoopM $ withLoopBinder b' $ buildParallelBlock $ asABlock body diff --git a/src/lib/Parser.hs b/src/lib/Parser.hs index 2a7eac7b3..9cbff0e0f 100644 --- a/src/lib/Parser.hs +++ b/src/lib/Parser.hs @@ -18,6 +18,7 @@ import Data.Functor import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.Map.Strict as M import Data.Void +import qualified Data.Set as S import Data.String (fromString) import qualified Text.Megaparsec.Char.Lexer as L import qualified Text.Megaparsec.Debug @@ -490,7 +491,7 @@ effects = braces someEffects <|> return Pure someEffects = do effs <- liftM2 (,) effectName (lowerName <|> upperName) `sepBy` sym "," v <- optional $ symbol "|" >> lowerName - return $ EffectRow effs v + return $ EffectRow (S.fromList effs) v effectName :: Parser EffectName effectName = (keyWord WriteKW $> Writer) diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index b6d9dc01b..510f9fc8d 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -19,8 +19,7 @@ module Syntax ( Effect, EffectName (..), EffectRow (..), ClassName (..), TyQual (..), SrcPos, Var, Binder, Block (..), Decl (..), Expr (..), Atom (..), ArrowP (..), Arrow, PrimTC (..), Abs (..), - PrimExpr (..), PrimCon (..), LitVal (..), - PrimEffect (..), PrimOp (..), EffectSummary, pattern NoEffects, + PrimExpr (..), PrimCon (..), LitVal (..), PrimEffect (..), PrimOp (..), PrimHof (..), LamExpr, PiType, WithSrc (..), srcPos, LetAnn (..), BinOp (..), UnOp (..), CmpOp (..), SourceBlock (..), ReachedEOF, SourceBlock' (..), SubstEnv, ScopedSubstEnv, @@ -70,7 +69,6 @@ import Control.Monad.Identity import Control.Monad.Writer hiding (Alt) import Control.Monad.Except hiding (Except) import qualified Data.ByteString.Char8 as B -import Data.List (sort) import qualified Data.List.NonEmpty as NE import qualified Data.Set as S import Data.Store (Store) @@ -425,32 +423,13 @@ showPrimName prim = primNameToStr $ fmap (const ()) prim -- === effects === type Effect = (EffectName, Name) -data EffectRow = EffectRow [Effect] (Maybe Name) - deriving (Show, Generic) +data EffectRow = EffectRow (S.Set Effect) (Maybe Name) + deriving (Show, Eq, Generic) data EffectName = Reader | Writer | State deriving (Show, Eq, Ord, Generic) -type EffectSummary = S.Set Effect - -instance HasVars EffectSummary where - freeVars effs = foldMap (\(_, reg) -> reg @> (TyKind, UnknownBinder)) effs - -instance Subst EffectSummary where - subst (env, _) effs = S.map substEff effs - where - substEff (eff, name) = case envLookup env name of - Just ~(Var (name':>_)) -> (eff, name') - Nothing -> (eff, name) - pattern Pure :: EffectRow -pattern Pure = EffectRow [] Nothing - -pattern NoEffects :: EffectSummary -pattern NoEffects <- ((S.null) -> True) - where NoEffects = mempty - -instance Eq EffectRow where - EffectRow effs t == EffectRow effs' t' = - sort effs == sort effs' && t == t' +pattern Pure <- ((\(EffectRow effs t) -> (S.null effs, t)) -> (True, Nothing)) + where Pure = mempty theWorld :: Name theWorld = GlobalName "World" @@ -466,6 +445,19 @@ initTopEnv = fold [v @> (ty, LamBound ImplicitArrow) | (v, ty) <- hostPtrTy :: BaseType -> BaseType hostPtrTy ty = PtrType (Heap CPU, ty) +instance Semigroup EffectRow where + EffectRow effs t <> EffectRow effs' t' = + EffectRow (S.union effs effs') newTail + where + newTail = case (t, t') of + (Nothing, effTail) -> effTail + (effTail, Nothing) -> effTail + _ | t == t' -> t + | otherwise -> error "Can't combine effect rows with mismatched tails" + +instance Monoid EffectRow where + mempty = EffectRow mempty Nothing + -- === top-level constructs === data SourceBlock = SourceBlock @@ -1139,7 +1131,7 @@ instance HasVars EffectRow where <> foldMap (\v -> v@>(EffKind, UnknownBinder)) t instance Subst EffectRow where subst (env, _) (EffectRow row t) = extendEffRow - (fmap (\(effName, v) -> (effName, substName env v)) row) + (S.map (\(effName, v) -> (effName, substName env v)) row) (substEffTail env t) instance HasVars BinderInfo where @@ -1174,10 +1166,10 @@ instance Subst (ExtLabeledItems Type Name) where prefixExtLabeledItems (subst env items) (substExtLabeledItemsTail env' rest) substEffTail :: SubstEnv -> Maybe Name -> EffectRow -substEffTail _ Nothing = EffectRow [] Nothing +substEffTail _ Nothing = EffectRow mempty Nothing substEffTail env (Just v) = case envLookup env (v:>()) of - Nothing -> EffectRow [] (Just v) - Just (Var (v':>_)) -> EffectRow [] (Just v') + Nothing -> EffectRow mempty (Just v) + Just (Var (v':>_)) -> EffectRow mempty (Just v') Just (Eff r) -> r _ -> error "Not a valid effect substitution" @@ -1187,7 +1179,7 @@ substName env v = case envLookup env (v:>()) of Just (Var (v':>_)) -> v' _ -> error "Should only substitute with a name" -extendEffRow :: [Effect] -> EffectRow -> EffectRow +extendEffRow :: S.Set Effect -> EffectRow -> EffectRow extendEffRow effs (EffectRow effs' t) = EffectRow (effs <> effs') t substExtLabeledItemsTail :: SubstEnv -> Maybe Name -> ExtLabeledItems Type Name diff --git a/src/lib/Type.hs b/src/lib/Type.hs index e177709b6..d687feff2 100644 --- a/src/lib/Type.hs +++ b/src/lib/Type.hs @@ -13,7 +13,7 @@ module Type ( isPure, functionEffs, exprEffs, blockEffs, extendEffect, isData, checkBinOp, checkUnOp, checkIntBaseType, checkFloatBaseType, withBinder, isDependent, indexSetConcreteSize, checkNoShadow, traceCheckM, traceCheck, projectLength, - typeReduceBlock, typeReduceAtom, typeReduceExpr) where + typeReduceBlock, typeReduceAtom, typeReduceExpr, oneEffect, ioEffect) where import Prelude hiding (pi) import Control.Monad @@ -257,7 +257,7 @@ checkApp fTy x = do return resultTy -- TODO: replace with something more precise (this is too cautious) -blockEffs :: Block -> EffectSummary +blockEffs :: Block -> EffectRow blockEffs (Block decls result) = foldMap declEffs decls <> exprEffs result where declEffs (Let _ _ expr) = exprEffs expr @@ -265,23 +265,23 @@ blockEffs (Block decls result) = isPure :: Expr -> Bool isPure expr = exprEffs expr == mempty -exprEffs :: Expr -> EffectSummary +exprEffs :: Expr -> EffectRow exprEffs expr = case expr of - Atom _ -> NoEffects + Atom _ -> Pure App f _ -> functionEffs f Op op -> case op of PrimEffect ref m -> case m of - MGet -> S.singleton (State, h) - MPut _ -> S.singleton (State, h) - MAsk -> S.singleton (Reader, h) - MTell _ -> S.singleton (Writer, h) + MGet -> oneEffect (State, h) + MPut _ -> oneEffect (State, h) + MAsk -> oneEffect (Reader, h) + MTell _ -> oneEffect (Writer, h) where RefTy (Var (h:>_)) _ = getType ref - IOAlloc _ _ -> S.singleton (State, theWorld) - IOFree _ -> S.singleton (State, theWorld) - PtrLoad _ -> S.singleton (State, theWorld) - PtrStore _ _ -> S.singleton (State, theWorld) - FFICall _ _ _ -> S.singleton (State, theWorld) - _ -> NoEffects + IOAlloc _ _ -> oneEffect ioEffect + IOFree _ -> oneEffect ioEffect + PtrLoad _ -> oneEffect ioEffect + PtrStore _ _ -> oneEffect ioEffect + FFICall _ _ _ -> oneEffect ioEffect + _ -> Pure Hof hof -> case hof of For _ f -> functionEffs f Tile _ _ _ -> error "not implemented" @@ -292,17 +292,16 @@ exprEffs expr = case expr of RunWriter f -> handleRunner Writer f RunState _ f -> handleRunner State f PTileReduce _ _ -> mempty - RunIO ~(Lam (Abs _ (PlainArrow (EffectRow effs Nothing), _))) -> - S.delete (State, theWorld) $ S.fromList effs + RunIO ~(Lam (Abs _ (PlainArrow (EffectRow effs t), _))) -> + EffectRow (S.delete ioEffect effs) t Case _ alts _ -> foldMap (\(Abs _ block) -> blockEffs block) alts where - handleRunner effName ~(BinaryFunVal (Bind (h:>_)) _ (EffectRow effs Nothing) _) = - S.delete (effName, h) $ S.fromList effs + handleRunner effName ~(BinaryFunVal (Bind (h:>_)) _ (EffectRow effs t) _) = + EffectRow (S.delete (effName, h) effs) t -functionEffs :: Atom -> EffectSummary +functionEffs :: Atom -> EffectRow functionEffs f = case getType f of - Pi (Abs _ (arr, _)) -> S.fromList effs - where EffectRow effs Nothing = arrowEff arr + Pi (Abs _ (arr, _)) -> arrowEff arr _ -> error "Expected a function type" instance HasType Block where @@ -481,10 +480,8 @@ checkEffRow (EffectRow effs effTail) = do Nothing -> throw CompilerErr $ "Lookup failed: " ++ pprint v Just (ty, _) -> assertEq EffKind ty "Effect var" -declareEff :: (EffectName, Maybe Name) -> TypeM () -declareEff (effName, Just h) = - declareEffs $ EffectRow [(effName, h)] Nothing -declareEff (_, Nothing) = return () +declareEff :: Effect -> TypeM () +declareEff eff = declareEffs $ oneEffect eff declareEffs :: EffectRow -> TypeM () declareEffs effs = checkWithEnv $ \(_, allowedEffects) -> @@ -501,7 +498,13 @@ checkExtends allowed (EffectRow effs effTail) = do "\nAllowed: " ++ pprint allowed extendEffect :: Effect -> EffectRow -> EffectRow -extendEffect eff (EffectRow effs t) = EffectRow (eff:effs) t +extendEffect eff (EffectRow effs t) = EffectRow (S.insert eff effs) t + +oneEffect :: Effect -> EffectRow +oneEffect eff = EffectRow (S.singleton eff) Nothing + +ioEffect :: Effect +ioEffect = (State, theWorld) -- === labeled row types === @@ -682,7 +685,7 @@ typeCheckOp op = case op of BaseTy _ -> return () _ -> throw TypeErr $ "All arguments of FFI calls have to be " ++ "fixed-width base types, but got: " ++ pprint argTy - declareEff (State, Just theWorld) + declareEff ioEffect return ansTy Inject i -> do TC tc <- typeCheck i @@ -691,15 +694,12 @@ typeCheckOp op = case op of ParIndexRange ty _ _ -> return ty _ -> throw TypeErr $ "Unsupported inject argument type: " ++ pprint (TC tc) PrimEffect ref m -> do - TC (RefType h s) <- typeCheck ref - let h'' = case h of - Just ~(Var (h':>TyKind)) -> Just h' - Nothing -> Nothing + TC (RefType ~(Just (Var (h':>TyKind))) s) <- typeCheck ref case m of - MGet -> declareEff (State , h'') $> s - MPut x -> x|:s >> declareEff (State , h'') $> UnitTy - MAsk -> declareEff (Reader, h'') $> s - MTell x -> x|:s >> declareEff (Writer, h'') $> UnitTy + MGet -> declareEff (State , h') $> s + MPut x -> x|:s >> declareEff (State , h') $> UnitTy + MAsk -> declareEff (Reader, h') $> s + MTell x -> x|:s >> declareEff (Writer, h') $> UnitTy IndexRef ref i -> do RefTy h (TabTyAbs a) <- typeCheck ref i |: absArgType a @@ -712,11 +712,11 @@ typeCheckOp op = case op of return $ RefTy h b IOAlloc t n -> do n |: IdxRepTy - declareEff (State, Just theWorld) + declareEff ioEffect return $ PtrTy (Heap CPU, t) IOFree ptr -> do PtrTy _ <- typeCheck ptr - declareEff (State, Just theWorld) + declareEff ioEffect return UnitTy PtrOffset arr off -> do PtrTy (a, b) <- typeCheck arr @@ -724,12 +724,12 @@ typeCheckOp op = case op of return $ PtrTy (a, b) PtrLoad ptr -> do PtrTy (_, t) <- typeCheck ptr - declareEff (State, Just theWorld) + declareEff ioEffect return $ BaseTy t PtrStore ptr val -> do PtrTy (_, t) <- typeCheck ptr val |: BaseTy t - declareEff (State, Just theWorld) + declareEff ioEffect return $ UnitTy SliceOffset s i -> do TC (IndexSlice n l) <- typeCheck s @@ -879,7 +879,7 @@ typeCheckHof hof = case hof of return $ PairTy resultTy stateTy RunIO f -> do FunTy _ eff resultTy <- typeCheck f - extendAllowedEffect (State, theWorld) $ declareEffs eff + extendAllowedEffect ioEffect $ declareEffs eff return resultTy checkAction :: EffectName -> Atom -> TypeM (Type, Type) diff --git a/tests/monad-tests.dx b/tests/monad-tests.dx index 4612954c0..b29840ff8 100644 --- a/tests/monad-tests.dx +++ b/tests/monad-tests.dx @@ -167,3 +167,10 @@ symmetrizeInPlace [[1.,2.],[3.,4.]] ref += [1.,2.,3.] ref += [2.,4.,5.] > [3.0, 6.0, 8.0] + +def effectsAtZero (eff:Effects)?-> (f: Int ->{|eff} Unit) : {|eff} Unit = + f 0 + () + +:p withState 0 \ref. effectsAtZero \_. ref := 1 +> ((), 1) From e0742a55fb22302e9990db8a402b17759b8a29f9 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Mon, 28 Dec 2020 17:38:45 -0500 Subject: [PATCH 035/105] Add a special case to effect inference to accept subsets of allowed effects. Previously, we couldn't use functions with a rigid effect variable in a context with extra concrete effects, like this: def liftState (f:(a -> {|eff} b)) (x:a) : {State h |eff} b = f x That meant we couldn't have an effect-polymorphic version of iter, `iter : a -> (Int -> a -> {|eff} IterResult a b) -> {|eff} b`. With this change, we can write helper functions like `liftState` that make it possible to implement `iter`. It's not ideal, but it unblock us for now. --- src/lib/Embed.hs | 1 + src/lib/Inference.hs | 17 ++++++++++++++--- src/lib/Type.hs | 2 +- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/lib/Embed.hs b/src/lib/Embed.hs index 1dc88144e..f580776bf 100644 --- a/src/lib/Embed.hs +++ b/src/lib/Embed.hs @@ -366,6 +366,7 @@ mkBinaryEffFun newEff v ty body = do buildForAnnAux :: MonadEmbed m => ForAnn -> Binder -> (Atom -> m (Atom, a)) -> m (Atom, a) buildForAnnAux ann i body = do + -- TODO: consider only tracking the effects that are actually needed. eff <- getAllowedEffects (lam, aux) <- buildLamAux i (const $ return $ PlainArrow eff) body (,aux) <$> (emit $ Hof $ For ann lam) diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index 05178f075..2c3f31e82 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -633,9 +633,20 @@ emitZonked expr = zonk expr >>= emit addEffects :: EffectRow -> UInferM () addEffects eff = do - eff' <- openEffectRow eff - allowedEffects <- getAllowedEffects - constrainEq (Eff allowedEffects) (Eff eff') + allowed <- checkAllowedUnconditionally eff + unless allowed $ do + allowedEffects <- getAllowedEffects + eff' <- openEffectRow eff + constrainEq (Eff allowedEffects) (Eff eff') + +checkAllowedUnconditionally :: EffectRow -> UInferM Bool +checkAllowedUnconditionally Pure = return True +checkAllowedUnconditionally eff = do + eff' <- zonk eff + effAllowed <- getAllowedEffects >>= zonk + return $ case checkExtends effAllowed eff' of + Left _ -> False + Right () -> True openEffectRow :: EffectRow -> UInferM EffectRow openEffectRow (EffectRow effs Nothing) = extendEffRow effs <$> freshEff diff --git a/src/lib/Type.hs b/src/lib/Type.hs index d687feff2..88b7f61c9 100644 --- a/src/lib/Type.hs +++ b/src/lib/Type.hs @@ -11,7 +11,7 @@ module Type ( getType, checkType, HasType (..), Checkable (..), litType, isPure, functionEffs, exprEffs, blockEffs, extendEffect, isData, checkBinOp, checkUnOp, - checkIntBaseType, checkFloatBaseType, withBinder, isDependent, + checkIntBaseType, checkFloatBaseType, withBinder, isDependent, checkExtends, indexSetConcreteSize, checkNoShadow, traceCheckM, traceCheck, projectLength, typeReduceBlock, typeReduceAtom, typeReduceExpr, oneEffect, ioEffect) where From afcddfea7021230a7c74cfe06868a1a2dcc89718 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Mon, 28 Dec 2020 19:15:51 -0500 Subject: [PATCH 036/105] Make a simpler iteration combinator that lets effects carry any needed state. --- examples/raytrace.dx | 102 ++++++++++++++++++++----------------------- lib/prelude.dx | 91 +++++++++++++++++++++----------------- 2 files changed, 100 insertions(+), 93 deletions(-) diff --git a/examples/raytrace.dx b/examples/raytrace.dx index 08d6758b7..cb02f838f 100644 --- a/examples/raytrace.dx +++ b/examples/raytrace.dx @@ -170,32 +170,30 @@ data RayMarchResult = HitNothing def raymarch (scene:Scene n) (ray:Ray) : RayMarchResult = - max_iters = 100 + maxIters = 100 tol = 0.01 startLength = 10.0 * tol -- trying to escape the current surface (rayOrigin, rayDir) = ray - iter (10.0 * tol) \i rayLength. - case i >= max_iters of - True -> Done HitNothing - False -> - rayPos = rayOrigin + rayLength .* rayDir - (obj, d) = sdScene scene $ rayPos - -- 0.9 ensures we come close to the surface but don't touch it - dNew = rayLength + 0.9 * d - case d < tol of - False -> Continue $ dNew - True -> - surfNorm = calcNormal obj rayPos - case positiveProjection rayDir surfNorm of - True -> - -- Oops, we didn't escape the surface we're leaving.. - -- (Is there a more standard way to do this?) - Continue dNew - False -> - -- We made it! - Done $ case obj of - PassiveObject _ surf -> HitObj (rayPos, rayDir) (surfNorm, surf) - Light _ _ radiance -> HitLight radiance + fst $ withState (10.0 * tol) \rayLength. + boundedIter maxIters HitNothing \_. + rayPos = rayOrigin + get rayLength .* rayDir + (obj, d) = sdScene scene $ rayPos + -- 0.9 ensures we come close to the surface but don't touch it + rayLength := get rayLength + 0.9 * d + case d < tol of + False -> Continue + True -> + surfNorm = calcNormal obj rayPos + case positiveProjection rayDir surfNorm of + True -> + -- Oops, we didn't escape the surface we're leaving.. + -- (Is there a more standard way to do this?) + Continue + False -> + -- We made it! + Done $ case obj of + PassiveObject _ surf -> HitObj (rayPos, rayDir) (surfNorm, surf) + Light _ _ radiance -> HitLight radiance def rayDirectRadiance (scene:Scene n) (ray:Ray) : Radiance = case raymarch scene ray of @@ -220,42 +218,38 @@ def sampleLightRadiance Light lightPos hw _ -> (dirToLight, distToLight) = directionAndLength $ lightPos + sampleSquare hw k - rayPos - case positiveProjection dirToLight surfNor of - False -> () -- light on the far side of current surface - True -> - fracSolidAngle = (relu $ dot dirToLight yHat) * sq hw / (pi * sq distToLight) - outRay = (rayPos, dirToLight) - coeff = fracSolidAngle * probReflection osurf inRay outRay - radiance += coeff .* rayDirectRadiance scene outRay - -def trace (params:Params) (scene:Scene n) (init_ray:Ray) (k:Key) : Color = - -- TODO: we ought to be able to use an accumulator here, but there's a bug + when (positiveProjection dirToLight surfNor) \(). + -- light on this far side of current surface + fracSolidAngle = (relu $ dot dirToLight yHat) * sq hw / (pi * sq distToLight) + outRay = (rayPos, dirToLight) + coeff = fracSolidAngle * probReflection osurf inRay outRay + radiance += coeff .* rayDirectRadiance scene outRay + +def trace (params:Params) (scene:Scene n) (initRay:Ray) (k:Key) : Color = noFilter = [1.0, 1.0, 1.0] - iter (noFilter, zero, init_ray) $ - \i (filter, radiance, ray). - case i >= getAt #maxBounces params of - True -> Done radiance - False -> case raymarch scene ray of - HitNothing -> Done radiance - HitLight intensity -> case i == 0 of - True -> Done intensity -- TODO: scale etc - False -> Done radiance + snd $ withAccum \radiance. + withState noFilter \filter. + withState initRay \ray. + boundedIter (getAt #maxBounces params) () \i. + case raymarch scene $ get ray of + HitNothing -> Done () + HitLight intensity -> + when (i == 0) \(). radiance += intensity -- TODO: scale etc + Done () HitObj incidentRay osurf -> [k1, k2] = splitKey $ hash k i - lightRadiance = sampleLightRadiance scene osurf incidentRay k1 - outRayHemisphere = sampleReflection osurf incidentRay k2 - newFilter = surfaceFilter filter (snd osurf) - newRadiance = radiance + applyFilter newFilter lightRadiance - Continue (newFilter, newRadiance, outRayHemisphere) - --- TODO: add number of pixels once we can hide sizes --- sensor half-width, pinhole-sensor distance, pinhole position --- (Assumes we're looking towards -z.) + lightRadiance = sampleLightRadiance scene osurf incidentRay k1 + ray := sampleReflection osurf incidentRay k2 + filter := surfaceFilter (get filter) (snd osurf) + radiance += applyFilter (get filter) lightRadiance + Continue + +-- Assumes we're looking towards -z. Camera = { numPix : Int - & pos : Position - & halfWidth : Float - & sensorDist : Float } + & pos : Position -- pinhole position + & halfWidth : Float -- sensor half-width + & sensorDist : Float } -- pinhole-sensor distance -- TODO: might be better with an anonymous dependent pair for the result def cameraRays (n:Int) (camera:Camera) : Fin n => Fin n => (Key -> Ray) = diff --git a/lib/prelude.dx b/lib/prelude.dx index 80d9dfea0..c06102a4f 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -249,6 +249,9 @@ def withState def unsafeIO (f: Unit -> {State World|eff} a) : {|eff} a = %runIO f +def unreachable (():Unit) : a = unsafeIO \(). + %throwError a + 'Type classes data Eq a:Type = MkEq (a -> a -> Bool) @@ -1014,17 +1017,50 @@ def while cond' : Unit -> {|eff} Word8 = \_. BToW8 $ cond () %while cond' body +data IterResult a:Type = + Continue + Done a + +def when (cond:Bool) (f:Unit -> {|eff} Unit) : {|eff} Unit = + if cond + then f () + else () + +-- TODO: can we improve effect inference so we don't need this? +def liftState (ref: Ref h c) (f:a -> {|eff} b) (x:a) : {State h|eff} b = + f x + +-- A little iteration combinator +def iter (body: Int -> {|eff} IterResult a) : {|eff} a = + result = snd $ withState Nothing \resultRef. withState 0 \i. + while (\(). isNothing $ get resultRef) \(). + case liftState resultRef (liftState i body) (get i) of + Continue -> i := get i + 1 + Done result -> resultRef := Just result + case result of + Just ans -> ans + Nothing -> unreachable () + +def boundedIter (maxIters:Int) (fallback:a) + (body: Int -> {|eff} IterResult a) : {|eff} a = + iter \i. + if i >= maxIters + then Done fallback + else body i + def fread (stream:Stream ReadMode) : {State World} String = (MkStream stream') = stream -- TODO: allow reading longer files! n = 4096 withAlloc n \ptr:(Ptr Char). withDynamicBuffer \buf. - while (\(). + iter \_. (MkPtr rawPtr) = ptr numRead = I64ToI $ %ffi fread Int64 rawPtr (IToI64 1) (IToI64 n) stream' extendDynBuffer buf $ stringFromCharPtr numRead ptr - numRead == n) (\(). ()) + if numRead == n + then Continue + else Done () loadDynBuffer buf def deleteFile (f:FilePath) : {State World} Unit = @@ -1134,45 +1170,22 @@ instance finArb : n:Int ?-> Arbitrary (Fin n) where 'Control flow -data IterResult a:Type b:Type = - Continue a - Done b - --- A little iteration combinator --- TODO: allow effects (bug #267) -def iter (init:a) (body: Int -> a -> IterResult a b) : b = - result = snd $ withState Nothing \resultRef. - withState init \carryRef. - withState 0 \i. - while (\(). isNothing (get resultRef)) \(). - case body (get i) (get carryRef) of - Continue carry -> - i := get i + 1 - carryRef := carry - Done result -> - resultRef := Just result - case result of - Just ans -> ans - Nothing -> error "should be unreachable" - -- returns the highest index `i` such that `xs.i <= x` def searchSorted (_:Ord a) ?=> (xs:n=>a) (x:a) : Maybe n = - case size n == 0 of - True -> Nothing - False -> case x < xs.(fromOrdinal _ 0) of - True -> Nothing - False -> - iter (0, size n) \_ (low, high). - numLeft = high - low - case numLeft == 1 of - True -> Done $ Just $ fromOrdinal _ low - False -> - centerIx = low + idiv (high - low) 2 - case x < xs.(fromOrdinal _ centerIx) of - True -> Continue (low, centerIx) - False -> Continue (centerIx, high) - - + if size n == 0 + then Nothing + else if x < xs.(fromOrdinal _ 0) + then Nothing + else fst $ withState 0 \low. fst $ withState (size n) \high. iter \_. + numLeft = get high - get low + if numLeft == 1 + then Done $ Just $ fromOrdinal _ $ get low + else + centerIx = get low + idiv numLeft 2 + if x < xs.(fromOrdinal _ centerIx) + then high := centerIx + else low := centerIx + Continue 'min / max etc From e13897148611fdc6fc78b9512710ae874c324068 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Mon, 28 Dec 2020 20:27:27 -0500 Subject: [PATCH 037/105] Add `do ` as syntactic sugar for `\(). `. --- examples/ode-integrator.dx | 2 +- examples/raytrace.dx | 4 ++-- examples/rejection-sampler.dx | 2 +- lib/prelude.dx | 20 ++++++++++---------- misc/dex.el | 2 +- src/lib/Parser.hs | 12 ++++++++++-- 6 files changed, 25 insertions(+), 17 deletions(-) diff --git a/examples/ode-integrator.dx b/examples/ode-integrator.dx index dc784b177..6c9a4e2ca 100644 --- a/examples/ode-integrator.dx +++ b/examples/ode-integrator.dx @@ -135,7 +135,7 @@ def odeint (func: d=>Float -> Time -> d=>Float) -- Take steps until we pass target_t new_state = snd $ withState init_carry \state. - while (\(). stopping_condition (get state)) \(). + while (do stopping_condition (get state)) do state := possible_step (get state) (_, _, t, _, last_t, interp_coeff) = new_state diff --git a/examples/raytrace.dx b/examples/raytrace.dx index cb02f838f..45fea7b6b 100644 --- a/examples/raytrace.dx +++ b/examples/raytrace.dx @@ -218,7 +218,7 @@ def sampleLightRadiance Light lightPos hw _ -> (dirToLight, distToLight) = directionAndLength $ lightPos + sampleSquare hw k - rayPos - when (positiveProjection dirToLight surfNor) \(). + when (positiveProjection dirToLight surfNor) do -- light on this far side of current surface fracSolidAngle = (relu $ dot dirToLight yHat) * sq hw / (pi * sq distToLight) outRay = (rayPos, dirToLight) @@ -234,7 +234,7 @@ def trace (params:Params) (scene:Scene n) (initRay:Ray) (k:Key) : Color = case raymarch scene $ get ray of HitNothing -> Done () HitLight intensity -> - when (i == 0) \(). radiance += intensity -- TODO: scale etc + when (i == 0) do radiance += intensity -- TODO: scale etc Done () HitObj incidentRay osurf -> [k1, k2] = splitKey $ hash k i diff --git a/examples/rejection-sampler.dx b/examples/rejection-sampler.dx index 503dcfb69..fb2100a16 100644 --- a/examples/rejection-sampler.dx +++ b/examples/rejection-sampler.dx @@ -2,7 +2,7 @@ def rejectionSample (try: Key -> Maybe a) (k:Key) : a = ans = fst $ withState 0 \i. snd $ withState Nothing \sample. - while (\(). isNothing (get sample)) \(). + while (do isNothing (get sample)) do i := get i + 1 sample := try $ hash k (get i) case ans of Just sample -> sample diff --git a/lib/prelude.dx b/lib/prelude.dx index c06102a4f..daf9ac32a 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -249,7 +249,7 @@ def withState def unsafeIO (f: Unit -> {State World|eff} a) : {|eff} a = %runIO f -def unreachable (():Unit) : a = unsafeIO \(). +def unreachable (():Unit) : a = unsafeIO do %throwError a 'Type classes @@ -588,7 +588,7 @@ def inner (x:n=>Float) (mat:n=>m=>Float) (y:m=>Float) : Float = -- TODO: newtype Key = Int64 -def hash (x:Key) (y:Int32) : Key = unsafeIO \(). +def hash (x:Key) (y:Int32) : Key = unsafeIO do y64 = IToI64 y %ffi threefry2x32 Int64 x y64 def newKey (x:Int) : Key = hash (IToI64 0) x @@ -596,7 +596,7 @@ def many (f:Key->a) (k:Key) (i:n) : a = f (hash k (ordinal i)) def ixkey (k:Key) (i:n) : Key = hash k (ordinal i) def ixkey2 (k:Key) (i:n) (j:m) : Key = hash (hash k (ordinal i)) (ordinal j) def splitKey (n:Int) ?-> (k:Key) : Fin n => Key = for i. ixkey k i -def rand (k:Key) : Float = unsafeIO \(). F64ToF $ %ffi randunif Float64 k +def rand (k:Key) : Float = unsafeIO do F64ToF $ %ffi randunif Float64 k def randVec (n:Int) (f: Key -> a) (k: Key) : Fin n => a = for i:(Fin n). f (ixkey k i) @@ -923,22 +923,22 @@ instance showString : Show String where show = id instance showInt32 : Show Int32 where - show = \x: Int32. unsafeIO \(). + show = \x: Int32. unsafeIO do (n, ptr) = %ffi showInt32 (Int32 & RawPtr) x stringFromCharPtr n $ MkPtr ptr instance showInt64 : Show Int64 where - show = \x: Int64. unsafeIO \(). + show = \x: Int64. unsafeIO do (n, ptr) = %ffi showInt64 (Int32 & RawPtr) x stringFromCharPtr n $ MkPtr ptr instance showFloat32 : Show Float32 where - show = \x: Float32.unsafeIO \(). + show = \x: Float32.unsafeIO do (n, ptr) = %ffi showFloat32 (Int32 & RawPtr) x stringFromCharPtr n $ MkPtr ptr instance showFloat64 : Show Float64 where - show = \x: Float64.unsafeIO \(). + show = \x: Float64.unsafeIO do (n, ptr) = %ffi showFloat64 (Int32 & RawPtr) x stringFromCharPtr n $ MkPtr ptr @@ -1033,7 +1033,7 @@ def liftState (ref: Ref h c) (f:a -> {|eff} b) (x:a) : {State h|eff} b = -- A little iteration combinator def iter (body: Int -> {|eff} IterResult a) : {|eff} a = result = snd $ withState Nothing \resultRef. withState 0 \i. - while (\(). isNothing $ get resultRef) \(). + while (do isNothing $ get resultRef) do case liftState resultRef (liftState i body) (get i) of Continue -> i := get i + 1 Done result -> resultRef := Just result @@ -1115,7 +1115,7 @@ def shellOut (command:String) : {State World} String = 'Partial functions -def error (s:String) : a = unsafeIO \(). +def error (s:String) : a = unsafeIO do print s %throwError a @@ -1425,7 +1425,7 @@ def concat (lists:n=>(List a)) : List a = AsList _ $ fst $ withState 0 \listIdx. fst $ withState 0 \eltIdx. for i:(Fin totalSize). - while (\(). get eltIdx >= listLength (lists.((get listIdx)@_))) \(). + while (do get eltIdx >= listLength (lists.((get listIdx)@_))) do eltIdx := 0 listIdx := get listIdx + 1 (AsList _ xs) = lists.((get listIdx)@_) diff --git a/misc/dex.el b/misc/dex.el index 85f0acad4..9381371a5 100644 --- a/misc/dex.el +++ b/misc/dex.el @@ -10,7 +10,7 @@ ("^'\\(.\\|\n.\\)*\n\n" . font-lock-comment-face) ("\\w+:" . font-lock-comment-face) ("^:\\w*" . font-lock-preprocessor-face) - ("\\bdef\\b\\|\\bfor\\b\\|\\brof\\b\\|\\bcase\\b\\|\\bdata\\b\\|\\bwhere\\b\\|\\bof\\b\\|\\bif\\b\\|\\bthen\\b\\|\\belse\\b\\|\\binterface\\b\\|\\binstance\\b" . + ("\\bdef\\b\\|\\bfor\\b\\|\\brof\\b\\|\\bcase\\b\\|\\bdata\\b\\|\\bwhere\\b\\|\\bof\\b\\|\\bif\\b\\|\\bthen\\b\\|\\belse\\b\\|\\binterface\\b\\|\\binstance\\b\\|\\bdo\\b" . font-lock-keyword-face) ("--o" . font-lock-variable-name-face) ("[-.,!$^&*:~+/=<>|?\\\\]" . font-lock-variable-name-face) diff --git a/src/lib/Parser.hs b/src/lib/Parser.hs index 9cbff0e0f..3c7d50401 100644 --- a/src/lib/Parser.hs +++ b/src/lib/Parser.hs @@ -204,6 +204,7 @@ leafExpr = parens (mayPair $ makeExprParser leafExpr ops) <|> unitCon <|> (uLabeledExprs `fallBackTo` uVariantExpr) <|> uIsoSugar + <|> uDoSugar "expression" containedExpr :: Parser UExpr @@ -714,6 +715,12 @@ uLabeledExprs = withSrc $ varPun :: SrcPos -> Label -> UExpr varPun pos str = WithSrc (Just pos) $ UVar (mkName str :> ()) +uDoSugar :: Parser UExpr +uDoSugar = withSrc $ do + keyWord DoKW + body <- blockOrExpr + return $ ULam (WithSrc Nothing UPatUnit, Nothing) (PlainArrow ()) body + uIsoSugar :: Parser UExpr uIsoSugar = withSrc (char '#' *> options) where options = (recordFieldIso <$> fieldLabel) @@ -1019,7 +1026,7 @@ type Lexer = Parser data KeyWord = DefKW | ForKW | For_KW | RofKW | Rof_KW | CaseKW | OfKW | ReadKW | WriteKW | StateKW | DataKW | InterfaceKW - | InstanceKW | WhereKW | IfKW | ThenKW | ElseKW + | InstanceKW | WhereKW | IfKW | ThenKW | ElseKW | DoKW upperName :: Lexer Name upperName = liftM mkName $ label "upper-case name" $ lexeme $ @@ -1056,11 +1063,12 @@ keyWord kw = lexeme $ try $ string s >> notFollowedBy nameTailChar InterfaceKW -> "interface" InstanceKW -> "instance" WhereKW -> "where" + DoKW -> "do" keyWordStrs :: [String] keyWordStrs = ["def", "for", "for_", "rof", "rof_", "case", "of", "llam", "Read", "Write", "Accum", "data", "interface", - "instance", "where", "if", "then", "else"] + "instance", "where", "if", "then", "else", "do"] fieldLabel :: Lexer Label fieldLabel = label "field label" $ lexeme $ From 42193e3ad9a60eda0a5caa80d6f5fd9e5305f4ed Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 28 Dec 2020 02:20:58 +0100 Subject: [PATCH 038/105] Add shell.nix. --- shell.nix | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 shell.nix diff --git a/shell.nix b/shell.nix new file mode 100644 index 000000000..93215131e --- /dev/null +++ b/shell.nix @@ -0,0 +1,15 @@ +{ nixpkgs ? import {} }: +with nixpkgs; +stdenv.mkDerivation { + name = "dex"; + buildInputs = [ + cabal-install + haskell.compiler.ghc884 + llvm_9 + clang_9 + pkg-config + libpng + git + cacert + ]; +} From dab1c4f4197c01da947a2a2b364ec9ec3bcaa6cc Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 28 Dec 2020 02:21:58 +0100 Subject: [PATCH 039/105] Add cabal.project. This is necessary for cabal to find dependencies stored in Git repositories. --- cabal.project | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 cabal.project diff --git a/cabal.project b/cabal.project new file mode 100644 index 000000000..7ab885353 --- /dev/null +++ b/cabal.project @@ -0,0 +1,13 @@ +packages: dex.cabal + +source-repository-package + type: git + location: https://github.com/apaszke/llvm-hs + tag: a9a74be1a7c15f3d21b2fffd35a425002ae7736f + subdir: llvm-hs + +source-repository-package + type: git + location: https://github.com/apaszke/llvm-hs + tag: a9a74be1a7c15f3d21b2fffd35a425002ae7736f + subdir: llvm-hs-pure From c595c6a9e5c1a4693966cf36b362b7c13d03ae92 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Mon, 28 Dec 2020 22:05:22 -0500 Subject: [PATCH 040/105] Print floats with fewer decimal places to make quine tests less brittle. --- examples/ctc.dx | 8 +- examples/isomorphisms.dx | 22 ++-- examples/mcmc.dx | 6 +- examples/ode-integrator.dx | 4 +- examples/particle-swarm-optimizer.dx | 18 ++-- examples/pi.dx | 4 +- examples/regression.dx | 2 +- makefile | 14 ++- src/lib/PPrint.hs | 6 +- tests/ad-tests.dx | 78 +++++++------- tests/adt-tests.dx | 2 +- tests/complex-tests.dx | 16 +-- tests/eval-tests.dx | 126 +++++++++++----------- tests/monad-tests.dx | 18 ++-- tests/parser-tests.dx | 18 ++-- tests/record-variant-tests.dx | 22 ++-- tests/repl-multiline-test-expected-output | 4 +- tests/serialize-tests.dx | 8 +- tests/type-tests.dx | 4 +- tests/uexpr-tests.dx | 22 ++-- 20 files changed, 205 insertions(+), 197 deletions(-) diff --git a/examples/ctc.dx b/examples/ctc.dx index 4e6544eb2..d0dc7979a 100644 --- a/examples/ctc.dx +++ b/examples/ctc.dx @@ -118,7 +118,7 @@ labels = for i:position. randIdxNoZero Vocab (newKey (ordinal i)) -- Evaluate marginal probability of labels given logits :p exp $ ctc blank logits labels -> 1.0398488e-3 +> 0.00104 @@ -132,12 +132,12 @@ labels = for i:position. randIdxNoZero Vocab (newKey (ordinal i)) :p sum for i:Vocab. exp $ ctc blank logits [i] -> 0.14146839 +> 0.141468 :p sum for (i, j):(Vocab & Vocab). exp $ ctc blank logits [i, j] -> 0.7091234 +> 0.709123 :p sum for (i, j, k):(Vocab & Vocab & Vocab). exp $ ctc blank logits [i, j, k] -> 0.9251011 +> 0.925101 diff --git a/examples/isomorphisms.dx b/examples/isomorphisms.dx index b127c16a2..9668eac42 100644 --- a/examples/isomorphisms.dx +++ b/examples/isomorphisms.dx @@ -16,21 +16,21 @@ cycleThree : Iso (a & b & c) (b & c & a) = and flipped with `flipIso` :p appIso cycleThree (1, 2.0, 3) -> (2.0, (3, 1)) +> (2., (3, 1)) :p revIso cycleThree (1, 2.0, 3) -> (3, (1, 2.0)) +> (3, (1, 2.)) :p appIso (flipIso cycleThree) (1, 2.0, 3) -> (3, (1, 2.0)) +> (3, (1, 2.)) 'They can also be composed with `&>>`: :p appIso (cycleThree &>> cycleThree) (1, 2.0, 3) -> (3, (1, 2.0)) +> (3, (1, 2.)) :p appIso (cycleThree &>> cycleThree &>> cycleThree) (1, 2.0, 3) -> (1, (2.0, 3)) +> (1, (2., 3)) 'Note that we assume but do not check that the isomorphism is lawful (i.e. `appIso iso $ revIso iso x == x` for all `x`, or equivalently @@ -88,13 +88,13 @@ Record accessor isomorphisms can be passed into the helper function `getAt`: 'We can also do other types of things: :p popAt #foo {foo=1, bar=2.0} -> {bar = 2.0} +> {bar = 2.} :p pushAt #foo 3.0 {foo=1, bar=2.0} -> {bar = 2.0, foo = 3.0, foo = 1} +> {bar = 2., foo = 3., foo = 1} :p setAt #foo 2 {foo=1, bar=2.0} -> {bar = 2.0, foo = 2} +> {bar = 2., foo = 2} 'These helper functions work with any "lens-like" isomorphism. For instance, we can select everything except for a particular field: @@ -103,7 +103,7 @@ we can select everything except for a particular field: > ((a:Type) ?-> (b:Type) ?-> (c:Type) ?-> (Iso a (b & c)) -> Iso a (c & b)) :p getAt (exceptLens #foo) {foo=1, bar=2.0, baz=3} -> {bar = 2.0, baz = 3} +> {bar = 2., baz = 3} '## Variant accessors and prism-like helpers @@ -127,7 +127,7 @@ Similarly, there are prism-like helpers > {| foo = 3 |} :p matchWith (exceptPrism #?foo) $ {|bar = 1.0|}:{foo:Int | bar:Float} -> (Just {| bar = 1.0 |}) +> (Just {| bar = 1. |}) '## Record zipper isomorphisms The isomorphisms shown above are specialized for removing a single field from @@ -200,7 +200,7 @@ ordinary record accessor lens. '`splitR` can be used if you want to process multiple fields at once: :p pushAt (splitR &>> #&a &>> #&b) {a=1, b=2.0} {c=3, d=4.0} -> {a = 1, b = 2.0, c = 3, d = 4.0} +> {a = 1, b = 2., c = 3, d = 4.} '## Variant zipper isomorphisms Just as there are record zipper isomorphisms, there are also variant diff --git a/examples/mcmc.dx b/examples/mcmc.dx index 0df6c857d..1cf4229c0 100644 --- a/examples/mcmc.dx +++ b/examples/mcmc.dx @@ -95,8 +95,7 @@ mhParams = 0.1 mhSamples = runChain randnVec (mhStep mhParams myLogProb) numSamples k0 :p meanAndCovariance mhSamples -> ( [1.5165595, 2.493105] -> , [[1.0373967, 1.1820998e-2], [1.1820998e-2, 5.377563e-2]] ) +> ([1.51656, 2.493105], [[1.037397, 0.011821], [0.011821, 0.053776]]) :html showPlot $ yPlot $ slice (map head mhSamples) 0 (Fin 1000) @@ -106,8 +105,7 @@ hmcParams = (10, 0.1) hmcSamples = runChain randnVec (hmcStep hmcParams myLogProb) numSamples k0 :p meanAndCovariance hmcSamples -> ( [1.5045699, 2.5000212] -> , [[0.97386724, 3.422921e-3], [3.422921e-3, 5.058581e-2]] ) +> ([1.50457, 2.500021], [[0.973867, 0.003423], [0.003423, 0.050586]]) :html showPlot $ yPlot $ slice (map head hmcSamples) 0 (Fin 1000) diff --git a/examples/ode-integrator.dx b/examples/ode-integrator.dx index dc784b177..4e77b1b02 100644 --- a/examples/ode-integrator.dx +++ b/examples/ode-integrator.dx @@ -160,12 +160,12 @@ t1 = [1.0] approx_e = odeint myDyn z0 t0 t1 :p approx_e -> [[2.7201762]] +> [[2.720176]] exact_e = [[exp 1.0]] :p (approx_e - exact_e) -- amount of numerical error -> [[1.894474e-3]] +> [[0.001894]] times = linspace (Fin 100) 0.00001 1.0 ys = odeint myDyn z0 t0 times diff --git a/examples/particle-swarm-optimizer.dx b/examples/particle-swarm-optimizer.dx index 3677c15e7..21b0cab5d 100644 --- a/examples/particle-swarm-optimizer.dx +++ b/examples/particle-swarm-optimizer.dx @@ -16,16 +16,16 @@ rosenbrock2 : ((Fin 2)=>Float) -> Float = ' Min should be at 1.0, 1.0 :p rosenbrock 1.0 1.000 -> 0.0 +> 0. :p rosenbrock2 [1.0, 1.000] -> 0.0 +> 0. :p rosenbrock 1.0 1.02 -> 3.199994e-2 +> 0.032 :p rosenbrock2 [1.0, 1.02] -> 3.199994e-2 +> 0.032 ' ## Helper functions @@ -43,7 +43,7 @@ randBounded : Key -> (d=>Float)->(d=>Float)->(d=>Float) = for i. lb.i + ((rand $ ixkey key i) * (ub.i - lb.i)) :p randBounded (newKey 4) [1.0, -2.0] [-1.0, 2.0] -> [-0.35101044, 1.4935503] +> [-0.35101, 1.49355] ' ## The Optimizer itself. We have **arguments**: @@ -103,13 +103,13 @@ Run it for more iterations and result should improve. Which it indeed does. :p optimize 50 10 (newKey 0) rosenbrock2 ([-10.0, -10.0],[20.0, 20.0]) (0.5,0.3,0.4) -> [7.698643e-2, 0.23281813] +> [0.076986, 0.232818] :p optimize 50 20 (newKey 0) rosenbrock2 ([-10.0, -10.0],[20.0, 20.0]) (0.5,0.3,0.4) -> [0.90125036, 0.75044703] +> [0.90125, 0.750447] :p optimize 50 100 (newKey 0) rosenbrock2 ([-10.0, -10.0],[20.0, 20.0]) (0.5,0.3,0.4) -> [0.9990686, 0.9981924] +> [0.999069, 0.998192] :p optimize 50 1000 (newKey 0) rosenbrock2 ([-10.0, -10.0],[20.0, 20.0]) (0.5,0.3,0.4) -> [1.0, 1.0] +> [1., 1.] diff --git a/examples/pi.dx b/examples/pi.dx index 6c05e4b86..c8cc314e7 100644 --- a/examples/pi.dx +++ b/examples/pi.dx @@ -18,7 +18,7 @@ def meanAndStdDev (n:Int) (f: Key -> Float) (key:Key) : (Float & Float) = numSamps = 1000000 :p meanAndStdDev numSamps estimatePiArea (newKey 0) -> (3.143452, 1.6408892) +> (3.143452, 1.640889) :p meanAndStdDev numSamps estimatePiAvgVal (newKey 0) -> (3.1437902, 0.88649935) +> (3.14379, 0.886499) diff --git a/examples/regression.dx b/examples/regression.dx index fbf484aa1..ca8ce2731 100644 --- a/examples/regression.dx +++ b/examples/regression.dx @@ -62,7 +62,7 @@ def rmsErr (truth:n=>Float) (pred:n=>Float) : Float = sqrt $ mean for i. sq (pred.i - truth.i) :p rmsErr ys (map predict xs) -> 0.25269496 +> 0.252695 def tabCat (xs:n=>a) (ys:m=>a) : ({left:n|right:m})=>a = diff --git a/makefile b/makefile index ef8adb0f2..a72e4123e 100644 --- a/makefile +++ b/makefile @@ -95,7 +95,8 @@ all-names = $(test-names:%=tests/%) $(example-names:%=examples/%) quine-test-targets = $(all-names:%=run-%) -update-targets = $(example-names:%=update-%) +update-test-targets = $(test-names:%=update-tests-%) +update-example-targets = $(example-names:%=update-examples-%) doc-names = $(example-names:%=doc/%.html) @@ -115,10 +116,15 @@ run-examples/%: examples/%.dx build prop-tests: cbits/libdex.so $(STACK) test $(PROF) -update-all: $(update-targets) - update-%: export DEX_ALLOW_CONTRACTIONS=0 -update-%: tests/%.dx build + +update-all: $(update-test-targets) $(update-example-targets) + +update-tests-%: tests/%.dx build + $(dex) script --allow-errors $< > $<.tmp + mv $<.tmp $< + +update-examples-%: examples/%.dx build $(dex) script --allow-errors $< > $<.tmp mv $<.tmp $< diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index d37847ac1..ecc816193 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -134,12 +134,16 @@ instance PrettyPrec ScalarBaseType where printDouble :: Double -> Doc ann printDouble x = p (double2Float x) +printFloat :: Float -> Doc ann +printFloat x = p $ reverse $ dropWhile (=='0') $ reverse $ + showFFloat (Just 6) x "" + instance Pretty LitVal where pretty = prettyFromPrettyPrec instance PrettyPrec LitVal where prettyPrec (Int64Lit x) = atPrec ArgPrec $ p x prettyPrec (Int32Lit x) = atPrec ArgPrec $ p x prettyPrec (Float64Lit x) = atPrec ArgPrec $ printDouble x - prettyPrec (Float32Lit x) = atPrec ArgPrec $ p x + prettyPrec (Float32Lit x) = atPrec ArgPrec $ printFloat x prettyPrec (Word8Lit x) = atPrec ArgPrec $ p $ show $ toEnum @Char $ fromIntegral x prettyPrec (PtrLit ty x) = atPrec ArgPrec $ "Ptr" <+> p ty <+> p (show x) prettyPrec (VecLit l) = atPrec ArgPrec $ encloseSep "<" ">" ", " $ fmap p l diff --git a/tests/ad-tests.dx b/tests/ad-tests.dx index 5d5b4e7c1..d541bc363 100644 --- a/tests/ad-tests.dx +++ b/tests/ad-tests.dx @@ -5,50 +5,50 @@ def sum' (xs:n=>Float) : Float = snd $ withAccum \ref. for i. ref += xs.i :p f : Float -> Float = \x. x jvp f 3.0 2.0 -> 2.0 +> 2. :p f = \x. x * x jvp f 3.0 1.5 -> 9.0 +> 9. :p f = \x. x + x jvp f 3.0 2.0 -> 4.0 +> 4. :p f = \x. x * x * x jvp f 2.0 1.5 -> 18.0 +> 18. :p f : Float --o Float = \x. x transposeLinear f 2.0 -> 2.0 +> 2. :p f : Float --o Float = \x. x + x transposeLinear f 1.0 -> 2.0 +> 2. :p f : Float --o Float = \x. x + (x + x) * 2.0 transposeLinear f 1.0 -> 5.0 +> 5. :p f : Float --o Float = \x. x * 2.0 transposeLinear f 1.0 -> 2.0 +> 2. :p f : Float --o Float = \x. 2.0 * x transposeLinear f 1.0 -> 2.0 +> 2. :p grad (\x. x * x) 1.0 -> 2.0 +> 2. :p deriv (\x. 3.0 / x) 2.0 > -0.75 @@ -61,49 +61,49 @@ def sum' (xs:n=>Float) : Float = snd $ withAccum \ref. for i. ref += xs.i \xs. for i. xs.i * xs.i jvp f [1.,2.] [3.,4.] -> [6.0, 16.0] +> [6., 16.] :p jvp transpose [[1.,2.], [3.,4.]] [[10.,20.], [30.,40.]] -> [[10.0, 30.0], [20.0, 40.0]] +> [[10., 30.], [20., 40.]] :p jvp sum' [1., 2.] [10.0, 20.0] -> 30.0 +> 30. f : Float -> Float = \x. snd $ withAccum \ref. ref += x :p jvp f 1.0 1.0 -> 1.0 +> 1. :p f = \x. x * x * x jvp (\x. jvp f x 1.0) 2.0 1.0 -> 12.0 +> 12. :p f = \x. 4.0 * x * x * x deriv (deriv (deriv f)) 1.234 -> 24.0 +> 24. :p f : Float --o (Float & Float) = \x. (x, 2.0 * x) transposeLinear f (1.0, 3.0) -> 7.0 +> 7. :p f : (Float & Float) --o Float = \(x,y). x + 2.0 * y transposeLinear f 1.0 -> (1.0, 2.0) +> (1., 2.) :p deriv cos 0.0 -> 0.0 +> 0. :p deriv sin 0.0 -> 1.0 +> 1. :p (sin 1.0, deriv (deriv sin) 1.0) -> (0.84147096, -0.84147096) +> (0.841471, -0.841471) :p (cos 1.0, deriv (deriv (deriv sin)) 1.0) -> (0.5403023, -0.5403023) +> (0.540302, -0.540302) :p checkDeriv sin 1.0 > True @@ -141,30 +141,30 @@ f : Float -> Float = \x. snd $ withAccum \ref. ref += x -- Perturbation confusion test suggested by Barak Pearlmutter -- https://github.com/HIPS/autograd/issues/4 :p deriv (\x. x * deriv (\y. x * y) 2.0) 1.0 -> 2.0 +> 2. tripleit : Float --o Float = \x. x + x + x :p tripleit 1.0 -> 3.0 +> 3. :p transposeLinear tripleit 1.0 -> 3.0 +> 3. :p transposeLinear (transposeLinear tripleit) 1.0 -> 3.0 +> 3. :p f : n:Type ?-> Float --o n=>Float = \x. for i. x transposeLinear f [1.0, 2.0] -> 3.0 +> 3. :p f : n:Type ?-> n=>Float --o n=>Float = \x. for i. x.i * 2.0 transposeLinear f [1.0, 2.0] -> [2.0, 4.0] +> [2., 4.] myOtherSquare : Float -> Float = \x. snd $ withAccum \w. w += x * x @@ -177,50 +177,50 @@ myOtherSquare : Float -> Float = \x. fst (x * x, 2 + 1) jvp f 1.0 3.0 -> 6.0 +> 6. :p f : Float -> Float = \x. x * IToF (1 + 1) jvp f 1.0 2.0 -> 4.0 +> 4. :p f : (Fin 2)=>Float -> Float = \xs. xs.(0 @ Fin 2) * xs.(1 @ Fin 2) jvp f [1., 2.] [3.0, 4.0] -> 10.0 +> 10. :p f : (Float & Float) -> Float = \(x,y). x * y jvp f (1., 2.) (3.0, 4.0) -> 10.0 +> 10. :p f : n:Type ?-> n=>Float -> n=>Float = \xs. for i. xs.i * xs.i jvp f [1.,2.] [3.,4.] -> [6.0, 16.0] +> [6., 16.] :p jvp sum' [1., 2.] [3.0, 4.0] -> 7.0 +> 7. :p grad sum' [1.,2.] -> [1.0, 1.0] +> [1., 1.] vec = [1.] :p jvp (\x. vec) [1.] [1.] -> [0.0] +> [0.] :p grad (\(x, y). vdot x y) ([1.,2.], [3.,4.]) -> ([3.0, 4.0], [1.0, 2.0]) +> ([3., 4.], [1., 2.]) :p f : Float -> Float = \x. @@ -229,7 +229,7 @@ vec = [1.] a += x * 2.0 a += y grad f 1.0 -> 4.0 +> 4. :p f : Float -> Float = \x. @@ -269,7 +269,7 @@ vec = [1.] :p f = \x. for i:(Fin 4). { x=x * x * (IToF $ ordinal i) } jvp f 2.0 1.0 -> [{x = 0.0}, {x = 4.0}, {x = 8.0}, {x = 12.0}] +> [{x = 0.}, {x = 4.}, {x = 8.}, {x = 12.}] :p s : { a : Float | b : Float } = case 2 == 2 of diff --git a/tests/adt-tests.dx b/tests/adt-tests.dx index e8fb17566..72ff717d8 100644 --- a/tests/adt-tests.dx +++ b/tests/adt-tests.dx @@ -243,7 +243,7 @@ def zerosLikeList (l : List a) : (Fin (listLength l))=>Float = for i:(Fin $ listLength l). 0.0 :p zerosLikeList l2 -> [0.0, 0.0, 0.0] +> [0., 0., 0.] data Graph a:Type = MkGraph n:Type nodes:(n=>a) m:Type edges:(m=>(n & n)) diff --git a/tests/complex-tests.dx b/tests/complex-tests.dx index 7af6ea331..bf799cb39 100644 --- a/tests/complex-tests.dx +++ b/tests/complex-tests.dx @@ -1,11 +1,11 @@ :p complex_floor $ MkComplex 0.3 0.6 -> (MkComplex 0.0 0.0) +> (MkComplex 0. 0.) :p complex_floor $ MkComplex 0.6 0.8 -> (MkComplex 0.0 1.0) +> (MkComplex 0. 1.) :p complex_floor $ MkComplex 0.8 0.6 -> (MkComplex 1.0 0.0) +> (MkComplex 1. 0.) :p complex_floor $ MkComplex 0.6 0.3 -> (MkComplex 0.0 0.0) +> (MkComplex 0. 0.) a = MkComplex 2.1 0.4 b = MkComplex (-1.1) 1.3 @@ -48,10 +48,10 @@ b = MkComplex (-1.1) 1.3 > True :p sinh (MkComplex 1.2 3.2) -> (MkComplex -1.5068874 -0.10569556) +> (MkComplex -1.506887 -0.105696) :p cosh (MkComplex 1.2 3.2) -> (MkComplex -1.807568 8.811359e-2) +> (MkComplex -1.807568 0.088114) :p tanh (MkComplex 1.1 0.1) -> (MkComplex 0.80337524 3.5809334e-2) +> (MkComplex 0.803375 0.035809) :p tan (MkComplex 1.2 3.2) -> (MkComplex 2.2501666e-3 1.002451) +> (MkComplex 0.00225 1.002451) diff --git a/tests/eval-tests.dx b/tests/eval-tests.dx index 57fa06f56..b1d443bb4 100644 --- a/tests/eval-tests.dx +++ b/tests/eval-tests.dx @@ -1,10 +1,10 @@ :p 1.0 + 2.0 -> 3.0 +> 3. :p double = \x. x * 2.0 double 10.0 -> 20.0 +> 20. :p sum (iota (Fin 10)) > 45 @@ -25,14 +25,14 @@ x = iota (Fin 3) y = map IToF x vdot' y y -> 10.0 +> 10. :p x = iota $ Fin 3 y = iota $ Fin 4 z = for i j. IToF x.i * IToF y.j sum (for i. sum z.i) -> 18.0 +> 18. -- :p randint (hash 0 0) 10 -- :p let x = unpack range 10000 @@ -51,7 +51,7 @@ arr = iota NArr fun = \y. sum (map IToF arr) + y :p fun 3.0 -> 24.0 +> 24. :p arr > [0, 1, 2, 3, 4, 5, 6] @@ -60,10 +60,10 @@ fun = \y. sum (map IToF arr) + y > 21 :p 6.0 - 10.0 -> -4.0 +> -4. :p (\(x, y). x + y) (1.0, 2.0) -> 3.0 +> 3. :p f : a:Type ?-> b:Type ?-> (a -> b & a) -> b = @@ -75,40 +75,40 @@ fun = \y. sum (map IToF arr) + y (x,y) = ((1.0,2.0),3.0) (x1, x2) = x x1 + x2 + y -> 6.0 +> 6. :p x = (1.0,2.0) (y,z) = x y + z -> 3.0 +> 3. -- :p let f (x, y) = x + 2 * y; -- z.i = (x.i, x.i * x.i) -- in sum (for i. f z.i) :p exp 1.0 -> 2.7182817 +> 2.718282 :p exp2 3.0 -> 8.0 +> 8. :p log 1.0 -> 0.0 +> 0. :p log2 8.0 -> 3.0 +> 3. :p log10 100.0 -> 2.0 +> 2. :p sqrt 2.0 -> 1.4142135 +> 1.414214 :p sin 3.14159 -> 2.5351817e-6 +> 0.000003 :p cos 0.0 -> 1.0 +> 1. :p tan 1.57079 > 159378.27 @@ -125,7 +125,7 @@ fun = \y. sum (map IToF arr) + y s = 1.0 :p s -> 1.0 +> 1. :p [2, 4, 8] > [2, 4, 8] @@ -141,7 +141,7 @@ cumsumplus : n=>Float -> n=>Float = (ans, 1.0 + ans) :p cumsumplus [1.0, 2.0, 3.0] -> [2.0, 4.0, 7.0] +> [2., 4., 7.] :p [False, False, True] > [False, False, True] @@ -213,15 +213,15 @@ litArr = [10, 5, 3] :p k = newKey 0 mean for i:(Fin 100). randn (ixkey k i) -> -0.1157995 +> -0.1158 :p k = newKey 0 mean for i:(Fin 100). sq $ randn (ixkey k i) -> 1.2581898 +> 1.25819 :p for i:(Fin 3) j:(Fin 2). rand $ ixkey2 (newKey 11) i j -> [[0.47415292, 0.9145164], [0.7944602, 0.27679908], [0.58958626, 0.7116251]] +> [[0.474153, 0.914516], [0.79446, 0.276799], [0.589586, 0.711625]] :p x = for i:(Fin 3). 0 @@ -234,7 +234,7 @@ litArr = [10, 5, 3] > [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]] :p fold (for i:(Fin 3). 0.0) $ \i:(Fin 2) c. (for j. c.j + IToF (ordinal j)) -> [0.0, 2.0, 4.0] +> [0., 2., 4.] :p mat2 = for i:(Fin 4) j:(Fin 4) . ordinal i @@ -275,7 +275,7 @@ litArr = [10, 5, 3] > 1 :p select False 1.0 2.0 -> 2.0 +> 2. :p select True [1,2,3] [10,20,30] > [1, 2, 3] @@ -296,7 +296,7 @@ litArr = [10, 5, 3] > (False, (True, (True, True))) :p [(for i:(Fin 1). (False, for j:(Fin 2). 1.0)), [(True, for k:(Fin 2) . 2.0)]] -> [[(False, [1.0, 1.0])], [(True, [2.0, 2.0])]] +> [[(False, [1., 1.])], [(True, [2., 2.])]] -- TODO: parse negative integer literals -- :p (mod 5 3, mod 7 3, mod (-1) 3, mod -5 3) @@ -324,15 +324,15 @@ litArr = [10, 5, 3] > 2 :p fold (1.0, 2.0) \i:(Fin 2) (x, y). (y, x) -> (1.0, 2.0) +> (1., 2.) :p fold (1.0, 2.0) \i:(Fin 3) (x, y). (y, x) -> (2.0, 1.0) +> (2., 1.) :p id 2 > 2 :p min 2.0 3.0 -> 2.0 +> 2. :p minBy sq 0.5 (-2.0) > 0.5 @@ -344,16 +344,16 @@ litArr = [10, 5, 3] > (1.5, 15) :p max 2.0 3.0 -> 3.0 +> 3. :p maxBy sq 0.5 (-2.0) -> -2.0 +> -2. :p maximum [2.0, 4.0, 1.5, 7.0] -> 7.0 +> 7. :p maximumBy fst [(2.0, 20), (1.5, 15), (10.0, 100)] -> (10.0, 100) +> (10., 100) :p (1 == 2, (-1) == (-1), 1 < 2, -1 < 2, 2 < (-1)) > (False, (True, (True, (True, False)))) @@ -367,7 +367,7 @@ litArr = [10, 5, 3] σ = 1.0 + 2.0 :p σ -> 3.0 +> 3. δ : Int -> Int = \x. x @@ -399,59 +399,59 @@ litArr = [10, 5, 3] -- line comment should be ok here 2.0 * x f 1.0 -> 2.0 +> 2. -- Not sure why the ordinary `sum/for` version doesn't work anymore :p n = 3 + 7 fsum \i:(Fin n). 1.0 -> 10.0 +> 10. :p n = 4 fsum \i:(Fin n). 1.0 -> 4.0 +> 4. :p one = fromOrdinal (Fin 4) 1 for i:(Fin 4). sum for j:(one..i). 1.0 -> [0.0, 1.0, 2.0, 3.0] +> [0., 1., 2., 3.] :p one = fromOrdinal (Fin 4) 1 for i:(Fin 4). sum for j:(one<..i). 1.0 -> [0.0, 0.0, 1.0, 2.0] +> [0., 0., 1., 2.] :p one = fromOrdinal (Fin 4) 1 for i:(Fin 4). sum for j:(one<..i). 1.0 -> [0.0, 0.0, 1.0, 2.0] +> [0., 0., 1., 2.] :p one = fromOrdinal (Fin 4) 1 for i:(Fin 4). sum for j:(one.. [0.0, 0.0, 1.0, 2.0] +> [0., 0., 1., 2.] :p one = fromOrdinal (Fin 4) 1 for i:(Fin 4). sum for j:(one<.. [0.0, 0.0, 0.0, 1.0] +> [0., 0., 0., 1.] :p for i:(Fin 4). sum for j:(..i). 1.0 -> [1.0, 2.0, 3.0, 4.0] +> [1., 2., 3., 4.] :p for i:(Fin 4). sum for j:(.. [0.0, 1.0, 2.0, 3.0] +> [0., 1., 2., 3.] :p for i:(Fin 4). sum for j:(i..). 1.0 -> [4.0, 3.0, 2.0, 1.0] +> [4., 3., 2., 1.] :p for i:(Fin 4). sum for j:(i<..). 1.0 -> [3.0, 2.0, 1.0, 0.0] +> [3., 2., 1., 0.] :p idiv 10 3 > 3 @@ -464,7 +464,7 @@ litArr = [10, 5, 3] ys = [1.,2.,3.] xys = for (i,j). xs.i + ys.j sum xys -> 102.0 +> 102. :p xs = [10.,20.] @@ -472,7 +472,7 @@ litArr = [10, 5, 3] zs = [1.] xys = for (i,(j,k)). xs.i + ys.j + zs.k sum xys -> 108.0 +> 108. :p xs = [[1,2],[3,4]] @@ -506,14 +506,14 @@ litArr = [10, 5, 3] c = get ref ref := c + 1.0 c -> ([0.0, 1.0, 2.0, 3.0], 4.0) +> ([0., 1., 2., 3.], 4.) :p withState 0.0 \ref. rof i:(Fin 4). c = get ref ref := c + 1.0 c -> ([3.0, 2.0, 1.0, 0.0], 4.0) +> ([3., 2., 1., 0.], 4.) def eitherFloor (x:(Int|Float)) : Int = case x of Left i -> i @@ -582,7 +582,7 @@ def unflatten (params:Params n m) : (Weights n m & Biases n) = -- TODO: within-module version of this (currently fails in Imp checking) upperBound = sum $ for i:(Fin 4). 1 :p for j:(Fin upperBound). 1.0 -> [1.0, 1.0, 1.0, 1.0] +> [1., 1., 1., 1.] :p (for i:(Fin upperBound). 1, for j:(Fin 2). 2) > ([1, 1, 1, 1], [2, 2]) @@ -609,7 +609,7 @@ for i:(Range 0 x). 1.0 -- Make sure that we can construct and print an array using a pair index set for i:(Fin 2 & Fin 2). 1.0 -> [1.0, 1.0, 1.0, 1.0]@(Fin 2 & Fin 2) +> [1., 1., 1., 1.]@(Fin 2 & Fin 2) 1@(Fin 2 & Fin 2) > ((0@Fin 2), (1@Fin 2)) @@ -618,11 +618,11 @@ for i:(Fin 5). for j:(..i). ir = IToF $ ordinal i jr = IToF $ ordinal j ir * (ir + 1.0) / 2.0 + jr -> [ [0.0]@(..(0@Fin 5)) -> , [1.0, 2.0]@(..(1@Fin 5)) -> , [3.0, 4.0, 5.0]@(..(2@Fin 5)) -> , [6.0, 7.0, 8.0, 9.0]@(..(3@Fin 5)) -> , [10.0, 11.0, 12.0, 13.0, 14.0]@(..(4@Fin 5)) ] +> [ [0.]@(..(0@Fin 5)) +> , [1., 2.]@(..(1@Fin 5)) +> , [3., 4., 5.]@(..(2@Fin 5)) +> , [6., 7., 8., 9.]@(..(3@Fin 5)) +> , [10., 11., 12., 13., 14.]@(..(4@Fin 5)) ] -- TODO: fix! -- -- Exercise the use of free variables in the sum solver @@ -640,7 +640,7 @@ for i:(Fin 5). for j:(..i). > [-1, -8] :p 2.0 .* [[1.0, 2.0], [3.0, 4.0]] -> [[2.0, 4.0], [6.0, 8.0]] +> [[2., 4.], [6., 8.]] def newtonIter (f: Float -> Float) (x:Float) : Float = x - (f x / deriv f x) @@ -651,7 +651,7 @@ def newtonSolve (tol:Float) (f : Float -> Float) (x0:Float) : Float = x := newtonIter f $ get x :p newtonSolve 0.001 (\x. sq x - 2.0) 1.0 -> 1.4142157 +> 1.414216 -- :p -- x = for i:(Fin 3). for j:(Fin 200). 1.0 @@ -699,7 +699,7 @@ def newtonSolve (tol:Float) (f : Float -> Float) (x0:Float) : Float = w = 2 \z. z + w (f 5, w) -> (7, 2.0) +> (7, 2.) -- def add (n : Type) ?-> (a : n=>Float) (b : n=>Float) : n=>Float = -- (tile (\t:(Tile n (Fin VectorWidth)). storeVector $ loadTile t a + loadTile t b) @@ -744,22 +744,22 @@ easy = [(-2.0), 3.0, 3.0, 0.1, 0.0] hard = [(-1000.0), 1000.0, 1000.0, 0.1, 0.0] :p logsumexp easy - (log $ sum for j. exp easy.j) -> 0.0 +> 0. :p sum $ softmax hard -> 1.0 +> 1. :p all for i. ((softmax hard).i >= 0.0) > True :p sum for i. exp $ (logsoftmax hard).i -> 0.9999709 +> 0.999971 :p all for i. abs ((softmax hard).i - exp (logsoftmax hard).i) < 0.0001 > True :p evalpoly [2.0, 3.0, 4.0] 10.0 -> 234.0 +> 234. str = ['x', 'y'] diff --git a/tests/monad-tests.dx b/tests/monad-tests.dx index 4612954c0..42b230ad7 100644 --- a/tests/monad-tests.dx +++ b/tests/monad-tests.dx @@ -22,7 +22,7 @@ ref := (z * 3.0) withState 1.0 stateAction -> ((), 9.0) +> ((), 9.) :p def rwsAction @@ -40,7 +40,7 @@ withState True \s. withAccum \w. rwsAction r w s -> ((4, 6.0), False) +> ((4, 6.), False) :p def m (h:Type) ?-> (s:Ref h (Fin 3=>Int)) : {State h} Unit = @@ -61,7 +61,7 @@ x = get s w += x withState 1.0 \s. withAccum \w . m w s -> (((), 1.0), 1.0) +> (((), 1.), 1.) def myAction (w:Ref hw Float) (r:Ref hr Float) : {Read hr, Accum hw} Unit = x = ask r @@ -79,7 +79,7 @@ def myAction (w:Ref hw Float) (r:Ref hr Float) : {Read hr, Accum hw} Unit = w2 += 3.0 w1 += 1.0 withAccum \w1. withAccum \w2. m w1 w2 -> (((), 3.0), 2.0) +> (((), 3.), 2.) def foom (h:Type) ?-> (s:Ref h ((Fin 3)=>Int)) : {State h} Unit = s!(fromOrdinal _ 0) := 1 @@ -134,7 +134,7 @@ def foom (h:Type) ?-> (s:Ref h ((Fin 3)=>Int)) : {State h} Unit = w' += x + x s := 4 (x, y) -> ((((2.0, 3), 4), 4.0), 2.0) +> ((((2., 3), 4), 4.), 2.) def symmetrizeInPlace (mat:n=>n=>Float) : n=>n=>Float = snd $ withState mat \ref. @@ -146,7 +146,7 @@ def symmetrizeInPlace (mat:n=>n=>Float) : n=>n=>Float = ref!j!i := avg symmetrizeInPlace [[1.,2.],[3.,4.]] -> [[1.0, 2.5], [2.5, 4.0]] +> [[1., 2.5], [2.5, 4.]] :p withReader 5 \r. () > () @@ -155,15 +155,15 @@ symmetrizeInPlace [[1.,2.],[3.,4.]] for i:(Fin 2). w += 1.0 w += 1.0 -> 4.0 +> 4. :p snd $ withAccum \w. for i:(Fin 2). w += 1.0 w += 1.0 -> 3.0 +> 3. :p snd $ withAccum \ref. ref += [1.,2.,3.] ref += [2.,4.,5.] -> [3.0, 6.0, 8.0] +> [3., 6., 8.] diff --git a/tests/parser-tests.dx b/tests/parser-tests.dx index 81e8acc72..ac388a409 100644 --- a/tests/parser-tests.dx +++ b/tests/parser-tests.dx @@ -1,28 +1,28 @@ 'For now, arithmetic is not sensitive to whitespace: :p 1.0+1.0 -> 2.0 +> 2. :p 1.0 +1.0 -> 2.0 +> 2. :p 1.0+ 1.0 -> 2.0 +> 2. :p 1.0 + 1.0 -> 2.0 +> 2. :p 1.0-1.0 -> 0.0 +> 0. :p 1.0 -1.0 -> 0.0 +> 0. :p 1.0- 1.0 -> 0.0 +> 0. :p 1.0 - 1.0 -> 0.0 +> 0. 'Applying a function to a negative literal thus requires parentheses. @@ -37,7 +37,7 @@ f = \x. x + 10. > ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ :p f (-1.0) -> 9.0 +> 9. 'Lambdas can have specific arrow annotations. diff --git a/tests/record-variant-tests.dx b/tests/record-variant-tests.dx index 871fa2941..1e0c259ea 100644 --- a/tests/record-variant-tests.dx +++ b/tests/record-variant-tests.dx @@ -48,7 +48,7 @@ Syntax for records, variants, and their types. x = {a=5.0, b=2} y : {a:Int & b:Int & ..._} = {a=3, a=4, ...x} y -> {a = 3, a = 4, a = 5.0, b = 2} +> {a = 3, a = 4, a = 5., b = 2} 'Variant (enum) types @@ -72,7 +72,7 @@ Syntax for records, variants, and their types. > {| a = 3 |} :p {| a | a = 3.0 |} : {a:Int | a:Float | a:Int} -> {|a| a = 3.0 |} +> {|a| a = 3. |} :t {| a | a = 3.0 |} : {a:Int | a:Float | a:Int} > {a: Int32 | a: Float32 | a: Int32} @@ -148,7 +148,7 @@ def getTwoFoosAndABar (rest : Fields)?-> (f1, f2, b) :p getTwoFoosAndABar {foo=1, bar=2, foo=0.0, foo=4, baz=3.0, bar=7} -> (1, (0.0, 2)) +> (1, (0., 2)) :p ({b=b, a=a1, a=a2}) = {a=1, b=2} @@ -183,7 +183,7 @@ x : {a:Int | a:Float | a:Int} = {| a | a = 3.0 |} foo = 1 bar = 2.0 {foo, bar} -> {bar = 2.0, foo = 1} +> {bar = 2., foo = 1} :p ({foo, ...}) = {foo=1, bar=2.0} @@ -207,7 +207,7 @@ x : {a:Int | a:Float | a:Int} = {| a | a = 3.0 |} {| a = x |} -> IToF x {| a | a = x |} -> x {| b = x |} -> IToF x -> 3.0 +> 3. 'Table values and imp lowering @@ -215,23 +215,23 @@ myRecordTable : (Fin 2)=>{a:Int & b:Float} = [{a=1, b=2.0}, {a=3, b=4.0}] :p myRecordTable -> [{a = 1, b = 2.0}, {a = 3, b = 4.0}] +> [{a = 1, b = 2.}, {a = 3, b = 4.}] :p for i:(Fin 2). ({a=a, b=b}) = myRecordTable.i {a=b, b=a} -> [{a = 2.0, b = 1}, {a = 4.0, b = 3}] +> [{a = 2., b = 1}, {a = 4., b = 3}] myVariantTable : (Fin 2)=>{a:Int | b:Float} = [{| a=1 |}, {| b=2.0 |}] :p myVariantTable -> [{| a = 1 |}, {| b = 2.0 |}] +> [{| a = 1 |}, {| b = 2. |}] :p for i:(Fin 2). v : {a:_ | b:_} = case myVariantTable.i of {| a=a |} -> {| b=a |} {| b=b |} -> {| a=b |} v -> [{| b = 1 |}, {| a = 2.0 |}] +> [{| b = 1 |}, {| a = 2. |}] -- Known variant, unused tail pattern :p @@ -240,7 +240,7 @@ myVariantTable : (Fin 2)=>{a:Int | b:Float} = [{| a=1 |}, {| b=2.0 |}] {| a = x |} -> 1.0 {| a | a = x |} -> x {|a|a| ..._ |} -> 5.0 -> 3.0 +> 3. -- Known variant, missing pattern :p @@ -248,7 +248,7 @@ myVariantTable : (Fin 2)=>{a:Int | b:Float} = [{| a=1 |}, {| b=2.0 |}] case x of {| a = x |} -> 1.0 {| a | a = x |} -> x -> 3.0 +> 3. -- Known variant, used tail pattern myVal = diff --git a/tests/repl-multiline-test-expected-output b/tests/repl-multiline-test-expected-output index 6ff96b0fe..129d35875 100644 --- a/tests/repl-multiline-test-expected-output +++ b/tests/repl-multiline-test-expected-output @@ -3,7 +3,7 @@ >=> >=> >=> ->=> ... ... ... ... 30.0 +>=> ... ... ... ... 30. >=> >=> ... ... >=> @@ -11,5 +11,5 @@ >=> >=> >=> ->=> 3.0 +>=> 3. >=> \ No newline at end of file diff --git a/tests/serialize-tests.dx b/tests/serialize-tests.dx index d35c66705..c4f27e80e 100644 --- a/tests/serialize-tests.dx +++ b/tests/serialize-tests.dx @@ -2,13 +2,13 @@ > 1 :p 1.0 -> 1.0 +> 1. :p [1, 2, 3] > [1, 2, 3] :p [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] -> [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] +> [[1., 2., 3.], [4., 5., 6.]] :p fromOrdinal (Fin 10) 7 > (7@Fin 10) @@ -32,10 +32,10 @@ x = ['a', 'b'] > {a = (AsList 4 "1234"), b = [1, 2, 3]} :p [{| a=1 |}, {| b=2.0 |}] : (Fin 2) => {a:Int | b:Float} -> [{| a = 1 |}, {| b = 2.0 |}] +> [{| a = 1 |}, {| b = 2. |}] :p {table = [{| a=1 |}, {| b=2.0 |}]} : {table: (Fin 2) => {a:Int | b:Float}} -> {table = [{| a = 1 |}, {| b = 2.0 |}]} +> {table = [{| a = 1 |}, {| b = 2. |}]} 'Values without a pretty-printer (currently shows warning message): diff --git a/tests/type-tests.dx b/tests/type-tests.dx index 396540e69..b9f1b1f0f 100644 --- a/tests/type-tests.dx +++ b/tests/type-tests.dx @@ -152,7 +152,7 @@ MyPair : Type -> Type = ((1, 2), (1.0, 2.0)) pairs -> ((1, 2), (1.0, 2.0)) +> ((1, 2), (1., 2.)) -- TODO: put source annotation on effect for a better message here @@ -367,7 +367,7 @@ def triRefIndex (ref:Ref h (i':n=>(..i')=>Float)) (i:n) : Ref h ((..i)=>Float) = -- There was a time when this wasn't possible, because checking mode would unify the -- input type with a non-dependent function type, leading to a later unification errors. id (for i:(Fin 2). for j:(..i). 1.0) -> [[1.0]@(..(0@Fin 2)), [1.0, 1.0]@(..(1@Fin 2))] +> [[1.]@(..(0@Fin 2)), [1., 1.]@(..(1@Fin 2))] def weakerInferenceReduction (l : i:n=>(..i)=>Float) (j:n): Unit = for i:(..j). diff --git a/tests/uexpr-tests.dx b/tests/uexpr-tests.dx index 7c76018b3..eb3f0c880 100644 --- a/tests/uexpr-tests.dx +++ b/tests/uexpr-tests.dx @@ -13,12 +13,12 @@ def returnFirstArg (a:Type) (b:Type) (x:a) (y:b) : a = x > 1 :p 1.0 + 2.0 -> 3.0 +> 3. def triple (x:Float) : Float = x + x + x :p triple 1.0 -> 3.0 +> 3. def idExplicit (a:Type) (x:a) : a = x @@ -39,7 +39,7 @@ idImplicit2 : (a:Type ?-> a -> a) = \x. x > 1 :p (\x y. x + y) 1.0 2.0 -> 3.0 +> 3. :p 1.0 + 1 > Type error: @@ -137,8 +137,8 @@ myPair = (1, 2.3) > [1, 2, 3, 4]@(Fin 2 & Fin 2) -:p sin 1.0 -> 0.84147096 +:p sin 1.01 +> 0.846832 :p (x,y) = (1,2) @@ -198,7 +198,7 @@ def passthrough (eff:Effects) ?-> (f:(a -> {|eff} b)) (x:a) : {|eff} b = f x > (((), 2), 1) :p (\f x y. f x y) (+) 1.0 2.0 -> 3.0 +> 3. :p myId = fst (\x. x, 2) @@ -220,23 +220,23 @@ def myOtherFst ((x, _):(a&b)) : a = x > 1 :p sum [1.,2.] -> 3.0 +> 3. :p xs = fanout _ 1.0 for i:(Fin 3). xs.i + xs.i -> [2.0, 2.0, 2.0] +> [2., 2., 2.] :p f = \x. x * x * x jvp f 2.0 1.5 -> 18.0 +> 18. :p f : Float --o Float = \x. 2.0 * (x + x) transposeLinear f 1.0 -> 4.0 +> 4. -- FIXME: This fails due to shadowing! --def transpose' (x:n=>m=>Float) --o : m=>n=>Float = for i j. x.j.i @@ -248,7 +248,7 @@ def myOtherFst ((x, _):(a&b)) : a = x f : Float --o (Fin 3=>Float) = \x. for i. x * 2.0 transposeLinear f [1.0, 2.0, 3.0] -> 12.0 +> 12. id'' : b -> b = id From 28b96bdb2df18fc6dabc6dbacbed8f37652f2aaf Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Tue, 29 Dec 2020 09:26:53 -0500 Subject: [PATCH 041/105] Add rejection sampler to tests and update it to use iteration combinator. --- examples/rejection-sampler.dx | 14 ++++++-------- makefile | 2 +- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/examples/rejection-sampler.dx b/examples/rejection-sampler.dx index e136a569c..b62ede490 100644 --- a/examples/rejection-sampler.dx +++ b/examples/rejection-sampler.dx @@ -3,11 +3,9 @@ 'We implement rejection sampling from a Binomial distribution using a uniform proposal. def rejectionSample (try: Key -> Maybe a) (k:Key) : a = - fromJust $ fst $ withState 0 \i. - snd $ withState Nothing \sample. - while (do isNothing (get sample)) do - i := get i + 1 - sample := try $ hash k (get i) + iter \i. case try $ hash k i of + Nothing -> Continue + Just x -> Done x Prob = Float LogProb = Float @@ -51,7 +49,7 @@ rejectionSamples = randVec numSamples (rejectionSample $ trySampleBinomial n p) def meanAndVariance (xs:n=>Float) : (Float&Float) = (mean xs, sq $ std xs) :p meanAndVariance $ map IToF rejectionSamples -> (3.9933999, 2.3585567) +> (3.9984, 2.361596) '## Alternative: Inversion sampling @@ -68,7 +66,7 @@ inversionSamples = randVec numSamples (binomialSample n p) k0 > [6, 7, 6, 5, 3, 2, 4, 4, 3, 4] :p meanAndVariance $ map IToF inversionSamples -> (3.9977999, 2.4097958) +> (3.9978, 2.409796) 'The following variant is guaranteed to evaluate the CDF only once. @@ -83,4 +81,4 @@ inversionBatchSamples = (binomialBatch n p k0) : Fin numSamples => Int > [6, 7, 6, 5, 3, 2, 4, 4, 3, 4] :p meanAndVariance $ map IToF inversionBatchSamples -> (3.9977999, 2.4097958) +> (3.9978, 2.409796) diff --git a/makefile b/makefile index a72e4123e..5eafc76bf 100644 --- a/makefile +++ b/makefile @@ -80,7 +80,7 @@ build-python: build # --- running tests --- # TODO: re-enable linear-tests ad-tests include-test chol -example-names = mandelbrot pi sierpinski \ +example-names = mandelbrot pi sierpinski rejection-sampler \ regression brownian_motion particle-swarm-optimizer \ ode-integrator mcmc ctc raytrace particle-filter \ isomorphisms ode-integrator linear_algebra fluidsim From be7de0537bb8cc138ed919d92602f7f0e3b62d8f Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sat, 26 Dec 2020 09:24:06 -0500 Subject: [PATCH 042/105] Improve HTML doc generation. Improve makefile: - Add support for building docs for files in lib/. - Add `mkdir -p doc` to commands so they do not fail if doc/ does not exist. - Add standalone `doc-prelude` target so that `make docs` is fully incremental. Improve prelude formatting: - Use consistent formatting (h2, ##) for section headers. --- README.md | 16 ++++++------- lib/prelude.dx | 46 ++++++++++++++++++-------------------- makefile | 23 +++++++++++++------ src/lib/RenderHtml.hs | 11 +++++++-- src/resources/Resources.hs | 9 ++++++-- 5 files changed, 62 insertions(+), 43 deletions(-) diff --git a/README.md b/README.md index 0b6e6fac3..c3c3462d9 100644 --- a/README.md +++ b/README.md @@ -12,14 +12,14 @@ To learn more, check out our or these example programs: * [Dex prelude](https://google-research.github.io/dex-lang/prelude.html) - * [Mandelbrot set](https://google-research.github.io/dex-lang/mandelbrot.html) - * [Ray tracer](https://google-research.github.io/dex-lang/raytrace.html) - * [Estimating pi](https://google-research.github.io/dex-lang/pi.html) - * [Hamiltonian Monte Carlo](https://google-research.github.io/dex-lang/mcmc.html) - * [ODE integrator](https://google-research.github.io/dex-lang/ode-integrator.html) - * [Sierpinski triangle](https://google-research.github.io/dex-lang/sierpinski.html) - * [Basis function regression](https://google-research.github.io/dex-lang/regression.html) - * [Brownian bridge](https://google-research.github.io/dex-lang/brownian_motion.html) + * [Mandelbrot set](https://google-research.github.io/dex-lang/examples/mandelbrot.html) + * [Ray tracer](https://google-research.github.io/dex-lang/examples/raytrace.html) + * [Estimating pi](https://google-research.github.io/dex-lang/examples/pi.html) + * [Hamiltonian Monte Carlo](https://google-research.github.io/dex-lang/examples/mcmc.html) + * [ODE integrator](https://google-research.github.io/dex-lang/oexamples/de-integrator.html) + * [Sierpinski triangle](https://google-research.github.io/dex-lang/examples/sierpinski.html) + * [Basis function regression](https://google-research.github.io/dex-lang/examples/regression.html) + * [Brownian bridge](https://google-research.github.io/dex-lang/examples/brownian_motion.html) ⚠️ Dex is an experimental research project at an early stage of development. Expect monstrous bugs and razor-sharp edges. Contributions welcome! ⚠️ diff --git a/lib/prelude.dx b/lib/prelude.dx index daf9ac32a..728855e58 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -1,9 +1,8 @@ - -'## Dex prelude +'# Dex prelude 'Runs before every Dex program unless an alternative is provided with `--prelude`. -'Wrappers around primitives +'## Wrappers around primitives Unit = %UnitType Type = %TyKind @@ -137,7 +136,7 @@ instance float64Fractional : Fractional Float64 where instance float32Fractional : Fractional Float32 where divide = \x:Float32 y:Float32. %fdiv x y -'Basic polymorphic functions and types +'## Basic polymorphic functions and types def (&) (a:Type) (b:Type) : Type = %PairType a b def (,) (x:a) (y:b) : (a & b) = %pair x y @@ -152,7 +151,7 @@ flip : (a -> b -> c) -> (b -> a -> c) = \f x y. f y x uncurry : (a -> b -> c) -> (a & b) -> c = \f (x,y). f x y const : a -> b -> a = \x _. x -'Vector spaces +'## Vector spaces data VSpace a:Type = MkVSpace (Add a) (Float -> a -> a) @@ -168,7 +167,7 @@ def neg (_:VSpace a) ?=> (v:a) : a = (-1.0) .* v @instance tabVS : VSpace a ?=> VSpace (n=>a) = MkVSpace tabAdd \s xs. for i. s .* xs.i @instance unitVS : VSpace Unit = MkVSpace unitAdd \s u. () -'Bool type +'## Boolean type data Bool = False @@ -192,7 +191,7 @@ def not (x:Bool) : Bool = x' = BToW8 x W8ToB $ %not x' -'Sum types +'## Sum types data Maybe a:Type = Nothing @@ -213,7 +212,7 @@ def select (p:Bool) (x:a) (y:a) : a = case p of def BToI (x:Bool) : Int = W8ToI $ BToW8 x def BToF (x:Bool) : Float = IToF (BToI x) -'Effects +'## Effects def Ref (r:Type) (a:Type) : Type = %Ref r a def get (ref:Ref h s) : {State h} s = %get ref @@ -252,7 +251,7 @@ def unsafeIO (f: Unit -> {State World|eff} a) : {|eff} a = def unreachable (():Unit) : a = unsafeIO do %throwError a -'Type classes +'## Type classes data Eq a:Type = MkEq (a -> a -> Bool) data Ord a:Type = MkOrd (Eq a) (a -> a -> Bool) (a -> a -> Bool) -- eq, gt, lt @@ -308,7 +307,7 @@ def tabEq (n:Type) ?-> (eqA: Eq a) ?=> : Eq (n=>a) = MkEq $ ref += (IToF (BToI (xs.i /= ys.i))) numDifferent == 0.0 -'Transcencendental functions +'## Transcencendental functions interface Floating a:Type where exp : a -> a @@ -384,7 +383,7 @@ instance float32Floating : Floating Float32 where pow = \x:Float32 y:Float32. %fpow x y lgamma = \x:Float32. %lgamma x -'Working with index sets +'## Index set utilities def Range (low:Int) (high:Int) : Type = %IntRange low high def Fin (n:Int) : Type = Range 0 n @@ -518,7 +517,7 @@ def withTabPtr (_:Storable a) ?=> def tabFromPtr (_:Storable a) ?=> (n:Type) -> (ptr:Ptr a) : {State World} n=>a = for i. load $ ptr +>> ordinal i -'Misc +'## Miscellaneous common utilities pi : Float = 3.141592653589793 @@ -559,7 +558,6 @@ def std (xs:n=>Float) : Float = sqrt $ mean (map sq xs) - sq (mean xs) def any (xs:n=>Bool) : Bool = reduce False (||) xs def all (xs:n=>Bool) : Bool = reduce True (&&) xs - def applyN (n:Int) (x:a) (f:a -> a) : a = snd $ withState x \ref. for _:(Fin n). ref := f (get ref) @@ -583,7 +581,7 @@ def dot (_:VSpace v) ?=> (s:n=>Float) (vs:n=>v) : v = sum for j. s.j .* vs.j def inner (x:n=>Float) (mat:n=>m=>Float) (y:m=>Float) : Float = fsum \(i,j). x.i * mat.i.j * y.j -'Functions for working with the pseudorandom number generator +'## Pseudorandom number generator utilities -- TODO: newtype Key = Int64 @@ -624,7 +622,7 @@ def cumSum (xs: n=>Float) : n=>Float = total := newTotal newTotal -'Automatic differentiation +'## Automatic differentiation -- TODO: add vector space constraints def linearize (f:a->b) (x:a) : (b & a --o b) = %linearize f x @@ -682,7 +680,7 @@ def checkDerivBase (f:Float->Float) (x:Float) : Bool = def checkDeriv (f:Float->Float) (x:Float) : Bool = checkDerivBase f x && checkDerivBase (deriv f) x -'Vector support +'## Vector support -- TODO: Reenable vector suport once fixed-width types are supported. -- def UNSAFEFromOrdinal (n : Type) (i : Int) : n = %unsafeAsIndex n i @@ -718,7 +716,7 @@ def checkDeriv (f:Float->Float) (x:Float) : Bool = -- @instance vectorFloatVSpace : VSpace VectorFloat = -- MkVSpace vectorFloatAdd \x v. broadcastVector x * v -'Tiling +'## Tiling functions def Tile (n : Type) (m : Type) : Type = %IndexSlice n m @@ -746,7 +744,7 @@ def tile1 (n : Type) ?-> (l : Type) ?-> (m : Type) ?-> -- arr.(t +> UNSAFEFromOrdinal idx 2) -- arr.(t +> UNSAFEFromOrdinal idx 3)) -'Monoid +'## Monoid typeclass interface Monoid a:Type where mempty : a @@ -754,7 +752,7 @@ interface Monoid a:Type where (<>) : Monoid a ?=> a -> a -> a = mcombine -'Length-erased lists +'## Length-erased lists data List a:Type = AsList n:Int foo:(Fin n => a) @@ -778,7 +776,7 @@ instance monoidList : Monoid (List a) where True -> xs.(unsafeFromOrdinal _ i') False -> ys.(unsafeFromOrdinal _ (i' - nx)) -'Isomorphisms +'## Isomorphisms data Iso a:Type b:Type = MkIso { fwd: a -> b & bwd: b -> a } @@ -945,7 +943,7 @@ instance showFloat64 : Show Float64 where -- pipe-like reverse function application def (|>) (x:a) (f: a -> b) : b = f x -'## Floating point helper functions +'## Floating-point helper functions def sign (x:Float) : Float = case x > 0.0 of @@ -1213,7 +1211,7 @@ def argmin (_:Ord o) ?=> (xs:n=>o) : n = def clip (_:Ord a) ?=> ((low,high):(a&a)) (x:a) : a = min high $ max low x -'## Trigonometric functions. +'## Trigonometric functions def atan_inner (x:Float) : Float = -- From "Computing accurate Horner form approximations to @@ -1255,7 +1253,6 @@ def atan2 (y:Float) (x:Float) : Float = def atan (x:Float) : Float = atan2 x 1.0 - '## Complex numbers data Complex = MkComplex Float Float -- real, imaginary @@ -1381,7 +1378,7 @@ def (>>) (x:Byte) (y:Int) : Byte = %shr x (IToW8 y) def (.|.) (x:Byte) (y:Byte) : Byte = %or x y def (.&.) (x:Byte) (y:Byte) : Byte = %and x y -'## Misc +'## Miscellaneous utilities def reverse (x:n=>a) : n=>a = s = size n @@ -1439,6 +1436,7 @@ def cumSumLow (xs: n=>Float) : n=>Float = oldTotal = get total total := oldTotal + xs.i oldTotal + -- cdf should include 0.0 but not 1.0 def categoricalFromCDF (cdf: n=>Float) (key: Key) : n = r = rand key diff --git a/makefile b/makefile index 5eafc76bf..2c827bb75 100644 --- a/makefile +++ b/makefile @@ -91,6 +91,8 @@ test-names = uexpr-tests adt-tests type-tests eval-tests show-tests \ record-variant-tests simple-include-test \ typeclass-tests complex-tests trig-tests +lib-names = diagram plot png + all-names = $(test-names:%=tests/%) $(example-names:%=examples/%) quine-test-targets = $(all-names:%=run-%) @@ -98,7 +100,9 @@ quine-test-targets = $(all-names:%=run-%) update-test-targets = $(test-names:%=update-tests-%) update-example-targets = $(example-names:%=update-examples-%) -doc-names = $(example-names:%=doc/%.html) +doc-example-names = $(example-names:%=doc/examples/%.html) + +doc-lib-names = $(lib-names:%=doc/lib/%.html) tests: quine-tests repl-test export-tests @@ -172,16 +176,21 @@ bench-summary: # --- building docs --- -slow-docs = doc/mnist-nearest-neighbors.html +slow-docs = doc/examples/mnist-nearest-neighbors.html -docs: doc/style.css $(doc-names) $(slow-docs) - $(dex) --prelude /dev/null script prelude.dx --html > doc/prelude.html +docs: doc-prelude $(doc-example-names) $(doc-lib-names) $(slow-docs) -doc/%.html: examples/%.dx +doc-prelude: lib/prelude.dx + mkdir -p doc + $(dex) --prelude /dev/null script lib/prelude.dx --outfmt HTML > doc/prelude.html + +doc/examples/%.html: examples/%.dx + mkdir -p doc/examples $(dex) script $^ --outfmt HTML > $@ -doc/%.css: static/%.css - cp $^ $@ +doc/lib/%.html: lib/%.dx + mkdir -p doc/lib + $(dex) script $^ --outfmt HTML > $@ clean: $(STACK) clean diff --git a/src/lib/RenderHtml.hs b/src/lib/RenderHtml.hs index cfd03ad15..e63aaa0df 100644 --- a/src/lib/RenderHtml.hs +++ b/src/lib/RenderHtml.hs @@ -12,6 +12,7 @@ module RenderHtml (pprintHtml, progHtml, ToMarkup) where import Text.Blaze.Html5 as H hiding (map) import Text.Blaze.Html5.Attributes as At import Text.Blaze.Html.Renderer.String +import Data.Char (isSpace) import Data.Text (pack) import CMark (commonmarkToHtml) @@ -19,10 +20,11 @@ import Control.Monad import Text.Megaparsec hiding (chunk) import Text.Megaparsec.Char as C +import Resources (cssSource) import Syntax import PPrint import Parser -import Serialize() +import Serialize () pprintHtml :: ToMarkup a => a -> String pprintHtml x = renderHtml $ toMarkup x @@ -31,10 +33,15 @@ progHtml :: LitProg -> String progHtml blocks = renderHtml $ wrapBody $ map toHtmlBlock blocks where toHtmlBlock (block,result) = toMarkup block <> toMarkup result +-- Minifies the given CSS snippet. +-- Currently, this simply removes all whitespace. +minifyCSS :: String -> String +minifyCSS = filter (not . isSpace) + wrapBody :: [Html] -> Html wrapBody blocks = docTypeHtml $ do H.head $ do - H.link ! rel "stylesheet" ! href "style.css" ! type_ "text/css" + H.style ! type_ "text/css" $ toHtml $ minifyCSS cssSource H.meta ! charset "UTF-8" H.body $ H.div inner ! At.id "main-output" where inner = foldMap (cdiv "cell") blocks diff --git a/src/resources/Resources.hs b/src/resources/Resources.hs index a21e98b29..d834767e1 100644 --- a/src/resources/Resources.hs +++ b/src/resources/Resources.hs @@ -1,6 +1,6 @@ {-# LANGUAGE TemplateHaskell #-} -module Resources (dexrtBC, preludeSource, curResourceVersion) where +module Resources (dexrtBC, preludeSource, cssSource, curResourceVersion) where import qualified Data.ByteString.Char8 as B import Data.FileEmbed @@ -11,5 +11,10 @@ curResourceVersion = __TIME__ dexrtBC :: B.ByteString dexrtBC = $(embedFile "src/lib/dexrt.bc") +-- The Dex prelude source code. preludeSource :: String -preludeSource = B.unpack $ $(embedFile "lib/prelude.dx") +preludeSource = B.unpack $(embedFile "lib/prelude.dx") + +-- The source code of the CSS used for rendering Dex programs as HTML. +cssSource :: String +cssSource = B.unpack $(embedFile "static/style.css") \ No newline at end of file From ab33856f83ed62204ac2d21d55d1687af3103dc4 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Tue, 29 Dec 2020 10:17:28 -0500 Subject: [PATCH 043/105] Remove css minification. --- src/lib/RenderHtml.hs | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/lib/RenderHtml.hs b/src/lib/RenderHtml.hs index e63aaa0df..6a2308a66 100644 --- a/src/lib/RenderHtml.hs +++ b/src/lib/RenderHtml.hs @@ -12,7 +12,6 @@ module RenderHtml (pprintHtml, progHtml, ToMarkup) where import Text.Blaze.Html5 as H hiding (map) import Text.Blaze.Html5.Attributes as At import Text.Blaze.Html.Renderer.String -import Data.Char (isSpace) import Data.Text (pack) import CMark (commonmarkToHtml) @@ -33,15 +32,10 @@ progHtml :: LitProg -> String progHtml blocks = renderHtml $ wrapBody $ map toHtmlBlock blocks where toHtmlBlock (block,result) = toMarkup block <> toMarkup result --- Minifies the given CSS snippet. --- Currently, this simply removes all whitespace. -minifyCSS :: String -> String -minifyCSS = filter (not . isSpace) - wrapBody :: [Html] -> Html wrapBody blocks = docTypeHtml $ do H.head $ do - H.style ! type_ "text/css" $ toHtml $ minifyCSS cssSource + H.style ! type_ "text/css" $ toHtml cssSource H.meta ! charset "UTF-8" H.body $ H.div inner ! At.id "main-output" where inner = foldMap (cdiv "cell") blocks From 3467bd0717cafc9515050b293dbe43b5bf004749 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Tue, 29 Dec 2020 12:09:26 -0500 Subject: [PATCH 044/105] Add a `getEnv` wrapper around POSIX `getenv`. Also added casts between pointers and Int64 so we can check for null pointers. --- lib/prelude.dx | 37 +++++++++++++++++++++++++++++++++++-- makefile | 2 ++ src/lib/Imp.hs | 2 ++ src/lib/JIT.hs | 3 +++ src/lib/Type.hs | 2 ++ tests/io-tests.dx | 6 ++++++ 6 files changed, 50 insertions(+), 2 deletions(-) diff --git a/lib/prelude.dx b/lib/prelude.dx index 728855e58..e794d5e64 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -18,6 +18,8 @@ Word8 = %Word8 Byte = Word8 Char = Byte +RawPtr : Type = %Word8Ptr + Int = Int32 Float = Float32 @@ -35,6 +37,8 @@ def IToI32 (x : Int) : Int32 = internalCast _ x def IToW8 (x : Int) : Word8 = internalCast _ x def IToF (x:Int) : Float = internalCast _ x def FToI (x:Float) : Int = internalCast _ x +def I64ToRawPtr (x:Int64 ) : RawPtr = internalCast _ x +def RawPtrToI64 (x:RawPtr) : Int64 = internalCast _ x interface Add a:Type where add : a -> a -> a @@ -274,6 +278,7 @@ def (>=) (d:Ord a) ?=> (x:a) (y:a) : Bool = x>y || x==y @instance word8Eq : Eq Word8 = MkEq \x:Word8 y:Word8. W8ToB $ %ieq x y @instance boolEq : Eq Bool = MkEq \x y. BToW8 x == BToW8 y @instance unitEq : Eq Unit = MkEq \x y. True +@instance rawPtrEq : Eq RawPtr = MkEq \x y. RawPtrToI64 x == RawPtrToI64 y @instance float64Ord : Ord Float64 = (MkOrd float64Eq (\x y. W8ToB $ %fgt x y) (\x y. W8ToB $ %flt x y)) @@ -402,8 +407,6 @@ def finOrd (n:Int) ?-> : Ord (Fin n) = '## Raw pointer operations -RawPtr : Type = %Word8Ptr - data Ptr a:Type = MkPtr RawPtr -- Is there a better way to select the right instance for `storageSize`?? @@ -903,6 +906,9 @@ def loadDynBuffer (_:Storable a) ?=> (size, _, bufPtr) = load dbPtr AsList size $ tabFromPtr _ bufPtr +def pushDynBuffer (_:Storable a) ?=> + (buf: DynBuffer a) (x:a) : {State World} Unit = + extendDynBuffer buf $ AsList _ [x] '## Strings and Characters @@ -975,6 +981,17 @@ def either_is_nan (x:Float) (y:Float) : Bool = (isnan x) || (isnan y) FilePath : Type = String data CString = MkCString RawPtr +def nullRawPtr : RawPtr = I64ToRawPtr $ IToI64 0 + +def fromNullableRawPtr (ptr:RawPtr) : Maybe (Ptr a) = + if ptr == nullRawPtr + then Nothing + else Just $ MkPtr ptr + +def cStringPtr (s:CString) : Maybe (Ptr Char) = + (MkCString ptr) = s + fromNullableRawPtr ptr + data StreamMode = ReadMode WriteMode @@ -1046,6 +1063,22 @@ def boundedIter (maxIters:Int) (fallback:a) then Done fallback else body i +def fromCString (s:CString) : {State World} (Maybe String) = + case cStringPtr s of + Nothing -> Nothing + Just ptr -> + Just $ withDynamicBuffer \buf. iter \i. + c = load $ ptr +>> i + if c == '\NUL' + then Done $ loadDynBuffer buf + else + pushDynBuffer buf c + Continue + +def getEnv (name:String) : {State World} Maybe String = + withCString name \(MkCString ptr). + fromCString $ MkCString $ %ffi getenv RawPtr ptr + def fread (stream:Stream ReadMode) : {State World} String = (MkStream stream') = stream -- TODO: allow reading longer files! diff --git a/makefile b/makefile index 2c827bb75..1bba733f3 100644 --- a/makefile +++ b/makefile @@ -111,6 +111,8 @@ quine-tests: $(quine-test-targets) quine-tests-interp: runinterp-eval-tests runinterp-ad-tests-interp runinterp-interp-tests run-%: export DEX_ALLOW_CONTRACTIONS=0 +run-%: export DEX_TEST_MODE=t + run-tests/%: tests/%.dx build misc/check-quine $< $(dex) script --allow-errors run-examples/%: examples/%.dx build diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index d2b655818..c50b14bfa 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -1217,6 +1217,8 @@ instrTypeChecked instr = case instr of case (dt, st) of (PtrType _, PtrType _) -> return () (Scalar _, Scalar _) -> return () + (Scalar Int64Type, PtrType _) -> return () + (PtrType _, Scalar Int64Type) -> return () _ -> throw CompilerErr $ "Can't cast " ++ pprint st ++ " to " ++ pprint dt return dt diff --git a/src/lib/JIT.hs b/src/lib/JIT.hs index 5261b519b..43ab475dd 100644 --- a/src/lib/JIT.hs +++ b/src/lib/JIT.hs @@ -304,6 +304,9 @@ compileInstr instr = case instr of (L.FloatingPointType _, L.IntegerType _) -> emitInstr dt $ L.FPToSI x dt [] (L.IntegerType _, L.FloatingPointType _) -> emitInstr dt $ L.SIToFP x dt [] (L.PointerType _ _, L.PointerType eltTy _) -> castLPtr eltTy x + (L.IntegerType 64 , ptrTy@(L.PointerType _ _)) -> + emitInstr ptrTy $ L.IntToPtr x ptrTy [] + (L.PointerType _ _, L.IntegerType 64) -> emitInstr i64 $ L.PtrToInt x i64 [] _ -> error $ "Unsupported cast" ICall f@(fname:> IFunType cc argTys resultTys) args -> do -- TODO: consider having a separate calling convention specification rather diff --git a/src/lib/Type.hs b/src/lib/Type.hs index 88b7f61c9..76c8aa25c 100644 --- a/src/lib/Type.hs +++ b/src/lib/Type.hs @@ -648,6 +648,8 @@ checkFloatBaseType allowVector t = case t of checkValidCast :: Type -> Type -> TypeM () checkValidCast (BaseTy (PtrType _)) (BaseTy (PtrType _)) = return () +checkValidCast (BaseTy (PtrType _)) (BaseTy (Scalar Int64Type)) = return () +checkValidCast (BaseTy (Scalar Int64Type)) (BaseTy (PtrType _)) = return () checkValidCast sourceTy destTy = checkScalarType sourceTy >> checkScalarType destTy where diff --git a/tests/io-tests.dx b/tests/io-tests.dx index 55fd74ed8..d14ba4874 100644 --- a/tests/io-tests.dx +++ b/tests/io-tests.dx @@ -69,3 +69,9 @@ unsafeIO \(). (AsList _ s') = readFile fname sum (for i. W8ToI s.i) == sum (for i. W8ToI s'.i) > True + +:p unsafeIO do getEnv "NOT_AN_ENV_VAR" +> Nothing + +:p unafeIO do getEnv "DEX_TEST_MODE" +> (Just (AsList 1 "t")) From 4ee5f7a7805a958fafb237aea458425b331281e2 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Tue, 29 Dec 2020 13:19:52 -0500 Subject: [PATCH 045/105] Use `DEX_TEST_MODE` flag to choose different params when testing. This should let us turn all the examples into tests since they can now run in reasonable time. Includes a workaround for a bug that feels like the same one as #348. --- examples/mcmc.dx | 8 +++++--- examples/raytrace.dx | 4 +++- lib/prelude.dx | 11 +++++++++++ makefile | 1 + tests/io-tests.dx | 5 ++++- 5 files changed, 24 insertions(+), 5 deletions(-) diff --git a/examples/mcmc.dx b/examples/mcmc.dx index 1cf4229c0..44ff113b0 100644 --- a/examples/mcmc.dx +++ b/examples/mcmc.dx @@ -88,14 +88,16 @@ def myLogProb (x:(Fin 2)=>Float) : LogProb = x' = x - [1.5, 2.5] neg $ 0.5 * inner x' [[1.,0.],[0.,20.]] x' -numSamples = 10000 +numSamples = if dex_test_mode () + then 1000 + else 10000 k0 = newKey 1 mhParams = 0.1 mhSamples = runChain randnVec (mhStep mhParams myLogProb) numSamples k0 :p meanAndCovariance mhSamples -> ([1.51656, 2.493105], [[1.037397, 0.011821], [0.011821, 0.053776]]) +> ([0.369159, 2.453517], [[0.575722, 0.08787], [0.08787, 0.125873]]) :html showPlot $ yPlot $ slice (map head mhSamples) 0 (Fin 1000) @@ -105,7 +107,7 @@ hmcParams = (10, 0.1) hmcSamples = runChain randnVec (hmcStep hmcParams myLogProb) numSamples k0 :p meanAndCovariance hmcSamples -> ([1.50457, 2.500021], [[0.973867, 0.003423], [0.003423, 0.050586]]) +> ([1.431633, 2.503093], [[0.964188, 0.005688], [0.005688, 0.049492]]) :html showPlot $ yPlot $ slice (map head hmcSamples) 0 (Fin 1000) diff --git a/examples/raytrace.dx b/examples/raytrace.dx index 45fea7b6b..2f9ee3601 100644 --- a/examples/raytrace.dx +++ b/examples/raytrace.dx @@ -308,8 +308,10 @@ defaultCamera = { numPix = 250 , sensorDist = 1.0 } -- We change to a small num pix here to reduce the compute needed for tests -camera = defaultCamera |> setAt #numPix 10 params = defaultParams +camera = if dex_test_mode () + then defaultCamera |> setAt #numPix 10 + else defaultCamera -- %time (MkImage _ _ image) = takePicture params theScene camera diff --git a/lib/prelude.dx b/lib/prelude.dx index e794d5e64..1785f0f0e 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -205,6 +205,8 @@ def isNothing (x:Maybe a) : Bool = case x of Nothing -> True Just _ -> False +def isJust (x:Maybe a) : Bool = not $ isNothing x + data (|) a:Type b:Type = Left a Right b @@ -1079,6 +1081,13 @@ def getEnv (name:String) : {State World} Maybe String = withCString name \(MkCString ptr). fromCString $ MkCString $ %ffi getenv RawPtr ptr +def checkEnv (name:String) : {State World} Bool = + -- This should be just `isJust $ getEnv name` but that sefaults (only if the + -- env var *is* defined), possibly related to bug #348. + withCString name \(MkCString ptr). + resultPtr = %ffi getenv RawPtr ptr + not $ resultPtr == nullRawPtr + def fread (stream:Stream ReadMode) : {State World} String = (MkStream stream') = stream -- TODO: allow reading longer files! @@ -1510,3 +1519,5 @@ def softmax (x: n=>Float) : n=>Float = def evalpoly (_:VSpace v) ?=> (coefficients:n=>v) (x:Float) : v = -- Evaluate a polynomial at x. Same as Numpy's polyval. fold zero \i c. coefficients.i + x .* c + +def dex_test_mode (():Unit) : Bool = unsafeIO do checkEnv "DEX_TEST_MODE" diff --git a/makefile b/makefile index 1bba733f3..a304dcce9 100644 --- a/makefile +++ b/makefile @@ -123,6 +123,7 @@ prop-tests: cbits/libdex.so $(STACK) test $(PROF) update-%: export DEX_ALLOW_CONTRACTIONS=0 +update-%: export DEX_TEST_MODE=t update-all: $(update-test-targets) $(update-example-targets) diff --git a/tests/io-tests.dx b/tests/io-tests.dx index d14ba4874..736c07ff4 100644 --- a/tests/io-tests.dx +++ b/tests/io-tests.dx @@ -73,5 +73,8 @@ unsafeIO \(). :p unsafeIO do getEnv "NOT_AN_ENV_VAR" > Nothing -:p unafeIO do getEnv "DEX_TEST_MODE" +:p unsafeIO do getEnv "DEX_TEST_MODE" > (Just (AsList 1 "t")) + +:p dex_test_mode () +> True From ed270a113c3243405ca30ad3e5147f97e4b2dcab Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 23 Dec 2020 23:08:05 +0000 Subject: [PATCH 046/105] Add Text support to Diagram --- lib/diagram.dx | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/lib/diagram.dx b/lib/diagram.dx index a10f6cc08..29e188ab3 100644 --- a/lib/diagram.dx +++ b/lib/diagram.dx @@ -7,6 +7,7 @@ data Geom = Circle Float Rectangle Float Float -- width, height Line Point + Text String -- HTML color (no alpha) -- TODO: replace with `Fin 3 => Word8` when we fix #348 @@ -62,6 +63,7 @@ flipY : Diagram -> Diagram = Circle r -> Circle r Rectangle w h -> Rectangle w h Line (x, y) -> Line (x, -y) + Text x -> Text x def scale (s:Float) : (Diagram -> Diagram) = applyTransformation ( \(x,y). (s * x, s * y) ) \geom. case geom of @@ -69,6 +71,7 @@ def scale (s:Float) : (Diagram -> Diagram) = Circle r -> Circle (s * r) Rectangle w h -> Rectangle (s * w) (s * h) Line (x, y) -> Line (s * x, s * y) + Text x -> Text x def moveXY ((offX, offY) : Point) : (Diagram -> Diagram) = applyTransformation (\(x,y). (x + offX, y + offY) ) id @@ -80,6 +83,7 @@ def pointDiagram : Diagram = singletonDefault PointGeom def circle (r:Float) : Diagram = singletonDefault $ Circle r def rect (w:Float) (h:Float) : Diagram = singletonDefault $ Rectangle w h def line (p:Point) : Diagram = singletonDefault $ Line p +def text (x:String) : Diagram = singletonDefault $ Text x def updateGeom (update: GeomStyle -> GeomStyle) (d:Diagram) : Diagram = (MkDiagram (AsList _ objs)) = d @@ -142,11 +146,14 @@ def attrString (attr:GeomStyle) : String = <+> ("stroke-width" <=> (getAt #strokeWidth attr))) def renderGeom (attr:GeomStyle) ((x,y):Point) (geom:Geom) : String = + -- For things that are solid. SVG says they have fill=stroke. + solidAttr = setAt #fillColor (getAt #strokeColor attr) attr + groupEle = \attr. tagBracketsAttr "g" (attrString attr) case geom of PointGeom -> pointAttr = setAt #fillColor (getAt #strokeColor attr) attr - groupEle pointAttr $ selfClosingBrackets $ + groupEle solidAttr $ selfClosingBrackets $ ("circle" <+> "cx" <=> x <.> "cy" <=> y <.> @@ -164,6 +171,11 @@ def renderGeom (attr:GeomStyle) ((x,y):Point) (geom:Geom) : String = "height" <=> h <.> "x" <=> (x - (w/2.0)) <.> "y" <=> (y - (h/2.0))) + Text content -> + textEle = tagBracketsAttr "text" $ + ("x" <=> x <.> + "y" <=> y) + groupEle solidAttr $ textEle content BoundingBox : Type = (Point & Point) @@ -189,10 +201,12 @@ moveX : Float -> Diagram -> Diagram = \x. moveXY (x, 0.0) moveY : Float -> Diagram -> Diagram = \y. moveXY (0.0, y) -- mydiagram : Diagram = --- ( (circle 7.0 |> moveXY (20.0, 20.0) |> setFillColor blue |> setStrokeColor red) --- <> (circle 5.0 |> moveXY (40.0, 40.0)) --- <> (rect 10.0 20.0 |> moveXY (5.0, 10.0) |> setStrokeColor red) --- <> (pointDiagram |> moveXY (15.0, 5.0) |> setStrokeColor red) --- ) +-- ( (circle 7.0 |> moveXY (20.0, 20.0) |> setFillColor blue |> setStrokeColor red) +-- <> (circle 5.0 |> moveXY (40.0, 40.0)) +-- <> (rect 10.0 20.0 |> moveXY (5.0, 10.0) |> setStrokeColor red) +-- <> (text "DexLang" |> moveXY (30.0, 10.0) |> setStrokeColor green) +-- <> (pointDiagram |> moveXY (15.0, 5.0) |> setStrokeColor red) +-- ) -- :html renderSVG mydiagram ((0.0, 0.0), (100.0, 50.0)) + From d4d6ec5f0140399885184d724f315666b86c668d Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 29 Dec 2020 19:15:59 +0000 Subject: [PATCH 047/105] center text element --- lib/diagram.dx | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lib/diagram.dx b/lib/diagram.dx index 29e188ab3..d0577ff4a 100644 --- a/lib/diagram.dx +++ b/lib/diagram.dx @@ -174,7 +174,10 @@ def renderGeom (attr:GeomStyle) ((x,y):Point) (geom:Geom) : String = Text content -> textEle = tagBracketsAttr "text" $ ("x" <=> x <.> - "y" <=> y) + "y" <=> y <.> + "text-anchor" <=> "middle" <.> -- horizontal center + "dominant-baseline" <=> "middle" -- vertical center + ) groupEle solidAttr $ textEle content BoundingBox : Type = (Point & Point) @@ -209,4 +212,3 @@ moveY : Float -> Diagram -> Diagram = \y. moveXY (0.0, y) -- ) -- :html renderSVG mydiagram ((0.0, 0.0), (100.0, 50.0)) - From db019ee27c51110725fcaa714b306ca9b130ce48 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 29 Dec 2020 19:16:25 +0000 Subject: [PATCH 048/105] Made diagram demos literate comments --- lib/diagram.dx | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/lib/diagram.dx b/lib/diagram.dx index d0577ff4a..4e91ddb9d 100644 --- a/lib/diagram.dx +++ b/lib/diagram.dx @@ -203,12 +203,24 @@ def renderSVG (d:Diagram) (bounds:BoundingBox) : String = moveX : Float -> Diagram -> Diagram = \x. moveXY (x, 0.0) moveY : Float -> Diagram -> Diagram = \y. moveXY (0.0, y) --- mydiagram : Diagram = --- ( (circle 7.0 |> moveXY (20.0, 20.0) |> setFillColor blue |> setStrokeColor red) --- <> (circle 5.0 |> moveXY (40.0, 40.0)) --- <> (rect 10.0 20.0 |> moveXY (5.0, 10.0) |> setStrokeColor red) --- <> (text "DexLang" |> moveXY (30.0, 10.0) |> setStrokeColor green) --- <> (pointDiagram |> moveXY (15.0, 5.0) |> setStrokeColor red) --- ) - --- :html renderSVG mydiagram ((0.0, 0.0), (100.0, 50.0)) +' A Demo showing all kind of features +``` +mydiagram : Diagram = + ( (circle 7.0 |> moveXY (20.0, 20.0) |> setFillColor blue |> setStrokeColor red) + <> (circle 5.0 |> moveXY (40.0, 40.0)) + <> (rect 10.0 20.0 |> moveXY (5.0, 10.0) |> setStrokeColor red) + <> (text "DexLang" |> moveXY (30.0, 10.0) |> setStrokeColor green) + <> (pointDiagram |> moveXY (15.0, 5.0) |> setStrokeColor red) + ) +:html renderSVG mydiagram ((0.0, 0.0), (100.0, 50.0)) +``` + +' Another demo that shows things are all center aligned: +``` +concentricDiagram : Diagram = ( + (rect 2.0 2.0 |> setFillColor red) + <> (circle 1.0 |> setFillColor blue) + <> (text "DexLang" |> setStrokeColor white) +) |> moveXY (5.0, 5.0) +:html renderSVG concentricDiagram ((0.0, 0.0), (10.0, 10.0)) +``` From 078e7dcb6b63a073e560e3a0133ee853c1d142b5 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Tue, 29 Dec 2020 13:49:47 -0500 Subject: [PATCH 049/105] Delete stale tests --- makefile | 17 +- tests/GenExpr.hs | 366 ----------------------------------- tests/PropTests.hs | 52 ----- tests/TestPass.hs | 76 -------- tests/actor-test.hs | 69 ------- tests/ad-tests-interp.dx | 56 ------ tests/flop-tests.dx | 33 ---- tests/include-test.dx | 109 ----------- tests/included.dx | 8 - tests/interp-tests.dx | 18 -- tests/jax-tests.dx | 28 --- tests/simple-include-test.dx | 7 - tests/somedata.dxo | 1 - tests/web-tests.dx | 12 -- 14 files changed, 2 insertions(+), 850 deletions(-) delete mode 100644 tests/GenExpr.hs delete mode 100644 tests/PropTests.hs delete mode 100644 tests/TestPass.hs delete mode 100644 tests/actor-test.hs delete mode 100644 tests/ad-tests-interp.dx delete mode 100644 tests/flop-tests.dx delete mode 100644 tests/include-test.dx delete mode 100644 tests/included.dx delete mode 100644 tests/interp-tests.dx delete mode 100644 tests/jax-tests.dx delete mode 100644 tests/simple-include-test.dx delete mode 100644 tests/somedata.dxo delete mode 100644 tests/web-tests.dx diff --git a/makefile b/makefile index a304dcce9..8f2ed6176 100644 --- a/makefile +++ b/makefile @@ -88,8 +88,7 @@ example-names = mandelbrot pi sierpinski rejection-sampler \ test-names = uexpr-tests adt-tests type-tests eval-tests show-tests \ shadow-tests monad-tests io-tests \ ad-tests parser-tests serialize-tests \ - record-variant-tests simple-include-test \ - typeclass-tests complex-tests trig-tests + record-variant-tests typeclass-tests complex-tests trig-tests lib-names = diagram plot png @@ -104,12 +103,10 @@ doc-example-names = $(example-names:%=doc/examples/%.html) doc-lib-names = $(lib-names:%=doc/lib/%.html) -tests: quine-tests repl-test export-tests +tests: quine-tests repl-test quine-tests: $(quine-test-targets) -quine-tests-interp: runinterp-eval-tests runinterp-ad-tests-interp runinterp-interp-tests - run-%: export DEX_ALLOW_CONTRACTIONS=0 run-%: export DEX_TEST_MODE=t @@ -144,16 +141,6 @@ update-gpu-tests: tests/gpu-tests.dx build $(dex) --backend LLVM-CUDA script --allow-errors $< > $<.tmp mv $<.tmp $< -export-tests: export-test-scalar export-test-array - -export-test-%: build - $(dex) export examples/export/$*.dx examples/export/$*.o - $(CXX) -std=c++11 examples/export/$*.o examples/export/$*.cpp -o examples/export/$* - examples/export/$* - -jax-tests: build - misc/check-quine examples/jax-tests.dx $(dex) --backend JAX script - uexpr-tests: misc/check-quine examples/uexpr-tests.dx $(dex) script diff --git a/tests/GenExpr.hs b/tests/GenExpr.hs deleted file mode 100644 index 53e8a0aef..000000000 --- a/tests/GenExpr.hs +++ /dev/null @@ -1,366 +0,0 @@ --- Copyright 2019 Google LLC --- --- Use of this source code is governed by a BSD-style --- license that can be found in the LICENSE file or at --- https://developers.google.com/open-source/licenses/bsd - -module GenExpr (sampleExpr, defaultGenOptions, GenOptions (..) - , testSampleExpr, testSample, makeGenEnv, genSourceBlock, - genUTopDecl, TypeEnv) where - -import Control.Monad -import Control.Monad.Reader -import GHC.Float -import Hedgehog hiding (Var, Command) -import Hedgehog.Internal.Shrink (towards) -import qualified Hedgehog.Gen as Gen -import qualified Hedgehog.Range as Range -import qualified Data.Map.Strict as M -import qualified Data.Set as S -import Lens.Micro.Platform -import Data.Text.Prettyprint.Doc -import Data.String - -import Record -import Env -import Syntax -import PPrint - -testSample :: (Pretty a) => TypeEnv -> GenM a -> Range.Size -> IO () -testSample env m s = - Gen.sample (runReaderT (Gen.resize s (pprint <$> m)) - (makeGenEnv env defaultGenOptions)) - >>= putStrLn - -testSampleExpr :: Int -> IO () -testSampleExpr n = testSample mempty sampleExpr (fromIntegral n) - - --- Variable names associated with a type -type TypeEnv = M.Map SigmaType [Name] - --- Variable names in scope -type ScopeEnv = S.Set Name - -data GenOptions = GenOptions { - tableSize :: Int - , numberSize :: Int - , tupleSize :: Int - , returnTypePref :: Int - } - deriving (Show, Eq, Ord) - -defaultGenOptions :: GenOptions -defaultGenOptions = GenOptions { - tableSize = 10 - , numberSize = 10 - , tupleSize = 5 - , returnTypePref = 2 - } - -data GenEnv = GenEnv { - typeEnv :: TypeEnv - , scopeEnv :: ScopeEnv - , optsEnv :: GenOptions} - deriving (Show, Eq, Ord) - - -makeGenEnv :: TypeEnv -> GenOptions -> GenEnv -makeGenEnv te opts = GenEnv te (S.fromList (concat (M.elems te))) opts - --- lens -typeEnvL :: Lens' GenEnv TypeEnv -typeEnvL = lens typeEnv (\e t -> e{typeEnv = t}) -scopeEnvL :: Lens' GenEnv ScopeEnv -scopeEnvL = lens scopeEnv (\e s -> e{scopeEnv = s}) -optsEnvL :: Lens' GenEnv GenOptions -optsEnvL = lens optsEnv (\e s -> e{optsEnv = s}) -tableSizeL :: Lens' GenEnv Int -tableSizeL = optsEnvL . lens tableSize (\e t -> e{tableSize=t}) -numberSizeL :: Lens' GenEnv Int -numberSizeL = optsEnvL . lens numberSize (\e t -> e{numberSize = t}) -tupleSizeL :: Lens' GenEnv Int -tupleSizeL = optsEnvL . lens tupleSize (\e t -> e{tupleSize = t}) -returnTypePrefL :: Lens' GenEnv Int -returnTypePrefL = optsEnvL . lens returnTypePref (\e t -> e{returnTypePref = t}) - --- utils -setBinding :: (Name, SigmaType) -> GenEnv -> GenEnv -setBinding (v, ty) = - over (typeEnvL . at ty) ((Just [v]) `mappend`) . over scopeEnvL (S.insert v) -setBinding' :: (Name, Type) -> GenEnv -> GenEnv -setBinding' (v, ty) = setBinding (v, Forall [] ty) -setBindings' :: [(Name, Type)] -> GenEnv -> GenEnv -setBindings' vs = foldl (.) id (setBinding' <$> vs) -withBindings :: [(Name, Type)] -> GenM a -> GenM a -withBindings vs g = local (setBindings' vs) g -notShadowed :: Name -> GenM Bool -notShadowed n = view (scopeEnvL . to (S.notMember n)) - -genUntil :: MonadGen m => (a -> m Bool) -> m a -> m a -genUntil f gen = do - x <- gen - isValid <- f x - if isValid then return x else genUntil f gen - -small :: MonadGen m => m a -> m a -small = Gen.scale (`div` 2) - -genSized :: MonadGen m => m a -> m a -> m a -genSized leaf tree = Gen.sized (\n -> if n == 0 then leaf else tree) - -element :: MonadGen m => [a] -> m a -element = Gen.prune . Gen.element - -prefer :: MonadGen m => Int -> m a -> m a -> m a -prefer w p r = Gen.prune (Gen.frequency [(w, p), (1, r)]) - -type GenM a = ReaderT GenEnv Gen a - -allTypes :: Type -> [Type] -allTypes ty = ty : case ty of - ArrType _ t1 t2 -> allTypes t1 ++ allTypes t2 - TabType _ t -> allTypes t - RecType _ ~(Tup ts) -> concatMap allTypes ts - _ -> [] - -preferReturnType :: Type -> GenM Type -> GenM Type -preferReturnType ty b = view returnTypePrefL >>= (\n -> prefer n (element (allTypes ty)) b) - - --- type utils -checkData :: Type -> Bool -checkData ty = case ty of - BaseType _ -> True - TabType _ a -> checkData a - RecType _ r -> all checkData r - IdxSetLit _ -> True - _ -> False - - --- | TODO: StrType -genBaseType :: GenM BaseType -genBaseType = element [IntType, BoolType, RealType] - - -genRecTypeWith :: GenM a -> GenM (Record a) -genRecTypeWith g = Tup <$> record - where - record = view tupleSizeL >>= \n -> Gen.list (Range.linear 2 n) (small g) - -genRecType :: GenM (Record Type) -genRecType = genRecTypeWith genType - -genTabTypeWith :: GenM Type -> GenM Type -genTabTypeWith g = liftM2 TabType genIdxSet (small g) - -genTabType :: GenM Type -genTabType = genTabTypeWith genType - --- types belong to the Data class -genDataType :: GenM Type -genDataType = genSized leaf tree - where - leaf = BaseType <$> genBaseType - tree = Gen.frequency [ - (1, leaf) - , (2, (RecType Cart) <$> genRecTypeWith genDataType) - , (2, genTabTypeWith genDataType) - ] - - -genIdxSet :: GenM IdxSet -genIdxSet = genSized leaf tree - where - lit = view tableSizeL >>= \n -> Gen.integral_ (Range.constant 1 n) - leaf = IdxSetLit <$> lit - tree = Gen.frequency [ - (1, leaf) - -- Tuple index has not been implemented in JIT - -- , (2, RecType <$> genRecTypeWith genIdxSet) - ] - --- TODO: TypeVar, Exists, BoundTVar. -genLeafType :: GenM Type -genLeafType = BaseType <$> genBaseType - --- TODO: Linear type, Tens -genTreeType :: GenM Type -genTreeType = Gen.choice [ - genLeafType - , arr - , genTabType - , (RecType Cart) <$> genRecType - ] - where - sub = small genType - arr = liftM2 (ArrType (Mult NonLin)) sub sub - - -genType :: GenM Type -genType = Gen.shrink shrinkType $ Gen.prune (genSized genLeafType genTreeType) - -shrinkType :: Type -> [Type] -shrinkType = tail . shrinkLis - where - shrinkLis :: Type -> [Type] - shrinkLis ty = case ty of - ArrType lin t1 t2 -> - -- TODO: generate smaller list - liftM2 (ArrType lin) (shrinkLis t1) (shrinkLis t2) ++ shrinkType t1 - TabType idx t -> - liftM2 TabType (shrinkLis idx) (shrinkLis t) ++ shrinkType t - (IdxSetLit n) -> IdxSetLit <$> towards n 1 - _ -> [ty] - - -genPatP :: (Type -> Ann) -> Type -> GenM (UPat, [(Name, Type)]) -genPatP ann ty = case ty of - (RecType _ (Tup as)) -> Gen.frequency [(1, genLeafPat), (2, genTupPat as)] - _ -> genLeafPat - where - genLeafPat = do - n <- genName - return (RecLeaf (n :> (ann ty)), [(n, ty)]) - genTreePat :: [Type] -> GenM ([UPat], [(Name, Type)]) - genTreePat [] = return ([], []) - genTreePat (t:ts) = do - (p1, vs1) <- genPatP ann t - (restp, restv) <- withBindings vs1 (genTreePat ts) -- make sure names are unique - return (p1:restp, vs1 ++ restv) - genTupPat :: [Type] -> GenM (UPat, [(Name, Type)]) - genTupPat ts = do - (ps, vs) <- genTreePat ts - return (RecTree (Tup ps), vs) - - --- | variable or literal value --- -genLit :: BaseType -> GenM (ExprP b) -genLit ty = Lit <$> case ty of - IntType -> - view numberSizeL >>= \n -> IntLit <$> Gen.integral_ (Range.constant (negate n) n) - BoolType -> BoolLit <$> Gen.bool_ - RealType -> do - n <- view (numberSizeL . to fromIntegral) - (RealLit . roundTripDouble) <$> Gen.realFrac_ (Range.constant (negate n) n) - StrType -> error "Str type not implemented" - --- TODO: remove this once we have more control over precision of printed floats -roundTripDouble :: Double -> Double -roundTripDouble x = read (show (double2Float x)) - -genName :: GenM Name -genName = Gen.prune (genUntil notShadowed (fromString <$> str)) - where - strLen = Range.constant 0 5 - strTail = Gen.frequency [(10, Gen.alphaNum), (1, return '\'')] - str = liftM2 (:) Gen.lower (Gen.list strLen strTail) - -genVars :: Type -> GenM [ExprP b] -genVars t = view (typeEnvL . at (Forall [] t) . to (maybe [] id) . to (map (flip Var []))) - -withVars :: Type -> GenM (ExprP b) -> GenM (ExprP b) -withVars t g = do - vs <- genVars t - e <- g - if null vs - then return e - else prefer 3 (Gen.element vs) (return e) -- preference to variable - --- TODO: Linear type -genLam :: Type -> Type -> GenM UExpr -genLam a b = do - (pat, env) <- genPatP Ann a - body <- withBindings env (genExpr b) - return (Lam (Ann (Mult NonLin)) pat body) - - --- table -genTabCon :: Int -> Type -> [GenM UExpr] -genTabCon n ty - | checkData ty = [TabCon NoAnn <$> replicateM n (small (genExpr ty))] - | otherwise = [] - -genFor :: Type -> Type -> GenM UExpr -genFor a b = do - (pat, env) <- small (genPatP Ann a) - body <- withBindings env (small (genExpr b)) - return (For pat body) - -genTable :: IdxSet -> Type -> GenM UExpr -genTable ty@(IdxSetLit n) b = Gen.choice (genFor ty b : genTabCon n b) -genTable ty b = genFor ty b - --- TODO: LetPoly, TAlias, Unpack -genDecl :: Type -> GenM UExpr -genDecl ty = do - -- preference over return type to increase variable usage - declTy <- small (preferReturnType ty genType) - declExpr <- small (genExpr declTy) - (declPat, env) <- small (genPatP (const NoAnn) declTy) - body <- small (withBindings env (genExpr ty)) - return (Decl (LetMono declPat declExpr) body) - -genGet :: Type -> GenM UExpr -genGet ty = do - idxty <- small genIdxSet - idx <- small (genExpr idxty) - body <- small (genExpr (TabType idxty ty)) - return (Get body idx) - - -genApp :: Type -> GenM UExpr -genApp ty = do - argty <- small (preferReturnType ty genType) - fun <- small (genExpr (ArrType (Mult NonLin) argty ty)) - arg <- small (genExpr argty) - return (App fun arg) - --- TODO: Tens -genRecCon :: Record Type -> GenM UExpr -genRecCon ~(Tup ts) = RecCon Cart <$> Tup <$> traverse (small . genExpr) ts - - -genLeafExpr :: Type -> GenM UExpr -genLeafExpr ty = withVars ty $ case ty of - BaseType t -> genLit t - ArrType _ t1 t2 -> genLam t1 t2 - TabType i t -> genTable i t - RecType _ rt -> genRecCon rt - IdxSetLit n -> do - val <- Gen.integral_ (Range.constant 0 (n - 1)) - return $ Annot (PrimOp IntAsIndex [] [Lit (IntLit val)]) ty - _ -> undefined - -genTreeExpr :: Type -> GenM UExpr -genTreeExpr ty = Gen.choice $ case ty of - BaseType{} -> commons - ArrType _ t1 t2 -> genLam t1 t2 : commons - TabType i t -> genTable i t : commons - RecType _ rt -> genRecCon rt : commons - _ -> commons - where - commons = [ - genDecl ty - , genApp ty - , genGet ty - ] - -genExpr :: Type -> GenM UExpr -genExpr ty = genSized (genLeafExpr ty) (genTreeExpr ty) - -sampleExpr :: GenM UExpr -sampleExpr = do - ty <- genDataType - genExpr ty - - -genUTopDecl :: GenM UTopDecl -genUTopDecl = (EvalCmd . Command (EvalExpr Printed)) <$> sampleExpr - -genSourceBlock :: GenM SourceBlock -genSourceBlock = do - topdecl <- UTopDecl <$> genUTopDecl - case topdecl of - ~(UTopDecl (EvalCmd (Command _ e))) -> return (SourceBlock 0 0 (pprint e) topdecl) diff --git a/tests/PropTests.hs b/tests/PropTests.hs deleted file mode 100644 index a149f787a..000000000 --- a/tests/PropTests.hs +++ /dev/null @@ -1,52 +0,0 @@ --- Copyright 2019 Google LLC --- --- Use of this source code is governed by a BSD-style --- license that can be found in the LICENSE file or at --- https://developers.google.com/open-source/licenses/bsd - -{-# LANGUAGE OverloadedStrings #-} - -import Control.Monad -import qualified Hedgehog as H -import Control.Monad.Reader -import qualified Data.Map.Strict as M - -import Syntax hiding (Result) -import Parser -import PPrint -import GenExpr -import TestPass - -main :: IO () -main = void tests - -prop_jitEval :: TypeEnv -> Evaluator -> Evaluator -> H.Property -prop_jitEval tenv jit interp = - H.property $ do - srcBlk <- H.forAllWith pprint (runReaderT genSourceBlock (makeGenEnv tenv defaultGenOptions)) - interres <- H.evalIO (interp srcBlk) - H.annotate ("Interpreter result: " ++ pprint interres) - jitres <- H.evalIO (jit srcBlk) - pprint interres H.=== pprint jitres - - -getExpr :: TopDeclP b -> ExprP b -getExpr ~(EvalCmd (Command _ e)) = e - -prop_pprint :: H.Property -prop_pprint = - H.property $ do - expr <- H.forAllWith pprint (runReaderT sampleExpr (makeGenEnv mempty defaultGenOptions)) - H.tripping expr pprintEsc (\s -> (getExpr . stripSrcAnnotTopDecl) <$> parseTopDecl s) - -tests :: IO Bool -tests = do - let prelude = "prelude.dx" - jit <- runTestPass prelude fullPassJit - interp <- runTestPass prelude fullPassInterp - preludeEnv <- loadTypeEnv prelude - let tyEnv = M.fromListWith (++) [(ty, [name]) | (ty, name) <- preludeEnv] - H.checkParallel $ H.Group "TypeCheck" [ - ("prop_jitEval", prop_jitEval tyEnv jit interp) - , ("prop_pprint", prop_pprint) - ] diff --git a/tests/TestPass.hs b/tests/TestPass.hs deleted file mode 100644 index c677709e1..000000000 --- a/tests/TestPass.hs +++ /dev/null @@ -1,76 +0,0 @@ --- Copyright 2019 Google LLC --- --- Use of this source code is governed by a BSD-style --- license that can be found in the LICENSE file or at --- https://developers.google.com/open-source/licenses/bsd - -module TestPass (typeCheckPass, fullPassInterp, fullPassJit, - runTestPass, Evaluator, loadTypeEnv) where - -import Data.Void -import Control.Monad.State.Strict -import qualified Data.Map.Strict as M -import Unsafe.Coerce - -import Pass -import DeShadow -import Inference -import Imp -import Syntax -import Type -import JIT -import Flops -import Normalize -import Simplify -import Interpreter -import Parser -import Env - -typeCheckPass :: TopPass SourceBlock TopDecl -typeCheckPass = sourcePass >+> deShadowPass >+> typePass >+> checkTyped - -fullPassInterp :: TopPass SourceBlock Void -fullPassInterp = typeCheckPass >+> interpPass - -fullPassJit :: TopPass SourceBlock Void -fullPassJit = typeCheckPass >+> normalizePass >+> checkNExpr - >+> derivPass >+> checkNExpr - >+> simpPass >+> checkNExpr - >+> impPass >+> checkImp - >+> flopsPass - >+> jitPass - - -type TestFullPass env b = SourceBlock -> TopPassM env b - -evalDecl :: Monoid env => TestFullPass env b -> SourceBlock -> StateT env IO () -evalDecl pass block = do - env <- get - (_, env') <- liftIO (runTopPassM env (pass block)) - modify (<> env') - -loadFile :: (Monoid env) => String -> TestFullPass env b -> IO env -loadFile fname pass = do - source <- readFile fname - let sourceBlocks = parseProg source - execStateT (mapM (evalDecl pass) sourceBlocks) mempty - -type Evaluator = SourceBlock -> IO Result' - -runTestPass :: String -> TopPass SourceBlock Void -> IO Evaluator -runTestPass fname (TopPass pass) = do - env <- loadFile fname pass - let eval source = do - ~(Left res, _) <- runTopPassM env (pass source) - return res - return eval - - -loadTypeEnv :: String -> IO [(SigmaType, Name)] -loadTypeEnv fname = - case sourcePass >+> deShadowPass >+> typePass of - TopPass pass -> do - envs <- loadFile fname pass - let env = (snd (unsafeCoerce envs)) :: TypeEnv - return $ case env of - Env m -> [(ty, name) | (name, L ty) <- M.toList m] diff --git a/tests/actor-test.hs b/tests/actor-test.hs deleted file mode 100644 index 4f9277298..000000000 --- a/tests/actor-test.hs +++ /dev/null @@ -1,69 +0,0 @@ --- Copyright 2019 Google LLC --- --- Use of this source code is governed by a BSD-style --- license that can be found in the LICENSE file or at --- https://developers.google.com/open-source/licenses/bsd - -import Actor -import Control.Monad -import Control.Monad.State -import System.IO -import qualified Data.Map.Strict as M - -type Key = String -type Val = String -type StoreID = Int -type ServerMsg = Either Val ClientToServer -type Server a = StateT (M.Map Key (Proc ())) (Actor ServerMsg) a - -data ClientToServer = Write Key Val | Read Key - -inputDriver :: Proc ServerMsg -> Actor () () -inputDriver server = do - command <- liftIO $ getLine - case words command of - ["write", s1, s2] -> server `send` (Right (Write s1 s2)) - ["read" , s ] -> server `send` (Right (Read s)) - _ -> liftIO $ putStrLn "didn't understand command" - loop - where loop = inputDriver server - -outputDriver :: Actor String () -outputDriver = do - receive $ \_ msg -> liftIO $ putStrLn msg - outputDriver - -serverProc :: Server () -serverProc = do - self <- getSelf - input <- spawn NoTrap (inputDriver self) - client <- spawn NoTrap outputDriver - forever $ mainServerLoop client - -storeProc :: Proc ServerMsg -> Val -> Actor () () -storeProc server val = receive $ \_ _ -> do - if length val > 5 then error "oops!" - else server `send` (Left val) >> loop - where loop = storeProc server val - -mainServerLoop :: Proc String -> Server () -mainServerLoop client = receiveAny handleMsg handleErr - where - handleMsg :: UProc -> ServerMsg -> Server () - handleMsg _ msg = case msg of - Left val -> send client val - Right req -> case req of - Write key val -> do - self <- getSelf - store <- spawnLink NoTrap (storeProc self val) - modify $ M.insert key store - Read key -> do - ans <- gets (M.lookup key) - case ans of Nothing -> sorry key - Just store -> store `send` () - handleErr err = client `send` ("Store " ++ show err ++ " down") - sorry key = client `send` ("Store " ++ key ++ " doesn't exist") - - -main :: IO () -main = runActor Trap (evalStateT serverProc mempty) diff --git a/tests/ad-tests-interp.dx b/tests/ad-tests-interp.dx deleted file mode 100644 index ce57d9854..000000000 --- a/tests/ad-tests-interp.dx +++ /dev/null @@ -1,56 +0,0 @@ - - -:p f : Float --o Float - f x = x - transposeLinear f 2.0 -> 2.0 - -:p f : Float --o Float - f x = y = x; y - transposeLinear f 2.0 -> 2.0 - -:p f : Float --o Float - f x = x + x - transposeLinear f 2.0 -> 4.0 - -:p f : Float --o Float - f x = y = 2.0 * x - 3.0 * y + x - transposeLinear f 1.0 -> 7.0 - -:p f : Float --o Float - f x = (2.0 + 3.0) * x - transposeLinear f 1.0 -> 5.0 - -:p f : (Float, Float) --o Float - f z = (x, y) = z - x + y * 2.0 - transposeLinear f 1.0 -> (1.0, 2.0) - -:p f : Float --o (Float, Float) - f x = (x, x * 2.0) - transposeLinear f (1.0, 3.0) -> 7.0 - -:p f x = x * x + 1.0 - jvp f 3.0 2.0 -> 12.0 - -:p f x = x * x + 1.0 - snd (vjp f 3.0) 2.0 -> 12.0 - -:p f : (Float, Float) -> Float - f (x,y) = x * y * 3.0 - jvp f (2.0, 5.0) (1.0, 100.0) -> 615.0 - -:p f : 3=>Float -> 3=>Float - f x = for i. x.i * x.i - jvp f [1.0, 1.5, 2.5] [3.0, 4.0, 1.0] -> [6.0, 12.0, 5.0] diff --git a/tests/flop-tests.dx b/tests/flop-tests.dx deleted file mode 100644 index e500ddc75..000000000 --- a/tests/flop-tests.dx +++ /dev/null @@ -1,33 +0,0 @@ -matmul : i=>j=>Float -> j=>k=>Float -> i=>k=>Float -matmul x y = for i k. sum (for j. x.i.j * y.j.k) - -_, N = unpack range 10 -_, M = unpack range 10 - -k = newKey 0 - -mat : N=>N=>Float -mat = for i j. rand (ixkey (ixkey k i) j) - -:flops matmul mat mat -> %fadd 1 N^3 -> %fmul 1 N^3 -> copy 1 N^3 - --- This should be O(N) but we're instantiating and adding zeros -:flops transposeLinear (llam xs. for i. xs.i) (for i:N. 0.0) -> %%int_to_index_set 1 N^1 -> %eq 1 N^2 -> %fadd 1 N^2 -> %isub 2 N^1 -> %select 1 N^2 -> copy 1 + 1 N^1 - --- This should be O(NM) but we're instantiating and adding zeros -:flops transposeLinear (llam m. for i j. m.j.i) (for i:N j:M. 0.0) -> %%int_to_index_set 1 M^1 N^1 + 1 N^1 -> %eq 1 M^1 N^2 + 1 M^2 N^1 -> %fadd 1 M^1 N^2 + 1 M^2 N^2 -> %isub 2 M^1 N^1 + 2 N^1 -> %select 1 M^1 N^2 + 1 M^2 N^2 -> copy 1 + 1 M^1 N^1 + 2 N^1 diff --git a/tests/include-test.dx b/tests/include-test.dx deleted file mode 100644 index c5127ef2f..000000000 --- a/tests/include-test.dx +++ /dev/null @@ -1,109 +0,0 @@ - -include "included.dx" -> 30 -> 40 - -:p x -> 10 - -load dxo "somedata.dxo" as dat - -:t dat -> (Float, 2, (2=>(3=>Float)), (2=>(Int, Bool))) - -:p dat -> (1.0, 1@2, [[2.0, 1.0, 3.0], [0.0, -10.0, 20.0]], [(1, True), (2, False)]) - -dump dxbo "test-scratch/bin-data-dump.dxbo" dat - -load dxbo "test-scratch/bin-data-dump.dxbo" as dat2 - -:t dat2 -> (Float, 2, (2=>(3=>Float)), (2=>(Int, Bool))) - -:p dat2 -> (1.0, 1@2, [[2.0, 1.0, 3.0], [0.0, -10.0, 20.0]], [(1, True), (2, False)]) - -load dxbo "not-a-file" as notData -> IO error: not-a-file: openFile: does not exist (No such file or directory) - -load dxbo "bad-binary-file.dxbo" as badData -> IO error: unexpected number of buffers: [16,8] vs [8,8,8] -> Validation error -> Claimed header length: 128 -> Claimed total length: 152 -> Actual file length: 128 -> Header data: -> type: ((2=>(1=>Float)), Int) -> bufferSizes: [8,8,8] - -load dxbo "test-scratch/pydata.dxbo" as pydata - -:t pydata -> ( Float -> , Int -> , () -> , Bool -> , Bool -> , (Int, (3=>Float)) -> , (2=>(3=>Float)) -> , (3=>(2=>Float)) -> , Float -> , Float -> , (1=>(1=>(1=>Int))) -> , (4=>Int) -> , (3=>Bool) ) - -:p pydata -> ( 1.2 -> , 12 -> , () -> , True -> , False -> , (-2, [1.0, 2.0, 3.0]) -> , [[10.0, 20.0, 30.0], [0.1, 0.2, 0.3]] -> , [[10.0, 0.1], [20.0, 0.2], [30.0, 0.3]] -> , 1.3 -> , 0.123 -> , [[[1]]] -> , [6, 5, 4, 3] -> , [True, False, True] ) - -dump dxbo "/tmp/stuff.dxbo" pydata - -load dxbo "/tmp/stuff.dxbo" as xs - -:t xs -> ( Float -> , Int -> , () -> , Bool -> , Bool -> , (Int, (3=>Float)) -> , (2=>(3=>Float)) -> , (3=>(2=>Float)) -> , Float -> , Float -> , (1=>(1=>(1=>Int))) -> , (4=>Int) -> , (3=>Bool) ) - -:p xs -> ( 1.2 -> , 12 -> , () -> , True -> , False -> , (-2, [1.0, 2.0, 3.0]) -> , [[10.0, 20.0, 30.0], [0.1, 0.2, 0.3]] -> , [[10.0, 0.1], [20.0, 0.2], [30.0, 0.3]] -> , 1.3 -> , 0.123 -> , [[[1]]] -> , [6, 5, 4, 3] -> , [True, False, True] ) - -load dxbo "examples/dxbo-example.dxbo" as exampleData - -:p exampleData -> (1, 2, [(3, [4, 5]), (6, [7, 8]), (9, [10, 11])]) diff --git a/tests/included.dx b/tests/included.dx deleted file mode 100644 index 14dd0687d..000000000 --- a/tests/included.dx +++ /dev/null @@ -1,8 +0,0 @@ - -x = 10 - -y = 20 - -:p 30 - -:p 40 diff --git a/tests/interp-tests.dx b/tests/interp-tests.dx deleted file mode 100644 index 4b1b3503c..000000000 --- a/tests/interp-tests.dx +++ /dev/null @@ -1,18 +0,0 @@ --- language features implemented in interpreter but not yet in the compiler - -:p 1 -> 1 - -_, M = unpack range 4 - -xs : M => (E n. n=>Int) -xs = for i. _, N = unpack range (asint i) - x = for j:N. asint j - pack x, N, E n. n=>Int - -for i. x, N2 = unpack xs.i -- TODO: underscore type binder - sum x -> [0, 0, 1, 3] - -:p filter (lam x. x > 2.0) [4.0, 0.0, 10.0, 2.0] -> pack [4.0, 10.0], 2, (Ea.(a=>Float)) diff --git a/tests/jax-tests.dx b/tests/jax-tests.dx deleted file mode 100644 index 5b13f6f3a..000000000 --- a/tests/jax-tests.dx +++ /dev/null @@ -1,28 +0,0 @@ - -x = 1.0 + 2.0 - -:p x + 3.0 -> 6.0 - -:p - getAccumulator \ref. - ref += 1.0 - ref += 2.0 -> 3.0 - -:p for i:3. x + 1.0 -> [4.0, 4.0, 4.0] - -xs = for i:4. 2.0 - -:p sum for i. xs.i * xs.i -> 16.0 - -:p float 1 -> 1.0 - -:p for i:3. neg (float 2) -> [-2.0, -2.0, -2.0] - -:p 0.0 + (1.0 + (2.0 + 0.0)) -> 3.0 diff --git a/tests/simple-include-test.dx b/tests/simple-include-test.dx deleted file mode 100644 index 93cff77e6..000000000 --- a/tests/simple-include-test.dx +++ /dev/null @@ -1,7 +0,0 @@ - -include "../tests/included.dx" -> 30 -> 40 - -:p x -> 10 diff --git a/tests/somedata.dxo b/tests/somedata.dxo deleted file mode 100644 index a9be9bb4e..000000000 --- a/tests/somedata.dxo +++ /dev/null @@ -1 +0,0 @@ -(1.0, 1@2, [[2.0, 1.0, 3.0], [0.0, -10.0, 20.0]], [(1, True), (2, False)]) diff --git a/tests/web-tests.dx b/tests/web-tests.dx deleted file mode 100644 index e50511fba..000000000 --- a/tests/web-tests.dx +++ /dev/null @@ -1,12 +0,0 @@ - -_, N = unpack range 7 - -xs = for i:N. float iota.i - -:p 1 + 1.0 - -:p 1.0 + 200.0 - -:plot for i. (xs.i, xs.i * xs.i) - -:plotmat for i:N j:N. rand $ hash (hash 0 iota.i) iota.j From 69ceb2fd8eb12a34415e0882b4f5ac04a2c67f7d Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Tue, 29 Dec 2020 13:59:22 -0500 Subject: [PATCH 050/105] Delete stale code. Fixes #342. --- dex.cabal | 2 +- src/lib/Flops.hs | 92 --------------------------------------------- src/lib/Parser.hs | 29 -------------- src/lib/Syntax.hs | 10 ++--- src/lib/TopLevel.hs | 1 - 5 files changed, 5 insertions(+), 129 deletions(-) delete mode 100644 src/lib/Flops.hs diff --git a/dex.cabal b/dex.cabal index 3cf1683c6..a4452d2cb 100644 --- a/dex.cabal +++ b/dex.cabal @@ -33,7 +33,7 @@ library exposed-modules: Env, Syntax, Type, Inference, JIT, LLVMExec, Parser, Util, Imp, Imp.Embed, Imp.Optimize, PPrint, Algebra, Parallelize, Optimize, Serialize - Actor, Cat, Flops, Embed, Export, + Actor, Cat, Embed, Export, RenderHtml, LiveOutput, Simplify, TopLevel, Autodiff, Interpreter, Logging, CUDA, LLVM.JIT, LLVM.Shims diff --git a/src/lib/Flops.hs b/src/lib/Flops.hs deleted file mode 100644 index 7ad8a3bbe..000000000 --- a/src/lib/Flops.hs +++ /dev/null @@ -1,92 +0,0 @@ --- Copyright 2019 Google LLC --- --- Use of this source code is governed by a BSD-style --- license that can be found in the LICENSE file or at --- https://developers.google.com/open-source/licenses/bsd - -{-# LANGUAGE OverloadedStrings #-} - -module Flops (impFunctionFlops) where - -import Control.Monad.Reader -import Control.Monad.Writer -import qualified Data.Map.Strict as M -import Data.Text.Prettyprint.Doc hiding (group) - -import Syntax -import Env -import PPrint - -data Term = Term Int [(Name, Int)] deriving (Show, Eq, Ord) -type Count = [Term] -newtype Profile = Profile (M.Map String Count) - -type FlopM a = ReaderT Term (Writer Profile) a - -impFunctionFlops :: ImpFunction -> Profile -impFunctionFlops (FFIFunction _) = mempty -impFunctionFlops (ImpFunction _ _ body) = - snd $ runWriter (runReaderT (flops body) (litTerm 1)) - -flops :: ImpBlock -> FlopM () -flops (ImpBlock statements _) = void $ traverse declFlops statements - -declFlops :: ImpDecl -> FlopM () -declFlops (ImpLet _ instr) = instrFlops instr - -instrFlops :: ImpInstr -> FlopM () -instrFlops instr = case instr of - IFor _ _ size block -> local (mulTerm $ evalSizeExpr size) $ flops block - ICond _ _ _ -> return () -- TODO: Implement - IWhile _ _ -> return () -- TODO: Implement - ILaunch _ _ _ -> return () -- TODO: Implement - IPrimOp op -> do - n <- ask - tell $ Profile $ M.singleton (showPrimName $ OpExpr op) [n] - _ -> return () - -evalSizeExpr :: IExpr -> Term -evalSizeExpr (IVar (v:>_)) = varTerm v -evalSizeExpr expr = error $ "Not implemented: " ++ pprint expr - -litTerm :: Int -> Term -litTerm n = Term n [] - -varTerm :: Name -> Term -varTerm v = Term 1 [(v, 1)] - -mulTerm :: Term -> Term -> Term -mulTerm (Term n xs) (Term n' xs') = Term (n * n') (xs <> xs') - -canonicalizeCount :: Count -> Count -canonicalizeCount terms = - let terms' = groupReduce (+) [(term, coeff) | - Term coeff term <- map canonicalizeTerm terms] - in [Term coeff term | (term, coeff) <- terms'] - -canonicalizeTerm :: Term -> Term -canonicalizeTerm (Term coeff term) = Term coeff (groupReduce (+) term) - -prettyCount :: Count -> Doc ann -prettyCount terms = - hsep $ punctuate " +" $ map pretty terms' - where terms' = canonicalizeCount terms - -groupReduce :: Ord a => (b -> b -> b) -> [(a,b)] -> [(a,b)] -groupReduce f pairs = M.toAscList $ foldr (M.unionWith f) mempty $ - map (uncurry M.singleton) pairs - -instance Semigroup Profile where - Profile m <> Profile m' = Profile $ M.unionWith (<>) m m' - -instance Monoid Profile where - mempty = Profile mempty - mappend = (<>) - -instance Pretty Profile where - pretty (Profile m) = vsep $ [pretty b <+> prettyCount c - | (b, c) <- M.toAscList m] - -instance Pretty Term where - pretty (Term coeff term) = pretty coeff <+> - hsep ([pretty v <> "^" <> pretty pow | (v, pow) <- term]) diff --git a/src/lib/Parser.hs b/src/lib/Parser.hs index 3c7d50401..722d85892 100644 --- a/src/lib/Parser.hs +++ b/src/lib/Parser.hs @@ -127,41 +127,12 @@ sourceBlock' = proseBlock :: Parser SourceBlock' proseBlock = label "prose block" $ char '\'' >> fmap (ProseBlock . fst) (withSource consumeTillBreak) -loadData :: Parser SourceBlock' -loadData = do - symbol "load" - fmt <- dataFormat - s <- stringLiteral - symbol "as" - b <- patAnn - void eol - return $ LoadData b fmt s - topLevelCommand :: Parser SourceBlock' topLevelCommand = (liftM IncludeSourceFile includeSourceFile) - <|> loadData - <|> dumpData <|> explicitCommand "top-level command" -dataFormat :: Parser DataFormat -dataFormat = do - s <- nameString - case s of - "dxo" -> return DexObject - "dxbo" -> return DexBinaryObject - _ -> fail $ show s ++ " not a recognized data format (one of dxo|dxbo)" - -dumpData :: Parser SourceBlock' -dumpData = do - symbol "dump" - fmt <- dataFormat - s <- stringLiteral - e <- blockOrExpr - void eol - return $ Command (Dump fmt s) (exprAsModule e) - explicitCommand :: Parser SourceBlock' explicitCommand = do cmdName <- char ':' >> nameString diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index 510f9fc8d..dcbb2b4a4 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -32,7 +32,7 @@ module Syntax ( UAlt (..), AltP, Alt, Label, LabeledItems (..), labeledSingleton, reflectLabels, withLabels, ExtLabeledItems (..), prefixExtLabeledItems, IScope, BinderInfo (..), Bindings, CUDAKernel (..), BenchStats, - SrcCtx, Result (..), Output (..), OutFormat (..), DataFormat (..), + SrcCtx, Result (..), Output (..), OutFormat (..), Err (..), ErrType (..), Except, throw, throwIf, modifyErr, addContext, addSrcContext, catchIOExcept, liftEitherIO, (-->), (--@), (==>), boundUVars, PassName (..), boundVars, renamingSubst, bindingsAsVars, @@ -474,14 +474,13 @@ data SourceBlock' = RunModule UModule | Command CmdName (Name, UModule) | GetNameType Name | IncludeSourceFile String - | LoadData UPatAnn DataFormat String | ProseBlock String | CommentLine | EmptyLines | UnParseable ReachedEOF String deriving (Show, Generic) -data CmdName = GetType | EvalExpr OutFormat | ExportFun String | Dump DataFormat String +data CmdName = GetType | EvalExpr OutFormat | ExportFun String deriving (Show, Generic) data LogLevel = LogNothing | PrintEvalTime | PrintBench String @@ -608,7 +607,7 @@ monMapLookup (MonMap m) k = case M.lookup k m of Nothing -> mempty -- === passes === data PassName = Parse | TypePass | SynthPass | SimpPass | ImpPass | JitPass - | Flops | LLVMOpt | AsmPass | JAXPass | JAXSimpPass | LLVMEval + | LLVMOpt | AsmPass | JAXPass | JAXSimpPass | LLVMEval | ResultPass | JaxprAndHLO | OptimPass deriving (Ord, Eq, Bounded, Enum) @@ -616,7 +615,7 @@ instance Show PassName where show p = case p of Parse -> "parse" ; TypePass -> "typed" ; SynthPass -> "synth" SimpPass -> "simp" ; ImpPass -> "imp" ; JitPass -> "llvm" - Flops -> "flops" ; LLVMOpt -> "llvmopt" ; AsmPass -> "asm" + LLVMOpt -> "llvmopt" ; AsmPass -> "asm" JAXPass -> "jax" ; JAXSimpPass -> "jsimp"; ResultPass -> "result" LLVMEval -> "llvmeval" ; JaxprAndHLO -> "jaxprhlo"; OptimPass -> "optimized" @@ -638,7 +637,6 @@ data Output = TextOut String deriving (Show, Eq, Generic) data OutFormat = Printed | RenderHtml deriving (Show, Eq, Generic) -data DataFormat = DexObject | DexBinaryObject deriving (Show, Eq, Generic) data Err = Err ErrType SrcCtx String deriving (Show, Eq) instance Exception Err diff --git a/src/lib/TopLevel.hs b/src/lib/TopLevel.hs index a36c18b86..d7269312d 100644 --- a/src/lib/TopLevel.hs +++ b/src/lib/TopLevel.hs @@ -105,7 +105,6 @@ evalSourceBlockM env block = case sbContents block of GetType -> do -- TODO: don't actually evaluate it val <- evalUModuleVal env v m logTop $ TextOut $ pprint $ getType val - Dump _ _ -> error "Not implemented" GetNameType v -> case envLookup env (v:>()) of Just (ty, _) -> logTop (TextOut $ pprint ty) >> return mempty _ -> liftEitherIO $ throw UnboundVarErr $ pprint v From 4a1fe72534dcc276d37693de139eb8fe3297be04 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Tue, 29 Dec 2020 14:35:24 -0500 Subject: [PATCH 051/105] Update or delete stale examples --- examples/aspirational.dx | 48 ---------------- examples/bugs.dx | 35 ------------ examples/chol.dx | 88 +++++++++++++---------------- examples/mnist-nearest-neighbors.dx | 3 + examples/sgd.dx | 2 + examples/tutorial-old.dx | 4 ++ lib/prelude.dx | 5 +- makefile | 4 +- 8 files changed, 53 insertions(+), 136 deletions(-) delete mode 100644 examples/aspirational.dx delete mode 100644 examples/bugs.dx diff --git a/examples/aspirational.dx b/examples/aspirational.dx deleted file mode 100644 index ed3c1b2b2..000000000 --- a/examples/aspirational.dx +++ /dev/null @@ -1,48 +0,0 @@ --- === logistic regression === - --- features needed --- * type aliases with type variables --- * index set sum types (and generalization of inL and inR) --- * while loop construct (for fixed-point iter version) --- loop : (a -> Either b a) -> (b, E n. n=>a) --- * unpacking multiple type variables - -type ParamsIdx d = Either d () -- concrete syntax for sum types? -type Params d = (ParamsIdx d)=>Float - -logLogistic : Float -> Float -logLogistic x = log $ 1 / (1 + exp (-x)) - -bool2pm1 : Bool -> Float -bool2pm1 x = select x 1.0 -1.0 - -evalLogreg : Params d -> d=>Float -> Float -evalLogreg params x = - let w.i = x.(L i) -- can we improve this unpacking? - b = x.(R ()) - in logLogistic $ b + vdot w x - -logRegLoss : Params d -> d=>Float -> Bool -> Float -logRegLoss params x y = (evalLogReg params x) * (bool2pm1 y) - --- what about looping until convergence? Need a different looping construct -optimize : (n=>Float -> Float) -> (n=>Float) -optimize f = .. - let lr = 0.1 - scale = 0.1 - nIters = 1000 - x0.i = scale * randn (fanout 0).i - in loopN nIters x0 lam x. let dx = grad f x - in for i. x.i + lr * dx.ix' - -theData : E n d. (n=>d=>Float, n=>Bool) - -(xs, ys), N, D = unpack data - -loss : D=>R -> R -loss params = mean (for i. logRegLoss params xs.i ys.i) - --- cross-validation? minibatches? -optParams : Params D -optParams = optimize 0 logRegLoss - diff --git a/examples/bugs.dx b/examples/bugs.dx deleted file mode 100644 index b0bd31919..000000000 --- a/examples/bugs.dx +++ /dev/null @@ -1,35 +0,0 @@ --- we don't do let generalization on patterns, but this is a problem if --- generalization is required. This fails: -(f, g) = (lam x. x, lam x. x) - - --- printing of tuple-index tables not implemented -x = [1,2,3] -:p for (i,j). iadd x.i x.j - - --- out-of-bounds indexing - need to wrap indices -:p let litArr = [10, 5, 3] - in litArr.(asidx 4) -> 5 - --- polymorphic declarations without explicit types crash the compiler --- (should be a straightforward error message) -f x = x - --- apparently we're treating unbound type aliases as things to infer -x : N -x = 1 - --- need a type class constrain for index sets so that this is an error -:t for i:Int. 1 - --- Bad error message because we lose provenance of the constraint -:t lam x. - z = iadd x 1 - y = fadd x 1.0 - (z, y) -> Type error: -> Expected: Int -> Actual: Float -> In: From subst diff --git a/examples/chol.dx b/examples/chol.dx index 2ef573113..a639d20b6 100644 --- a/examples/chol.dx +++ b/examples/chol.dx @@ -1,77 +1,65 @@ ' # Cholesky Factorization https://en.wikipedia.org/wiki/Cholesky_decomposition -' ### Matrix Math - -eye : n=>n=>Float -eye = for i j. select (i == j) 1.0 0.0 - -mmadd: (n=>m=>Float)->(n=>m=>Float)->(n=>m=>Float) -mmadd x y = for i j. x.i.j + y.i.j - ' ## Cholesky Algorithm -chol : (n=>n=>Float) -> (n=>n=>Float) -chol x = getState (for _ _. 0.0) \buf. - for i. - for j':(...i). +def chol (_:Eq n) ?=> (x:n=>n=>Float) : (n=>n=>Float) = + snd $ withState zero \buf. + for_ i. for j':(..i). j = %inject(j') - row = for k:(..n=>Float -> n=>Float -> n=>Float -trisolveL mat b = getState (for _. 0.0) \buf. - for i. - row = for j:(..n=>Float) (b:n=>Float) : n=>Float = + snd $ withState zero \buf. for i. + row = for j:(..n=>Float -> n=>Float -> n=>Float -trisolveU mat b = getState (for _. 0.0) \buf. - rof i. - row = for j:(i...). mat.i.%inject(j) - xPrev = for j:(i...). getAt buf %inject(j) - putAt buf i $ (b.i - vdot row xPrev) / mat.i.i +def trisolveU (mat:n=>n=>Float) (b:n=>Float) : n=>Float = + snd $ withState zero \buf. rof i. + row = for j:(i..). mat.i.%inject(j) + xPrev = for j:(i..). get (buf!%inject j) + buf!i := (b.i - vdot row xPrev) / mat.i.i -psdsolve : n=>n=>Float -> n=>Float -> n=>Float -psdsolve mat b = +def psdsolve (_:Eq n) ?=> (mat:n=>n=>Float) (b:n=>Float) : n=>Float = l = chol mat trisolveU (transpose l) $ trisolveL l b ' Test -type N = 4 -(k1, k2) = splitKey $ newKey 0 +N = Fin 4 +[k1, k2] = splitKey $ newKey 0 -psd : N=>N=>Float -psd = +psd : N=>N=>Float = a = for i:N j:N. randn $ ixkey2 k1 i j - x = mmp a (transpose a) - mmadd x eye + x = a ** transpose a + x + eye -l : N=>N=>Float -l = chol psd +l : N=>N=>Float = chol psd :p l -> [ [2.021765, 0.0, 0.0, 0.0] -> , [-1.7950183, 1.9901744, 0.0, 0.0] -> , [-0.89788574, 0.18675673, 1.9802661, 0.0] -> , [1.4457518, -0.29644823, 0.72458607, 2.2308075] ] +> [ [2.021765, 0., 0., 0.] +> , [-1.795019, 1.990174, 0., 0.] +> , [-0.897886, 0.186757, 1.980266, 0.] +> , [1.445752, -0.296448, 0.724586, 2.230807] ] -psdReconstructed = l `mmp` transpose l +psdReconstructed = l ** transpose l :p sum for (i, j). sq (psd.i.j - psdReconstructed.i.j) -> 2.4651903e-32 +> 0. -vec : N=>Float -vec = for i. randn $ ixkey k2 i +vec : N=>Float = arb k2 -:p (vec, mvp psd (psdsolve psd vec)) -> ( [1.2112769, 0.23284969, -0.74191034, 0.8833507] -> , [1.2112769, 0.23284969, -0.74191034, 0.8833507] ) +:p (vec, psd **. psdsolve psd vec) +> ( [1.211277, 0.23285, -0.741911, 0.883351] +> , [1.211277, 0.23285, -0.741911, 0.883351] ) diff --git a/examples/mnist-nearest-neighbors.dx b/examples/mnist-nearest-neighbors.dx index 161e6aa8b..4448420f5 100644 --- a/examples/mnist-nearest-neighbors.dx +++ b/examples/mnist-nearest-neighbors.dx @@ -1,3 +1,6 @@ +'# THIS FILE IS STALE + +'(But we plan to update it at some point) load dxbo "scratch/mnist.dxbo" as mnist diff --git a/examples/sgd.dx b/examples/sgd.dx index 754d1f550..40f496c96 100644 --- a/examples/sgd.dx +++ b/examples/sgd.dx @@ -31,5 +31,7 @@ stepsize = 0.01 decay = 0.9 num_iters = 1000 :p sgd stepsize decay num_iters gradfunc x_init +> [1.1, 1.1, 1.1, 1.1] :p optimum +> [1.1, 1.1, 1.1, 1.1] diff --git a/examples/tutorial-old.dx b/examples/tutorial-old.dx index ecd2ee80a..3822158c9 100644 --- a/examples/tutorial-old.dx +++ b/examples/tutorial-old.dx @@ -1,3 +1,7 @@ +'# THIS FILE IS STALE + +'(But we plan to update it at some point) + '# Introduction to the Dex language 'Dex is a functional, statically typed language for array processing. diff --git a/lib/prelude.dx b/lib/prelude.dx index 1785f0f0e..b6a68f844 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -586,6 +586,9 @@ def dot (_:VSpace v) ?=> (s:n=>Float) (vs:n=>v) : v = sum for j. s.j .* vs.j def inner (x:n=>Float) (mat:n=>m=>Float) (y:m=>Float) : Float = fsum \(i,j). x.i * mat.i.j * y.j +def eye (_:Eq n) ?=> : n=>n=>Float = + for i j. select (i == j) 1.0 0.0 + '## Pseudorandom number generator utilities -- TODO: newtype @@ -1082,7 +1085,7 @@ def getEnv (name:String) : {State World} Maybe String = fromCString $ MkCString $ %ffi getenv RawPtr ptr def checkEnv (name:String) : {State World} Bool = - -- This should be just `isJust $ getEnv name` but that sefaults (only if the + -- This should be just `isJust $ getEnv name` but that segfaults (only if the -- env var *is* defined), possibly related to bug #348. withCString name \(MkCString ptr). resultPtr = %ffi getenv RawPtr ptr diff --git a/makefile b/makefile index 8f2ed6176..ae20b85cf 100644 --- a/makefile +++ b/makefile @@ -79,11 +79,11 @@ build-python: build # --- running tests --- -# TODO: re-enable linear-tests ad-tests include-test chol example-names = mandelbrot pi sierpinski rejection-sampler \ regression brownian_motion particle-swarm-optimizer \ ode-integrator mcmc ctc raytrace particle-filter \ - isomorphisms ode-integrator linear_algebra fluidsim + isomorphisms ode-integrator linear_algebra fluidsim \ + sgd chol test-names = uexpr-tests adt-tests type-tests eval-tests show-tests \ shadow-tests monad-tests io-tests \ From f836ba5688fd1774b18b03956fc6b6597d6467af Mon Sep 17 00:00:00 2001 From: David Duvenaud Date: Wed, 23 Dec 2020 13:37:49 -0500 Subject: [PATCH 052/105] Improved linear algebra example. --- examples/linear_algebra.dx | 86 +++++++++++++++++++++----------------- 1 file changed, 47 insertions(+), 39 deletions(-) diff --git a/examples/linear_algebra.dx b/examples/linear_algebra.dx index 7dce43f3c..28f6b7bcc 100644 --- a/examples/linear_algebra.dx +++ b/examples/linear_algebra.dx @@ -3,69 +3,79 @@ def identity_matrix (_:Eq n) ?=> (_:Add a) ?=> (_:Mul a) ?=> : n=>n=>a = for i j. select (i == j) one zero - '### Triangular matrices -def LowerTriMat (n:Type) : Type = i:n=>(..i)=>Float -def UpperTriMat (n:Type) : Type = i:n=>(i..)=>Float +def LowerTriMat (n:Type) (v:Type) : Type = i:n=>(..i)=>v +def UpperTriMat (n:Type) (v:Type) : Type = i:n=>(i..)=>v + +def upperTriDiag (u:UpperTriMat n v) : n=>v = for i. u.i.(0@_) +def lowerTriDiag (l:LowerTriMat n v) : n=>v = for i. l.i.((ordinal i)@_) -def forward_substitute (_:VSpace v) ?=> (a:LowerTriMat n) (b:n=>v) : n=>v = +def forward_substitute (_:VSpace v) ?=> (a:LowerTriMat n Float) (b:n=>v) : n=>v = -- Solves lower triangular linear system (inverse a) **. b snd $ withState zero \sRef. for i:n. s = sum for k:(.. (a:UpperTriMat n) (b:n=>v) : n=>v = +def backward_substitute (_:VSpace v) ?=> (a:UpperTriMat n Float) (b:n=>v) : n=>v = -- Solves upper triangular linear system (inverse a) **. b snd $ withState zero \sRef. rof i:n. s = sum for k:(i..). -- dot product - a.i.((ordinal k)@_) .* (get sRef).(%inject k) + a.i.((ordinal k)@_) .* get sRef!(%inject k) sRef!i := (b.i - s) / a.i.(0@_) -- 0 is the diagonal index -- Todo: get rid of these by writing a dependent indexing (!) operator. -def lowerTriIndex (ref:Ref h (LowerTriMat n)) (i:n) : Ref h ((..i)=>Float) = +def lowerTriIndex (ref:Ref h (LowerTriMat n v)) (i:n) : Ref h ((..i)=>v) = %indexRef ref i -def upperTriIndex (ref:Ref h (UpperTriMat n)) (i:n) : Ref h ((i..)=>Float) = +def upperTriIndex (ref:Ref h (UpperTriMat n v)) (i:n) : Ref h ((i..)=>v) = %indexRef ref i '### Permutations -def Permutation (n:Type) : Type = n=>n -def apply_permutation (permutation: n=>n) (array: n=>t) : n=>t = - for i. array.(permutation.i) -def identity_permutation (n:Type) ?-> : Permutation n = - for i. i +-- The sign of the determinant of a permutation is either 1.0 or -1.0 +PermutationSign = Float +def Permutation (n:Type) : Type = (perm:n=>n & PermutationSign) -'### LU decomposition functions +def apply_permutation ((perm, _):Permutation n) (xs: n=>t) : n=>t = + for i. xs.(perm.i) + +def identity_permutation : Permutation n = + (for i. i, 1.0) + +def swapInPlace (pRef: Ref h (Permutation n)) (i:n) (j:n) : {State h} Unit = + (permRef, signRef) = (fstRef pRef, sndRef pRef) + tempj = get permRef!j + permRef!j := get permRef!i + permRef!i := tempj + signRef := -(get signRef) -Sign = Float -- Either 1.0 or -1.0 +def perToTable ((perm, _):Permutation n) : n=>n = perm +def permSign ((_, sign):Permutation n) : PermutationSign = sign + + + +'### LU decomposition functions -def pivotize (_:Eq n) ?=> (a:n=>n=>Float) : (Permutation n & Sign) = - -- Permutes rows of a matrix to make Gaussian elimination more stable. - -- Returns permutation and the sign of its determinant. - snd $ withState (identity_permutation, 1.0) \stateRef. - (pRef, signRef) = (fstRef stateRef, sndRef stateRef) +def pivotize (_:Eq n) ?=> (a:n=>n=>Float) : Permutation n = + -- Gives a row permutation that makes Gaussian elimination more stable. + snd $ withState identity_permutation \permRef. for j:n. - row_with_largest = argmin for i:(j..). (-(abs a.(%inject i).j)) - row_with_largest = %inject row_with_largest + row_with_largest' = argmin for i:(j..). (-(abs a.(%inject i).j)) + row_with_largest = %inject row_with_largest' case (j == row_with_largest) of True -> () - False -> - tempj = get pRef!j -- Is there a refSwap? - pRef!j := get pRef!row_with_largest - pRef!row_with_largest := tempj - signRef := -(get signRef) + False -> swapInPlace permRef j row_with_largest def lu (_:Eq n) ?=> (a: n=>n=>Float) : - (LowerTriMat n & UpperTriMat n & Permutation n & Sign) = + (LowerTriMat n Float & UpperTriMat n Float & Permutation n) = -- Computes lower, upper, and permuntation matrices from a square matrix, -- such that apply_permutation permutation a == lower ** upper. - (permutation, swapcount) = pivotize a + permutation = pivotize a a = apply_permutation permutation a init_lower = for i:n. for j':(..i). @@ -120,7 +130,7 @@ def lu (_:Eq n) ?=> (a: n=>n=>Float) : ujj = get (upperTriIndex uRef j)!(0@_) lijRef = (lowerTriIndex lRef i'')!((ordinal j)@_) lijRef := (a.i'.j - s) / ujj - (lower, upper, permutation, swapcount) + (lower, upper, permutation) '### General linear algebra functions. @@ -130,7 +140,7 @@ def solve (_:Eq n) ?=> (_:VSpace v) ?=> (a:n=>n=>Float) (b:n=>v) : n=>v = -- that l always has ones on the diagonal. It would just require a -- custom forward_substitute routine that doesn't divide -- by the diagonal entries. - (l, u, perm, _) = lu a + (l, u, perm) = lu a b' = apply_permutation perm b y = forward_substitute l b' backward_substitute u y @@ -139,15 +149,13 @@ def invert (_:Eq n) ?=> (a:n=>n=>Float) : n=>n=>Float = solve a identity_matrix def determinant (_:Eq n) ?=> (a:n=>n=>Float) : Float = - (l, u, perm, permutation_sign) = lu a - -- formerly u.i.i * l.i.i - (prod for i. u.i.(0@_) * l.i.((ordinal i)@_)) * permutation_sign + (l, u, perm) = lu a + prod (for i. (upperTriDiag u).i * (lowerTriDiag l).i) * permSign perm def sign_and_log_determinant (_:Eq n) ?=> (a:n=>n=>Float) : (Float & Float) = - (l, u, perm, permutation_sign) = lu a - -- formerly u.i.i * l.i.i - diags = for i. u.i.(0@_) * l.i.((ordinal i)@_) - sign = permutation_sign * prod for i. sign diags.i + (l, u, perm) = lu a + diags = for i. (upperTriDiag u).i * (lowerTriDiag l).i + sign = (permSign perm) * prod for i. sign diags.i sum_of_log_abs = sum for i. log (abs diags.i) (sign, sum_of_log_abs) From a3d65fe33611ca32d6696d0715898e50972740b9 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Wed, 30 Dec 2020 10:32:19 -0500 Subject: [PATCH 053/105] Start adding an exception effect - syntax and parser. This required reorganizing the way we represent effects, since exception effects don't take a "region" parameter (though maybe they should?). --- src/lib/Autodiff.hs | 23 ++++++++++++++-------- src/lib/Embed.hs | 6 +++--- src/lib/Inference.hs | 19 +++++++++++------- src/lib/PPrint.hs | 10 +++++++--- src/lib/Parallelize.hs | 9 +++++++-- src/lib/Parser.hs | 23 +++++++++++++++------- src/lib/Syntax.hs | 37 ++++++++++++++++++++++++----------- src/lib/Type.hs | 44 ++++++++++++++++++++++-------------------- 8 files changed, 109 insertions(+), 62 deletions(-) diff --git a/src/lib/Autodiff.hs b/src/lib/Autodiff.hs index f47c88de1..0624364d4 100644 --- a/src/lib/Autodiff.hs +++ b/src/lib/Autodiff.hs @@ -31,7 +31,9 @@ import GHC.Stack -- `DerivWrt` holds the (out-expr) variables that we're differentiating with -- respect to (including refs but not regions). -data DerivWrt = DerivWrt { activeVars :: Env Type, _activeEffs :: [Effect], rematVars :: Env Type } +data DerivWrt = DerivWrt { activeVars :: Env Type + , _activeEffs :: [Effect] + , rematVars :: Env Type } -- `Tangents` holds the tangent values and the region variables that are -- arguments to the linearized function. data TangentEnv = TangentEnv { tangentVals :: SubstEnv, activeRefs :: [Name], rematVals :: SubstEnv } @@ -301,15 +303,15 @@ linearizeHof env hof = case hof of extendTangentEnv (ref @> ref') [h] $ applyLinToTangents linBody return (ans, lin) - linearizeEffectFun :: EffectName -> Atom -> PrimalM (Atom, Var) - linearizeEffectFun effName ~(BinaryFunVal h ref eff body) = do + linearizeEffectFun :: RWS -> Atom -> PrimalM (Atom, Var) + linearizeEffectFun rws ~(BinaryFunVal h ref eff body) = do h' <- mapM (substEmbed env) h buildLamAux h' (const $ return PureArrow) $ \h''@(Var hVar) -> do let env' = env <> h@>h'' eff' <- substEmbed env' eff ref' <- mapM (substEmbed env') ref buildLamAux ref' (const $ return $ PlainArrow eff') $ \ref''@(Var refVar) -> - extendWrt [refVar] [(effName, varName hVar)] $ + extendWrt [refVar] [RWSEffect rws (varName hVar)] $ (,refVar) <$> (tangentFunAsLambda $ linearizeBlock (env' <> ref@>ref'') body) linearizePrimCon :: Con -> LinA Atom @@ -417,15 +419,16 @@ tangentFunAsLambda :: LinA Atom -> PrimalM Atom tangentFunAsLambda m = do (ans, tanFun) <- runLinA m DerivWrt activeVars effs remats <- ask - let hs = map (Bind . (:>TyKind) . snd) effs + let hs = map (Bind . (:>TyKind) . effectRegion) effs let rematList = envAsVars remats liftM (PairVal ans) $ lift $ do tanLam <- makeLambdas rematList $ \rematArgs -> buildNestedLam PureArrow hs $ \hVals -> do let hVarNames = map (\(Var (v:>_)) -> v) hVals - let effs' = zipWith (\(effName, _) v -> (effName, v)) effs hVarNames + -- TODO: handle exception effect too + let effs' = zipWith (\(RWSEffect rws _) v -> RWSEffect rws v) effs hVarNames -- want to use tangents here, not the original binders - let regionMap = newEnv (map ((:>()) . snd) effs) hVals + let regionMap = newEnv (map ((:>()) . effectRegion) effs) hVals -- TODO: Only bind tangents for free variables? let activeVarBinders = map (Bind . fmap (tangentRefRegion regionMap)) $ envAsVars activeVars buildNestedLam PureArrow activeVarBinders $ \activeVarArgs -> @@ -454,6 +457,10 @@ tangentFunAsLambda m = do RefTy ~(Var h) a -> RefTy (regEnv ! h) $ tangentType a _ -> tangentType ty + effectRegion eff = case eff of + RWSEffect _ h -> h + ExceptionEffect -> error "TODO!" + -- Inverse of tangentFunAsLambda. Should be used inside a returned tangent action. applyLinToTangents :: Atom -> TangentM Atom applyLinToTangents f = do @@ -759,7 +766,7 @@ isLinEff :: EffectRow -> TransposeM Bool isLinEff (EffectRow effs Nothing) = do regions <- asks linRegions return $ not $ null $ effRegions `envIntersect` regions - where effRegions = newEnv (S.map snd effs) (repeat ()) + where effRegions = freeVars $ toList effs isLinEff _ = error "Can't transpose polymorphic effects" emitCTToRef :: Maybe Atom -> Atom -> TransposeM () diff --git a/src/lib/Embed.hs b/src/lib/Embed.hs index f580776bf..437e2a32d 100644 --- a/src/lib/Embed.hs +++ b/src/lib/Embed.hs @@ -357,11 +357,11 @@ emitRunState :: MonadEmbed m => Name -> Atom -> (Atom -> m Atom) -> m Atom emitRunState v x0 body = do emit . Hof . RunState x0 =<< mkBinaryEffFun State v (getType x0) body -mkBinaryEffFun :: MonadEmbed m => EffectName -> Name -> Type -> (Atom -> m Atom) -> m Atom -mkBinaryEffFun newEff v ty body = do +mkBinaryEffFun :: MonadEmbed m => RWS -> Name -> Type -> (Atom -> m Atom) -> m Atom +mkBinaryEffFun rws v ty body = do eff <- getAllowedEffects buildLam (Bind ("h":>TyKind)) PureArrow $ \r@(Var (rName:>_)) -> do - let arr = PlainArrow $ extendEffect (newEff, rName) eff + let arr = PlainArrow $ extendEffect (RWSEffect rws rName) eff buildLam (Bind (v:> RefTy r ty)) arr body buildForAnnAux :: MonadEmbed m => ForAnn -> Binder -> (Atom -> m (Atom, a)) -> m (Atom, a) diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index 2c3f31e82..ecfd2e962 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -158,14 +158,14 @@ checkOrInferRho (WithSrc pos expr) reqTy = do kind' <- checkUType kind piTy <- case pat of Just pat' -> withNameHint ("pat" :: Name) $ buildPi b $ \x -> - withBindPat pat' x $ (,) <$> mapM checkUEff arr <*> checkUType ty + withBindPat pat' x $ (,) <$> mapM checkUEffRow arr <*> checkUType ty where b = case pat' of -- Note: The binder name becomes part of the type, so we -- need to keep the same name used in the pattern. WithSrc _ (UPatBinder (Bind (v:>()))) -> Bind (v:>kind') _ -> Ignore kind' Nothing -> buildPi (Ignore kind') $ const $ - (,) <$> mapM checkUEff arr <*> checkUType ty + (,) <$> mapM checkUEffRow arr <*> checkUType ty matchRequirement piTy UDecl decl body -> do env <- inferUDecl False decl @@ -393,11 +393,9 @@ checkULam (p, ann) body piTy = do $ \x@(Var v) -> checkLeaks [v] $ withBindPat p x $ checkSigma body Suggest $ snd $ applyAbs piTy x -checkUEff :: EffectRow -> UInferM EffectRow -checkUEff (EffectRow effs t) = do - effs' <- liftM S.fromList $ forM (toList effs) $ \(effName, region) -> do - (Var (v:>TyKind)) <- lookupSourceVar (region:>()) - return (effName, v) +checkUEffRow :: EffectRow -> UInferM EffectRow +checkUEffRow (EffectRow effs t) = do + effs' <- liftM S.fromList $ mapM checkUEff $ toList effs t' <- forM t $ \tv -> lookupVarName EffKind tv return $ EffectRow effs' t' where @@ -408,6 +406,13 @@ checkUEff (EffectRow effs t) = do constrainEq ty ty' return v' +checkUEff :: Effect -> UInferM Effect +checkUEff eff = case eff of + RWSEffect rws region -> do + (Var (v:>TyKind)) <- lookupSourceVar (region:>()) + return $ RWSEffect rws v + ExceptionEffect -> return ExceptionEffect + data CaseAltIndex = ConAlt Int | VariantAlt Label Int | VariantTailAlt (LabeledItems ()) diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index 61c56ec6a..a07cb7f91 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -665,14 +665,18 @@ spaced xs = hsep $ map p $ toList xs instance Pretty EffectRow where pretty Pure = mempty pretty (EffectRow effs tailVar) = - braces $ hsep (punctuate "," (fmap prettyEff (toList effs))) <> tailStr + braces $ hsep (punctuate "," (map p (toList effs))) <> tailStr where - prettyEff (effName, region) = p effName <+> p region tailStr = case tailVar of Nothing -> mempty Just v -> "|" <> p v -instance Pretty EffectName where +instance Pretty Effect where + pretty eff = case eff of + RWSEffect rws h -> p rws <+> p h + ExceptionEffect -> "Except" + +instance Pretty RWS where pretty eff = case eff of Reader -> "Read" Writer -> "Accum" diff --git a/src/lib/Parallelize.hs b/src/lib/Parallelize.hs index 445591f7a..86aef81d3 100644 --- a/src/lib/Parallelize.hs +++ b/src/lib/Parallelize.hs @@ -71,8 +71,7 @@ parallelTraverseExpr expr = case expr of refs <- gets activeAccs let allowedRegions = foldMap (\(varType -> RefTy (Var reg) _) -> reg @> ()) refs (EffectRow bodyEffs t) <- substEmbedR $ functionEffs fbody - let onlyAllowedEffects = flip all (toList bodyEffs) $ \(eff, reg) -> - eff == Writer && reg `isin` allowedRegions + let onlyAllowedEffects = all (parallelizableEffect allowedRegions) $ toList bodyEffs case t == Nothing && onlyAllowedEffects of True -> do b' <- substEmbedR b @@ -98,6 +97,12 @@ parallelTraverseExpr expr = case expr of disallowRef ~(Var refVar) = modify $ \accEnv -> accEnv { activeAccs = activeAccs accEnv `envDiff` (refVar @> ()) } +parallelizableEffect :: Env () -> Effect -> Bool +parallelizableEffect allowedRegions effect = case effect of + RWSEffect Writer h | h `isin` allowedRegions -> True + -- TODO: we should be able to parallelize the exception effect too + _ -> False + -- Precondition: This is never called with no binders in the loop env buildParallelBlock :: ABlock -> LoopM Atom buildParallelBlock ablock@(ABlock decls result) = do diff --git a/src/lib/Parser.hs b/src/lib/Parser.hs index 722d85892..34d29c4d2 100644 --- a/src/lib/Parser.hs +++ b/src/lib/Parser.hs @@ -461,15 +461,19 @@ effects :: Parser EffectRow effects = braces someEffects <|> return Pure where someEffects = do - effs <- liftM2 (,) effectName (lowerName <|> upperName) `sepBy` sym "," + effs <- effect `sepBy` sym "," v <- optional $ symbol "|" >> lowerName return $ EffectRow (S.fromList effs) v -effectName :: Parser EffectName -effectName = (keyWord WriteKW $> Writer) - <|> (keyWord ReadKW $> Reader) - <|> (keyWord StateKW $> State) - "effect name (Accum|Read|State)" +effect :: Parser Effect +effect = (RWSEffect <$> rwsName <*> anyCaseName) + <|> (keyWord ExceptKW $> ExceptionEffect) + "effect (Accum h | Read h | State h | Except)" + +rwsName :: Parser RWS +rwsName = (keyWord WriteKW $> Writer) + <|> (keyWord ReadKW $> Reader) + <|> (keyWord StateKW $> State) uLamExpr :: Parser UExpr uLamExpr = do @@ -998,6 +1002,7 @@ type Lexer = Parser data KeyWord = DefKW | ForKW | For_KW | RofKW | Rof_KW | CaseKW | OfKW | ReadKW | WriteKW | StateKW | DataKW | InterfaceKW | InstanceKW | WhereKW | IfKW | ThenKW | ElseKW | DoKW + | ExceptKW upperName :: Lexer Name upperName = liftM mkName $ label "upper-case name" $ lexeme $ @@ -1007,6 +1012,9 @@ lowerName :: Lexer Name lowerName = liftM mkName $ label "lower-case name" $ lexeme $ checkNotKeyword $ (:) <$> lowerChar <*> many nameTailChar +anyCaseName :: Lexer Name +anyCaseName = lowerName <|> upperName + checkNotKeyword :: Parser String -> Parser String checkNotKeyword p = try $ do s <- p @@ -1030,6 +1038,7 @@ keyWord kw = lexeme $ try $ string s >> notFollowedBy nameTailChar ReadKW -> "Read" WriteKW -> "Accum" StateKW -> "State" + ExceptKW -> "Except" DataKW -> "data" InterfaceKW -> "interface" InstanceKW -> "instance" @@ -1038,7 +1047,7 @@ keyWord kw = lexeme $ try $ string s >> notFollowedBy nameTailChar keyWordStrs :: [String] keyWordStrs = ["def", "for", "for_", "rof", "rof_", "case", "of", "llam", - "Read", "Write", "Accum", "data", "interface", + "Read", "Write", "Accum", "Except", "data", "interface", "instance", "where", "if", "then", "else", "do"] fieldLabel :: Lexer Label diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index dcbb2b4a4..55823e08e 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -16,7 +16,7 @@ module Syntax ( Type, Kind, BaseType (..), ScalarBaseType (..), - Effect, EffectName (..), EffectRow (..), + Effect (..), RWS (..), EffectRow (..), ClassName (..), TyQual (..), SrcPos, Var, Binder, Block (..), Decl (..), Expr (..), Atom (..), ArrowP (..), Arrow, PrimTC (..), Abs (..), PrimExpr (..), PrimCon (..), LitVal (..), PrimEffect (..), PrimOp (..), @@ -422,10 +422,11 @@ showPrimName prim = primNameToStr $ fmap (const ()) prim -- === effects === -type Effect = (EffectName, Name) data EffectRow = EffectRow (S.Set Effect) (Maybe Name) deriving (Show, Eq, Generic) -data EffectName = Reader | Writer | State deriving (Show, Eq, Ord, Generic) + +data RWS = Reader | Writer | State deriving (Show, Eq, Ord, Generic) +data Effect = RWSEffect RWS Name | ExceptionEffect deriving (Show, Eq, Ord, Generic) pattern Pure :: EffectRow pattern Pure <- ((\(EffectRow effs t) -> (S.null effs, t)) -> (True, Nothing)) @@ -808,7 +809,11 @@ instance BindsUVars SourceBlock where instance HasUVars EffectRow where freeUVars (EffectRow effs tailVar) = - foldMap (nameAsEnv . snd) effs <> foldMap nameAsEnv tailVar + foldMap freeUVars effs <> foldMap nameAsEnv tailVar + +instance HasUVars Effect where + freeUVars (RWSEffect _ h) = nameAsEnv h + freeUVars (ExceptionEffect) = mempty instance HasUVars a => HasUVars (LabeledItems a) where freeUVars (LabeledItems items) = foldMap freeUVars items @@ -1124,13 +1129,22 @@ instance Subst Module where where Abs decls' bindings' = subst env $ Abs decls bindings instance HasVars EffectRow where - freeVars (EffectRow row t) = - foldMap (\(_,v) -> v@>(TyKind , UnknownBinder)) row - <> foldMap (\v -> v@>(EffKind, UnknownBinder)) t + freeVars (EffectRow row t) = foldMap freeVars row + <> foldMap (\v -> v@>(EffKind, UnknownBinder)) t instance Subst EffectRow where - subst (env, _) (EffectRow row t) = extendEffRow - (S.map (\(effName, v) -> (effName, substName env v)) row) - (substEffTail env t) + subst env (EffectRow row t) = extendEffRow row' t' + where + row' = S.map (subst env) row + t' = substEffTail (fst env) t + +instance HasVars Effect where + freeVars eff = case eff of + RWSEffect _ v -> v@>(TyKind , UnknownBinder) + ExceptionEffect -> mempty +instance Subst Effect where + subst (env,_) eff = case eff of + RWSEffect rws v -> RWSEffect rws (substName env v) + ExceptionEffect -> ExceptionEffect instance HasVars BinderInfo where freeVars binfo = case binfo of @@ -1577,7 +1591,8 @@ instance Store Atom instance Store Expr instance Store Block instance Store Decl -instance Store EffectName +instance Store RWS +instance Store Effect instance Store EffectRow instance Store Direction instance Store UnOp diff --git a/src/lib/Type.hs b/src/lib/Type.hs index 76c8aa25c..2013f5a0f 100644 --- a/src/lib/Type.hs +++ b/src/lib/Type.hs @@ -271,10 +271,10 @@ exprEffs expr = case expr of App f _ -> functionEffs f Op op -> case op of PrimEffect ref m -> case m of - MGet -> oneEffect (State, h) - MPut _ -> oneEffect (State, h) - MAsk -> oneEffect (Reader, h) - MTell _ -> oneEffect (Writer, h) + MGet -> oneEffect (RWSEffect State h) + MPut _ -> oneEffect (RWSEffect State h) + MAsk -> oneEffect (RWSEffect Reader h) + MTell _ -> oneEffect (RWSEffect Writer h) where RefTy (Var (h:>_)) _ = getType ref IOAlloc _ _ -> oneEffect ioEffect IOFree _ -> oneEffect ioEffect @@ -288,16 +288,16 @@ exprEffs expr = case expr of While cond body -> functionEffs cond <> functionEffs body Linearize _ -> mempty -- Body has to be a pure function Transpose _ -> mempty -- Body has to be a pure function - RunReader _ f -> handleRunner Reader f - RunWriter f -> handleRunner Writer f - RunState _ f -> handleRunner State f + RunReader _ f -> handleRWSRunner Reader f + RunWriter f -> handleRWSRunner Writer f + RunState _ f -> handleRWSRunner State f PTileReduce _ _ -> mempty RunIO ~(Lam (Abs _ (PlainArrow (EffectRow effs t), _))) -> EffectRow (S.delete ioEffect effs) t Case _ alts _ -> foldMap (\(Abs _ block) -> blockEffs block) alts where - handleRunner effName ~(BinaryFunVal (Bind (h:>_)) _ (EffectRow effs t) _) = - EffectRow (S.delete (effName, h) effs) t + handleRWSRunner rws ~(BinaryFunVal (Bind (h:>_)) _ (EffectRow effs t) _) = + EffectRow (S.delete (RWSEffect rws h) effs) t functionEffs :: Atom -> EffectRow functionEffs f = case getType f of @@ -474,7 +474,9 @@ addExpr x m = modifyErr m $ \e -> case e of checkEffRow :: EffectRow -> TypeM () checkEffRow (EffectRow effs effTail) = do - forM_ effs $ \(_, v) -> Var (v:>TyKind) |: TyKind + forM_ effs $ \eff -> case eff of + RWSEffect _ v -> Var (v:>TyKind) |: TyKind + ExceptionEffect -> return () forM_ effTail $ \v -> do checkWithEnv $ \(env, _) -> case envLookup env (v:>()) of Nothing -> throw CompilerErr $ "Lookup failed: " ++ pprint v @@ -504,7 +506,7 @@ oneEffect :: Effect -> EffectRow oneEffect eff = EffectRow (S.singleton eff) Nothing ioEffect :: Effect -ioEffect = (State, theWorld) +ioEffect = RWSEffect State theWorld -- === labeled row types === @@ -698,10 +700,10 @@ typeCheckOp op = case op of PrimEffect ref m -> do TC (RefType ~(Just (Var (h':>TyKind))) s) <- typeCheck ref case m of - MGet -> declareEff (State , h') $> s - MPut x -> x|:s >> declareEff (State , h') $> UnitTy - MAsk -> declareEff (Reader, h') $> s - MTell x -> x|:s >> declareEff (Writer, h') $> UnitTy + MGet -> declareEff (RWSEffect State h') $> s + MPut x -> x|:s >> declareEff (RWSEffect State h') $> UnitTy + MAsk -> declareEff (RWSEffect Reader h') $> s + MTell x -> x|:s >> declareEff (RWSEffect Writer h') $> UnitTy IndexRef ref i -> do RefTy h (TabTyAbs a) <- typeCheck ref i |: absArgType a @@ -871,12 +873,12 @@ typeCheckHof hof = case hof of Pi (Abs (Ignore a) (LinArrow, b)) <- typeCheck f return $ b --@ a RunReader r f -> do - (resultTy, readTy) <- checkAction Reader f + (resultTy, readTy) <- checkRWSAction Reader f r |: readTy return resultTy - RunWriter f -> uncurry PairTy <$> checkAction Writer f + RunWriter f -> uncurry PairTy <$> checkRWSAction Writer f RunState s f -> do - (resultTy, stateTy) <- checkAction State f + (resultTy, stateTy) <- checkRWSAction State f s |: stateTy return $ PairTy resultTy stateTy RunIO f -> do @@ -884,12 +886,12 @@ typeCheckHof hof = case hof of extendAllowedEffect ioEffect $ declareEffs eff return resultTy -checkAction :: EffectName -> Atom -> TypeM (Type, Type) -checkAction effName f = do +checkRWSAction :: RWS -> Atom -> TypeM (Type, Type) +checkRWSAction rws f = do BinaryFunTy (Bind regionBinder) refBinder eff resultTy <- typeCheck f regionName:>_ <- return regionBinder let region = Var regionBinder - extendAllowedEffect (effName, regionName) $ declareEffs eff + extendAllowedEffect (RWSEffect rws regionName) $ declareEffs eff checkEq (varAnn regionBinder) TyKind RefTy region' referentTy <- return $ binderAnn refBinder checkEq region' region From d449ebf944feccf1bcc57c3673c65f25d3cfdb55 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Wed, 30 Dec 2020 14:22:28 -0500 Subject: [PATCH 054/105] Add `throw` and `catch` primitives and lower them away in simp pass. --- lib/prelude.dx | 13 +++++++++++ makefile | 2 +- src/lib/Simplify.hs | 47 ++++++++++++++++++++++++++++++++++++++-- src/lib/Syntax.hs | 27 +++++++++++++++++++++-- src/lib/Type.hs | 15 +++++++++++-- tests/exception-tests.dx | 35 ++++++++++++++++++++++++++++++ 6 files changed, 132 insertions(+), 7 deletions(-) create mode 100644 tests/exception-tests.dx diff --git a/lib/prelude.dx b/lib/prelude.dx index b6a68f844..f4c265a4f 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -1524,3 +1524,16 @@ def evalpoly (_:VSpace v) ?=> (coefficients:n=>v) (x:Float) : v = fold zero \i c. coefficients.i + x .* c def dex_test_mode (():Unit) : Bool = unsafeIO do checkEnv "DEX_TEST_MODE" + +'## Exception effect + +def catch (f:Unit -> {Except|eff} a) : {|eff} Maybe a = + %catchException f + +def throw (_:Unit) : {Except} a = + %throwException a + +def assert (b:Bool) : {Except} Unit = + if b + then () + else throw () diff --git a/makefile b/makefile index ae20b85cf..bf690edfc 100644 --- a/makefile +++ b/makefile @@ -86,7 +86,7 @@ example-names = mandelbrot pi sierpinski rejection-sampler \ sgd chol test-names = uexpr-tests adt-tests type-tests eval-tests show-tests \ - shadow-tests monad-tests io-tests \ + shadow-tests monad-tests io-tests exception-tests \ ad-tests parser-tests serialize-tests \ record-variant-tests typeclass-tests complex-tests trig-tests diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index 1442f81b6..f96c6a32b 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -9,7 +9,6 @@ module Simplify (simplifyModule, simplifyCase, splitSimpModule) where import Control.Monad -import Control.Monad.Identity import Control.Monad.Reader import Data.Maybe import Data.Foldable (toList) @@ -17,6 +16,7 @@ import Data.Functor import Data.List (partition, elemIndex) import qualified Data.List.NonEmpty as NE import qualified Data.Map.Strict as M +import qualified Data.Set as S import Autodiff import Env @@ -27,7 +27,7 @@ import Type import PPrint import Util -type SimplifyM = SubstEmbedT Identity +type SimplifyM = SubstEmbed simplifyModule :: TopEnv -> Module -> Module simplifyModule scope (Module Core decls bindings) = do @@ -479,6 +479,49 @@ simplifyHof hof = case hof of ~(lam', recon) <- simplifyLam lam ans <- emit $ Hof $ RunIO lam' applyRecon recon ans + CatchException lam -> do + ~(Lam (Abs _ (_, body)), Nothing) <- simplifyLam lam + dropSub $ exceptToMaybeBlock body where applyRecon Nothing x = return x applyRecon (Just f) x = f x + +exceptToMaybeBlock :: Block -> SubstEmbed Atom +exceptToMaybeBlock (Block Empty result) = exceptToMaybeExpr result +exceptToMaybeBlock (Block (Nest (Let _ b expr) decls) result) = do + maybeResult <- exceptToMaybeExpr expr + case maybeResult of + -- These two cases are just an optimization + JustAtom _ x -> extendR (b@>x) $ exceptToMaybeBlock $ Block decls result + NothingAtom a -> return $ NothingAtom a + _ -> do + blockTy <- substEmbedR $ getType result + let nothingPath = Abs Empty $ Block Empty $ Atom $ NothingAtom blockTy + b' <- mapM substEmbedR b + justPath <- buildNAbs (Nest b' Empty) $ \[x] -> + extendR (b@>x) $ exceptToMaybeBlock $ Block decls result + emit $ Case maybeResult [nothingPath, justPath] (MaybeTy blockTy) + +exceptToMaybeExpr :: Expr -> SubstEmbed Atom +exceptToMaybeExpr expr = do + case expr of + Case e alts resultTy -> do + e' <- substEmbedR e + resultTy' <- substEmbedR $ MaybeTy resultTy + alts' <- forM alts $ \(Abs bs body) -> do + bs' <- mapM (mapM substEmbedR) bs + buildNAbs bs' $ \xs -> extendR (newEnv bs' xs) $ exceptToMaybeBlock body + emit $ Case e' alts' resultTy' + Atom x -> substEmbedR $ JustAtom (getType x) x + Op (ThrowException a) -> substEmbedR $ NothingAtom a + _ | not (hasExceptions expr) -> do + x <- substEmbedR expr >>= emit + return $ JustAtom (getType x) x + | otherwise -> + error $ "Unexpected exception-throwing expression: " ++ pprint expr + +hasExceptions :: Expr -> Bool +hasExceptions expr = case t of + Nothing -> ExceptionEffect `S.member` effs + Just _ -> error "Shouldn't have tail left" + where (EffectRow effs t) = exprEffs expr diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index 55823e08e..5a8de99c9 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -49,6 +49,7 @@ module Syntax ( varType, binderType, isTabTy, LogLevel (..), IRVariant (..), applyIntBinOp, applyIntCmpOp, applyFloatBinOp, applyFloatUnOp, getIntLit, getFloatLit, sizeOf, vectorWidth, + pattern MaybeTy, pattern JustAtom, pattern NothingAtom, pattern IdxRepTy, pattern IdxRepVal, pattern IIdxRepVal, pattern IIdxRepTy, pattern TagRepTy, pattern TagRepVal, pattern Word8Ty, pattern IntLitExpr, pattern FloatLitExpr, @@ -340,6 +341,7 @@ data PrimOp e = | ToOrdinal e | IdxSetSize e | ThrowError e + | ThrowException e -- Catchable exceptions (unlike `ThrowError`) | CastOp e e -- Type, then value. See Type.hs for valid coercions. -- Extensible record and variant operations: -- Add fields to a record (on the left). Left arg contains values to add. @@ -368,6 +370,7 @@ data PrimHof e = | RunWriter e | RunState e e | RunIO e + | CatchException e | Linearize e | Transpose e | PTileReduce e e -- index set, thread body @@ -1489,7 +1492,25 @@ pattern Unlabeled as <- (_getUnlabeled -> Just as) Just ne -> LabeledItems (M.singleton InternalSingletonLabel ne) Nothing -> NoLabeledItems - -- TODO: Enable once https://gitlab.haskell.org//ghc/ghc/issues/13363 is fixed... +maybeDataDef :: DataDef +maybeDataDef = DataDef (GlobalName "Maybe") (Nest (Bind ("a":>TyKind)) Empty) + [ DataConDef (GlobalName "Nothing") Empty + , DataConDef (GlobalName "Just" ) (Nest (Ignore (Var ("a":>TyKind))) Empty)] + +pattern MaybeTy :: Type -> Type +pattern MaybeTy a = TypeCon MaybeDataDef [a] + +pattern MaybeDataDef :: DataDef +pattern MaybeDataDef <- ((\def -> def == maybeDataDef) -> True) + where MaybeDataDef = maybeDataDef + +pattern NothingAtom :: Type -> Atom +pattern NothingAtom ty = DataCon MaybeDataDef [ty] 0 [] + +pattern JustAtom :: Type -> Atom -> Atom +pattern JustAtom ty x = DataCon MaybeDataDef [ty] 1 [x] + +-- TODO: Enable once https://gitlab.haskell.org//ghc/ghc/issues/13363 is fixed... -- {-# COMPLETE TypeVar, ArrowType, TabTy, Forall, TypeAlias, Effect, NoAnn, TC #-} -- TODO: Can we derive these generically? Or use Show/Read? @@ -1518,7 +1539,8 @@ builtinNames = M.fromList , ("idxSetSize" , OpExpr $ IdxSetSize ()) , ("unsafeFromOrdinal", OpExpr $ UnsafeFromOrdinal () ()) , ("toOrdinal" , OpExpr $ ToOrdinal ()) - , ("throwError" , OpExpr $ ThrowError ()) + , ("throwError" , OpExpr $ ThrowError ()) + , ("throwException" , OpExpr $ ThrowException ()) , ("ask" , OpExpr $ PrimEffect () $ MAsk) , ("tell" , OpExpr $ PrimEffect () $ MTell ()) , ("get" , OpExpr $ PrimEffect () $ MGet) @@ -1533,6 +1555,7 @@ builtinNames = M.fromList , ("runWriter" , HofExpr $ RunWriter ()) , ("runState" , HofExpr $ RunState () ()) , ("runIO" , HofExpr $ RunIO ()) + , ("catchException" , HofExpr $ CatchException ()) , ("tiled" , HofExpr $ Tile 0 () ()) , ("tiledd" , HofExpr $ Tile 1 () ()) , ("TyKind" , TCExpr $ TypeKind) diff --git a/src/lib/Type.hs b/src/lib/Type.hs index 2013f5a0f..21accdc0d 100644 --- a/src/lib/Type.hs +++ b/src/lib/Type.hs @@ -276,6 +276,7 @@ exprEffs expr = case expr of MAsk -> oneEffect (RWSEffect Reader h) MTell _ -> oneEffect (RWSEffect Writer h) where RefTy (Var (h:>_)) _ = getType ref + ThrowException _ -> oneEffect ExceptionEffect IOAlloc _ _ -> oneEffect ioEffect IOFree _ -> oneEffect ioEffect PtrLoad _ -> oneEffect ioEffect @@ -427,6 +428,7 @@ instance CoreVariant (PrimTC a) where instance CoreVariant (PrimOp a) where checkVariant e = case e of + ThrowException _ -> goneBy Simp Select _ _ _ -> alwaysAllowed -- TODO: only scalar select after Simp _ -> alwaysAllowed @@ -447,6 +449,7 @@ instance CoreVariant (PrimHof a) where Transpose _ -> goneBy Simp Tile _ _ _ -> alwaysAllowed PTileReduce _ _ -> absentUntil Simp -- really absent until parallelization + CatchException _ -> goneBy Simp -- TODO: namespace restrictions? alwaysAllowed :: VariantM () @@ -758,7 +761,9 @@ typeCheckOp op = case op of i |: TC (IntRange (IdxRepVal 0) (IdxRepVal $ fromIntegral vectorWidth)) return $ BaseTy $ Scalar sb ThrowError ty -> ty|:TyKind $> ty - -- TODO: this should really be a 32 bit integer for unicode code point: but for now is 8 bit ASCII code point + ThrowException ty -> do + declareEff ExceptionEffect + ty|:TyKind $> ty CastOp t@(Var _) _ -> t |: TyKind $> t CastOp destTy e -> do sourceTy <- typeCheck e @@ -882,9 +887,15 @@ typeCheckHof hof = case hof of s |: stateTy return $ PairTy resultTy stateTy RunIO f -> do - FunTy _ eff resultTy <- typeCheck f + FunTy b eff resultTy <- typeCheck f + checkEq (binderAnn b) UnitTy extendAllowedEffect ioEffect $ declareEffs eff return resultTy + CatchException f -> do + FunTy b eff resultTy <- typeCheck f + checkEq (binderAnn b) UnitTy + extendAllowedEffect ExceptionEffect $ declareEffs eff + return $ MaybeTy resultTy checkRWSAction :: RWS -> Atom -> TypeM (Type, Type) checkRWSAction rws f = do diff --git a/tests/exception-tests.dx b/tests/exception-tests.dx new file mode 100644 index 000000000..8ba611385 --- /dev/null +++ b/tests/exception-tests.dx @@ -0,0 +1,35 @@ + + +def checkFloatInUnitInterval (x:Float) : {Except} Float = + assert $ x >= 0.0 + assert $ x <= 1.0 + x + +:p catch do assert False +> Nothing + +:p catch do assert True +> (Just ()) + +:p catch do checkFloatInUnitInterval 1.2 +> Nothing + +:p catch do checkFloatInUnitInterval (-1.2) +> Nothing + +:p catch do checkFloatInUnitInterval 0.2 +> (Just 0.2) + +:p snd $ withState 0 \ref. + catch do + ref := 1 + assert False + ref := 2 +> 1 + +-- Doesn't work yet +-- :p catch do +-- withState 0 \ref. +-- ref := 1 +-- assert False +-- ref := 2 From 003f1e23fc126f38b9addf3e91bad5843642a668 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Wed, 30 Dec 2020 15:07:35 -0500 Subject: [PATCH 055/105] Add a helper for emitting case expressions on `Maybe a` scrutinees. --- src/lib/Embed.hs | 12 +++++++++++- src/lib/Simplify.hs | 7 ++----- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/lib/Embed.hs b/src/lib/Embed.hs index 437e2a32d..04c4ce52b 100644 --- a/src/lib/Embed.hs +++ b/src/lib/Embed.hs @@ -24,7 +24,8 @@ module Embed (emit, emitTo, emitAnn, emitOp, buildDepEffLam, buildLamAux, buildP buildFor, buildForAux, buildForAnn, buildForAnnAux, emitBlock, unzipTab, isSingletonType, emitDecl, withNameHint, singletonTypeVal, scopedDecls, embedScoped, extendScope, checkEmbed, - embedExtend, unpackConsList, emitRunWriter, emitRunState, + embedExtend, unpackConsList, emitRunWriter, + emitRunState, emitMaybeCase, emitRunReader, tabGet, SubstEmbedT, SubstEmbed, runSubstEmbedT, traverseAtom, ptrOffset, ptrLoad, unsafePtrLoad, evalBlockE, substTraversalDef, @@ -345,6 +346,15 @@ unpackConsList xs = case getType xs of liftM (x:) $ unpackConsList rest _ -> error $ "Not a cons list: " ++ pprint (getType xs) +emitMaybeCase :: MonadEmbed m => Atom -> (m Atom) -> (Atom -> m Atom) -> m Atom +emitMaybeCase scrut nothingCase justCase = do + let (MaybeTy a) = getType scrut + nothingAlt <- buildNAbs Empty $ \[] -> nothingCase + justAlt <- buildNAbs (Nest (Bind ("x":>a)) Empty) $ \[x] -> justCase x + let (Abs _ nothingBody) = nothingAlt + let resultTy = getType nothingBody + emit $ Case scrut [nothingAlt, justAlt] resultTy + emitRunWriter :: MonadEmbed m => Name -> Type -> (Atom -> m Atom) -> m Atom emitRunWriter v ty body = do emit . Hof . RunWriter =<< mkBinaryEffFun Writer v ty body diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index f96c6a32b..3634828eb 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -496,11 +496,8 @@ exceptToMaybeBlock (Block (Nest (Let _ b expr) decls) result) = do NothingAtom a -> return $ NothingAtom a _ -> do blockTy <- substEmbedR $ getType result - let nothingPath = Abs Empty $ Block Empty $ Atom $ NothingAtom blockTy - b' <- mapM substEmbedR b - justPath <- buildNAbs (Nest b' Empty) $ \[x] -> - extendR (b@>x) $ exceptToMaybeBlock $ Block decls result - emit $ Case maybeResult [nothingPath, justPath] (MaybeTy blockTy) + emitMaybeCase maybeResult (return $ NothingAtom blockTy) $ \x -> do + extendR (b@>x) $ exceptToMaybeBlock $ Block decls result exceptToMaybeExpr :: Expr -> SubstEmbed Atom exceptToMaybeExpr expr = do From 2c09957bdc4ea8df7b381c0815b26c80a1568283 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Wed, 30 Dec 2020 16:06:21 -0500 Subject: [PATCH 056/105] Handle `for` expressions when lowering away exceptions. We want to preserve opportunities for parallelism even when we have exceptions. This lowering does that by discharding the exception at each index, producing a table of `Maybe`s (which can be evaluted in parallel) which are then combined. This interacts with state in a slightly surprising way. If the loop is stateful, we still run all the iterations (updating the state as we go) even if early iterations fail. It's much easier to implement `catMaybes : (n=>Maybe a) -> Maybe (n=>a)` in the prelude instead of using Embed directly. So I also added a utility for calling predlue-defined functions via Embed. --- lib/prelude.dx | 5 ++--- src/lib/Embed.hs | 23 +++++++++++++++++++++-- src/lib/Simplify.hs | 16 +++++++++++++++- tests/exception-tests.dx | 25 +++++++++++++++++++++++++ 4 files changed, 63 insertions(+), 6 deletions(-) diff --git a/lib/prelude.dx b/lib/prelude.dx index f4c265a4f..6ad8aba3a 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -1443,9 +1443,8 @@ def fromJust (x:Maybe a) : a = case x of Just x' -> x' def anySat (f:a -> Bool) (xs:n=>a) : Bool = any (map f xs) --- In Haskell this would just be `mapM`. The equivalent for us would be having --- an exception effect. -def seqMaybes (xs : n=>Maybe a) : Maybe (n => a) = +-- XXX: we use this internally so it's important to make the type args explicit +def seqMaybes (n:Type) ?-> (a:Type) ?-> (xs : n=>Maybe a) : Maybe (n => a) = -- is it possible to implement this safely? (i.e. without using partial -- functions) case anySat isNothing xs of diff --git a/src/lib/Embed.hs b/src/lib/Embed.hs index 04c4ce52b..cdc9a097a 100644 --- a/src/lib/Embed.hs +++ b/src/lib/Embed.hs @@ -13,7 +13,7 @@ module Embed (emit, emitTo, emitAnn, emitOp, buildDepEffLam, buildLamAux, buildPi, getAllowedEffects, withEffects, modifyAllowedEffects, buildLam, EmbedT, Embed, MonadEmbed, buildScoped, runEmbedT, - runSubstEmbed, runEmbed, getScope, embedLook, + runSubstEmbed, runEmbed, getScope, embedLook, liftEmbed, app, add, mul, sub, neg, div', iadd, imul, isub, idiv, ilt, ieq, @@ -24,7 +24,7 @@ module Embed (emit, emitTo, emitAnn, emitOp, buildDepEffLam, buildLamAux, buildP buildFor, buildForAux, buildForAnn, buildForAnnAux, emitBlock, unzipTab, isSingletonType, emitDecl, withNameHint, singletonTypeVal, scopedDecls, embedScoped, extendScope, checkEmbed, - embedExtend, unpackConsList, emitRunWriter, + embedExtend, unpackConsList, emitRunWriter, applyPreludeFunction, emitRunState, emitMaybeCase, emitRunReader, tabGet, SubstEmbedT, SubstEmbed, runSubstEmbedT, traverseAtom, ptrOffset, ptrLoad, unsafePtrLoad, @@ -42,6 +42,7 @@ import Control.Monad.Writer hiding (Alt) import Control.Monad.Identity import Control.Monad.State.Strict import Data.Foldable (toList) +import Data.String (fromString) import Data.Tuple (swap) import GHC.Stack @@ -322,6 +323,15 @@ appReduce (Lam (Abs v (_, b))) a = runReaderT (evalBlockE substTraversalDef b) (v @> a) appReduce _ _ = error "appReduce expected a lambda as the first argument" +-- TODO: this would be more convenient if we could add type inference too +applyPreludeFunction :: MonadEmbed m => String -> [Atom] -> m Atom +applyPreludeFunction s xs = do + scope <- getScope + case envLookup scope fname of + Nothing -> error $ "Function not defined yet: " ++ s + Just (ty, _) -> naryApp (Var (fname:>ty)) xs + where fname = GlobalName (fromString s) + appTryReduce :: MonadEmbed m => Atom -> Atom -> m Atom appTryReduce f x = case f of Lam _ -> appReduce f x @@ -387,6 +397,7 @@ buildForAnn ann i body = fst <$> buildForAnnAux ann i (\x -> (,()) <$> body x) buildForAux :: MonadEmbed m => Direction -> Binder -> (Atom -> m (Atom, a)) -> m (Atom, a) buildForAux = buildForAnnAux . RegularFor +-- Do we need this variant? buildFor :: MonadEmbed m => Direction -> Binder -> (Atom -> m Atom) -> m Atom buildFor = buildForAnn . RegularFor @@ -588,6 +599,14 @@ scopedDecls m = do (ans, (_, decls)) <- embedScoped m return (ans, decls) +liftEmbed :: MonadEmbed m => Embed a -> m a +liftEmbed action = do + envR <- embedAsk + envC <- embedLook + let (ans, envC') = runIdentity $ runEmbedT' action (envR, envC) + embedExtend envC' + return ans + -- === generic traversal === type TraversalDef m = (Decl -> m SubstEnv, Expr -> m Expr, Atom -> m Atom) diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index 3634828eb..fc787e9d6 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -506,11 +506,15 @@ exceptToMaybeExpr expr = do e' <- substEmbedR e resultTy' <- substEmbedR $ MaybeTy resultTy alts' <- forM alts $ \(Abs bs body) -> do - bs' <- mapM (mapM substEmbedR) bs + bs' <- substEmbedR bs buildNAbs bs' $ \xs -> extendR (newEnv bs' xs) $ exceptToMaybeBlock body emit $ Case e' alts' resultTy' Atom x -> substEmbedR $ JustAtom (getType x) x Op (ThrowException a) -> substEmbedR $ NothingAtom a + Hof (For ann ~(Lam (Abs b (_, body)))) -> do + b' <- substEmbedR b + maybes <- buildForAnn ann b' $ \i -> extendR (b@>i) $ exceptToMaybeBlock body + catMaybesE maybes _ | not (hasExceptions expr) -> do x <- substEmbedR expr >>= emit return $ JustAtom (getType x) x @@ -522,3 +526,13 @@ hasExceptions expr = case t of Nothing -> ExceptionEffect `S.member` effs Just _ -> error "Shouldn't have tail left" where (EffectRow effs t) = exprEffs expr + +catMaybesE :: MonadEmbed m => Atom -> m Atom +catMaybesE maybes = simplifyEmbed $ do + let (TabTy b (MaybeTy a)) = getType maybes + applyPreludeFunction "seqMaybes" [binderAnn b, a, maybes] + +simplifyEmbed :: MonadEmbed m => m Atom -> m Atom +simplifyEmbed m = do + block <- buildScoped m + liftEmbed $ runReaderT (simplifyBlock block) mempty diff --git a/tests/exception-tests.dx b/tests/exception-tests.dx index 8ba611385..8e62816a4 100644 --- a/tests/exception-tests.dx +++ b/tests/exception-tests.dx @@ -27,9 +27,34 @@ def checkFloatInUnitInterval (x:Float) : {Except} Float = ref := 2 > 1 +:p catch do + for i:(Fin 5). + if ordinal i > 3 + then throw () + else 23 +> Nothing + +:p catch do + for i:(Fin 3). + if ordinal i > 3 + then throw () + else 23 +> (Just [23, 23, 23]) + +-- Is this the result we want? +:p snd $ withState zero \ref. + catch do + for i:(Fin 6). + if (ordinal i `rem` 2) == 0 + then throw () + else () + ref!i := 1 +> [0, 1, 0, 1, 0, 1] + -- Doesn't work yet -- :p catch do -- withState 0 \ref. -- ref := 1 -- assert False -- ref := 2 + From 5be588fc14d981b57028b550934f9c7b1523498b Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Tue, 29 Dec 2020 20:20:31 -0500 Subject: [PATCH 057/105] Expose `view`, a "lazy" variant of `for`. Whereas `for` evaluates its body for each index immediately and carries out any effects right there and then, `view` behaves like an ordinary lambda: its body is only evaluated when you ask for the element value at a particular index. Views can't have effects. This is already how tables are represented internally. This change just exposes that representation to the surface language by adding a new parser rule. Using `view` we can write functions like `transpose` and `iota` and be sure they'll only take O(1) time. Often compiler optimizations can achieve the same thing automatically, but having to rely on that makes it hard to reason about performance. --- lib/prelude.dx | 17 ++++++++--------- misc/dex.el | 2 +- src/lib/Parser.hs | 16 +++++++++++++--- tests/eval-tests.dx | 4 ++-- 4 files changed, 24 insertions(+), 15 deletions(-) diff --git a/lib/prelude.dx b/lib/prelude.dx index b6a68f844..1f2ca1397 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -397,7 +397,7 @@ def Fin (n:Int) : Type = Range 0 n def ordinal (i:a) : Int = %toOrdinal i def size (n:Type) : Int = %idxSetSize n def unsafeFromOrdinal (n : Type) (i : Int) : n = %unsafeFromOrdinal n i -def iota (n:Type) : n=>Int = for i. ordinal i +def iota (n:Type) : n=>Int = view i. ordinal i -- TODO: we want Eq and Ord for all index sets, not just `Fin n` @instance @@ -529,9 +529,9 @@ pi : Float = 3.141592653589793 def id (x:a) : a = x def dup (x:a) : (a & a) = (x, x) def map (f:a->{|eff} b) (xs: n=>a) : {|eff} (n=>b) = for i. f xs.i -def zip (xs:n=>a) (ys:n=>b) : (n=>(a&b)) = for i. (xs.i, ys.i) +def zip (xs:n=>a) (ys:n=>b) : (n=>(a&b)) = view i. (xs.i, ys.i) def unzip (xys:n=>(a&b)) : (n=>a & n=>b) = (map fst xys, map snd xys) -def fanout (n:Type) (x:a) : n=>a = for i. x +def fanout (n:Type) (x:a) : n=>a = view i. x def sq (d:Mul a) ?=> (x:a) : a = x * x def abs (_:Add a) ?=> (_:Ord a) ?=> (x:a) : a = select (x > zero) x (zero - x) def mod (x:Int) (y:Int) : Int = rem (y + rem x y) y @@ -555,7 +555,7 @@ def reduce (identity:a) (combine:(a->a->a)) (xs:n=>a) : a = -- TODO: call this `scan` and call the current `scan` something else def scan' (init:a) (body:n->a->a) : n=>a = snd $ scan init \i x. dup (body i x) -- TODO: allow tables-via-lambda and get rid of this -def fsum (xs:n->Float) : Float = snd $ withAccum \ref. for i. ref += xs i +def fsum (xs:n=>Float) : Float = snd $ withAccum \ref. for i. ref += xs i def sum (_: Add v) ?=> (xs:n=>v) : v = reduce zero (+) xs def prod (_: Mul v) ?=> (xs:n=>v) : v = reduce one (*) xs def mean (n:Type) ?-> (xs:n=>Float) : Float = sum xs / IToF (size n) @@ -571,20 +571,19 @@ def linspace (n:Type) (low:Float) (high:Float) : n=>Float = dx = (high - low) / IToF (size n) for i:n. low + IToF (ordinal i) * dx -def transpose (x:n=>m=>a) : m=>n=>a = for i j. x.j.i -def vdot (x:n=>Float) (y:n=>Float) : Float = fsum \i. x.i * y.i +def transpose (x:n=>m=>a) : m=>n=>a = view i j. x.j.i +def vdot (x:n=>Float) (y:n=>Float) : Float = fsum view i. x.i * y.i def dot (_:VSpace v) ?=> (s:n=>Float) (vs:n=>v) : v = sum for j. s.j .* vs.j -- matmul. Better symbol to use? `@`? (**) : (l=>m=>Float) -> (m=>n=>Float) -> (l=>n=>Float) = \x y. - y' = transpose y - for i k. fsum \j. x.i.j * y'.k.j + for i k. fsum view j. x.i.j * y.j.k (**.) : (n=>m=>Float) -> (m=>Float) -> (n=>Float) = \mat v. for i. vdot mat.i v (.**) : (m=>Float) -> (n=>m=>Float) -> (n=>Float) = flip (**.) def inner (x:n=>Float) (mat:n=>m=>Float) (y:m=>Float) : Float = - fsum \(i,j). x.i * mat.i.j * y.j + fsum view (i,j). x.i * mat.i.j * y.j def eye (_:Eq n) ?=> : n=>n=>Float = for i j. select (i == j) 1.0 0.0 diff --git a/misc/dex.el b/misc/dex.el index 9381371a5..c7bf86dad 100644 --- a/misc/dex.el +++ b/misc/dex.el @@ -10,7 +10,7 @@ ("^'\\(.\\|\n.\\)*\n\n" . font-lock-comment-face) ("\\w+:" . font-lock-comment-face) ("^:\\w*" . font-lock-preprocessor-face) - ("\\bdef\\b\\|\\bfor\\b\\|\\brof\\b\\|\\bcase\\b\\|\\bdata\\b\\|\\bwhere\\b\\|\\bof\\b\\|\\bif\\b\\|\\bthen\\b\\|\\belse\\b\\|\\binterface\\b\\|\\binstance\\b\\|\\bdo\\b" . + ("\\bdef\\b\\|\\bfor\\b\\|\\brof\\b\\|\\bcase\\b\\|\\bdata\\b\\|\\bwhere\\b\\|\\bof\\b\\|\\bif\\b\\|\\bthen\\b\\|\\belse\\b\\|\\binterface\\b\\|\\binstance\\b\\|\\bdo\\b\\|\\bview\\b" . font-lock-keyword-face) ("--o" . font-lock-variable-name-face) ("[-.,!$^&*:~+/=<>|?\\\\]" . font-lock-variable-name-face) diff --git a/src/lib/Parser.hs b/src/lib/Parser.hs index 722d85892..593b3357d 100644 --- a/src/lib/Parser.hs +++ b/src/lib/Parser.hs @@ -168,6 +168,7 @@ leafExpr = parens (mayPair $ makeExprParser leafExpr ops) <|> uLit <|> uPiType <|> uLamExpr + <|> uViewExpr <|> uForExpr <|> caseExpr <|> ifExpr @@ -498,6 +499,14 @@ buildFor pos dir binders body = case binders of [] -> body b:bs -> WithSrc (Just pos) $ UFor dir b $ buildFor pos dir bs body +uViewExpr :: Parser UExpr +uViewExpr = do + keyWord ViewKW + bs <- some patAnn + argTerm + body <- blockOrExpr + return $ buildLam (zip bs (repeat TabArrow)) body + uForExpr :: Parser UExpr uForExpr = do ((dir, trailingUnit), pos) <- withPos $ @@ -997,7 +1006,7 @@ type Lexer = Parser data KeyWord = DefKW | ForKW | For_KW | RofKW | Rof_KW | CaseKW | OfKW | ReadKW | WriteKW | StateKW | DataKW | InterfaceKW - | InstanceKW | WhereKW | IfKW | ThenKW | ElseKW | DoKW + | InstanceKW | WhereKW | IfKW | ThenKW | ElseKW | DoKW | ViewKW upperName :: Lexer Name upperName = liftM mkName $ label "upper-case name" $ lexeme $ @@ -1034,12 +1043,13 @@ keyWord kw = lexeme $ try $ string s >> notFollowedBy nameTailChar InterfaceKW -> "interface" InstanceKW -> "instance" WhereKW -> "where" - DoKW -> "do" + DoKW -> "do" + ViewKW -> "view" keyWordStrs :: [String] keyWordStrs = ["def", "for", "for_", "rof", "rof_", "case", "of", "llam", "Read", "Write", "Accum", "data", "interface", - "instance", "where", "if", "then", "else", "do"] + "instance", "where", "if", "then", "else", "do", "view"] fieldLabel :: Lexer Label fieldLabel = label "field label" $ lexeme $ diff --git a/tests/eval-tests.dx b/tests/eval-tests.dx index b1d443bb4..2193d4399 100644 --- a/tests/eval-tests.dx +++ b/tests/eval-tests.dx @@ -404,12 +404,12 @@ litArr = [10, 5, 3] -- Not sure why the ordinary `sum/for` version doesn't work anymore :p n = 3 + 7 - fsum \i:(Fin n). 1.0 + fsum view i:(Fin n). 1.0 > 10. :p n = 4 - fsum \i:(Fin n). 1.0 + fsum view i:(Fin n). 1.0 > 4. :p From 0ecf6bbed5d5cdc86fc5720412d77f76173aaab3 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Wed, 30 Dec 2020 17:33:40 -0500 Subject: [PATCH 058/105] Start a pure-Dex parser combinator library --- lib/parser.dx | 27 +++++++++++++++++++++++++++ makefile | 2 +- tests/parser-combinator-tests.dx | 20 ++++++++++++++++++++ 3 files changed, 48 insertions(+), 1 deletion(-) create mode 100644 lib/parser.dx create mode 100644 tests/parser-combinator-tests.dx diff --git a/lib/parser.dx b/lib/parser.dx new file mode 100644 index 000000000..69c935284 --- /dev/null +++ b/lib/parser.dx @@ -0,0 +1,27 @@ + +def ParserHandle (h:Type) : Type = (String & Ref h Int) + +def Parser (a:Type) : Type = h:Type ?-> ParserHandle h -> {Except, State h} a + +def fromOrdinalExc (n:Type) (i:Int) : {Except} n = + if (0 <= i) && (i < size n) + then unsafeFromOrdinal _ i + else throw () + +def indexList (l:List a) (i:Int) : {Except} a = + (AsList n xs) = l + xs.(fromOrdinalExc _ i) + +def pChar (c:Char) : Parser Unit = \(s, posRef). + i = get posRef + c' = indexList s i + assert (c == c') + posRef := i + 1 + +def pEOF : Parser Unit = \(s, posRef). + assert $ get posRef >= listLength s + +def runParser (s:String) (parser:Parser a) : Maybe a = + fst $ withState 0 \pos. + catch $ do + parser (s, pos) diff --git a/makefile b/makefile index bf690edfc..23afc3aa0 100644 --- a/makefile +++ b/makefile @@ -87,7 +87,7 @@ example-names = mandelbrot pi sierpinski rejection-sampler \ test-names = uexpr-tests adt-tests type-tests eval-tests show-tests \ shadow-tests monad-tests io-tests exception-tests \ - ad-tests parser-tests serialize-tests \ + ad-tests parser-tests serialize-tests parser-combinator-tests \ record-variant-tests typeclass-tests complex-tests trig-tests lib-names = diagram plot png diff --git a/tests/parser-combinator-tests.dx b/tests/parser-combinator-tests.dx new file mode 100644 index 000000000..24162136f --- /dev/null +++ b/tests/parser-combinator-tests.dx @@ -0,0 +1,20 @@ + +include "parser.dx" + +parseABC : Parser Unit = \h. + pChar 'A' h + pChar 'B' h + pChar 'C' h + pEOF h + +:p runParser "AAA" parseABC +> Nothing + +:p runParser "ABCABC" parseABC +> Nothing + +:p runParser "AB" parseABC +> Nothing + +:p runParser "ABC" parseABC +> (Just ()) From 9ffb81f889605cf2b136a8a5930216c741796795 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Wed, 30 Dec 2020 18:16:04 -0500 Subject: [PATCH 059/105] Add some more parser combinators. We need `many` to do anything useful. I have a version that type checks but it fails at run time because we don't have exceptions with while loops. We also need in-place list updates for it to be efficient. --- lib/parser.dx | 62 ++++++++++++++++++++++++++++---- tests/parser-combinator-tests.dx | 20 +++++++++++ 2 files changed, 76 insertions(+), 6 deletions(-) diff --git a/lib/parser.dx b/lib/parser.dx index 69c935284..5f032c8da 100644 --- a/lib/parser.dx +++ b/lib/parser.dx @@ -1,17 +1,35 @@ -def ParserHandle (h:Type) : Type = (String & Ref h Int) -def Parser (a:Type) : Type = h:Type ?-> ParserHandle h -> {Except, State h} a +'Utilities unrelated to parsing def fromOrdinalExc (n:Type) (i:Int) : {Except} n = if (0 <= i) && (i < size n) then unsafeFromOrdinal _ i else throw () +-- TODO: allow this to happen in-place +-- TODO: if it takes too long to make that possible, start with a bounded version +def push (ref:Ref h (List a)) (x:a) : {State h} Unit = + l = get ref + ref := l <> AsList _ [x] + def indexList (l:List a) (i:Int) : {Except} a = (AsList n xs) = l xs.(fromOrdinalExc _ i) +'The Parser type + +def ParserHandle (h:Type) : Type = (String & Ref h Int) + +def Parser (a:Type) : Type = h:Type ?-> ParserHandle h -> {Except, State h} a + +def runParser (s:String) (parser:Parser a) : Maybe a = + fst $ withState 0 \pos. + catch $ do + parser (s, pos) + +'Primitive combinators + def pChar (c:Char) : Parser Unit = \(s, posRef). i = get posRef c' = indexList s i @@ -21,7 +39,39 @@ def pChar (c:Char) : Parser Unit = \(s, posRef). def pEOF : Parser Unit = \(s, posRef). assert $ get posRef >= listLength s -def runParser (s:String) (parser:Parser a) : Maybe a = - fst $ withState 0 \pos. - catch $ do - parser (s, pos) +def (<|>) (p1:Parser a) (p2:Parser a) : Parser a = \h. + (s, posRef) = h + curPos = get posRef + case catch do p1 h of + Nothing -> + assert $ curPos == get posRef + p2 h + Just ans -> ans + +def return (x:a) : Parser a = \_. x + +'Derived combinators + +def optional (parser:Parser a) : Parser (Maybe a) = + (\h. Just (parser h)) <|> return Nothing + +def parseMany (parser:Parser a) : Parser (List a) = \h. + snd $ withState (AsList _ []) \results. + iter \_. + maybeVal = optional parser h + case maybeVal of + Nothing -> Done () + Just x -> + push results x + Continue + +def bracketed (l:Parser Unit) (r:Parser Unit) (body:Parser a) : Parser a = \h. + l h + ans = body h + r h + ans + +-- This fails. Type inference is unable to unify two region variables. I think +-- it's to do with implicit type application. +-- def parens (parser:Parser Unit) : Parser a = +-- bracketed (pChar '(') (pChar ')') parser diff --git a/tests/parser-combinator-tests.dx b/tests/parser-combinator-tests.dx index 24162136f..6362ca608 100644 --- a/tests/parser-combinator-tests.dx +++ b/tests/parser-combinator-tests.dx @@ -18,3 +18,23 @@ parseABC : Parser Unit = \h. :p runParser "ABC" parseABC > (Just ()) + +def parseTF : Parser Bool = + (\h. + pChar 'T' h + True) <|> (\h. + pChar 'F' h + False) + +def parserTFTriple : Parser (Fin 3=>Bool) = + \h. + ans = for i. parseTF h + pEOF h + ans + +:p runParser "TTF" parserTFTriple +> (Just [True, True, False]) + +:p runParser "TTFX" parserTFTriple +> Nothing + From 59a507d9ffe3ebb90ed2f425d6e7b2ed12225b19 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Wed, 30 Dec 2020 22:45:39 -0500 Subject: [PATCH 060/105] Simplify `while`. It doesn't need separate `cond` and `body` functions. We can always just put the body within the cond, under a branch that conditions on whether the cond succeeded. Simplifying it will make it easier to handle exceptions. --- examples/ode-integrator.dx | 19 ++-- lib/prelude.dx | 31 +++--- src/lib/Autodiff.hs | 4 +- src/lib/Imp.hs | 194 ++++++++++++++++++------------------- src/lib/Imp/Embed.hs | 3 +- src/lib/JIT.hs | 13 ++- src/lib/PPrint.hs | 4 +- src/lib/Simplify.hs | 5 +- src/lib/Syntax.hs | 8 +- src/lib/Type.hs | 11 +-- tests/eval-tests.dx | 8 +- 11 files changed, 154 insertions(+), 146 deletions(-) diff --git a/examples/ode-integrator.dx b/examples/ode-integrator.dx index 458d372fa..8edc7d8eb 100644 --- a/examples/ode-integrator.dx +++ b/examples/ode-integrator.dx @@ -112,15 +112,15 @@ def odeint (func: d=>Float -> Time -> d=>Float) atol = 1.4e-8 -- absolute local error tolerance for solver. max_iters = 10000 - integrate_to_next_time = \iter init_carry. - target_t = times.iter + integrate_to_next_time = \i init_carry. + target_t = times.i - stopping_condition = \(_, _, t, dt, _, _). + continue_condition = \(_, _, t, dt, _, _). -- State of solver: (next state, next f, next time, dt, t, interp coeffs) -- def State (v:Type) : Type = (v & v & Time & Time & Time & (Fin 5)=>v) -- This ended up being unnecessary to spell anywhere, but was -- useful for debugging. - (t < target_t) && (dt > 0.0) && (ordinal iter < max_iters) + (t < target_t) && (dt > 0.0) && (ordinal i < max_iters) possible_step = \(z, f, t, dt, last_t, interp_coeff). (next_z, next_f, next_z_error, k) = runge_kutta_step func z f t dt @@ -134,9 +134,14 @@ def odeint (func: d=>Float -> Time -> d=>Float) select (ratio <= 1.0) move_state stay_state -- Take steps until we pass target_t - new_state = snd $ withState init_carry \state. - while (do stopping_condition (get state)) do - state := possible_step (get state) + new_state = snd $ withState init_carry \stateRef. + iter \_. + state = get stateRef + if continue_condition state + then + stateRef := possible_step state + Continue + else Done () (_, _, t, _, last_t, interp_coeff) = new_state -- Interpolate to the target time. diff --git a/lib/prelude.dx b/lib/prelude.dx index 59f85ccaa..6137977ee 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -1030,11 +1030,10 @@ def fwrite (stream:Stream WriteMode) (s:String) : {State World} Unit = def while (eff:Effects) ?-> - (cond: Unit -> {|eff} Bool) - (body: Unit -> {|eff} Unit) + (body: Unit -> {|eff} Bool) : {|eff} Unit = - cond' : Unit -> {|eff} Word8 = \_. BToW8 $ cond () - %while cond' body + body' : Unit -> {|eff} Word8 = \_. BToW8 $ body () + %while body' data IterResult a:Type = Continue @@ -1052,10 +1051,15 @@ def liftState (ref: Ref h c) (f:a -> {|eff} b) (x:a) : {State h|eff} b = -- A little iteration combinator def iter (body: Int -> {|eff} IterResult a) : {|eff} a = result = snd $ withState Nothing \resultRef. withState 0 \i. - while (do isNothing $ get resultRef) do - case liftState resultRef (liftState i body) (get i) of - Continue -> i := get i + 1 - Done result -> resultRef := Just result + while do + continue = isNothing $ get resultRef + if continue + then case liftState resultRef (liftState i body) (get i) of + Continue -> i := get i + 1 + Done result -> resultRef := Just result + else () + continue + case result of Just ans -> ans Nothing -> unreachable () @@ -1465,9 +1469,14 @@ def concat (lists:n=>(List a)) : List a = AsList _ $ fst $ withState 0 \listIdx. fst $ withState 0 \eltIdx. for i:(Fin totalSize). - while (do get eltIdx >= listLength (lists.((get listIdx)@_))) do - eltIdx := 0 - listIdx := get listIdx + 1 + while do + continue = get eltIdx >= listLength (lists.((get listIdx)@_)) + if continue + then + eltIdx := 0 + listIdx := get listIdx + 1 + else () + continue (AsList _ xs) = lists.((get listIdx)@_) eltIdxVal = get eltIdx eltIdx := eltIdxVal + 1 diff --git a/src/lib/Autodiff.hs b/src/lib/Autodiff.hs index 0624364d4..d48837dfe 100644 --- a/src/lib/Autodiff.hs +++ b/src/lib/Autodiff.hs @@ -274,7 +274,7 @@ linearizeHof env hof = case hof of -- TODO: Consider providing an upper bound for the number of while iterations as a hint. -- In the current form the best we can do is try to use some dynamically growing lists, -- but that won't work on the GPU. - While _ _ -> notImplemented + While _ -> notImplemented Linearize _ -> error "Unexpected linearization" Transpose _ -> error "Unexpected transposition" PTileReduce _ _ -> error "Unexpected PTileReduce" @@ -698,7 +698,7 @@ transposeHof hof ct = case hof of transposeAtom s cts RunIO _ -> error "Not implemented" Tile _ _ _ -> notImplemented - While _ _ -> notImplemented + While _ -> notImplemented Linearize _ -> error "Unexpected linearization" Transpose _ -> error "Unexpected transposition" PTileReduce _ _ -> error "Unexpected PTileReduce" diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index c50b14bfa..15a366532 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -363,20 +363,20 @@ toImpHof env (maybeDest, hof) = do let idx = Con $ ParIndexCon idxTy $ toScalarAtom i ithDest <- destGet dest idx void $ translateBlock (env <> b @> idx) (Just ithDest, body) - GPU -> do -- Grid stride loop - iPtr <- alloc IdxRepTy - copyAtom iPtr gtid - cond <- liftM snd $ scopedBlock $ do - i <- destToAtom iPtr - inRange <- (fromScalarAtom i) `iltI` n - return ((), [inRange]) - wbody <- scopedErrBlock $ do - i <- destToAtom iPtr - let idx = Con $ ParIndexCon idxTy i - ithDest <- destGet dest idx - void $ translateBlock (env <> b @> idx) (Just ithDest, body) - copyAtom iPtr . toScalarAtom =<< iaddI (fromScalarAtom i) (fromScalarAtom numThreads) - emitStatement $ IWhile cond wbody + -- GPU -> do -- Grid stride loop + -- iPtr <- alloc IdxRepTy + -- copyAtom iPtr gtid + -- cond <- liftM snd $ scopedBlock $ do + -- i <- destToAtom iPtr + -- inRange <- (fromScalarAtom i) `iltI` n + -- return ((), [inRange]) + -- wbody <- scopedErrBlock $ do + -- i <- destToAtom iPtr + -- let idx = Con $ ParIndexCon idxTy i + -- ithDest <- destGet dest idx + -- void $ translateBlock (env <> b @> idx) (Just ithDest, body) + -- copyAtom iPtr . toScalarAtom =<< iaddI (fromScalarAtom i) (fromScalarAtom numThreads) + -- emitStatement $ IWhile cond wbody destToAtom dest _ -> do n <- indexSetSize idxTy @@ -418,86 +418,85 @@ toImpHof env (maybeDest, hof) = do sDest <- fromEmbed $ indexDestDim d dest idx void $ translateBlock (env <> sb @> idx) (Just sDest, sBody) destToAtom dest - PTileReduce idxTy' ~(BinaryFunVal gtidB nthrB _ body) -> do - idxTy <- impSubst env idxTy' - (mappingDest, finalAccDest) <- destPairUnpack <$> allocDest maybeDest resultTy - let PairTy _ accType = resultTy - (numTileWorkgroups, wgResArr, widIdxTy) <- buildKernel idxTy $ \LaunchInfo{..} buildBody -> do - let widIdxTy = Fin $ toScalarAtom numWorkgroups - let tidIdxTy = Fin $ toScalarAtom workgroupSize - wgResArr <- alloc $ TabTy (Ignore widIdxTy) accType - thrAccArr <- alloc $ TabTy (Ignore widIdxTy) $ TabTy (Ignore tidIdxTy) accType - mappingKernelBody <- buildBody $ \ThreadInfo{..} -> do - let TC (ParIndexRange _ gtid nthr) = threadRange - let scope = freeVars mappingDest - let tileDest = Con $ TabRef $ fst $ flip runSubstEmbed scope $ do - buildLam (Bind $ "hwidx":>threadRange) TabArrow $ \hwidx -> do - indexDest mappingDest =<< (emitOp $ Inject hwidx) - wgAccs <- destGet thrAccArr =<< intToIndex widIdxTy wid - thrAcc <- destGet wgAccs =<< intToIndex tidIdxTy tid - let threadDest = Con $ ConRef $ PairCon tileDest thrAcc - void $ translateBlock (env <> gtidB @> gtid <> nthrB @> nthr) (Just threadDest, body) - wgRes <- destGet wgResArr =<< intToIndex widIdxTy wid - workgroupReduce tid wgRes wgAccs workgroupSize - return (mappingKernelBody, (numWorkgroups, wgResArr, widIdxTy)) - -- TODO: Skip the reduction kernel if unnecessary? - -- TODO: Reduce sequentially in the CPU backend? - -- TODO: Actually we only need the previous-power-of-2 many threads - buildKernel widIdxTy $ \LaunchInfo{..} buildBody -> do - -- We only do a one-level reduciton in the workgroup, so it is correct - -- only if the end up scheduling a single workgroup. - moreThanOneGroup <- (IIdxRepVal 1) `iltI` numWorkgroups - guardBlock moreThanOneGroup $ emitStatement IThrowError - redKernelBody <- buildBody $ \ThreadInfo{..} -> - workgroupReduce tid finalAccDest wgResArr numTileWorkgroups - return (redKernelBody, ()) - PairVal <$> destToAtom mappingDest <*> destToAtom finalAccDest - where - guardBlock cond m = do - block <- scopedErrBlock m - emitStatement $ ICond cond block (ImpBlock mempty mempty) - workgroupReduce tid resDest arrDest elemCount = do - elemCountDown2 <- prevPowerOf2 elemCount - let RawRefTy (TabTy arrIdxB _) = getType arrDest - let arrIdxTy = binderType arrIdxB - offPtr <- alloc IdxRepTy - copyAtom offPtr $ toScalarAtom elemCountDown2 - cond <- liftM snd $ scopedBlock $ do - off <- fromScalarAtom <$> destToAtom offPtr - cond <- emitInstr $ IPrimOp $ ScalarBinOp (ICmp Greater) off (IIdxRepVal 0) - return ((), [cond]) - wbody <- scopedErrBlock $ do - off <- fromScalarAtom <$> destToAtom offPtr - loadIdx <- iaddI tid off - shouldAdd <- bindM2 bandI (tid `iltI` off) (loadIdx `iltI` elemCount) - guardBlock shouldAdd $ do - threadDest <- destGet arrDest =<< intToIndex arrIdxTy tid - addToAtom threadDest =<< destToAtom =<< destGet arrDest =<< intToIndex arrIdxTy loadIdx - emitStatement ISyncWorkgroup - copyAtom offPtr . toScalarAtom =<< off `idivI` (IIdxRepVal 2) - emitStatement $ IWhile cond wbody - firstThread <- tid `iltI` (IIdxRepVal 1) - guardBlock firstThread $ - copyAtom resDest =<< destToAtom =<< destGet arrDest =<< intToIndex arrIdxTy tid - -- TODO: Do some popcount tricks? - prevPowerOf2 :: IExpr -> ImpM IExpr - prevPowerOf2 x = do - rPtr <- alloc IdxRepTy - copyAtom rPtr (IdxRepVal 1) - let getNext = imulI (IIdxRepVal 2) . fromScalarAtom =<< destToAtom rPtr - cond <- liftM snd $ scopedBlock $ do - canGrow <- getNext >>= (`iltI` x) - return ((), [canGrow]) - wbody <- scopedErrBlock $ do - copyAtom rPtr . toScalarAtom =<< getNext - emitStatement $ IWhile cond wbody - fromScalarAtom <$> destToAtom rPtr - While ~(Lam (Abs _ (_, cond))) ~(Lam (Abs _ (_, body))) -> do - cond' <- liftM snd $ scopedBlock $ do - ans <- translateBlock env (Nothing, cond) + -- PTileReduce idxTy' ~(BinaryFunVal gtidB nthrB _ body) -> do + -- idxTy <- impSubst env idxTy' + -- (mappingDest, finalAccDest) <- destPairUnpack <$> allocDest maybeDest resultTy + -- let PairTy _ accType = resultTy + -- (numTileWorkgroups, wgResArr, widIdxTy) <- buildKernel idxTy $ \LaunchInfo{..} buildBody -> do + -- let widIdxTy = Fin $ toScalarAtom numWorkgroups + -- let tidIdxTy = Fin $ toScalarAtom workgroupSize + -- wgResArr <- alloc $ TabTy (Ignore widIdxTy) accType + -- thrAccArr <- alloc $ TabTy (Ignore widIdxTy) $ TabTy (Ignore tidIdxTy) accType + -- mappingKernelBody <- buildBody $ \ThreadInfo{..} -> do + -- let TC (ParIndexRange _ gtid nthr) = threadRange + -- let scope = freeVars mappingDest + -- let tileDest = Con $ TabRef $ fst $ flip runSubstEmbed scope $ do + -- buildLam (Bind $ "hwidx":>threadRange) TabArrow $ \hwidx -> do + -- indexDest mappingDest =<< (emitOp $ Inject hwidx) + -- wgAccs <- destGet thrAccArr =<< intToIndex widIdxTy wid + -- thrAcc <- destGet wgAccs =<< intToIndex tidIdxTy tid + -- let threadDest = Con $ ConRef $ PairCon tileDest thrAcc + -- void $ translateBlock (env <> gtidB @> gtid <> nthrB @> nthr) (Just threadDest, body) + -- wgRes <- destGet wgResArr =<< intToIndex widIdxTy wid + -- workgroupReduce tid wgRes wgAccs workgroupSize + -- return (mappingKernelBody, (numWorkgroups, wgResArr, widIdxTy)) + -- -- TODO: Skip the reduction kernel if unnecessary? + -- -- TODO: Reduce sequentially in the CPU backend? + -- -- TODO: Actually we only need the previous-power-of-2 many threads + -- buildKernel widIdxTy $ \LaunchInfo{..} buildBody -> do + -- -- We only do a one-level reduciton in the workgroup, so it is correct + -- -- only if the end up scheduling a single workgroup. + -- moreThanOneGroup <- (IIdxRepVal 1) `iltI` numWorkgroups + -- guardBlock moreThanOneGroup $ emitStatement IThrowError + -- redKernelBody <- buildBody $ \ThreadInfo{..} -> + -- workgroupReduce tid finalAccDest wgResArr numTileWorkgroups + -- return (redKernelBody, ()) + -- PairVal <$> destToAtom mappingDest <*> destToAtom finalAccDest + -- where + -- guardBlock cond m = do + -- block <- scopedErrBlock m + -- emitStatement $ ICond cond block (ImpBlock mempty mempty) + -- workgroupReduce tid resDest arrDest elemCount = do + -- elemCountDown2 <- prevPowerOf2 elemCount + -- let RawRefTy (TabTy arrIdxB _) = getType arrDest + -- let arrIdxTy = binderType arrIdxB + -- offPtr <- alloc IdxRepTy + -- copyAtom offPtr $ toScalarAtom elemCountDown2 + -- cond <- liftM snd $ scopedBlock $ do + -- off <- fromScalarAtom <$> destToAtom offPtr + -- cond <- emitInstr $ IPrimOp $ ScalarBinOp (ICmp Greater) off (IIdxRepVal 0) + -- return ((), [cond]) + -- wbody <- scopedErrBlock $ do + -- off <- fromScalarAtom <$> destToAtom offPtr + -- loadIdx <- iaddI tid off + -- shouldAdd <- bindM2 bandI (tid `iltI` off) (loadIdx `iltI` elemCount) + -- guardBlock shouldAdd $ do + -- threadDest <- destGet arrDest =<< intToIndex arrIdxTy tid + -- addToAtom threadDest =<< destToAtom =<< destGet arrDest =<< intToIndex arrIdxTy loadIdx + -- emitStatement ISyncWorkgroup + -- copyAtom offPtr . toScalarAtom =<< off `idivI` (IIdxRepVal 2) + -- emitStatement $ IWhile cond wbody + -- firstThread <- tid `iltI` (IIdxRepVal 1) + -- guardBlock firstThread $ + -- copyAtom resDest =<< destToAtom =<< destGet arrDest =<< intToIndex arrIdxTy tid + -- -- TODO: Do some popcount tricks? + -- prevPowerOf2 :: IExpr -> ImpM IExpr + -- prevPowerOf2 x = do + -- rPtr <- alloc IdxRepTy + -- copyAtom rPtr (IdxRepVal 1) + -- let getNext = imulI (IIdxRepVal 2) . fromScalarAtom =<< destToAtom rPtr + -- cond <- liftM snd $ scopedBlock $ do + -- canGrow <- getNext >>= (`iltI` x) + -- return ((), [canGrow]) + -- wbody <- scopedErrBlock $ do + -- copyAtom rPtr . toScalarAtom =<< getNext + -- emitStatement $ IWhile cond wbody + -- fromScalarAtom <$> destToAtom rPtr + While ~(Lam (Abs _ (_, body))) -> do + body' <- liftM snd $ scopedBlock $ do + ans <- translateBlock env (Nothing, body) return ((), [fromScalarAtom ans]) - body' <- scopedErrBlock $ void $ translateBlock env (Nothing, body) - emitStatement $ IWhile cond' body' + emitStatement $ IWhile body' return UnitVal RunReader r ~(BinaryFunVal _ ref _ body) -> do rDest <- alloc $ getType r @@ -1193,10 +1192,9 @@ instrTypeChecked instr = case instr of assertEq (binderAnn i) (getIType size) $ "Mismatch between the loop iterator and upper bound type" [] <- withTypeEnv (i @> getIType size) $ checkBlock block return [] - IWhile cond body -> do - [condTy] <- checkBlock cond - assertEq (Scalar Word8Type) condTy $ "Not a bool: " ++ pprint cond - [] <- checkBlock body + IWhile body -> do + [condTy] <- checkBlock body + assertEq (Scalar Word8Type) condTy $ "Not a bool: " ++ pprint body return [] ICond predicate consequent alternative -> do predTy <- checkIExpr predicate @@ -1331,7 +1329,7 @@ impInstrTypes instr = case instr of IThrowError -> [] MemCopy _ _ _ -> [] IFor _ _ _ _ -> [] - IWhile _ _ -> [] + IWhile _ -> [] ICond _ _ _ -> [] ILaunch _ _ _ -> [] ISyncWorkgroup -> [] diff --git a/src/lib/Imp/Embed.hs b/src/lib/Imp/Embed.hs index ef437e01a..315a00226 100644 --- a/src/lib/Imp/Embed.hs +++ b/src/lib/Imp/Embed.hs @@ -147,8 +147,7 @@ traverseImpInstr def instr = case instr of b' <- freshIVar b IFor dir (Bind b') <$> traverseIExpr size <*> (extendValSubst (b @> IVar b') $ traverseImpBlock def body) - IWhile cond body -> - IWhile <$> traverseImpBlock def cond <*> traverseImpBlock def body + IWhile body -> IWhile <$> traverseImpBlock def body ICond cond tb fb -> ICond <$> traverseIExpr cond <*> traverseImpBlock def tb <*> traverseImpBlock def fb diff --git a/src/lib/JIT.hs b/src/lib/JIT.hs index 43ab475dd..a9d374269 100644 --- a/src/lib/JIT.hs +++ b/src/lib/JIT.hs @@ -201,8 +201,8 @@ compileInstr instr = case instr of IFor d i n body -> [] <$ do n' <- compileExpr n compileLoop d i n' $ compileVoidBlock body - IWhile cond body -> [] <$ do - compileWhile (head <$> compileBlock cond) (compileVoidBlock body) + IWhile body -> [] <$ do + compileWhile (head <$> compileBlock body) ICond p cons alt -> [] <$ do p' <- compileExpr p >>= (`asIntWidth` i1) compileIf p' (compileVoidBlock cons) (compileVoidBlock alt) @@ -370,14 +370,13 @@ compileIf cond tb fb = do fb finishBlock (L.Br contName []) contName -compileWhile :: Compile Operand -> Compile () -> Compile () -compileWhile compileCond compileBody = do +compileWhile :: Compile Operand -> Compile () +compileWhile compileBody = do loopBlock <- freshName "whileLoop" nextBlock <- freshName "whileCont" - entryCond <- compileCond >>= (`asIntWidth` i1) + entryCond <- compileBody >>= (`asIntWidth` i1) finishBlock (L.CondBr entryCond loopBlock nextBlock []) loopBlock - compileBody - loopCond <- compileCond >>= (`asIntWidth` i1) + loopCond <- compileBody >>= (`asIntWidth` i1) finishBlock (L.CondBr loopCond loopBlock nextBlock []) nextBlock throwRuntimeError :: Compile () diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index a07cb7f91..976d86dba 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -478,9 +478,7 @@ instance Pretty ImpFunction where instance Pretty ImpInstr where pretty (IFor a i n block) = forStr (RegularFor a) <+> p i <+> "<" <+> p n <> nest 4 (hardline <> p block) - pretty (IWhile cond body) = "while" <+> - nest 2 (p cond) <+> "do" <> - nest 4 (hardline <> p body) + pretty (IWhile body) = "while" <+> nest 2 (p body) pretty (ICond predicate cons alt) = "if" <+> p predicate <+> "then" <> nest 2 (hardline <> p cons) <> hardline <> "else" <> nest 2 (hardline <> p alt) diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index fc787e9d6..eabb5801e 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -447,10 +447,9 @@ simplifyHof hof = case hof of ~(fS', Nothing) <- simplifyLam fS emit $ Hof $ Tile d fT' fS' PTileReduce _ _ -> error "Unexpected PTileReduce" - While cond body -> do - ~(cond', Nothing) <- simplifyLam cond + While body -> do ~(body', Nothing) <- simplifyLam body - emit $ Hof $ While cond' body' + emit $ Hof $ While body' Linearize lam -> do ~(lam', Nothing) <- simplifyLam lam scope <- getScope diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index 5a8de99c9..277d94693 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -365,7 +365,7 @@ data PrimOp e = data PrimHof e = For ForAnn e | Tile Int e e -- dimension number, tiled body, scalar body - | While e e + | While e | RunReader e e | RunWriter e | RunState e e @@ -529,7 +529,7 @@ data ImpFunction = ImpFunction IFunVar [IBinder] ImpBlock data ImpBlock = ImpBlock (Nest ImpDecl) [IExpr] deriving (Show) data ImpDecl = ImpLet [IBinder] ImpInstr deriving (Show) data ImpInstr = IFor Direction IBinder Size ImpBlock - | IWhile ImpBlock ImpBlock -- cond block, body block + | IWhile ImpBlock | ICond IExpr ImpBlock ImpBlock | IQueryParallelism IFunVar IExpr -- returns the number of available concurrent threads | ISyncWorkgroup @@ -1264,7 +1264,7 @@ instance HasIVars ImpBlock where instance HasIVars ImpInstr where freeIVars i = case i of IFor _ b n p -> freeIVars n <> (freeIVars p `envDiff` (b @> ())) - IWhile c p -> freeIVars c <> freeIVars p + IWhile p -> freeIVars p ICond c t f -> freeIVars c <> freeIVars t <> freeIVars f IQueryParallelism _ s -> freeIVars s ISyncWorkgroup -> mempty @@ -1548,7 +1548,7 @@ builtinNames = M.fromList , ("indexRef" , OpExpr $ IndexRef () ()) , ("inject" , OpExpr $ Inject ()) , ("select" , OpExpr $ Select () () ()) - , ("while" , HofExpr $ While () ()) + , ("while" , HofExpr $ While ()) , ("linearize" , HofExpr $ Linearize ()) , ("linearTranspose" , HofExpr $ Transpose ()) , ("runReader" , HofExpr $ RunReader () ()) diff --git a/src/lib/Type.hs b/src/lib/Type.hs index 21accdc0d..4dd84a817 100644 --- a/src/lib/Type.hs +++ b/src/lib/Type.hs @@ -286,7 +286,7 @@ exprEffs expr = case expr of Hof hof -> case hof of For _ f -> functionEffs f Tile _ _ _ -> error "not implemented" - While cond body -> functionEffs cond <> functionEffs body + While body -> functionEffs body Linearize _ -> mempty -- Body has to be a pure function Transpose _ -> mempty -- Body has to be a pure function RunReader _ f -> handleRWSRunner Reader f @@ -440,7 +440,7 @@ instance CoreVariant (PrimCon a) where instance CoreVariant (PrimHof a) where checkVariant e = case e of For _ _ -> alwaysAllowed - While _ _ -> alwaysAllowed + While _ -> alwaysAllowed RunReader _ _ -> alwaysAllowed RunWriter _ -> alwaysAllowed RunState _ _ -> alwaysAllowed @@ -863,13 +863,10 @@ typeCheckHof hof = case hof of checkEq threadRange (binderType threadRange') -- PTileReduce n mapping : (n=>a, ro) return $ PairTy (TabTy (Ignore n) tileElemTy) accTy - While cond body -> do - Pi (Abs (Ignore UnitTy) (arr , condTy)) <- typeCheck cond - Pi (Abs (Ignore UnitTy) (arr', bodyTy)) <- typeCheck body + While body -> do + Pi (Abs (Ignore UnitTy) (arr , condTy)) <- typeCheck body declareEffs $ arrowEff arr - declareEffs $ arrowEff arr' checkEq (BaseTy $ Scalar Word8Type) condTy - checkEq UnitTy bodyTy return UnitTy Linearize f -> do Pi (Abs (Ignore a) (PlainArrow Pure, b)) <- typeCheck f diff --git a/tests/eval-tests.dx b/tests/eval-tests.dx index 2193d4399..f54b1a192 100644 --- a/tests/eval-tests.dx +++ b/tests/eval-tests.dx @@ -647,8 +647,12 @@ def newtonIter (f: Float -> Float) (x:Float) : Float = def newtonSolve (tol:Float) (f : Float -> Float) (x0:Float) : Float = snd $ withState x0 \x. - while (\(). abs (f $ get x) > tol) \(). - x := newtonIter f $ get x + iter \i. + if abs (f $ get x) <= tol + then Done () + else + x := newtonIter f $ get x + Continue :p newtonSolve 0.001 (\x. sq x - 2.0) 1.0 > 1.414216 From 4425cb369331df6056b501b0db0485c4b1147dc7 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Thu, 31 Dec 2020 00:04:21 -0500 Subject: [PATCH 061/105] Handle exceptions under `while` and `runState`. --- lib/prelude.dx | 19 +++++++++++++++---- src/lib/Embed.hs | 10 ++++++++-- src/lib/Simplify.hs | 27 +++++++++++++++++++++++---- 3 files changed, 46 insertions(+), 10 deletions(-) diff --git a/lib/prelude.dx b/lib/prelude.dx index 6137977ee..9f0637daf 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -1028,10 +1028,7 @@ def fwrite (stream:Stream WriteMode) (s:String) : {State World} Unit = %ffi fflush Int64 stream' () -def while - (eff:Effects) ?-> - (body: Unit -> {|eff} Bool) - : {|eff} Unit = +def while (eff:Effects) ?-> (body: Unit -> {|eff} Bool) : {|eff} Unit = body' : Unit -> {|eff} Word8 = \_. BToW8 $ body () %while body' @@ -1064,6 +1061,20 @@ def iter (body: Int -> {|eff} IterResult a) : {|eff} a = Just ans -> ans Nothing -> unreachable () +-- XXX: used internally by compiler for exceptional while +def whileMaybe (eff:Effects) -> (body: Unit -> {|eff} (Maybe Word8)) : {|eff} Maybe Unit = + hadError = snd $ withState False \ref. + while do + ans = liftState ref body () + case ans of + Nothing -> + ref := True + False + Just cond -> W8ToB cond + if hadError + then Nothing + else Just () + def boundedIter (maxIters:Int) (fallback:a) (body: Int -> {|eff} IterResult a) : {|eff} a = iter \i. diff --git a/src/lib/Embed.hs b/src/lib/Embed.hs index cdc9a097a..731cfce66 100644 --- a/src/lib/Embed.hs +++ b/src/lib/Embed.hs @@ -25,7 +25,7 @@ module Embed (emit, emitTo, emitAnn, emitOp, buildDepEffLam, buildLamAux, buildP emitBlock, unzipTab, isSingletonType, emitDecl, withNameHint, singletonTypeVal, scopedDecls, embedScoped, extendScope, checkEmbed, embedExtend, unpackConsList, emitRunWriter, applyPreludeFunction, - emitRunState, emitMaybeCase, + emitRunState, emitMaybeCase, emitWhile, emitRunReader, tabGet, SubstEmbedT, SubstEmbed, runSubstEmbedT, traverseAtom, ptrOffset, ptrLoad, unsafePtrLoad, evalBlockE, substTraversalDef, @@ -356,7 +356,13 @@ unpackConsList xs = case getType xs of liftM (x:) $ unpackConsList rest _ -> error $ "Not a cons list: " ++ pprint (getType xs) -emitMaybeCase :: MonadEmbed m => Atom -> (m Atom) -> (Atom -> m Atom) -> m Atom +emitWhile :: MonadEmbed m => m Atom -> m () +emitWhile body = do + eff <- getAllowedEffects + lam <- buildLam (Ignore UnitTy) (PlainArrow eff) $ \_ -> body + void $ emit $ Hof $ While lam + +emitMaybeCase :: MonadEmbed m => Atom -> m Atom -> (Atom -> m Atom) -> m Atom emitMaybeCase scrut nothingCase justCase = do let (MaybeTy a) = getType scrut nothingAlt <- buildNAbs Empty $ \[] -> nothingCase diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index eabb5801e..9f3b256da 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -488,18 +488,19 @@ simplifyHof hof = case hof of exceptToMaybeBlock :: Block -> SubstEmbed Atom exceptToMaybeBlock (Block Empty result) = exceptToMaybeExpr result exceptToMaybeBlock (Block (Nest (Let _ b expr) decls) result) = do + a <- substEmbedR $ getType result maybeResult <- exceptToMaybeExpr expr case maybeResult of -- These two cases are just an optimization JustAtom _ x -> extendR (b@>x) $ exceptToMaybeBlock $ Block decls result - NothingAtom a -> return $ NothingAtom a + NothingAtom _ -> return $ NothingAtom a _ -> do - blockTy <- substEmbedR $ getType result - emitMaybeCase maybeResult (return $ NothingAtom blockTy) $ \x -> do + emitMaybeCase maybeResult (return $ NothingAtom a) $ \x -> do extendR (b@>x) $ exceptToMaybeBlock $ Block decls result exceptToMaybeExpr :: Expr -> SubstEmbed Atom exceptToMaybeExpr expr = do + a <- substEmbedR $ getType expr case expr of Case e alts resultTy -> do e' <- substEmbedR e @@ -509,11 +510,24 @@ exceptToMaybeExpr expr = do buildNAbs bs' $ \xs -> extendR (newEnv bs' xs) $ exceptToMaybeBlock body emit $ Case e' alts' resultTy' Atom x -> substEmbedR $ JustAtom (getType x) x - Op (ThrowException a) -> substEmbedR $ NothingAtom a + Op (ThrowException _) -> return $ NothingAtom a Hof (For ann ~(Lam (Abs b (_, body)))) -> do b' <- substEmbedR b maybes <- buildForAnn ann b' $ \i -> extendR (b@>i) $ exceptToMaybeBlock body catMaybesE maybes + Hof (RunState s lam) -> do + s' <- substEmbedR s + let BinaryFunVal _ b _ body = lam + result <- emitRunState "ref" s' $ \ref -> + extendR (b@>ref) $ exceptToMaybeBlock body + (maybeAns, newState) <- fromPair result + emitMaybeCase maybeAns (return $ NothingAtom a) $ \ans -> + return $ JustAtom a $ PairVal ans newState + Hof (While ~(Lam (Abs _ (_, body)))) -> do + eff <- getAllowedEffects + lam <- buildLam (Ignore UnitTy) (PlainArrow eff) $ \_ -> + exceptToMaybeBlock body + runMaybeWhile lam _ | not (hasExceptions expr) -> do x <- substEmbedR expr >>= emit return $ JustAtom (getType x) x @@ -531,6 +545,11 @@ catMaybesE maybes = simplifyEmbed $ do let (TabTy b (MaybeTy a)) = getType maybes applyPreludeFunction "seqMaybes" [binderAnn b, a, maybes] +runMaybeWhile :: MonadEmbed m => Atom -> m Atom +runMaybeWhile lam = simplifyEmbed $ do + let (Pi (Abs _ (PlainArrow eff, _))) = getType lam + applyPreludeFunction "whileMaybe" [Eff eff, lam] + simplifyEmbed :: MonadEmbed m => m Atom -> m Atom simplifyEmbed m = do block <- buildScoped m From 9da58a1a39b81bcaa75b971694f0004289f976cb Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Thu, 31 Dec 2020 00:06:36 -0500 Subject: [PATCH 062/105] Make a newtype wrapper for `Parser a`. It's a shame that we still need to pass around the handle. I've tried conjuring some magic to make it implicit but I don't have anything that works yet. --- lib/parser.dx | 55 +++++++++++++++++++------------- tests/parser-combinator-tests.dx | 34 ++++++++++---------- 2 files changed, 51 insertions(+), 38 deletions(-) diff --git a/lib/parser.dx b/lib/parser.dx index 5f032c8da..a3d21b7dd 100644 --- a/lib/parser.dx +++ b/lib/parser.dx @@ -21,57 +21,68 @@ def indexList (l:List a) (i:Int) : {Except} a = def ParserHandle (h:Type) : Type = (String & Ref h Int) -def Parser (a:Type) : Type = h:Type ?-> ParserHandle h -> {Except, State h} a +data Parser a:Type = + MkParser (h:Type ?-> ParserHandle h -> {Except, State h} a) -def runParser (s:String) (parser:Parser a) : Maybe a = +def parse (handle:ParserHandle h) (parser:Parser a) : {Except, State h} a = + (MkParser f) = parser + f handle + +def runParserPartial (s:String) (parser:Parser a) : Maybe a = + (MkParser f) = parser fst $ withState 0 \pos. catch $ do - parser (s, pos) + f (s, pos) 'Primitive combinators -def pChar (c:Char) : Parser Unit = \(s, posRef). +def pChar (c:Char) : Parser Unit = MkParser \(s, posRef). i = get posRef c' = indexList s i assert (c == c') posRef := i + 1 -def pEOF : Parser Unit = \(s, posRef). +def pEOF : Parser Unit = MkParser \(s, posRef). assert $ get posRef >= listLength s -def (<|>) (p1:Parser a) (p2:Parser a) : Parser a = \h. +def (<|>) (p1:Parser a) (p2:Parser a) : Parser a = MkParser \h. (s, posRef) = h curPos = get posRef - case catch do p1 h of + case catch do parse h p1 of Nothing -> assert $ curPos == get posRef - p2 h + parse h p2 Just ans -> ans -def return (x:a) : Parser a = \_. x +def return (x:a) : Parser a = MkParser \_. x + +def runParser (s:String) (parser:Parser a) : Maybe a = + runParserPartial s $ MkParser \h. + ans = parse h parser + _ = parse h pEOF + ans 'Derived combinators -def optional (parser:Parser a) : Parser (Maybe a) = - (\h. Just (parser h)) <|> return Nothing +def optional (p:Parser a) : Parser (Maybe a) = + (MkParser \h. Just (parse h p)) <|> return Nothing -def parseMany (parser:Parser a) : Parser (List a) = \h. +def parseMany (parser:Parser a) : Parser (List a) = MkParser \h. snd $ withState (AsList _ []) \results. iter \_. - maybeVal = optional parser h + maybeVal = parse h $ optional parser case maybeVal of Nothing -> Done () Just x -> push results x Continue -def bracketed (l:Parser Unit) (r:Parser Unit) (body:Parser a) : Parser a = \h. - l h - ans = body h - r h - ans +def bracketed (l:Parser Unit) (r:Parser Unit) (body:Parser a) : Parser a = + MkParser \h. + _ = parse h l + ans = parse h body + _ = parse h r + ans --- This fails. Type inference is unable to unify two region variables. I think --- it's to do with implicit type application. --- def parens (parser:Parser Unit) : Parser a = --- bracketed (pChar '(') (pChar ')') parser +def parens (parser:Parser a) : Parser a = + bracketed (pChar '(') (pChar ')') parser diff --git a/tests/parser-combinator-tests.dx b/tests/parser-combinator-tests.dx index 6362ca608..d4624c7cb 100644 --- a/tests/parser-combinator-tests.dx +++ b/tests/parser-combinator-tests.dx @@ -1,11 +1,10 @@ include "parser.dx" -parseABC : Parser Unit = \h. - pChar 'A' h - pChar 'B' h - pChar 'C' h - pEOF h +parseABC : Parser Unit = MkParser \h. + parse h $ pChar 'A' + parse h $ pChar 'B' + parse h $ pChar 'C' :p runParser "AAA" parseABC > Nothing @@ -19,18 +18,19 @@ parseABC : Parser Unit = \h. :p runParser "ABC" parseABC > (Just ()) +def parseT : Parser Bool = MkParser \h. + parse h $ pChar 'T' + True + +def parseF : Parser Bool = MkParser \h. + parse h $ pChar 'F' + False + def parseTF : Parser Bool = - (\h. - pChar 'T' h - True) <|> (\h. - pChar 'F' h - False) - -def parserTFTriple : Parser (Fin 3=>Bool) = - \h. - ans = for i. parseTF h - pEOF h - ans + parseT <|> parseF + +def parserTFTriple : Parser (Fin 3=>Bool) = MkParser \h. + for i. parse h parseTF :p runParser "TTF" parserTFTriple > (Just [True, True, False]) @@ -38,3 +38,5 @@ def parserTFTriple : Parser (Fin 3=>Bool) = :p runParser "TTFX" parserTFTriple > Nothing +:p runParser "TTFFTT" $ parseMany parseTF +> (Just (AsList 6 [True, True, False, False, True, True])) From 18926b575c891bf3c762f2390ab8223ff97eb2fa Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Thu, 31 Dec 2020 00:32:55 -0500 Subject: [PATCH 063/105] Add a parser for integers. Fixes #312. --- lib/parser.dx | 36 ++++++++++++++++++++++++++++++++ tests/parser-combinator-tests.dx | 12 +++++++++++ 2 files changed, 48 insertions(+) diff --git a/lib/parser.dx b/lib/parser.dx index a3d21b7dd..83eae1892 100644 --- a/lib/parser.dx +++ b/lib/parser.dx @@ -62,8 +62,31 @@ def runParser (s:String) (parser:Parser a) : Maybe a = _ = parse h pEOF ans +def parseAny : Parser Char = MkParser \h. + (s, posRef) = h + i = get posRef + c' = indexList s i + posRef := i + 1 + c' + +def try (parser:Parser a) : Parser a = MkParser \h. + (s, posRef) = h + savedPos = get posRef + ans = catch do parse h parser + case ans of + Nothing -> + posRef := savedPos + throw () + Just x -> x + 'Derived combinators +def parseDigit : Parser Int = try $ MkParser \h. + c = parse h $ parseAny + i = W8ToI c - 48 + assert $ 0 <= i && i < 10 + i + def optional (p:Parser a) : Parser (Maybe a) = (MkParser \h. Just (parse h p)) <|> return Nothing @@ -77,6 +100,19 @@ def parseMany (parser:Parser a) : Parser (List a) = MkParser \h. push results x Continue +def parseUnsignedInt : Parser Int = MkParser \h. + (AsList _ digits) = parse h $ parseMany parseDigit + snd $ withState 0 \ref. + for i. + ref := 10 * get ref + digits.i + +def parseInt : Parser Int = MkParser \h. + negSign = parse h $ optional $ pChar '-' + x = parse h $ parseUnsignedInt + case negSign of + Nothing -> x + Just () -> (-1) * x + def bracketed (l:Parser Unit) (r:Parser Unit) (body:Parser a) : Parser a = MkParser \h. _ = parse h l diff --git a/tests/parser-combinator-tests.dx b/tests/parser-combinator-tests.dx index d4624c7cb..983a4da83 100644 --- a/tests/parser-combinator-tests.dx +++ b/tests/parser-combinator-tests.dx @@ -40,3 +40,15 @@ def parserTFTriple : Parser (Fin 3=>Bool) = MkParser \h. :p runParser "TTFFTT" $ parseMany parseTF > (Just (AsList 6 [True, True, False, False, True, True])) + +:p runParser "1021389" $ parseMany parseDigit +> (Just (AsList 7 [1, 0, 2, 1, 3, 8, 9])) + +:p runParser "1389" $ parseInt +> (Just 1389) + +:p runParser "01389" $ parseInt +> (Just 1389) + +:p runParser "-1389" $ parseInt +> (Just -1389) From cf0f508b7388c8f3089da4571e05ae8cea0bab9c Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Thu, 31 Dec 2020 01:10:03 -0500 Subject: [PATCH 064/105] Fix Imp code I accidentally left commented out. --- src/lib/Imp.hs | 179 ++++++++++++++++++++------------------- tests/exception-tests.dx | 13 ++- 2 files changed, 97 insertions(+), 95 deletions(-) diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index 15a366532..4ad9aa0f8 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -363,20 +363,20 @@ toImpHof env (maybeDest, hof) = do let idx = Con $ ParIndexCon idxTy $ toScalarAtom i ithDest <- destGet dest idx void $ translateBlock (env <> b @> idx) (Just ithDest, body) - -- GPU -> do -- Grid stride loop - -- iPtr <- alloc IdxRepTy - -- copyAtom iPtr gtid - -- cond <- liftM snd $ scopedBlock $ do - -- i <- destToAtom iPtr - -- inRange <- (fromScalarAtom i) `iltI` n - -- return ((), [inRange]) - -- wbody <- scopedErrBlock $ do - -- i <- destToAtom iPtr - -- let idx = Con $ ParIndexCon idxTy i - -- ithDest <- destGet dest idx - -- void $ translateBlock (env <> b @> idx) (Just ithDest, body) - -- copyAtom iPtr . toScalarAtom =<< iaddI (fromScalarAtom i) (fromScalarAtom numThreads) - -- emitStatement $ IWhile cond wbody + GPU -> do -- Grid stride loop + iPtr <- alloc IdxRepTy + copyAtom iPtr gtid + cond <- liftM snd $ scopedBlock $ do + i <- destToAtom iPtr + inRange <- (fromScalarAtom i) `iltI` n + emitWhen inRange $ do + let idx = Con $ ParIndexCon idxTy i + ithDest <- destGet dest idx + void $ translateBlock (env <> b @> idx) (Just ithDest, body) + copyAtom iPtr . toScalarAtom =<< iaddI (fromScalarAtom i) + (fromScalarAtom numThreads) + return ((), [inRange]) + emitStatement $ IWhile cond destToAtom dest _ -> do n <- indexSetSize idxTy @@ -418,80 +418,80 @@ toImpHof env (maybeDest, hof) = do sDest <- fromEmbed $ indexDestDim d dest idx void $ translateBlock (env <> sb @> idx) (Just sDest, sBody) destToAtom dest - -- PTileReduce idxTy' ~(BinaryFunVal gtidB nthrB _ body) -> do - -- idxTy <- impSubst env idxTy' - -- (mappingDest, finalAccDest) <- destPairUnpack <$> allocDest maybeDest resultTy - -- let PairTy _ accType = resultTy - -- (numTileWorkgroups, wgResArr, widIdxTy) <- buildKernel idxTy $ \LaunchInfo{..} buildBody -> do - -- let widIdxTy = Fin $ toScalarAtom numWorkgroups - -- let tidIdxTy = Fin $ toScalarAtom workgroupSize - -- wgResArr <- alloc $ TabTy (Ignore widIdxTy) accType - -- thrAccArr <- alloc $ TabTy (Ignore widIdxTy) $ TabTy (Ignore tidIdxTy) accType - -- mappingKernelBody <- buildBody $ \ThreadInfo{..} -> do - -- let TC (ParIndexRange _ gtid nthr) = threadRange - -- let scope = freeVars mappingDest - -- let tileDest = Con $ TabRef $ fst $ flip runSubstEmbed scope $ do - -- buildLam (Bind $ "hwidx":>threadRange) TabArrow $ \hwidx -> do - -- indexDest mappingDest =<< (emitOp $ Inject hwidx) - -- wgAccs <- destGet thrAccArr =<< intToIndex widIdxTy wid - -- thrAcc <- destGet wgAccs =<< intToIndex tidIdxTy tid - -- let threadDest = Con $ ConRef $ PairCon tileDest thrAcc - -- void $ translateBlock (env <> gtidB @> gtid <> nthrB @> nthr) (Just threadDest, body) - -- wgRes <- destGet wgResArr =<< intToIndex widIdxTy wid - -- workgroupReduce tid wgRes wgAccs workgroupSize - -- return (mappingKernelBody, (numWorkgroups, wgResArr, widIdxTy)) - -- -- TODO: Skip the reduction kernel if unnecessary? - -- -- TODO: Reduce sequentially in the CPU backend? - -- -- TODO: Actually we only need the previous-power-of-2 many threads - -- buildKernel widIdxTy $ \LaunchInfo{..} buildBody -> do - -- -- We only do a one-level reduciton in the workgroup, so it is correct - -- -- only if the end up scheduling a single workgroup. - -- moreThanOneGroup <- (IIdxRepVal 1) `iltI` numWorkgroups - -- guardBlock moreThanOneGroup $ emitStatement IThrowError - -- redKernelBody <- buildBody $ \ThreadInfo{..} -> - -- workgroupReduce tid finalAccDest wgResArr numTileWorkgroups - -- return (redKernelBody, ()) - -- PairVal <$> destToAtom mappingDest <*> destToAtom finalAccDest - -- where - -- guardBlock cond m = do - -- block <- scopedErrBlock m - -- emitStatement $ ICond cond block (ImpBlock mempty mempty) - -- workgroupReduce tid resDest arrDest elemCount = do - -- elemCountDown2 <- prevPowerOf2 elemCount - -- let RawRefTy (TabTy arrIdxB _) = getType arrDest - -- let arrIdxTy = binderType arrIdxB - -- offPtr <- alloc IdxRepTy - -- copyAtom offPtr $ toScalarAtom elemCountDown2 - -- cond <- liftM snd $ scopedBlock $ do - -- off <- fromScalarAtom <$> destToAtom offPtr - -- cond <- emitInstr $ IPrimOp $ ScalarBinOp (ICmp Greater) off (IIdxRepVal 0) - -- return ((), [cond]) - -- wbody <- scopedErrBlock $ do - -- off <- fromScalarAtom <$> destToAtom offPtr - -- loadIdx <- iaddI tid off - -- shouldAdd <- bindM2 bandI (tid `iltI` off) (loadIdx `iltI` elemCount) - -- guardBlock shouldAdd $ do - -- threadDest <- destGet arrDest =<< intToIndex arrIdxTy tid - -- addToAtom threadDest =<< destToAtom =<< destGet arrDest =<< intToIndex arrIdxTy loadIdx - -- emitStatement ISyncWorkgroup - -- copyAtom offPtr . toScalarAtom =<< off `idivI` (IIdxRepVal 2) - -- emitStatement $ IWhile cond wbody - -- firstThread <- tid `iltI` (IIdxRepVal 1) - -- guardBlock firstThread $ - -- copyAtom resDest =<< destToAtom =<< destGet arrDest =<< intToIndex arrIdxTy tid - -- -- TODO: Do some popcount tricks? - -- prevPowerOf2 :: IExpr -> ImpM IExpr - -- prevPowerOf2 x = do - -- rPtr <- alloc IdxRepTy - -- copyAtom rPtr (IdxRepVal 1) - -- let getNext = imulI (IIdxRepVal 2) . fromScalarAtom =<< destToAtom rPtr - -- cond <- liftM snd $ scopedBlock $ do - -- canGrow <- getNext >>= (`iltI` x) - -- return ((), [canGrow]) - -- wbody <- scopedErrBlock $ do - -- copyAtom rPtr . toScalarAtom =<< getNext - -- emitStatement $ IWhile cond wbody - -- fromScalarAtom <$> destToAtom rPtr + PTileReduce idxTy' ~(BinaryFunVal gtidB nthrB _ body) -> do + idxTy <- impSubst env idxTy' + (mappingDest, finalAccDest) <- destPairUnpack <$> allocDest maybeDest resultTy + let PairTy _ accType = resultTy + (numTileWorkgroups, wgResArr, widIdxTy) <- buildKernel idxTy $ \LaunchInfo{..} buildBody -> do + let widIdxTy = Fin $ toScalarAtom numWorkgroups + let tidIdxTy = Fin $ toScalarAtom workgroupSize + wgResArr <- alloc $ TabTy (Ignore widIdxTy) accType + thrAccArr <- alloc $ TabTy (Ignore widIdxTy) $ TabTy (Ignore tidIdxTy) accType + mappingKernelBody <- buildBody $ \ThreadInfo{..} -> do + let TC (ParIndexRange _ gtid nthr) = threadRange + let scope = freeVars mappingDest + let tileDest = Con $ TabRef $ fst $ flip runSubstEmbed scope $ do + buildLam (Bind $ "hwidx":>threadRange) TabArrow $ \hwidx -> do + indexDest mappingDest =<< (emitOp $ Inject hwidx) + wgAccs <- destGet thrAccArr =<< intToIndex widIdxTy wid + thrAcc <- destGet wgAccs =<< intToIndex tidIdxTy tid + let threadDest = Con $ ConRef $ PairCon tileDest thrAcc + void $ translateBlock (env <> gtidB @> gtid <> nthrB @> nthr) (Just threadDest, body) + wgRes <- destGet wgResArr =<< intToIndex widIdxTy wid + workgroupReduce tid wgRes wgAccs workgroupSize + return (mappingKernelBody, (numWorkgroups, wgResArr, widIdxTy)) + -- TODO: Skip the reduction kernel if unnecessary? + -- TODO: Reduce sequentially in the CPU backend? + -- TODO: Actually we only need the previous-power-of-2 many threads + buildKernel widIdxTy $ \LaunchInfo{..} buildBody -> do + -- We only do a one-level reduciton in the workgroup, so it is correct + -- only if the end up scheduling a single workgroup. + moreThanOneGroup <- (IIdxRepVal 1) `iltI` numWorkgroups + guardBlock moreThanOneGroup $ emitStatement IThrowError + redKernelBody <- buildBody $ \ThreadInfo{..} -> + workgroupReduce tid finalAccDest wgResArr numTileWorkgroups + return (redKernelBody, ()) + PairVal <$> destToAtom mappingDest <*> destToAtom finalAccDest + where + guardBlock cond m = do + block <- scopedErrBlock m + emitStatement $ ICond cond block (ImpBlock mempty mempty) + workgroupReduce tid resDest arrDest elemCount = do + elemCountDown2 <- prevPowerOf2 elemCount + let RawRefTy (TabTy arrIdxB _) = getType arrDest + let arrIdxTy = binderType arrIdxB + offPtr <- alloc IdxRepTy + copyAtom offPtr $ toScalarAtom elemCountDown2 + let wbody = do + off <- fromScalarAtom <$> destToAtom offPtr + loadIdx <- iaddI tid off + shouldAdd <- bindM2 bandI (tid `iltI` off) (loadIdx `iltI` elemCount) + guardBlock shouldAdd $ do + threadDest <- destGet arrDest =<< intToIndex arrIdxTy tid + addToAtom threadDest =<< destToAtom =<< destGet arrDest =<< intToIndex arrIdxTy loadIdx + emitStatement ISyncWorkgroup + copyAtom offPtr . toScalarAtom =<< off `idivI` (IIdxRepVal 2) + cond <- liftM snd $ scopedBlock $ do + off <- fromScalarAtom <$> destToAtom offPtr + cond <- emitInstr $ IPrimOp $ ScalarBinOp (ICmp Greater) off (IIdxRepVal 0) + emitWhen cond wbody + return ((), [cond]) + emitStatement $ IWhile cond + firstThread <- tid `iltI` (IIdxRepVal 1) + guardBlock firstThread $ + copyAtom resDest =<< destToAtom =<< destGet arrDest =<< intToIndex arrIdxTy tid + -- TODO: Do some popcount tricks? + prevPowerOf2 :: IExpr -> ImpM IExpr + prevPowerOf2 x = do + rPtr <- alloc IdxRepTy + copyAtom rPtr (IdxRepVal 1) + let getNext = imulI (IIdxRepVal 2) . fromScalarAtom =<< destToAtom rPtr + cond <- liftM snd $ scopedBlock $ do + canGrow <- getNext >>= (`iltI` x) + emitWhen canGrow $ copyAtom rPtr . toScalarAtom =<< getNext + return ((), [canGrow]) + emitStatement $ IWhile cond + fromScalarAtom <$> destToAtom rPtr While ~(Lam (Abs _ (_, body))) -> do body' <- liftM snd $ scopedBlock $ do ans <- translateBlock env (Nothing, body) @@ -1046,6 +1046,9 @@ alloc ty = makeAllocDest Managed ty handleErrors :: ImpM () -> ImpM () handleErrors m = m `catchError` (const $ emitStatement IThrowError) +emitWhen :: IExpr -> ImpM () -> ImpM () +emitWhen cond doIfTrue = emitSwitch cond [return (), doIfTrue] + -- TODO: Consider targeting LLVM's `switch` instead of chained conditionals. emitSwitch :: IExpr -> [ImpM ()] -> ImpM () emitSwitch testIdx = rec 0 diff --git a/tests/exception-tests.dx b/tests/exception-tests.dx index 8e62816a4..687742559 100644 --- a/tests/exception-tests.dx +++ b/tests/exception-tests.dx @@ -51,10 +51,9 @@ def checkFloatInUnitInterval (x:Float) : {Except} Float = ref!i := 1 > [0, 1, 0, 1, 0, 1] --- Doesn't work yet --- :p catch do --- withState 0 \ref. --- ref := 1 --- assert False --- ref := 2 - +:p catch do + withState 0 \ref. + ref := 1 + assert False + ref := 2 +> Nothing From 58c2e794fb75ec25176383cedf5db1b4eff772d0 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Thu, 31 Dec 2020 13:06:58 -0500 Subject: [PATCH 065/105] New `run/with/yield` naming convention for effect runners. I still find it hard to remember whether to write `snd $ withState ...` or `fst $ withState` when I'm interested in the final state instead of the result of the body. Haskell's MTL convention is `run/eval/exec` for returning everything/result/state but I've never found that easy to remember. This change implements the following convention. run* runs the effect in the most general way". For `State`, that means that you get back both the result of the body and the final state with* only returns the result of the body. This convention can go beyond effects, and include things like `withFile`. If you use something whose name starts with `with`, you can be sure that the result of the whole `with*` expression is the same as the result of the body. yield* only gives the final state or accumulation. --- examples/chol.dx | 6 ++-- examples/fluidsim.dx | 4 +-- examples/linear_algebra.dx | 8 ++--- examples/mcmc.dx | 2 +- examples/ode-integrator.dx | 6 ++-- examples/particle-filter.dx | 4 +-- examples/raytrace.dx | 12 +++---- examples/sgd.dx | 2 +- examples/tiled-matmul.dx | 2 +- lib/parser.dx | 9 +++-- lib/prelude.dx | 70 +++++++++++++++++++++++++------------ tests/ad-tests.dx | 12 +++---- tests/adt-tests.dx | 8 ++--- tests/eval-tests.dx | 14 ++++---- tests/exception-tests.dx | 6 ++-- tests/gpu-tests.dx | 2 +- tests/monad-tests.dx | 42 +++++++++++----------- tests/parser-tests.dx | 2 +- tests/uexpr-tests.dx | 12 +++---- 19 files changed, 124 insertions(+), 99 deletions(-) diff --git a/examples/chol.dx b/examples/chol.dx index a639d20b6..1d8c53d11 100644 --- a/examples/chol.dx +++ b/examples/chol.dx @@ -4,7 +4,7 @@ https://en.wikipedia.org/wiki/Cholesky_decomposition ' ## Cholesky Algorithm def chol (_:Eq n) ?=> (x:n=>n=>Float) : (n=>n=>Float) = - snd $ withState zero \buf. + yieldState zero \buf. for_ i. for j':(..i). j = %inject(j') row = for k:(.. (x:n=>n=>Float) : (n=>n=>Float) = ' ## PSD solver based on Cholesky decomposition def trisolveL (mat:n=>n=>Float) (b:n=>Float) : n=>Float = - snd $ withState zero \buf. for i. + yieldState zero \buf. for i. row = for j:(..n=>Float) (b:n=>Float) : n=>Float = - snd $ withState zero \buf. rof i. + yieldState zero \buf. rof i. row = for j:(i..). mat.i.%inject(j) xPrev = for j:(i..). get (buf!%inject j) buf!i := (b.i - vdot row xPrev) / mat.i.i diff --git a/examples/fluidsim.dx b/examples/fluidsim.dx index 52f745057..227817517 100644 --- a/examples/fluidsim.dx +++ b/examples/fluidsim.dx @@ -51,7 +51,7 @@ def project (_:VSpace a) ?=> (v: n=>m=>(Fin 2)=>a) : n=>m=>(Fin 2)=>a = div = -0.5 .* h .* (divergence vx vy) - p = snd $ withState zero \state. + p = yieldState zero \state. for i:(Fin 10). state := (1.0 / 4.0) .* (div + add_neighbours_2d (get state)) @@ -97,7 +97,7 @@ def advect (_:VSpace a) ?=> (f: n=>m=>a) (v: n=>m=>(Fin 2)=>Float) : n=>m=>a = def fluidsim (_: VSpace a) ?=> (num_steps: Int) (color_init: n=>m=>a) (v: n=>m=>(Fin 2)=>Float) : (Fin num_steps)=>n=>m=>a = - fst $ withState (color_init, v) \state. + withState (color_init, v) \state. for i:(Fin num_steps). (color, v) = get state v = advect v v -- Move velocities diff --git a/examples/linear_algebra.dx b/examples/linear_algebra.dx index 28f6b7bcc..ec6a6e5e9 100644 --- a/examples/linear_algebra.dx +++ b/examples/linear_algebra.dx @@ -13,7 +13,7 @@ def lowerTriDiag (l:LowerTriMat n v) : n=>v = for i. l.i.((ordinal i)@_) def forward_substitute (_:VSpace v) ?=> (a:LowerTriMat n Float) (b:n=>v) : n=>v = -- Solves lower triangular linear system (inverse a) **. b - snd $ withState zero \sRef. + yieldState zero \sRef. for i:n. s = sum for k:(.. (a:LowerTriMat n Float) (b:n=>v) : n=>v def backward_substitute (_:VSpace v) ?=> (a:UpperTriMat n Float) (b:n=>v) : n=>v = -- Solves upper triangular linear system (inverse a) **. b - snd $ withState zero \sRef. + yieldState zero \sRef. rof i:n. s = sum for k:(i..). -- dot product a.i.((ordinal k)@_) .* get sRef!(%inject k) @@ -63,7 +63,7 @@ def permSign ((_, sign):Permutation n) : PermutationSign = sign def pivotize (_:Eq n) ?=> (a:n=>n=>Float) : Permutation n = -- Gives a row permutation that makes Gaussian elimination more stable. - snd $ withState identity_permutation \permRef. + yieldState identity_permutation \permRef. for j:n. row_with_largest' = argmin for i:(j..). (-(abs a.(%inject i).j)) row_with_largest = %inject row_with_largest' @@ -82,7 +82,7 @@ def lu (_:Eq n) ?=> (a: n=>n=>Float) : select (i == (%inject j')) 1.0 0.0 init_upper = for i:n. for j'':(i..). 0.0 - (lower, upper) = snd $ withState (init_lower, init_upper) \stateRef. + (lower, upper) = yieldState (init_lower, init_upper) \stateRef. lRef = fstRef stateRef uRef = sndRef stateRef diff --git a/examples/mcmc.dx b/examples/mcmc.dx index 44ff113b0..1ba161c85 100644 --- a/examples/mcmc.dx +++ b/examples/mcmc.dx @@ -13,7 +13,7 @@ def runChain (k:Key) : Fin numSamples => a = [k1, k2] = splitKey k - fst $ withState (initialize k1) \s. + withState (initialize k1) \s. for i:(Fin numSamples). x = step (ixkey k2 i) (get s) s := x diff --git a/examples/ode-integrator.dx b/examples/ode-integrator.dx index 8edc7d8eb..f553fca91 100644 --- a/examples/ode-integrator.dx +++ b/examples/ode-integrator.dx @@ -67,10 +67,10 @@ c_error = [35. / 384. - 1951. / 21600., 0., 500. / 1113. - 22642. / 50085., def runge_kutta_step (_:VSpace v) ?=> (func:v->Time->v) (z0:v) (f0:v) (t0:Time) (dt:Time) : (v & v & v & (Fin 7)=>v) = - evals_init = snd $ withState zero \r. + evals_init = yieldState zero \r. r!(0@_) := f0 - evals_filled = snd $ withState evals_init \func_evals. for i:(Fin 6). + evals_filled = yieldState evals_init \func_evals. for i:(Fin 6). cur_evals = for j:(..i). get func_evals!((ordinal j)@_) ti = t0 + dt .* alpha.i zi = z0 + dt .* dot beta.i cur_evals @@ -134,7 +134,7 @@ def odeint (func: d=>Float -> Time -> d=>Float) select (ratio <= 1.0) move_state stay_state -- Take steps until we pass target_t - new_state = snd $ withState init_carry \stateRef. + new_state = yieldState init_carry \stateRef. iter \_. state = get stateRef if continue_condition state diff --git a/examples/particle-filter.dx b/examples/particle-filter.dx index f88d8e541..10291f5e6 100644 --- a/examples/particle-filter.dx +++ b/examples/particle-filter.dx @@ -15,7 +15,7 @@ def simulate (model: Model s v) (t: Int) (key: Key) : Fin t=>(s & v) = (init, dynamics, observe) = model [key, subkey] = splitKey key s0 = sample init subkey - fst $ withState s0 \s_ref . + withState s0 \s_ref . for i. [k1, k2] = splitKey (ixkey key i) s = get s_ref @@ -34,7 +34,7 @@ def filter (init, dynamics, observe) = model [key, init_key] = splitKey key init_particles = for i: (Fin num_particles). sample init (ixkey init_key i) - fst $ withState init_particles \p_ref . + withState init_particles \p_ref . for t: (Fin num_timesteps). p_prev = get p_ref logLikelihoods = for i. snd (observe p_prev.i) obs.t diff --git a/examples/raytrace.dx b/examples/raytrace.dx index 2f9ee3601..066f8e7f5 100644 --- a/examples/raytrace.dx +++ b/examples/raytrace.dx @@ -25,7 +25,7 @@ def randuniform (lower:Float) (upper:Float) (k:Key) : Float = lower + (rand k) * (upper - lower) def sampleAveraged (_:VSpace a) ?=> (sample:Key -> a) (n:Int) (k:Key) : a = - snd $ withState zero \total. + yieldState zero \total. for i:(Fin n). total := get total + sample (ixkey k i) / IToF n @@ -174,7 +174,7 @@ def raymarch (scene:Scene n) (ray:Ray) : RayMarchResult = tol = 0.01 startLength = 10.0 * tol -- trying to escape the current surface (rayOrigin, rayDir) = ray - fst $ withState (10.0 * tol) \rayLength. + withState (10.0 * tol) \rayLength. boundedIter maxIters HitNothing \_. rayPos = rayOrigin + get rayLength .* rayDir (obj, d) = sdScene scene $ rayPos @@ -212,7 +212,7 @@ def sampleLightRadiance (surfNor, surf) = osurf (rayPos, _) = inRay (MkScene objs) = scene - snd $ withAccum \radiance. + yieldAccum \radiance. for i. case objs.i of PassiveObject _ _ -> () Light lightPos hw _ -> @@ -227,9 +227,9 @@ def sampleLightRadiance def trace (params:Params) (scene:Scene n) (initRay:Ray) (k:Key) : Color = noFilter = [1.0, 1.0, 1.0] - snd $ withAccum \radiance. - withState noFilter \filter. - withState initRay \ray. + yieldAccum \radiance. + runState noFilter \filter. + runState initRay \ray. boundedIter (getAt #maxBounces params) () \i. case raymarch scene $ get ray of HitNothing -> Done () diff --git a/examples/sgd.dx b/examples/sgd.dx index 40f496c96..3e5a5575a 100644 --- a/examples/sgd.dx +++ b/examples/sgd.dx @@ -10,7 +10,7 @@ def sgd_step (dict: VSpace a) ?=> (step_size: Float) (decay: Float) (gradfunc: a -- In-place optimization loop. def sgd (dict: VSpace a) ?=> (step_size:Float) (decay:Float) (num_steps:Int) (gradient: a -> Int -> a) (x0: a) : a = m0 = zero - (x_final, m_final) = snd $ withState (x0, m0) \state. + (x_final, m_final) = yieldState (x0, m0) \state. for i:(Fin num_steps). (x, m) = get state state := sgd_step step_size decay gradient x m (ordinal i) diff --git a/examples/tiled-matmul.dx b/examples/tiled-matmul.dx index 42e3b5753..7238d671e 100644 --- a/examples/tiled-matmul.dx +++ b/examples/tiled-matmul.dx @@ -16,7 +16,7 @@ def matmul (k : Type) ?-> (n : Type) ?-> (m : Type) ?-> vectorTile = Fin VectorWidth colTile = (colVectors & vectorTile) (tile2d (\nt:(Tile n rowTile). \mt:(Tile m colTile). - ct = snd $ withAccum \acc. + ct = yieldAccum \acc. for l:k. for i:rowTile. ail = broadcastVector a.(nt +> i).l diff --git a/lib/parser.dx b/lib/parser.dx index 83eae1892..1190fe84e 100644 --- a/lib/parser.dx +++ b/lib/parser.dx @@ -30,7 +30,7 @@ def parse (handle:ParserHandle h) (parser:Parser a) : {Except, State h} a = def runParserPartial (s:String) (parser:Parser a) : Maybe a = (MkParser f) = parser - fst $ withState 0 \pos. + withState 0 \pos. catch $ do f (s, pos) @@ -91,7 +91,7 @@ def optional (p:Parser a) : Parser (Maybe a) = (MkParser \h. Just (parse h p)) <|> return Nothing def parseMany (parser:Parser a) : Parser (List a) = MkParser \h. - snd $ withState (AsList _ []) \results. + yieldState (AsList _ []) \results. iter \_. maybeVal = parse h $ optional parser case maybeVal of @@ -102,9 +102,8 @@ def parseMany (parser:Parser a) : Parser (List a) = MkParser \h. def parseUnsignedInt : Parser Int = MkParser \h. (AsList _ digits) = parse h $ parseMany parseDigit - snd $ withState 0 \ref. - for i. - ref := 10 * get ref + digits.i + yieldState 0 \ref. + for i. ref := 10 * get ref + digits.i def parseInt : Parser Int = MkParser \h. negSign = parse h $ optional $ pChar '-' diff --git a/lib/prelude.dx b/lib/prelude.dx index 9f0637daf..c21b635c3 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -229,27 +229,53 @@ def (!) (ref:Ref h (n=>a)) (i:n) : Ref h a = %indexRef ref i def fstRef (ref: Ref h (a & b)) : Ref h a = %fstRef ref def sndRef (ref: Ref h (a & b)) : Ref h b = %sndRef ref -def withReader - (eff:Effects) ?-> (a:Type) ?-> (r:Type) ?-> - (init:r) (action: (h:Type ?-> Ref h r -> {Read h|eff} a)) +def runReader + (eff:Effects) ?-> + (init:r) + (action: (h:Type ?-> Ref h r -> {Read h|eff} a)) : {|eff} a = def explicitAction (h':Type) (ref:Ref h' r) : {Read h'|eff} a = action ref %runReader init explicitAction -def withAccum - (eff:Effects) ?-> (a:Type) ?-> (w:Type) ?-> +def withReader + (eff:Effects) ?-> + (init:r) + (action: (h:Type ?-> Ref h r -> {Read h|eff} a)) + : {|eff} a = + runReader init action + +def runAccum + (eff:Effects) ?-> (action: (h:Type ?-> Ref h w -> {Accum h|eff} a)) : {|eff} (a & w) = def explicitAction (h':Type) (ref:Ref h' w) : {Accum h'|eff} a = action ref %runWriter explicitAction -def withState - (eff:Effects) ?-> (a:Type) ?-> (s:Type) ?-> +def yieldAccum + (eff:Effects) ?-> + (action: (h:Type ?-> Ref h w -> {Accum h|eff} a)) + : {|eff} w = + snd $ runAccum action + +def runState + (eff:Effects) ?-> (init:s) - (action: (h:Type ?-> Ref h s -> {State h |eff} a)) + (action: h:Type ?-> Ref h s -> {State h |eff} a) : {|eff} (a & s) = - def explicitAction (h':Type) (ref:Ref h' s) : {State h'|eff} a = action ref - %runState init explicitAction + def explicitAction (h':Type) (ref:Ref h' s) : {State h'|eff} a = action ref + %runState init explicitAction + +def withState + (eff:Effects) ?-> + (init:s) + (action: h:Type ?-> Ref h s -> {State h |eff} a) + : {|eff} a = fst $ runState init action + +def yieldState + (eff:Effects) ?-> + (init:s) + (action: h:Type ?-> Ref h s -> {State h |eff} a) + : {|eff} s = snd $ runState init action def unsafeIO (f: Unit -> {State World|eff} a) : {|eff} a = %runIO f @@ -310,7 +336,7 @@ def pairOrd (ordA: Ord a)?=> (ordB: Ord b)?=> : Ord (a & b) = def tabEq (n:Type) ?-> (eqA: Eq a) ?=> : Eq (n=>a) = MkEq $ \xs ys. numDifferent : Float = - snd $ withAccum \ref. for i. + yieldAccum \ref. for i. ref += (IToF (BToI (xs.i /= ys.i))) numDifferent == 0.0 @@ -539,7 +565,7 @@ def mod (x:Int) (y:Int) : Int = rem (y + rem x y) y def reindex (ixr: b -> a) (tab: a=>v) : b=>v = for i. tab.(ixr i) def scan (init:a) (body:n->a->(a&b)) : (a & n=>b) = - swap $ withState init \s. for i. + swap $ runState init \s. for i. c = get s (c', y) = body i c s := c' @@ -555,7 +581,7 @@ def reduce (identity:a) (combine:(a->a->a)) (xs:n=>a) : a = -- TODO: call this `scan` and call the current `scan` something else def scan' (init:a) (body:n->a->a) : n=>a = snd $ scan init \i x. dup (body i x) -- TODO: allow tables-via-lambda and get rid of this -def fsum (xs:n=>Float) : Float = snd $ withAccum \ref. for i. ref += xs i +def fsum (xs:n=>Float) : Float = yieldAccum \ref. for i. ref += xs i def sum (_: Add v) ?=> (xs:n=>v) : v = reduce zero (+) xs def prod (_: Mul v) ?=> (xs:n=>v) : v = reduce one (*) xs def mean (n:Type) ?-> (xs:n=>Float) : Float = sum xs / IToF (size n) @@ -564,7 +590,7 @@ def any (xs:n=>Bool) : Bool = reduce False (||) xs def all (xs:n=>Bool) : Bool = reduce True (&&) xs def applyN (n:Int) (x:a) (f:a -> a) : a = - snd $ withState x \ref. for _:(Fin n). + yieldState x \ref. for _:(Fin n). ref := f (get ref) def linspace (n:Type) (low:Float) (high:Float) : n=>Float = @@ -623,7 +649,7 @@ def randnVec (n:Type) ?-> (k:Key) : n=>Float = for i. randn (ixkey k i) def cumSum (xs: n=>Float) : n=>Float = - fst $ withState 0.0 \total. + withState 0.0 \total. for i. newTotal = get total + xs.i total := newTotal @@ -1047,7 +1073,7 @@ def liftState (ref: Ref h c) (f:a -> {|eff} b) (x:a) : {State h|eff} b = -- A little iteration combinator def iter (body: Int -> {|eff} IterResult a) : {|eff} a = - result = snd $ withState Nothing \resultRef. withState 0 \i. + result = yieldState Nothing \resultRef. withState 0 \i. while do continue = isNothing $ get resultRef if continue @@ -1063,7 +1089,7 @@ def iter (body: Int -> {|eff} IterResult a) : {|eff} a = -- XXX: used internally by compiler for exceptional while def whileMaybe (eff:Effects) -> (body: Unit -> {|eff} (Maybe Word8)) : {|eff} Maybe Unit = - hadError = snd $ withState False \ref. + hadError = yieldState False \ref. while do ans = liftState ref body () case ans of @@ -1233,7 +1259,7 @@ def searchSorted (_:Ord a) ?=> (xs:n=>a) (x:a) : Maybe n = then Nothing else if x < xs.(fromOrdinal _ 0) then Nothing - else fst $ withState 0 \low. fst $ withState (size n) \high. iter \_. + else withState 0 \low. withState (size n) \high. iter \_. numLeft = get high - get low if numLeft == 1 then Done $ Just $ fromOrdinal _ $ get low @@ -1466,7 +1492,7 @@ def seqMaybes (n:Type) ?-> (a:Type) ?-> (xs : n=>Maybe a) : Maybe (n => a) = False -> Just $ map fromJust xs def linearSearch (_:Eq a) ?=> (xs:n=>a) (query:a) : Maybe n = - snd $ withState Nothing \ref. for i. + yieldState Nothing \ref. for i. case xs.i == query of True -> ref := Just i False -> () @@ -1477,8 +1503,8 @@ def listLength ((AsList n _):List a) : Int = n -- TODO: we want this for any monoid but this implementation won't work. def concat (lists:n=>(List a)) : List a = totalSize = sum for i. listLength lists.i - AsList _ $ fst $ withState 0 \listIdx. - fst $ withState 0 \eltIdx. + AsList _ $ withState 0 \listIdx. + withState 0 \eltIdx. for i:(Fin totalSize). while do continue = get eltIdx >= listLength (lists.((get listIdx)@_)) @@ -1494,7 +1520,7 @@ def concat (lists:n=>(List a)) : List a = xs.(eltIdxVal@_) def cumSumLow (xs: n=>Float) : n=>Float = - fst $ withState 0.0 \total. + withState 0.0 \total. for i. oldTotal = get total total := oldTotal + xs.i diff --git a/tests/ad-tests.dx b/tests/ad-tests.dx index d541bc363..6affc69f6 100644 --- a/tests/ad-tests.dx +++ b/tests/ad-tests.dx @@ -1,6 +1,6 @@ -- TODO: use prelude sum instead once we can differentiate state effect -def sum' (xs:n=>Float) : Float = snd $ withAccum \ref. for i. ref += xs.i +def sum' (xs:n=>Float) : Float = yieldAccum \ref. for i. ref += xs.i :p f : Float -> Float = \x. x @@ -69,7 +69,7 @@ def sum' (xs:n=>Float) : Float = snd $ withAccum \ref. for i. ref += xs.i :p jvp sum' [1., 2.] [10.0, 20.0] > 30. -f : Float -> Float = \x. snd $ withAccum \ref. ref += x +f : Float -> Float = \x. yieldAccum \ref. ref += x :p jvp f 1.0 1.0 > 1. @@ -167,7 +167,7 @@ tripleit : Float --o Float = \x. x + x + x > [2., 4.] myOtherSquare : Float -> Float = - \x. snd $ withAccum \w. w += x * x + \x. yieldAccum \w. w += x * x :p checkDeriv myOtherSquare 3.0 > True @@ -225,7 +225,7 @@ vec = [1.] :p f : Float -> Float = \x. y = x * 2.0 - snd $ withAccum \a. + yieldAccum \a. a += x * 2.0 a += y grad f 1.0 @@ -242,7 +242,7 @@ vec = [1.] :p f : Float -> Float = \x. - snd $ withState x \xr. + yieldState x \xr. for i:(Fin 2). xr := get xr * get xr checkDeriv f 2.0 @@ -297,7 +297,7 @@ vec = [1.] :p f = \c. v = for i:(Fin 2). 2.0 - (c, v) = snd $ withState (c, v) \r. for i:(Fin 2). + (c, v) = yieldState (c, v) \r. for i:(Fin 2). (c, v) = get r r := (c + sum v, v) c diff --git a/tests/adt-tests.dx b/tests/adt-tests.dx index 72ff717d8..1d2d2306e 100644 --- a/tests/adt-tests.dx +++ b/tests/adt-tests.dx @@ -102,7 +102,7 @@ myTab = [MyLeft 1, MyRight 3.5, MyLeft 123, MyLeft 456] > Runtime error :p - snd $ withAccum \ref. + yieldAccum \ref. for i. case myTab.i of MyLeft tmp -> () MyRight val -> ref += 1.0 + val @@ -110,7 +110,7 @@ myTab = [MyLeft 1, MyRight 3.5, MyLeft 123, MyLeft 456] :p -- check that the order of the case alternatives doesn't matter - snd $ withAccum \ref. + yieldAccum \ref. for i. case myTab.i of MyRight val -> ref += 1.0 + val MyLeft tmp -> () @@ -128,7 +128,7 @@ threeCaseTab : (Fin 4)=>ThreeCases = > [(TheIntCase 3), TheEmptyCase, (ThePairCase 2 0.1), TheEmptyCase] :p - snd $ withAccum \ref. + yieldAccum \ref. for i. case threeCaseTab.i of TheEmptyCase -> ref += 1000.0 ThePairCase x y -> ref += 100.0 + y + IToF x @@ -250,7 +250,7 @@ data Graph a:Type = def graphToAdjacencyMatrix ((MkGraph n nodes m edges):Graph a) : n=>n=>Bool = init = for i j. False - snd $ withState init \mRef. + yieldState init \mRef. for i:m. (from, to) = edges.i mRef!from!to := True diff --git a/tests/eval-tests.dx b/tests/eval-tests.dx index f54b1a192..f853dceab 100644 --- a/tests/eval-tests.dx +++ b/tests/eval-tests.dx @@ -502,14 +502,14 @@ litArr = [10, 5, 3] -- > [2.0, 2.0, 2.0] :p - withState 0.0 \ref. for i:(Fin 4). + runState 0.0 \ref. for i:(Fin 4). c = get ref ref := c + 1.0 c > ([0., 1., 2., 3.], 4.) :p - withState 0.0 \ref. rof i:(Fin 4). + runState 0.0 \ref. rof i:(Fin 4). c = get ref ref := c + 1.0 c @@ -646,7 +646,7 @@ def newtonIter (f: Float -> Float) (x:Float) : Float = x - (f x / deriv f x) def newtonSolve (tol:Float) (f : Float -> Float) (x0:Float) : Float = - snd $ withState x0 \x. + yieldState x0 \x. iter \i. if abs (f $ get x) <= tol then Done () @@ -661,7 +661,7 @@ def newtonSolve (tol:Float) (f : Float -> Float) (x0:Float) : Float = -- x = for i:(Fin 3). for j:(Fin 200). 1.0 -- -- Last dimension split to allow for vector loads -- y = for i:(Fin 200). for j:(Fin 4). for h:(Fin VectorWidth). IToF $ (iota _).(i,j,h) --- z = snd $ withAccum \acc. +-- z = yieldAccum \acc. -- for l. -- for i. -- xil = (broadcastVector x.i.l) @@ -689,7 +689,7 @@ def newtonSolve (tol:Float) (f : Float -> Float) (x0:Float) : Float = > [0, 2, 4, 6] :p - f = fst $ withState 1 \ref. + f = withState 1 \ref. x = get ref ref := 3 + x y = get ref @@ -698,7 +698,7 @@ def newtonSolve (tol:Float) (f : Float -> Float) (x0:Float) : Float = > 415 :p - (f, w) = withAccum \ref. + (f, w) = runAccum \ref. ref += 2.0 w = 2 \z. z + w @@ -717,7 +717,7 @@ arr2d.(1@_) > [2, 3] :p - withState (1,2) \ref. + runState (1,2) \ref. r1 = fstRef ref r2 = sndRef ref x = get r1 diff --git a/tests/exception-tests.dx b/tests/exception-tests.dx index 687742559..3df621614 100644 --- a/tests/exception-tests.dx +++ b/tests/exception-tests.dx @@ -20,7 +20,7 @@ def checkFloatInUnitInterval (x:Float) : {Except} Float = :p catch do checkFloatInUnitInterval 0.2 > (Just 0.2) -:p snd $ withState 0 \ref. +:p yieldState 0 \ref. catch do ref := 1 assert False @@ -42,7 +42,7 @@ def checkFloatInUnitInterval (x:Float) : {Except} Float = > (Just [23, 23, 23]) -- Is this the result we want? -:p snd $ withState zero \ref. +:p yieldState zero \ref. catch do for i:(Fin 6). if (ordinal i `rem` 2) == 0 @@ -52,7 +52,7 @@ def checkFloatInUnitInterval (x:Float) : {Except} Float = > [0, 1, 0, 1, 0, 1] :p catch do - withState 0 \ref. + runState 0 \ref. ref := 1 assert False ref := 2 diff --git a/tests/gpu-tests.dx b/tests/gpu-tests.dx index 7116a4f19..9f0ed5831 100644 --- a/tests/gpu-tests.dx +++ b/tests/gpu-tests.dx @@ -27,7 +27,7 @@ testNestedLoops.(4@_).(5@_) -- single GPU thread. It should get lifted to a top-level allocation instead. allocationLiftingTest = for i:(Fin 100). - snd $ withState (for j:(Fin 1000). ordinal i) $ \s. + yieldState (for j:(Fin 1000). ordinal i) $ \s. s!(0@_) := get s!(0@_) + 1 (allocationLiftingTest.(4@_).(0@_), allocationLiftingTest.(4@_).(1@_)) > (5, 4) diff --git a/tests/monad-tests.dx b/tests/monad-tests.dx index 65e459c69..8f9994252 100644 --- a/tests/monad-tests.dx +++ b/tests/monad-tests.dx @@ -1,12 +1,12 @@ :p def m (h:Type) ?-> (ref:Ref h Int) : {State h} Int = get ref - withState 2 m + runState 2 m > (2, 2) :p def m (h:Type) ?-> (ref:Ref h Int) : {State h} Unit = ref := 3 - withState 0 m + runState 0 m > ((), 3) :p @@ -21,7 +21,7 @@ z = get ref ref := (z * 3.0) - withState 1.0 stateAction + runState 1.0 stateAction > ((), 9.) :p @@ -37,8 +37,8 @@ r + 2 withReader 2 \r. - withState True \s. - withAccum \w. + runState True \s. + runAccum \w. rwsAction r w s > ((4, 6.), False) @@ -48,7 +48,7 @@ s!(fromOrdinal _ 2) := 20 x = get (s!(fromOrdinal _ 0)) s!(fromOrdinal _ 1) := x - withState [0,0,0] m + runState [0,0,0] m > ((), [10, 10, 20]) :p withReader [1,2,3] \r . ask r!(fromOrdinal _ 1) @@ -60,7 +60,7 @@ : {Accum wh, State sh} Unit = x = get s w += x - withState 1.0 \s. withAccum \w . m w s + runState 1.0 \s. runAccum \w . m w s > (((), 1.), 1.) def myAction (w:Ref hw Float) (r:Ref hr Float) : {Read hr, Accum hw} Unit = @@ -68,7 +68,7 @@ def myAction (w:Ref hw Float) (r:Ref hr Float) : {Read hr, Accum hw} Unit = w += x w += 2.0 -:p withReader 1.5 \r. withAccum \w. myAction w r +:p withReader 1.5 \r. runAccum \w. myAction w r > ((), 3.5) :p @@ -78,14 +78,14 @@ def myAction (w:Ref hw Float) (r:Ref hr Float) : {Read hr, Accum hw} Unit = w1 += 1.0 w2 += 3.0 w1 += 1.0 - withAccum \w1. withAccum \w2. m w1 w2 + runAccum \w1. runAccum \w2. m w1 w2 > (((), 3.), 2.) def foom (h:Type) ?-> (s:Ref h ((Fin 3)=>Int)) : {State h} Unit = s!(fromOrdinal _ 0) := 1 s!(fromOrdinal _ 2) := 2 -:p withState [0,0,0] foom +:p runState [0,0,0] foom > ((), [1, 0, 2]) -- TODO: handle effects returning functions @@ -102,7 +102,7 @@ def foom (h:Type) ?-> (s:Ref h ((Fin 3)=>Int)) : {State h} Unit = -- :p -- foo : Float -> (Float, Float) -- foo x = --- (f, ans) = withState x \s. +-- (f, ans) = runState x \s. -- y = get s -- \z. 100.0 * x + 10.0 * y + z -- (f 1.0, ans) @@ -113,7 +113,7 @@ def foom (h:Type) ?-> (s:Ref h ((Fin 3)=>Int)) : {State h} Unit = -- :p -- foo : Float -> (Float, Float) -- foo x = --- (f, ans) = withAccumulator \s. +-- (f, ans) = runAccumulator \s. -- s += x -- \y. 10.0 * x + y -- (f 1.0, ans) @@ -121,13 +121,13 @@ def foom (h:Type) ?-> (s:Ref h ((Fin 3)=>Int)) : {State h} Unit = -- foo 3.0 -- > (31.0, 3.0) --- TODO: some way to explicitly give type to `withAccum` +-- TODO: some way to explicitly give type to `runAccum` -- (maybe just explicit implicit args) :p withReader 2.0 \r. - withAccum \w. - withAccum \w'. - withState 3 \s. + runAccum \w. + runAccum \w'. + runState 3 \s. x = ask r y = get s w += x @@ -137,7 +137,7 @@ def foom (h:Type) ?-> (s:Ref h ((Fin 3)=>Int)) : {State h} Unit = > ((((2., 3), 4), 4.), 2.) def symmetrizeInPlace (mat:n=>n=>Float) : n=>n=>Float = - snd $ withState mat \ref. + yieldState mat \ref. for i j. x = get ref!i!j y = get ref!j!i @@ -151,19 +151,19 @@ symmetrizeInPlace [[1.,2.],[3.,4.]] :p withReader 5 \r. () > () -:p snd $ withAccum \w. +:p yieldAccum \w. for i:(Fin 2). w += 1.0 w += 1.0 > 4. -:p snd $ withAccum \w. +:p yieldAccum \w. for i:(Fin 2). w += 1.0 w += 1.0 > 3. -:p snd $ withAccum \ref. +:p yieldAccum \ref. ref += [1.,2.,3.] ref += [2.,4.,5.] > [3., 6., 8.] @@ -172,5 +172,5 @@ def effectsAtZero (eff:Effects)?-> (f: Int ->{|eff} Unit) : {|eff} Unit = f 0 () -:p withState 0 \ref. effectsAtZero \_. ref := 1 +:p runState 0 \ref. effectsAtZero \_. ref := 1 > ((), 1) diff --git a/tests/parser-tests.dx b/tests/parser-tests.dx index ac388a409..8022eb51c 100644 --- a/tests/parser-tests.dx +++ b/tests/parser-tests.dx @@ -94,7 +94,7 @@ lam4 = \n m ?-> (0@n, 0@m) > [1, 0, 0] :p - withState 5 \ref. + runState 5 \ref. n = get ref for_ i:(Fin n). ref := get ref + 1 diff --git a/tests/uexpr-tests.dx b/tests/uexpr-tests.dx index eb3f0c880..eb13abb1f 100644 --- a/tests/uexpr-tests.dx +++ b/tests/uexpr-tests.dx @@ -100,13 +100,13 @@ myPair = (1, 2.3) > 1 :p - snd $ withState 2 \s. + yieldState 2 \s. x = get s s := x + 3 > 5 :p - snd $ withState 1 \s. + yieldState 1 \s. for i:(Fin 10). x = get s s := x + x @@ -178,7 +178,7 @@ myPair = (1, 2.3) > ^^^ :p - snd $ withState [1,2,3] \xsRef. + yieldState [1,2,3] \xsRef. for i:(Fin 3). xsRef!i := ordinal i > [0, 1, 2] @@ -186,13 +186,13 @@ myPair = (1, 2.3) def passthrough (eff:Effects) ?-> (f:(a -> {|eff} b)) (x:a) : {|eff} b = f x :p - snd $ withState 1 \ref. + yieldState 1 \ref. passthrough (\(). ref := 10) () > 10 :p - withState 0 \r1. - withState 0 \r2. + runState 0 \r1. + runState 0 \r2. r1 := 1 r2 := 2 > (((), 2), 1) From 0b76c128cc3f56b561fa2bba070d340a1fe9d6c8 Mon Sep 17 00:00:00 2001 From: joaogui1 Date: Thu, 31 Dec 2020 16:05:02 -0300 Subject: [PATCH 066/105] clang install instructions --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index c3c3462d9..6879b33b1 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,9 @@ development. Expect monstrous bugs and razor-sharp edges. Contributions welcome! * Ubuntu/Debian: `apt-get install llvm-9-dev` * macOS: `brew install llvm@9` * Make sure `llvm@9` is on your `PATH` before building. Example: `export PATH="$(brew --prefix llvm@9)/bin:$PATH"` + * Install clang (may be installed together with llvm) + * Ubuntu/Debian: `apt-get install clang` + * macOS: `installs with llvm` * Install libpng (often included by default in *nix platforms) * Ubuntu/Debian: `apt-get install libpng-dev` * macOS: `brew install libpng` From 37e31a116d7fcdc6538d5d1d1967e11e19cf86e6 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Thu, 31 Dec 2020 14:20:46 -0500 Subject: [PATCH 067/105] Make the `else` clause of `if..then..else` optional. If the `else` clause is absent it parses as a trivial `else` clause, `()`. This is convenient for conditionally executing an effect. Also allow `if..then..else` on a single line. --- examples/raytrace.dx | 4 ++-- lib/prelude.dx | 28 +++++++++------------------- src/lib/Parser.hs | 38 ++++++++++++++++++++++++++++---------- tests/parser-tests.dx | 17 +++++++++++++++++ 4 files changed, 56 insertions(+), 31 deletions(-) diff --git a/examples/raytrace.dx b/examples/raytrace.dx index 066f8e7f5..fc44b7054 100644 --- a/examples/raytrace.dx +++ b/examples/raytrace.dx @@ -218,7 +218,7 @@ def sampleLightRadiance Light lightPos hw _ -> (dirToLight, distToLight) = directionAndLength $ lightPos + sampleSquare hw k - rayPos - when (positiveProjection dirToLight surfNor) do + if positiveProjection dirToLight surfNor then -- light on this far side of current surface fracSolidAngle = (relu $ dot dirToLight yHat) * sq hw / (pi * sq distToLight) outRay = (rayPos, dirToLight) @@ -234,7 +234,7 @@ def trace (params:Params) (scene:Scene n) (initRay:Ray) (k:Key) : Color = case raymarch scene $ get ray of HitNothing -> Done () HitLight intensity -> - when (i == 0) do radiance += intensity -- TODO: scale etc + if i == 0 then radiance += intensity -- TODO: scale etc Done () HitObj incidentRay osurf -> [k1, k2] = splitKey $ hash k i diff --git a/lib/prelude.dx b/lib/prelude.dx index c21b635c3..3e00ac2e0 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -911,14 +911,12 @@ def maybeIncreaseBufferSize (_:Storable a) ?=> (MkDynBuffer dbPtr) = buf (size, maxSize, bufPtr) = load dbPtr newSize = sizeDelta + size - if newSize > maxSize - then - -- TODO: maybe this should use integer arithmetic? - newMaxSize = FToI $ pow 2.0 (ceil $ log2 $ IToF newSize) - newBufPtr = malloc newMaxSize - memcpy newBufPtr bufPtr size - store dbPtr (size, newMaxSize, newBufPtr) - else () + if newSize > maxSize then + -- TODO: maybe this should use integer arithmetic? + newMaxSize = FToI $ pow 2.0 (ceil $ log2 $ IToF newSize) + newBufPtr = malloc newMaxSize + memcpy newBufPtr bufPtr size + store dbPtr (size, newMaxSize, newBufPtr) def extendDynBuffer (_:Storable a) ?=> (buf: DynBuffer a) (new:List a) : {State World} Unit = @@ -1062,11 +1060,6 @@ data IterResult a:Type = Continue Done a -def when (cond:Bool) (f:Unit -> {|eff} Unit) : {|eff} Unit = - if cond - then f () - else () - -- TODO: can we improve effect inference so we don't need this? def liftState (ref: Ref h c) (f:a -> {|eff} b) (x:a) : {State h|eff} b = f x @@ -1076,11 +1069,10 @@ def iter (body: Int -> {|eff} IterResult a) : {|eff} a = result = yieldState Nothing \resultRef. withState 0 \i. while do continue = isNothing $ get resultRef - if continue - then case liftState resultRef (liftState i body) (get i) of + if continue then + case liftState resultRef (liftState i body) (get i) of Continue -> i := get i + 1 Done result -> resultRef := Just result - else () continue case result of @@ -1578,6 +1570,4 @@ def throw (_:Unit) : {Except} a = %throwException a def assert (b:Bool) : {Except} Unit = - if b - then () - else throw () + if not b then throw () diff --git a/src/lib/Parser.hs b/src/lib/Parser.hs index 83527369a..3129b4c0e 100644 --- a/src/lib/Parser.hs +++ b/src/lib/Parser.hs @@ -520,14 +520,14 @@ uForExpr = do <|> (keyWord Rof_KW $> (Rev, True )) e <- buildFor pos dir <$> (some patAnn <* argTerm) <*> blockOrExpr if trailingUnit - then return $ noSrc $ UDecl (ULet PlainLet underscorePat e) unit + then return $ noSrc $ UDecl (ULet PlainLet underscorePat e) $ noSrc unitExpr else return e where underscorePat :: UPatAnn underscorePat = (noSrc $ UPatBinder $ Ignore (), Nothing) - unit :: UExpr - unit = noSrc $ UPrimExpr $ ConExpr UnitCon +unitExpr :: UExpr' +unitExpr = UPrimExpr $ ConExpr UnitCon noSrc :: a -> WithSrc a noSrc = WithSrc Nothing @@ -536,7 +536,7 @@ blockOrExpr :: Parser UExpr blockOrExpr = block <|> expr unitCon :: Parser UExpr -unitCon = withSrc $ symbol "()" $> (UPrimExpr $ ConExpr $ UnitCon) +unitCon = withSrc $ symbol "()" $> unitExpr uTabCon :: Parser UExpr uTabCon = withSrc $ do @@ -608,14 +608,32 @@ ifExpr :: Parser UExpr ifExpr = withSrc $ do keyWord IfKW e <- expr - withIndent $ mayNotBreak $ do - alt1 <- keyWord ThenKW >> blockOrExpr - nextLine - alt2 <- keyWord ElseKW >> blockOrExpr - return $ UCase e - [ UAlt (globalEnumPat "True") alt1 + (alt1, maybeAlt2) <- oneLineThenElse <|> blockThenElse + let alt2 = case maybeAlt2 of + Nothing -> noSrc unitExpr + Just alt -> alt + return $ UCase e + [ UAlt (globalEnumPat "True" ) alt1 , UAlt (globalEnumPat "False") alt2] +oneLineThenElse :: Parser (UExpr, Maybe UExpr) +oneLineThenElse = do + keyWord ThenKW + alt1 <- eitherP block expr + case alt1 of + Left e -> return (e, Nothing) + Right e -> do + alt2 <- optional $ keyWord ElseKW >> blockOrExpr + return (e, alt2) + +blockThenElse :: Parser (UExpr, Maybe UExpr) +blockThenElse = withIndent $ mayNotBreak $ do + alt1 <- keyWord ThenKW >> blockOrExpr + alt2 <- optional $ do + try $ nextLine >> keyWord ElseKW + blockOrExpr + return (alt1, alt2) + globalEnumPat :: Tag -> UPat globalEnumPat s = noSrc $ UPatCon (GlobalName s) Empty diff --git a/tests/parser-tests.dx b/tests/parser-tests.dx index 8022eb51c..93d216354 100644 --- a/tests/parser-tests.dx +++ b/tests/parser-tests.dx @@ -111,3 +111,20 @@ def myInt : {State h} Int = 1 > 107 | def myInt : {State h} Int = 1 > | ^ > Nullary def can't have effects + +:p + yieldAccum \ref. + x = if True then 1. else 3. + if True then ref += x + + if True then + ref += 1. + ref += 2. + + if False then ref += 100. else + ref += 1. + ref += 2. + + if True + then ref += 2. +> 9. From d0f4702207096a7b2dbb1ffdf8ef5690e30326ed Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Sat, 2 Jan 2021 18:56:29 +0100 Subject: [PATCH 068/105] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 6879b33b1..06b37702f 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ development. Expect monstrous bugs and razor-sharp edges. Contributions welcome! * Make sure `llvm@9` is on your `PATH` before building. Example: `export PATH="$(brew --prefix llvm@9)/bin:$PATH"` * Install clang (may be installed together with llvm) * Ubuntu/Debian: `apt-get install clang` - * macOS: `installs with llvm` + * macOS: installs with llvm * Install libpng (often included by default in *nix platforms) * Ubuntu/Debian: `apt-get install libpng-dev` * macOS: `brew install libpng` From c004ceb28e1354cd5a079c6aa40f67e0515950bc Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 4 Jan 2021 20:15:58 +0100 Subject: [PATCH 069/105] Make README a little nicer, add pointers to new label system --- README.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 06b37702f..a02161574 100644 --- a/README.md +++ b/README.md @@ -21,8 +21,10 @@ or these example programs: * [Basis function regression](https://google-research.github.io/dex-lang/examples/regression.html) * [Brownian bridge](https://google-research.github.io/dex-lang/examples/brownian_motion.html) -⚠️ Dex is an experimental research project at an early stage of -development. Expect monstrous bugs and razor-sharp edges. Contributions welcome! ⚠️ +🚨 **Dex is an experimental research project at an early stage of +development. Expect monstrous bugs and razor-sharp edges!** + +🤝 **Contributions welcome!** See our issue tracker for [good first issues](https://github.com/google-research/dex-lang/labels/good%20first%20issue), or browse by [tematic labels](https://github.com/google-research/dex-lang/labels). ## Dependencies From 66d225f2619779433beb4e5c737248ba3d3a8595 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 4 Jan 2021 23:26:59 +0000 Subject: [PATCH 070/105] Add completions to REPL (#407) --- src/dex.hs | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/src/dex.hs b/src/dex.hs index f08f56c25..7de5696fa 100644 --- a/src/dex.hs +++ b/src/dex.hs @@ -15,6 +15,7 @@ import System.Posix.Terminal (queryTerminal) import System.Posix.IO (stdOutput) import System.Exit import System.Directory +import Data.List import Syntax import PPrint @@ -25,6 +26,7 @@ import Resources import TopLevel import Parser hiding (Parser) import LiveOutput +import Env (envNames) import Export data ErrorHandling = HaltOnErr | ContinueOnErr @@ -45,8 +47,10 @@ runMode evalMode preludeFile opts = do env <- cached "prelude" key $ evalPrelude opts preludeFile let runEnv m = evalStateT m env case evalMode of - ReplMode prompt -> - runEnv $ runInputT defaultSettings $ forever (replLoop prompt opts) + ReplMode prompt -> do + let filenameAndDexCompletions = completeQuotedWord (Just '\\') "\"'" listFiles dexCompletions + let hasklineSettings = setComplete filenameAndDexCompletions defaultSettings + runEnv $ runInputT hasklineSettings $ forever (replLoop prompt opts) ScriptMode fname fmt _ -> do results <- runEnv $ evalFile opts fname printLitProg fmt results @@ -81,6 +85,20 @@ replLoop prompt opts = do _ -> return () liftIO $ putStrLn $ pprint result +dexCompletions :: CompletionFunc (StateT TopEnv IO) +dexCompletions (line, _) = do + env <- get + let varNames = map pprint $ envNames env + -- note: line and thus word and rest have character order reversed + let (word, rest) = break (== ' ') line + let anywhereKeywords = ["def", "for", "rof", "case", "data", "where", "of", "if", + "then", "else", "interface", "instance", "do", "view"] + let startoflineKeywords = ["%bench \"", ":p", ":t", ":html", ":export"] + let candidates = (if null rest then startoflineKeywords else []) ++ + anywhereKeywords ++ varNames + let completions = map simpleCompletion $ filter ((reverse word) `isPrefixOf`) candidates + return (rest, completions) + liftErrIO :: MonadIO m => Except a -> m a liftErrIO (Left err) = liftIO $ putStrLn (pprint err) >> exitFailure liftErrIO (Right x) = return x From cd7cd095755abc84f1bac0d6cd844031a09e6091 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Thu, 31 Dec 2020 17:28:10 -0500 Subject: [PATCH 071/105] Add new syntax for type class constraints on `def` decls. Fixes #370. --- examples/chol.dx | 4 +- examples/ctc.dx | 7 +- examples/fluidsim.dx | 20 +++--- examples/linear_algebra.dx | 26 ++++---- examples/mcmc.dx | 2 +- examples/ode-integrator.dx | 6 +- examples/raytrace.dx | 2 +- examples/sgd.dx | 4 +- lib/diagram.dx | 4 +- lib/plot.dx | 2 +- lib/png.dx | 2 +- lib/prelude.dx | 131 ++++++++++++++++++------------------- src/lib/Parser.hs | 30 ++++++--- 13 files changed, 126 insertions(+), 114 deletions(-) diff --git a/examples/chol.dx b/examples/chol.dx index 1d8c53d11..63473ba91 100644 --- a/examples/chol.dx +++ b/examples/chol.dx @@ -3,7 +3,7 @@ https://en.wikipedia.org/wiki/Cholesky_decomposition ' ## Cholesky Algorithm -def chol (_:Eq n) ?=> (x:n=>n=>Float) : (n=>n=>Float) = +def chol [Eq n] (x:n=>n=>Float) : (n=>n=>Float) = yieldState zero \buf. for_ i. for j':(..i). j = %inject(j') @@ -31,7 +31,7 @@ def trisolveU (mat:n=>n=>Float) (b:n=>Float) : n=>Float = xPrev = for j:(i..). get (buf!%inject j) buf!i := (b.i - vdot row xPrev) / mat.i.i -def psdsolve (_:Eq n) ?=> (mat:n=>n=>Float) (b:n=>Float) : n=>Float = +def psdsolve [Eq n] (mat:n=>n=>Float) (b:n=>Float) : n=>Float = l = chol mat trisolveU (transpose l) $ trisolveL l b diff --git a/examples/ctc.dx b/examples/ctc.dx index d0dc7979a..aa7b5fc77 100644 --- a/examples/ctc.dx +++ b/examples/ctc.dx @@ -48,8 +48,11 @@ def logaddexp (x:Float) (y:Float) : Float = m = max x y m + ( log ( (exp (x - m) + exp (y - m)))) -def ctc (dict: Eq vocab) ?=> (dict2: Eq position) ?=> (dict3: Eq time) ?=> (blank: vocab) - (logits: time=>vocab=>Float) (labels: position=>vocab) : Float = +def ctc [Eq vocab, Eq position, Eq time] + (blank: vocab) + (logits: time=>vocab=>Float) + (labels: position=>vocab) + : Float = -- Computes log p(labels | logits), marginalizing over possible alignments. -- Todo: remove unnecessary implicit type annotations once -- Dex starts putting implicit types in scope. diff --git a/examples/fluidsim.dx b/examples/fluidsim.dx index 227817517..e1ea7e9ac 100644 --- a/examples/fluidsim.dx +++ b/examples/fluidsim.dx @@ -14,10 +14,10 @@ def incwrap (i:n) : n = -- Increment index, wrapping around at ends. def decwrap (i:n) : n = -- Decrement index, wrapping around at ends. asidx $ mod ((ordinal i) - 1) $ size n -def finite_difference_neighbours (_:Add a) ?=> (x:n=>a) : n=>a = +def finite_difference_neighbours [Add a] (x:n=>a) : n=>a = for i. x.(incwrap i) - x.(decwrap i) -def add_neighbours (_:Add a) ?=> (x:n=>a) : n=>a = +def add_neighbours [Add a] (x:n=>a) : n=>a = for i. x.(incwrap i) + x.(decwrap i) def apply_along_axis1 (f:b=>a -> b=>a) (x:b=>c=>a) : b=>c=>a = @@ -26,21 +26,21 @@ def apply_along_axis1 (f:b=>a -> b=>a) (x:b=>c=>a) : b=>c=>a = def apply_along_axis2 (f:c=>a -> c=>a) (x:b=>c=>a) : b=>c=>a = for i. f x.i -def fdx (_:Add a) ?=> (x:n=>m=>a) : (n=>m=>a) = +def fdx [Add a] (x:n=>m=>a) : (n=>m=>a) = apply_along_axis1 finite_difference_neighbours x -def fdy (_:Add a) ?=> (x:n=>m=>a) : (n=>m=>a) = +def fdy [Add a] (x:n=>m=>a) : (n=>m=>a) = apply_along_axis2 finite_difference_neighbours x -def divergence (_:Add a) ?=> (vx:n=>m=>a) (vy:n=>m=>a) : (n=>m=>a) = +def divergence [Add a] (vx:n=>m=>a) (vy:n=>m=>a) : (n=>m=>a) = fdx vx + fdy vy -def add_neighbours_2d (_:Add a) ?=> (x:n=>m=>a) : (n=>m=>a) = +def add_neighbours_2d [Add a] (x:n=>m=>a) : (n=>m=>a) = ax1 = apply_along_axis1 add_neighbours x ax2 = apply_along_axis2 add_neighbours x ax1 + ax2 -def project (_:VSpace a) ?=> (v: n=>m=>(Fin 2)=>a) : n=>m=>(Fin 2)=>a = +def project [VSpace a] (v: n=>m=>(Fin 2)=>a) : n=>m=>(Fin 2)=>a = -- Project the velocity field to be approximately mass-conserving, -- using a few iterations of Gauss-Seidel. h = 1.0 / IToF (size n) @@ -60,13 +60,13 @@ def project (_:VSpace a) ?=> (v: n=>m=>(Fin 2)=>a) : n=>m=>(Fin 2)=>a = for i j. [vx.i.j, vy.i.j] -- pack back into a table. -def bilinear_interp (_:VSpace a) ?=> (right_weight:Float) (bottom_weight:Float) +def bilinear_interp [VSpace a] (right_weight:Float) (bottom_weight:Float) (topleft: a) (bottomleft: a) (topright: a) (bottomright: a) : a = left = (1.0 - right_weight) .* ((1.0 - bottom_weight) .* topleft + bottom_weight .* bottomleft) right = right_weight .* ((1.0 - bottom_weight) .* topright + bottom_weight .* bottomright) left + right -def advect (_:VSpace a) ?=> (f: n=>m=>a) (v: n=>m=>(Fin 2)=>Float) : n=>m=>a = +def advect [VSpace a] (f: n=>m=>a) (v: n=>m=>(Fin 2)=>Float) : n=>m=>a = -- Move field f according to x and y velocities (u and v) -- using an implicit Euler integrator. @@ -95,7 +95,7 @@ def advect (_:VSpace a) ?=> (f: n=>m=>a) (v: n=>m=>(Fin 2)=>Float) : n=>m=>a = -- A convex weighting of the 4 surrounding cells. bilinear_interp right_weight bottom_weight f.l.t f.l.b f.r.t f.r.b -def fluidsim (_: VSpace a) ?=> (num_steps: Int) (color_init: n=>m=>a) +def fluidsim [ VSpace a] (num_steps: Int) (color_init: n=>m=>a) (v: n=>m=>(Fin 2)=>Float) : (Fin num_steps)=>n=>m=>a = withState (color_init, v) \state. for i:(Fin num_steps). diff --git a/examples/linear_algebra.dx b/examples/linear_algebra.dx index ec6a6e5e9..2d0cffd14 100644 --- a/examples/linear_algebra.dx +++ b/examples/linear_algebra.dx @@ -1,6 +1,6 @@ '## LU Decomposition and Matrix Inversion -def identity_matrix (_:Eq n) ?=> (_:Add a) ?=> (_:Mul a) ?=> : n=>n=>a = +def identity_matrix [Eq n, Add a, Mul a] : n=>n=>a = for i j. select (i == j) one zero '### Triangular matrices @@ -11,7 +11,7 @@ def UpperTriMat (n:Type) (v:Type) : Type = i:n=>(i..)=>v def upperTriDiag (u:UpperTriMat n v) : n=>v = for i. u.i.(0@_) def lowerTriDiag (l:LowerTriMat n v) : n=>v = for i. l.i.((ordinal i)@_) -def forward_substitute (_:VSpace v) ?=> (a:LowerTriMat n Float) (b:n=>v) : n=>v = +def forward_substitute [VSpace v] (a:LowerTriMat n Float) (b:n=>v) : n=>v = -- Solves lower triangular linear system (inverse a) **. b yieldState zero \sRef. for i:n. @@ -19,7 +19,7 @@ def forward_substitute (_:VSpace v) ?=> (a:LowerTriMat n Float) (b:n=>v) : n=>v a.i.((ordinal k)@_) .* get sRef!(%inject k) sRef!i := (b.i - s) / a.i.((ordinal i)@_) -def backward_substitute (_:VSpace v) ?=> (a:UpperTriMat n Float) (b:n=>v) : n=>v = +def backward_substitute [VSpace v] (a:UpperTriMat n Float) (b:n=>v) : n=>v = -- Solves upper triangular linear system (inverse a) **. b yieldState zero \sRef. rof i:n. @@ -61,7 +61,7 @@ def permSign ((_, sign):Permutation n) : PermutationSign = sign '### LU decomposition functions -def pivotize (_:Eq n) ?=> (a:n=>n=>Float) : Permutation n = +def pivotize [Eq n] (a:n=>n=>Float) : Permutation n = -- Gives a row permutation that makes Gaussian elimination more stable. yieldState identity_permutation \permRef. for j:n. @@ -71,7 +71,7 @@ def pivotize (_:Eq n) ?=> (a:n=>n=>Float) : Permutation n = True -> () False -> swapInPlace permRef j row_with_largest -def lu (_:Eq n) ?=> (a: n=>n=>Float) : +def lu [Eq n] (a: n=>n=>Float) : (LowerTriMat n Float & UpperTriMat n Float & Permutation n) = -- Computes lower, upper, and permuntation matrices from a square matrix, -- such that apply_permutation permutation a == lower ** upper. @@ -113,10 +113,10 @@ def lu (_:Eq n) ?=> (a: n=>n=>Float) : ukj = get (upperTriIndex uRef k')!(((ordinal j) - (ordinal k))@_) lik = get (lowerTriIndex lRef i')!((ordinal k)@_) ukj * lik - + uijRef = (upperTriIndex uRef i')!(((ordinal j) - (ordinal i))@_) uijRef := a.(%inject i).j - s - + for i:(j<..). i' = %inject i s = sum for k:(..j). @@ -125,7 +125,7 @@ def lu (_:Eq n) ?=> (a: n=>n=>Float) : ukj = get (upperTriIndex uRef k')!i'' lik = get (lowerTriIndex lRef i')!((ordinal k)@_) ukj * lik - + i'' = ((ordinal i) + (ordinal j) + 1)@_ ujj = get (upperTriIndex uRef j)!(0@_) lijRef = (lowerTriIndex lRef i'')!((ordinal j)@_) @@ -135,7 +135,7 @@ def lu (_:Eq n) ?=> (a: n=>n=>Float) : '### General linear algebra functions. -def solve (_:Eq n) ?=> (_:VSpace v) ?=> (a:n=>n=>Float) (b:n=>v) : n=>v = +def solve [Eq n, VSpace v] (a:n=>n=>Float) (b:n=>v) : n=>v = -- There's a small speedup possible by exploiting the fact -- that l always has ones on the diagonal. It would just require a -- custom forward_substitute routine that doesn't divide @@ -145,18 +145,18 @@ def solve (_:Eq n) ?=> (_:VSpace v) ?=> (a:n=>n=>Float) (b:n=>v) : n=>v = y = forward_substitute l b' backward_substitute u y -def invert (_:Eq n) ?=> (a:n=>n=>Float) : n=>n=>Float = +def invert [Eq n] (a:n=>n=>Float) : n=>n=>Float = solve a identity_matrix -def determinant (_:Eq n) ?=> (a:n=>n=>Float) : Float = +def determinant [Eq n] (a:n=>n=>Float) : Float = (l, u, perm) = lu a prod (for i. (upperTriDiag u).i * (lowerTriDiag l).i) * permSign perm -def sign_and_log_determinant (_:Eq n) ?=> (a:n=>n=>Float) : (Float & Float) = +def sign_and_log_determinant [Eq n] (a:n=>n=>Float) : (Float & Float) = (l, u, perm) = lu a diags = for i. (upperTriDiag u).i * (lowerTriDiag l).i sign = (permSign perm) * prod for i. sign diags.i - sum_of_log_abs = sum for i. log (abs diags.i) + sum_of_log_abs = sum for i. log (abs diags.i) (sign, sum_of_log_abs) diff --git a/examples/mcmc.dx b/examples/mcmc.dx index 1ba161c85..a3bcbd314 100644 --- a/examples/mcmc.dx +++ b/examples/mcmc.dx @@ -55,7 +55,7 @@ def mhStep HMCParams : Type = (Int & Float) -- leapfrog steps, step size def leapfrogIntegrate - (_:VSpace a) ?=> + [VSpace a] ((nsteps, dt): HMCParams) (logProb: a -> LogProb) ((x, p): (a & a)) diff --git a/examples/ode-integrator.dx b/examples/ode-integrator.dx index f553fca91..53e568d5a 100644 --- a/examples/ode-integrator.dx +++ b/examples/ode-integrator.dx @@ -12,7 +12,7 @@ Time = Float def length (x: d=>Float) : Float = sqrt $ sum for i. sq x.i def (./) (x: d=>Float) (y: d=>Float) : d=>Float = for i. x.i / y.i -def fit_4th_order_polynomial (_:VSpace v) ?=> +def fit_4th_order_polynomial [VSpace v] (z0:v) (z1:v) (z_mid:v) (dz0:v) (dz1:v) (dt:Time) : (Fin 5)=>v = -- dz0 and dz1 are gradient evaluations. a = -2. * dt .* dz0 + 2. * dt .* dz1 - 8. .* z0 - 8. .* z1 + 16. .* z_mid @@ -26,7 +26,7 @@ dps_c_mid = [6025192743. /30085553152. /2., 0., 51252292925. /65400821598. /2., -2691868925. /45128329728. /2., 187940372067. /1594534317056. /2., -1776094331. /19743644256. /2., 11237099. /235043384. /2.] -def interp_fit_dopri (_:VSpace v) ?=> +def interp_fit_dopri [VSpace v] (z0:v) (z1:v) (k:(Fin 7)=>v) (dt:Time) : (Fin 5)=>v = -- Fit a polynomial to the results of a Runge-Kutta step. z_mid = z0 + dt .* (dot dps_c_mid k) @@ -64,7 +64,7 @@ c_error = [35. / 384. - 1951. / 21600., 0., 500. / 1113. - 22642. / 50085., 125. / 192. - 451. / 720., -2187. / 6784. + 12231. / 42400., 11. / 84. - 649. / 6300., -1. / 60.] -def runge_kutta_step (_:VSpace v) ?=> (func:v->Time->v) +def runge_kutta_step [VSpace v] (func:v->Time->v) (z0:v) (f0:v) (t0:Time) (dt:Time) : (v & v & v & (Fin 7)=>v) = evals_init = yieldState zero \r. diff --git a/examples/raytrace.dx b/examples/raytrace.dx index fc44b7054..051722f56 100644 --- a/examples/raytrace.dx +++ b/examples/raytrace.dx @@ -24,7 +24,7 @@ def directionAndLength (x: d=>Float) : (d=>Float & Float) = def randuniform (lower:Float) (upper:Float) (k:Key) : Float = lower + (rand k) * (upper - lower) -def sampleAveraged (_:VSpace a) ?=> (sample:Key -> a) (n:Int) (k:Key) : a = +def sampleAveraged [VSpace a] (sample:Key -> a) (n:Int) (k:Key) : a = yieldState zero \total. for i:(Fin n). total := get total + sample (ixkey k i) / IToF n diff --git a/examples/sgd.dx b/examples/sgd.dx index 3e5a5575a..bc1a0cb29 100644 --- a/examples/sgd.dx +++ b/examples/sgd.dx @@ -1,14 +1,14 @@ '## Stochastic Gradient Descent with Momentum -def sgd_step (dict: VSpace a) ?=> (step_size: Float) (decay: Float) (gradfunc: a -> Int -> a) (x: a) (m: a) (iter:Int) : (a & a) = +def sgd_step [VSpace a] (step_size: Float) (decay: Float) (gradfunc: a -> Int -> a) (x: a) (m: a) (iter:Int) : (a & a) = g = gradfunc x iter new_m = decay .* m + g new_x = x - step_size .* new_m (new_x, new_m) -- In-place optimization loop. -def sgd (dict: VSpace a) ?=> (step_size:Float) (decay:Float) (num_steps:Int) (gradient: a -> Int -> a) (x0: a) : a = +def sgd [VSpace a] (step_size:Float) (decay:Float) (num_steps:Int) (gradient: a -> Int -> a) (x0: a) : a = m0 = zero (x_final, m_final) = yieldState (x0, m0) \state. for i:(Fin num_steps). diff --git a/lib/diagram.dx b/lib/diagram.dx index 4e91ddb9d..98ae4e60e 100644 --- a/lib/diagram.dx +++ b/lib/diagram.dx @@ -112,7 +112,7 @@ def quote (s:String) : String = "\"" <.> s <.> "\"" def strSpaceCatUncurried ((s1,s2):(String & String)) : String = s1 <.> " " <.> s2 -def (<+>) (_:Show a) ?=> (_:Show b) ?=> (s1:a) (s2:b) : String = +def (<+>) [Show a, Show b] (s1:a) (s2:b) : String = strSpaceCatUncurried ((show s1), (show s2)) def selfClosingBrackets (s:String) : String = "<" <.> s <.> "/>" @@ -127,7 +127,7 @@ def tagBracketsAttrUncurried ((tag, attr, s):(String & String & String)) : Strin def tagBracketsAttr (tag:String) (attr:String) (s:String) : String = tagBracketsAttrUncurried (tag, attr, s) -def (<=>) (_:Show b) ?=> (attr:String) (val:b) : String = +def (<=>) [Show b] (attr:String) (val:b) : String = attr <.> "=" <.> quote (show val) def htmlColor(cs:HtmlColor) : String = diff --git a/lib/plot.dx b/lib/plot.dx index 0212ad537..4529435fb 100644 --- a/lib/plot.dx +++ b/lib/plot.dx @@ -49,7 +49,7 @@ def getScaled (sd:ScaledData n a) (i:n) : Maybe Float = lowColor = [1.0, 0.5, 0.0] highColor = [0.0, 0.5, 1.0] -def interpolate (_:VSpace a) ?=> (low:a) (high:a) (x:Float) : a = +def interpolate [VSpace a] (low:a) (high:a) (x:Float) : a = x' = clip (0.0, 1.0) x (x' .* low) + ((1.0 - x') .* high) diff --git a/lib/png.dx b/lib/png.dx index 131f7c609..b542449d5 100644 --- a/lib/png.dx +++ b/lib/png.dx @@ -72,7 +72,7 @@ def decodeChunk (chunk : Fin 4 => Char) : Maybe (Fin 3 => Char) = Just base64s -> Just $ base64sToBytes base64s -- TODO: put this in prelude? -def replace (_:Eq a) ?=> ((old,new):(a&a)) (x:a) : a = +def replace [Eq a] ((old,new):(a&a)) (x:a) : a = case x == old of True -> new False -> x diff --git a/lib/prelude.dx b/lib/prelude.dx index 3e00ac2e0..54f7aefef 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -45,8 +45,8 @@ interface Add a:Type where sub : a -> a -> a zero : a -def (+) (d:Add a) ?=> : a -> a -> a = add -def (-) (d:Add a) ?=> : a -> a -> a = sub +def (+) [Add a] : a -> a -> a = add +def (-) [Add a] : a -> a -> a = sub instance float64Add : Add Float64 where add = \x:Float64 y:Float64. %fadd x y @@ -87,7 +87,7 @@ interface Mul a:Type where mul : a -> a -> a one : a -def (*) (d:Mul a) ?=> : a -> a -> a = mul +def (*) [Mul a] : a -> a -> a = mul instance float64Mul : Mul Float64 where mul = \x:Float64 y:Float64. %fmul x y @@ -162,10 +162,10 @@ data VSpace a:Type = MkVSpace (Add a) (Float -> a -> a) @superclass def addFromVSpace (d:VSpace a) : Add a = case d of MkVSpace addDict _ -> addDict -def (.*) (d:VSpace a) ?=> : Float -> a -> a = case d of MkVSpace _ scale -> scale -(*.) : VSpace a ?=> a -> Float -> a = flip (.*) -def (/) (_:VSpace a) ?=> (v:a) (s:Float) : a = (divide 1.0 s) .* v -def neg (_:VSpace a) ?=> (v:a) : a = (-1.0) .* v +def (.*) (d:VSpace a) ?=> : Float -> a -> a = case d of MkVSpace _ scale -> scale +def (*.) [VSpace a] : a -> Float -> a = flip (.*) +def (/) [VSpace a] (v:a) (s:Float) : a = divide 1.0 s .* v +def neg [VSpace a] (v:a) : a = (-1.0) .* v @instance floatVS : VSpace Float = MkVSpace float32Add (*) @instance tabVS : VSpace a ?=> VSpace (n=>a) = MkVSpace tabAdd \s xs. for i. s .* xs.i @@ -292,12 +292,12 @@ data Ord a:Type = MkOrd (Eq a) (a -> a -> Bool) (a -> a -> Bool) -- eq, gt, lt def eqFromOrd (d:Ord a) : Eq a = case d of MkOrd eq _ _ -> eq def (==) (d:Eq a) ?=> (x:a) (y:a) : Bool = case d of MkEq eq -> eq x y -def (/=) (d:Eq a) ?=> (x:a) (y:a) : Bool = not $ x == y +def (/=) [Eq a] (x:a) (y:a) : Bool = not $ x == y def (>) (d:Ord a) ?=> (x:a) (y:a) : Bool = case d of MkOrd _ gt _ -> gt x y def (<) (d:Ord a) ?=> (x:a) (y:a) : Bool = case d of MkOrd _ _ lt -> lt x y -def (<=) (d:Ord a) ?=> (x:a) (y:a) : Bool = x=) (d:Ord a) ?=> (x:a) (y:a) : Bool = x>y || x==y +def (<=) [Ord a] (x:a) (y:a) : Bool = x=) [Ord a] (x:a) (y:a) : Bool = x>y || x==y @instance float64Eq : Eq Float64 = MkEq \x:Float64 y:Float64. W8ToB $ %feq x y @instance float32Eq : Eq Float32 = MkEq \x:Float32 y:Float32. W8ToB $ %feq x y @@ -321,19 +321,18 @@ def (>=) (d:Ord a) ?=> (x:a) (y:a) : Bool = x>y || x==y @instance unitOrd : Ord Unit = (MkOrd unitEq (\x y. False) (\x y. False)) @instance -def pairEq (eqA: Eq a)?=> (eqB: Eq b)?=> : Eq (a & b) = MkEq $ +def pairEq [Eq a, Eq b] : Eq (a & b) = MkEq $ \(x1,x2) (y1,y2). x1 == y1 && x2 == y2 @instance -def pairOrd (ordA: Ord a)?=> (ordB: Ord b)?=> : Ord (a & b) = +def pairOrd [Ord a, Ord b] : Ord (a & b) = pairGt = \(x1,x2) (y1,y2). x1 > y1 || (x1 == y1 && x2 > y2) pairLt = \(x1,x2) (y1,y2). x1 < y1 || (x1 == y1 && x2 < y2) MkOrd pairEq pairGt pairLt - -- TODO: accumulate using the True/&& monoid @instance -def tabEq (n:Type) ?-> (eqA: Eq a) ?=> : Eq (n=>a) = MkEq $ +def tabEq [Eq a] : Eq (n=>a) = MkEq $ \xs ys. numDifferent : Float = yieldAccum \ref. for i. @@ -362,7 +361,7 @@ interface Floating a:Type where pow : a -> a -> a lgamma : a -> a -def lbeta (_ : Add a) ?=> (_ : Floating a) ?=> : a -> a -> a = \x y. lgamma x + lgamma y - lgamma (x + y) +def lbeta [Add a, Floating a] : a -> a -> a = \x y. lgamma x + lgamma y - lgamma (x + y) -- Todo: better numerics for very large and small values. -- Using %exp here to avoid circular definition problems. @@ -468,28 +467,28 @@ instance int32Storable : Storable Int32 where load = int32Load storageSize = const 4 -def unpackPairPtr (_:Storable a) ?=> (_:Storable b) ?=> +def unpackPairPtr [Storable a, Storable b] (pairPtr: Ptr (a & b)) : (Ptr a & Ptr b) = (MkPtr rawPtrX) = pairPtr rawPtrY = %ptrOffset rawPtrX (storageSize (typeVehicle a)) (MkPtr rawPtrX, MkPtr rawPtrY) -def pairStore (_:Storable a) ?=> (_:Storable b) ?=> +def pairStore [Storable a, Storable b] (pairPtr:Ptr (a & b)) ((x, y):(a & b)) : {State World} Unit = (xPtr, yPtr) = unpackPairPtr pairPtr store xPtr x store yPtr y -def pairLoad (_:Storable a) ?=> (_:Storable b) ?=> +def pairLoad [Storable a, Storable b] (pairPtr:Ptr (a & b)) : {State World} (a & b) = (xPtr, yPtr) = unpackPairPtr pairPtr (load xPtr, load yPtr) -def pairStorageSize (_:Storable a) ?=> (_:Storable b) ?=> +def pairStorageSize [Storable a, Storable b] (_:TypeVehicle (a & b)) : Int = storageSize (typeVehicle a) + storageSize (typeVehicle b) -instance pairStorable : Storable a ?=> Storable b ?=> Storable (a & b) where +instance pairStorable : (Storable a) ?=> (Storable b) ?=> Storable (a & b) where store = pairStore load = pairLoad storageSize = pairStorageSize @@ -508,7 +507,7 @@ instance ptrStorable : Storable (Ptr a) where -- TODO: Storable instances for other types -def malloc (_:Storable a) ?=> (n:Int) : {State World} (Ptr a) = +def malloc [Storable a] (n:Int) : {State World} (Ptr a) = numBytes = storageSize (typeVehicle a) * n MkPtr $ %alloc numBytes @@ -516,7 +515,7 @@ def free (ptr:Ptr a) : {State World} Unit = (MkPtr ptr') = ptr %free ptr' -def (+>>) (_:Storable a) ?=> (ptr:Ptr a) (i:Int) : Ptr a = +def (+>>) [Storable a] (ptr:Ptr a) (i:Int) : Ptr a = (MkPtr ptr') = ptr i' = i * storageSize (typeVehicle a) MkPtr $ %ptrOffset ptr' i' @@ -524,28 +523,28 @@ def (+>>) (_:Storable a) ?=> (ptr:Ptr a) (i:Int) : Ptr a = -- TODO: generalize these brackets to allow other effects -- TODO: consider making a Storable instance for tables instead -def storeTab (_:Storable a) ?=> (ptr: Ptr a) (tab:n=>a) : {State World} Unit = +def storeTab [Storable a] (ptr: Ptr a) (tab:n=>a) : {State World} Unit = for_ i. store (ptr +>> ordinal i) tab.i -def memcpy (_:Storable a) ?=> (dest:Ptr a) (src:Ptr a) (n:Int) : {State World} Unit = +def memcpy [Storable a] (dest:Ptr a) (src:Ptr a) (n:Int) : {State World} Unit = for_ i:(Fin n). i' = ordinal i store (dest +>> i') (load $ src +>> i') -def withAlloc (_:Storable a) ?=> +def withAlloc [Storable a] (n:Int) (action: Ptr a -> {State World} b) : {State World} b = ptr = malloc n result = action ptr free ptr result -def withTabPtr (_:Storable a) ?=> +def withTabPtr [Storable a] (xs:n=>a) (action : Ptr a -> {State World} b) : {State World} b = withAlloc (size n) \ptr. for i. store (ptr +>> ordinal i) xs.i action ptr -def tabFromPtr (_:Storable a) ?=> (n:Type) -> (ptr:Ptr a) : {State World} n=>a = +def tabFromPtr [Storable a] (n:Type) -> (ptr:Ptr a) : {State World} n=>a = for i. load $ ptr +>> ordinal i '## Miscellaneous common utilities @@ -558,8 +557,8 @@ def map (f:a->{|eff} b) (xs: n=>a) : {|eff} (n=>b) = for i. f xs.i def zip (xs:n=>a) (ys:n=>b) : (n=>(a&b)) = view i. (xs.i, ys.i) def unzip (xys:n=>(a&b)) : (n=>a & n=>b) = (map fst xys, map snd xys) def fanout (n:Type) (x:a) : n=>a = view i. x -def sq (d:Mul a) ?=> (x:a) : a = x * x -def abs (_:Add a) ?=> (_:Ord a) ?=> (x:a) : a = select (x > zero) x (zero - x) +def sq [Mul a] (x:a) : a = x * x +def abs [Add a, Ord a] (x:a) : a = select (x > zero) x (zero - x) def mod (x:Int) (y:Int) : Int = rem (y + rem x y) y def reindex (ixr: b -> a) (tab: a=>v) : b=>v = for i. tab.(ixr i) @@ -582,9 +581,9 @@ def reduce (identity:a) (combine:(a->a->a)) (xs:n=>a) : a = def scan' (init:a) (body:n->a->a) : n=>a = snd $ scan init \i x. dup (body i x) -- TODO: allow tables-via-lambda and get rid of this def fsum (xs:n=>Float) : Float = yieldAccum \ref. for i. ref += xs i -def sum (_: Add v) ?=> (xs:n=>v) : v = reduce zero (+) xs -def prod (_: Mul v) ?=> (xs:n=>v) : v = reduce one (*) xs -def mean (n:Type) ?-> (xs:n=>Float) : Float = sum xs / IToF (size n) +def sum [Add v] (xs:n=>v) : v = reduce zero (+) xs +def prod [Mul v] (xs:n=>v) : v = reduce one (*) xs +def mean (xs:n=>Float) : Float = sum xs / IToF (size n) def std (xs:n=>Float) : Float = sqrt $ mean (map sq xs) - sq (mean xs) def any (xs:n=>Bool) : Bool = reduce False (||) xs def all (xs:n=>Bool) : Bool = reduce True (&&) xs @@ -599,7 +598,7 @@ def linspace (n:Type) (low:Float) (high:Float) : n=>Float = def transpose (x:n=>m=>a) : m=>n=>a = view i j. x.j.i def vdot (x:n=>Float) (y:n=>Float) : Float = fsum view i. x.i * y.i -def dot (_:VSpace v) ?=> (s:n=>Float) (vs:n=>v) : v = sum for j. s.j .* vs.j +def dot [VSpace v] (s:n=>Float) (vs:n=>v) : v = sum for j. s.j .* vs.j -- matmul. Better symbol to use? `@`? (**) : (l=>m=>Float) -> (m=>n=>Float) -> (l=>n=>Float) = \x y. @@ -611,7 +610,7 @@ def dot (_:VSpace v) ?=> (s:n=>Float) (vs:n=>v) : v = sum for j. s.j .* vs.j def inner (x:n=>Float) (mat:n=>m=>Float) (y:m=>Float) : Float = fsum view (i,j). x.i * mat.i.j * y.j -def eye (_:Eq n) ?=> : n=>n=>Float = +def eye [Eq n] : n=>n=>Float = for i j. select (i == j) 1.0 0.0 '## Pseudorandom number generator utilities @@ -645,7 +644,7 @@ def randInt (k:Key) : Int = (I64ToI k) `mod` 2147483647 def bern (p:Float) (k:Key) : Bool = rand k < p -def randnVec (n:Type) ?-> (k:Key) : n=>Float = +def randnVec (k:Key) : n=>Float = for i. randn (ixkey k i) def cumSum (xs: n=>Float) : n=>Float = @@ -679,7 +678,7 @@ interface HasDefaultTolerance a:Type where atol : a rtol : a -def (~~) (_:HasAllClose a) ?=> (d:HasDefaultTolerance a) ?=> : a -> a -> Bool = allclose atol rtol +def (~~) [HasAllClose a, HasDefaultTolerance a] : a -> a -> Bool = allclose atol rtol instance allCloseF32 : HasAllClose Float32 where allclose = \atol rtol x y. abs (x - y) <= (atol + rtol * abs y) @@ -758,14 +757,12 @@ def Tile (n : Type) (m : Type) : Type = %IndexSlice n m -- elements of n. In this view (+>) is just function application, while ++> -- is currying followed by function application. We cannot represent currying -- in isolation, because `Tile n (Tile u v)` does not make sense, unlike `Tile n (u & v)`. -def (+>) (l : Type) ?-> (t:Tile n l) (i : l) : n = %sliceOffset t i +def (+>) (t:Tile n l) (i : l) : n = %sliceOffset t i def (++>) (t : Tile n (u & v)) (i : u) : Tile n v = %sliceCurry t i -def tile (l : Type) ?-> - (fTile : (t:(Tile n l) -> {|eff} l=>a)) +def tile (fTile : (t:(Tile n l) -> {|eff} l=>a)) (fScalar : n -> {|eff} a) : {|eff} n=>a = %tiled fTile fScalar -def tile1 (n : Type) ?-> (l : Type) ?-> (m : Type) ?-> - (fTile : (t:(Tile n l) -> {|eff} m=>l=>a)) +def tile1 (fTile : (t:(Tile n l) -> {|eff} m=>l=>a)) (fScalar : n -> {|eff} m=>a) : {|eff} m=>n=>a = %tiledd fTile fScalar -- TODO: This should become just `loadVector $ for i. arr.(t +> i)` @@ -783,7 +780,7 @@ interface Monoid a:Type where mempty : a mcombine : a -> a -> a -- can't use `<>` just for parser reasons? -(<>) : Monoid a ?=> a -> a -> a = mcombine +def (<>) [Monoid a] : a -> a -> a = mcombine '## Length-erased lists @@ -793,7 +790,7 @@ data List a:Type = def unsafeCastTable (m:Type) (xs:n=>a) : m=>a = for i. xs.(unsafeFromOrdinal _ (ordinal i)) -def toList (n:Type) ?-> (xs:n=>a) : List a = +def toList (xs:n=>a) : List a = n' = size n AsList _ $ unsafeCastTable (Fin n') xs @@ -895,7 +892,7 @@ def sliceFields (iso: Iso ({|} | a) (b | c)) (tab: a=>v) : b=>v = -- TODO: would be nice to be able to use records here data DynBuffer a:Type = MkDynBuffer (Ptr (Int & Int & Ptr a)) -- size, max size, buf ptr -def withDynamicBuffer (_:Storable a) ?=> +def withDynamicBuffer [Storable a] (action: DynBuffer a -> {State World} b) : {State World} b = initMaxSize = 256 withAlloc 1 \dbPtr. @@ -906,7 +903,7 @@ def withDynamicBuffer (_:Storable a) ?=> free bufPtr' result -def maybeIncreaseBufferSize (_:Storable a) ?=> +def maybeIncreaseBufferSize [Storable a] (buf: DynBuffer a) (sizeDelta:Int) : {State World} Unit = (MkDynBuffer dbPtr) = buf (size, maxSize, bufPtr) = load dbPtr @@ -918,7 +915,7 @@ def maybeIncreaseBufferSize (_:Storable a) ?=> memcpy newBufPtr bufPtr size store dbPtr (size, newMaxSize, newBufPtr) -def extendDynBuffer (_:Storable a) ?=> +def extendDynBuffer [Storable a] (buf: DynBuffer a) (new:List a) : {State World} Unit = (AsList n xs) = new maybeIncreaseBufferSize buf n @@ -928,13 +925,13 @@ def extendDynBuffer (_:Storable a) ?=> storeTab (bufPtr +>> size) xs store dbPtr (newSize, maxSize, bufPtr) -def loadDynBuffer (_:Storable a) ?=> +def loadDynBuffer [Storable a] (buf: DynBuffer a) : {State World} (List a) = (MkDynBuffer dbPtr) = buf (size, _, bufPtr) = load dbPtr AsList size $ tabFromPtr _ bufPtr -def pushDynBuffer (_:Storable a) ?=> +def pushDynBuffer [Storable a] (buf: DynBuffer a) (x:a) : {State World} Unit = extendDynBuffer buf $ AsList _ [x] @@ -1194,7 +1191,7 @@ def error (s:String) : a = unsafeIO do print s %throwError a -def todo (a:Type) ?-> : a = error "TODO: implement it!" +def todo : a = error "TODO: implement it!" def fromOrdinal (n:Type) (i:Int) : n = case (0 <= i) && (i < size n) of @@ -1210,7 +1207,7 @@ def castTable (m:Type) (xs:n=>a) : m=>a = False -> error $ "Table size mismatch in cast: " <> show (size m) <> " vs " <> show (size n) -def asidx (n:Type) ?-> (i:Int) : n = fromOrdinal n i +def asidx (i:Int) : n = fromOrdinal n i def (@) (i:Int) (n:Type) : n = fromOrdinal n i def slice (xs:n=>a) (start:Int) (m:Type) : m=>a = @@ -1218,11 +1215,11 @@ def slice (xs:n=>a) (start:Int) (m:Type) : m=>a = def head (xs:n=>a) : a = xs.(0@_) -def tail (n:Type) ?-> (xs:n=>a) (start:Int) : List a = +def tail (xs:n=>a) (start:Int) : List a = numElts = size n - start toList $ slice xs start (Fin numElts) -def randIdx (n:Type) ?-> (k:Key) : n = +def randIdx (k:Key) : n = unif = rand k fromOrdinal n $ FToI $ floor $ unif * IToF (size n) @@ -1246,7 +1243,7 @@ instance finArb : n:Int ?-> Arbitrary (Fin n) where 'Control flow -- returns the highest index `i` such that `xs.i <= x` -def searchSorted (_:Ord a) ?=> (xs:n=>a) (x:a) : Maybe n = +def searchSorted [Ord a] (xs:n=>a) (x:a) : Maybe n = if size n == 0 then Nothing else if x < xs.(fromOrdinal _ 0) @@ -1264,28 +1261,28 @@ def searchSorted (_:Ord a) ?=> (xs:n=>a) (x:a) : Maybe n = 'min / max etc -def minBy (_:Ord o) ?=> (f:a->o) (x:a) (y:a) : a = select (f x < f y) x y -def maxBy (_:Ord o) ?=> (f:a->o) (x:a) (y:a) : a = select (f x > f y) x y +def minBy [Ord o] (f:a->o) (x:a) (y:a) : a = select (f x < f y) x y +def maxBy [Ord o] (f:a->o) (x:a) (y:a) : a = select (f x > f y) x y -def min (_:Ord o) ?=> (x1: o) -> (x2: o) : o = minBy id x1 x2 -def max (_:Ord o) ?=> (x1: o) -> (x2: o) : o = maxBy id x1 x2 +def min [Ord o] (x1: o) -> (x2: o) : o = minBy id x1 x2 +def max [Ord o] (x1: o) -> (x2: o) : o = maxBy id x1 x2 -def minimumBy (_:Ord o) ?=> (f:a->o) (xs:n=>a) : a = +def minimumBy [Ord o] (f:a->o) (xs:n=>a) : a = reduce xs.(0@_) (minBy f) xs -def maximumBy (_:Ord o) ?=> (f:a->o) (xs:n=>a) : a = +def maximumBy [Ord o] (f:a->o) (xs:n=>a) : a = reduce xs.(0@_) (maxBy f) xs -def minimum (_:Ord o) ?=> (xs:n=>o) : o = minimumBy id xs -def maximum (_:Ord o) ?=> (xs:n=>o) : o = maximumBy id xs +def minimum [Ord o] (xs:n=>o) : o = minimumBy id xs +def maximum [Ord o] (xs:n=>o) : o = maximumBy id xs -def argmin (_:Ord o) ?=> (xs:n=>o) : n = +def argmin [Ord o] (xs:n=>o) : n = zeroth = (0@_, xs.(0@_)) compare = \(idx1, x1) (idx2, x2). select (x1 < x2) (idx1, x1) (idx2, x2) zipped = for i. (i, xs.i) fst $ reduce zeroth compare zipped -def clip (_:Ord a) ?=> ((low,high):(a&a)) (x:a) : a = +def clip [Ord a] ((low,high):(a&a)) (x:a) : a = min high $ max low x '## Trigonometric functions @@ -1307,7 +1304,7 @@ def atan_inner (x:Float) : Float = r = r * s r * x + x -def min_and_max (_: Ord a) ?=> (x:a) (y:a) : (a & a) = +def min_and_max [Ord a] (x:a) (y:a) : (a & a) = select (x < y) (x, y) (y, x) -- get both with one comparison. def atan2 (y:Float) (x:Float) : Float = @@ -1461,7 +1458,7 @@ def reverse (x:n=>a) : n=>a = s = size n for i. x.((s - 1 - ordinal i)@_) -def padTo (n:Type) ?-> (m:Type) (x:a) (xs:n=>a) : (m=>a) = +def padTo (m:Type) (x:a) (xs:n=>a) : (m=>a) = n' = size n for i. i' = ordinal i @@ -1483,7 +1480,7 @@ def seqMaybes (n:Type) ?-> (a:Type) ?-> (xs : n=>Maybe a) : Maybe (n => a) = True -> Nothing False -> Just $ map fromJust xs -def linearSearch (_:Eq a) ?=> (xs:n=>a) (query:a) : Maybe n = +def linearSearch [Eq a] (xs:n=>a) (query:a) : Maybe n = yieldState Nothing \ref. for i. case xs.i == query of True -> ref := Just i @@ -1555,7 +1552,7 @@ def softmax (x: n=>Float) : n=>Float = s = sum e for i. e.i / s -def evalpoly (_:VSpace v) ?=> (coefficients:n=>v) (x:Float) : v = +def evalpoly [VSpace v] (coefficients:n=>v) (x:Float) : v = -- Evaluate a polynomial at x. Same as Numpy's polyval. fold zero \i c. coefficients.i + x .* c diff --git a/src/lib/Parser.hs b/src/lib/Parser.hs index 3129b4c0e..6eab3f678 100644 --- a/src/lib/Parser.hs +++ b/src/lib/Parser.hs @@ -258,7 +258,7 @@ findImplicitImplicitArgNames typ = filter isLowerCaseName $ envNames $ UFor _ _ _ -> error "Unexpected for in type annotation" UHole -> mempty UTypeAnn v ty -> findVarsInAppLHS v <> findVarsInAppLHS ty - UTabCon _ -> error "Unexpected table in type annotation" + UTabCon _ -> mempty UIndexRange low high -> foldMap findVarsInAppLHS low <> foldMap findVarsInAppLHS high UPrimExpr prim -> foldMap findVarsInAppLHS prim @@ -329,7 +329,7 @@ interfaceDef = do ns $ UApp (PlainArrow ()) func (var typeVarName) recordStr = "recordVar" recordPat = ns $ UPatRecord $ Ext (labeledSingleton fLabel (patb - fLabel)) $ Just (ns (UPatBinder (Ignore ()))) + fLabel)) $ Just underscorePat conPat = ns $ UPatCon (mkInterfaceConsName interfaceName) $ toNest [patb recordStr] @@ -430,20 +430,31 @@ funDefLet :: Parser (UExpr -> UDecl) funDefLet = label "function definition" $ mayBreak $ do keyWord DefKW v <- letPat - bs <- many arg + cs <- defClassConstraints + argBinders <- many arg (eff, ty) <- label "result type annotation" $ annot effectiveType - when (null bs && eff /= Pure) $ fail "Nullary def can't have effects" + when (null argBinders && eff /= Pure) $ fail "Nullary def can't have effects" + let bs = map classAsBinder cs ++ argBinders let funTy = buildPiType bs eff ty let letBinder = (v, Just funTy) let lamBinders = flip map bs $ \(p,_, arr) -> ((p,Nothing), arr) return $ \body -> ULet PlainLet letBinder (buildLam lamBinders body) where + classAsBinder :: UType -> (UPat, UType, UArrow) + classAsBinder ty = (underscorePat, ty, ClassArrow) + arg :: Parser (UPat, UType, UArrow) arg = label "def arg" $ do (p, ty) <-parens ((,) <$> pat <*> annot uType) arr <- arrow (return ()) <|> return (PlainArrow ()) return (p, ty, arr) +defClassConstraints :: Parser [UType] +defClassConstraints = + (brackets $ mayNotPair $ uType `sepBy` sym ",") + <|> return [] + "class constraints" + nameAsPat :: Parser Name -> Parser UPat nameAsPat p = withSrc $ (UPatBinder . Bind . (:>())) <$> p @@ -520,11 +531,12 @@ uForExpr = do <|> (keyWord Rof_KW $> (Rev, True )) e <- buildFor pos dir <$> (some patAnn <* argTerm) <*> blockOrExpr if trailingUnit - then return $ noSrc $ UDecl (ULet PlainLet underscorePat e) $ noSrc unitExpr + then return $ noSrc $ UDecl (ULet PlainLet (underscorePat, Nothing) e) $ + noSrc unitExpr else return e - where - underscorePat :: UPatAnn - underscorePat = (noSrc $ UPatBinder $ Ignore (), Nothing) + +underscorePat :: UPat +underscorePat = noSrc $ UPatBinder $ Ignore () unitExpr :: UExpr' unitExpr = UPrimExpr $ ConExpr UnitCon @@ -558,7 +570,7 @@ wrapUStatements statements = case statements of (s, pos):rest -> WithSrc (Just pos) $ case s of Left d -> UDecl d $ wrapUStatements rest Right e -> UDecl d $ wrapUStatements rest - where d = ULet PlainLet (WithSrc (Just pos) (UPatBinder (Ignore ())), Nothing) e + where d = ULet PlainLet (underscorePat, Nothing) e [] -> error "Shouldn't be reachable" uStatement :: Parser UStatement From 599e2ee47c53b355954012c96e9515d3cbdc863c Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Mon, 4 Jan 2021 14:00:37 -0500 Subject: [PATCH 072/105] Overhaul new-style interface/instance decls. * Handle superclasses * Remove the need to name instances explicitly * Push down types from interface definitions into instance methods * Improve error messages for missing/duplicated methods As par of this change, I moved the lowering (turning interface/instance decls into data defs and method/super class getters) from the parser to type inference where we have much more context about existing definitions. Fixes #370. --- lib/diagram.dx | 2 +- lib/prelude.dx | 476 +++++++++++++++++++-------------------- src/lib/Embed.hs | 28 ++- src/lib/Env.hs | 2 + src/lib/Imp.hs | 1 + src/lib/Inference.hs | 157 +++++++++++-- src/lib/PPrint.hs | 47 +--- src/lib/Parser.hs | 286 ++++++++++------------- src/lib/Syntax.hs | 57 ++++- src/lib/Type.hs | 9 +- tests/adt-tests.dx | 14 +- tests/io-tests.dx | 2 +- tests/typeclass-tests.dx | 61 ++--- 13 files changed, 619 insertions(+), 523 deletions(-) diff --git a/lib/diagram.dx b/lib/diagram.dx index 98ae4e60e..a05fc1cb3 100644 --- a/lib/diagram.dx +++ b/lib/diagram.dx @@ -36,7 +36,7 @@ defaultGeomStyle : GeomStyle = -- TODO: consider sharing attributes among a set of objects for efficiency data Diagram = MkDiagram (List (GeomStyle & Point & Geom)) -instance monoidDiagram : Monoid Diagram where +instance Monoid Diagram mempty = MkDiagram mempty mcombine = \(MkDiagram d1) (MkDiagram d2). MkDiagram $ d1 <> d2 diff --git a/lib/prelude.dx b/lib/prelude.dx index 54f7aefef..1831ef83e 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -40,7 +40,7 @@ def FToI (x:Float) : Int = internalCast _ x def I64ToRawPtr (x:Int64 ) : RawPtr = internalCast _ x def RawPtrToI64 (x:RawPtr) : Int64 = internalCast _ x -interface Add a:Type where +interface Add a add : a -> a -> a sub : a -> a -> a zero : a @@ -48,97 +48,97 @@ interface Add a:Type where def (+) [Add a] : a -> a -> a = add def (-) [Add a] : a -> a -> a = sub -instance float64Add : Add Float64 where - add = \x:Float64 y:Float64. %fadd x y - sub = \x:Float64 y:Float64. %fsub x y +instance Add Float64 + add = \x y. %fadd x y + sub = \x y. %fsub x y zero = FToF64 0.0 -instance float32Add : Add Float32 where - add = \x:Float32 y:Float32. %fadd x y - sub = \x:Float32 y:Float32. %fsub x y +instance Add Float32 + add = \x y. %fadd x y + sub = \x y. %fsub x y zero = FToF32 0.0 -instance int64Add : Add Int64 where - add = \x:Int64 y:Int64. %iadd x y - sub = \x:Int64 y:Int64. %isub x y +instance Add Int64 + add = \x y. %iadd x y + sub = \x y. %isub x y zero = IToI64 0 -instance int32Add : Add Int32 where - add = \x:Int32 y:Int32. %iadd x y - sub = \x:Int32 y:Int32. %isub x y +instance Add Int32 + add = \x y. %iadd x y + sub = \x y. %isub x y zero = IToI32 0 -instance word8Add : Add Word8 where - add = \x:Word8 y:Word8. %iadd x y - sub = \x:Word8 y:Word8. %isub x y +instance Add Word8 + add = \x y. %iadd x y + sub = \x y. %isub x y zero = IToW8 0 -instance unitAdd : Add Unit where +instance Add Unit add = \x y. () sub = \x y. () zero = () -instance tabAdd : Add a ?=> Add (n=>a) where +instance [Add a] Add (n=>a) add = \xs ys. for i. xs.i + ys.i sub = \xs ys. for i. xs.i - ys.i zero = for _. zero -interface Mul a:Type where +interface Mul a mul : a -> a -> a one : a def (*) [Mul a] : a -> a -> a = mul -instance float64Mul : Mul Float64 where - mul = \x:Float64 y:Float64. %fmul x y +instance Mul Float64 + mul = \x y. %fmul x y one = FToF64 1.0 -instance float32Mul : Mul Float32 where - mul = \x:Float32 y:Float32. %fmul x y +instance Mul Float32 + mul = \x y. %fmul x y one = FToF32 1.0 -instance int64Mul : Mul Int64 where - mul = \x:Int64 y:Int64. %imul x y +instance Mul Int64 + mul = \x y. %imul x y one = IToI64 1 -instance int32Mul : Mul Int32 where - mul = \x:Int32 y:Int32. %imul x y +instance Mul Int32 + mul = \x y. %imul x y one = IToI32 1 -instance word8Mul : Mul Word8 where - mul = \x:Word8 y:Word8. %imul x y +instance Mul Word8 + mul = \x y. %imul x y one = IToW8 1 -instance unitMul : Mul Unit where +instance Mul Unit mul = \x y. () one = () -interface Integral a:Type where - idiv: a->a->a - rem: a->a->a +interface Integral a + idiv : a->a->a + rem : a->a->a -instance int64Integral : Integral Int64 where - idiv = \x:Int64 y:Int64. %idiv x y - rem = \x:Int64 y:Int64. %irem x y +instance Integral Int64 + idiv = \x y. %idiv x y + rem = \x y. %irem x y -instance int32Integral : Integral Int32 where - idiv = \x:Int32 y:Int32. %idiv x y - rem = \x:Int32 y:Int32. %irem x y +instance Integral Int32 + idiv = \x y. %idiv x y + rem = \x y. %irem x y -instance word8Integral : Integral Word8 where - idiv = \x:Word8 y:Word8. %idiv x y - rem = \x:Word8 y:Word8. %irem x y +instance Integral Word8 + idiv = \x y. %idiv x y + rem = \x y. %irem x y -interface Fractional a:Type where +interface Fractional a divide : a -> a -> a -instance float64Fractional : Fractional Float64 where - divide = \x:Float64 y:Float64. %fdiv x y +instance Fractional Float64 + divide = \x y. %fdiv x y -instance float32Fractional : Fractional Float32 where - divide = \x:Float32 y:Float32. %fdiv x y +instance Fractional Float32 + divide = \x y. %fdiv x y '## Basic polymorphic functions and types @@ -157,19 +157,22 @@ const : a -> b -> a = \x _. x '## Vector spaces -data VSpace a:Type = MkVSpace (Add a) (Float -> a -> a) +interface [Add a] VSpace a + scaleVec : Float -> a -> a -@superclass -def addFromVSpace (d:VSpace a) : Add a = case d of MkVSpace addDict _ -> addDict - -def (.*) (d:VSpace a) ?=> : Float -> a -> a = case d of MkVSpace _ scale -> scale -def (*.) [VSpace a] : a -> Float -> a = flip (.*) +def (.*) [VSpace a] : Float -> a -> a = scaleVec +def (*.) [VSpace a] : a -> Float -> a = flip scaleVec def (/) [VSpace a] (v:a) (s:Float) : a = divide 1.0 s .* v def neg [VSpace a] (v:a) : a = (-1.0) .* v -@instance floatVS : VSpace Float = MkVSpace float32Add (*) -@instance tabVS : VSpace a ?=> VSpace (n=>a) = MkVSpace tabAdd \s xs. for i. s .* xs.i -@instance unitVS : VSpace Unit = MkVSpace unitAdd \s u. () +instance VSpace Float + scaleVec = \x y. x * y + +instance [VSpace a] VSpace (n=>a) + scaleVec = \s xs. for i. s .* xs.i + +instance VSpace Unit + scaleVec = \_ _. () '## Boolean type @@ -197,7 +200,7 @@ def not (x:Bool) : Bool = '## Sum types -data Maybe a:Type = +data Maybe a = Nothing Just a @@ -207,7 +210,7 @@ def isNothing (x:Maybe a) : Bool = case x of def isJust (x:Maybe a) : Bool = not $ isNothing x -data (|) a:Type b:Type = +data (|) a b = Left a Right b @@ -285,55 +288,76 @@ def unreachable (():Unit) : a = unsafeIO do '## Type classes -data Eq a:Type = MkEq (a -> a -> Bool) -data Ord a:Type = MkOrd (Eq a) (a -> a -> Bool) (a -> a -> Bool) -- eq, gt, lt - -@superclass -def eqFromOrd (d:Ord a) : Eq a = case d of MkOrd eq _ _ -> eq +interface Eq a + (==) : a -> a -> Bool -def (==) (d:Eq a) ?=> (x:a) (y:a) : Bool = case d of MkEq eq -> eq x y def (/=) [Eq a] (x:a) (y:a) : Bool = not $ x == y -def (>) (d:Ord a) ?=> (x:a) (y:a) : Bool = case d of MkOrd _ gt _ -> gt x y -def (<) (d:Ord a) ?=> (x:a) (y:a) : Bool = case d of MkOrd _ _ lt -> lt x y +interface [Eq a] Ord a + (>) : a -> a -> Bool + (<) : a -> a -> Bool + def (<=) [Ord a] (x:a) (y:a) : Bool = x=) [Ord a] (x:a) (y:a) : Bool = x>y || x==y -@instance float64Eq : Eq Float64 = MkEq \x:Float64 y:Float64. W8ToB $ %feq x y -@instance float32Eq : Eq Float32 = MkEq \x:Float32 y:Float32. W8ToB $ %feq x y -@instance int64Eq : Eq Int64 = MkEq \x:Int64 y:Int64. W8ToB $ %ieq x y -@instance int32Eq : Eq Int32 = MkEq \x:Int32 y:Int32. W8ToB $ %ieq x y -@instance word8Eq : Eq Word8 = MkEq \x:Word8 y:Word8. W8ToB $ %ieq x y -@instance boolEq : Eq Bool = MkEq \x y. BToW8 x == BToW8 y -@instance unitEq : Eq Unit = MkEq \x y. True -@instance rawPtrEq : Eq RawPtr = MkEq \x y. RawPtrToI64 x == RawPtrToI64 y - -@instance float64Ord : Ord Float64 = (MkOrd float64Eq (\x y. W8ToB $ %fgt x y) - (\x y. W8ToB $ %flt x y)) -@instance float32Ord : Ord Float32 = (MkOrd float32Eq (\x y. W8ToB $ %fgt x y) - (\x y. W8ToB $ %flt x y)) -@instance int64Ord : Ord Int64 = (MkOrd int64Eq (\x y. W8ToB $ %igt x y) - (\x y. W8ToB $ %ilt x y)) -@instance int32Ord : Ord Int32 = (MkOrd int32Eq (\x y. W8ToB $ %igt x y) - (\x y. W8ToB $ %ilt x y)) -@instance word8Ord : Ord Word8 = (MkOrd word8Eq (\x y. W8ToB $ %igt x y) - (\x y. W8ToB $ %ilt x y)) -@instance unitOrd : Ord Unit = (MkOrd unitEq (\x y. False) (\x y. False)) - -@instance -def pairEq [Eq a, Eq b] : Eq (a & b) = MkEq $ - \(x1,x2) (y1,y2). x1 == y1 && x2 == y2 - -@instance -def pairOrd [Ord a, Ord b] : Ord (a & b) = - pairGt = \(x1,x2) (y1,y2). x1 > y1 || (x1 == y1 && x2 > y2) - pairLt = \(x1,x2) (y1,y2). x1 < y1 || (x1 == y1 && x2 < y2) - MkOrd pairEq pairGt pairLt +instance Eq Float64 + (==) = \x y. W8ToB $ %feq x y + +instance Eq Float32 + (==) = \x y. W8ToB $ %feq x y + +instance Eq Int64 + (==) = \x y. W8ToB $ %ieq x y + +instance Eq Int32 + (==) = \x y. W8ToB $ %ieq x y + +instance Eq Word8 + (==) = \x y. W8ToB $ %ieq x y + +instance Eq Bool + (==) = \x y. BToW8 x == BToW8 y + +instance Eq Unit + (==) = \x y. True + +instance Eq RawPtr + (==) = \x y. RawPtrToI64 x == RawPtrToI64 y + +instance Ord Float64 + (>) = \x y. W8ToB $ %fgt x y + (<) = \x y. W8ToB $ %flt x y + +instance Ord Float32 + (>) = \x y. W8ToB $ %fgt x y + (<) = \x y. W8ToB $ %flt x y + +instance Ord Int64 + (>) = \x y. W8ToB $ %igt x y + (<) = \x y. W8ToB $ %ilt x y + +instance Ord Int32 + (>) = \x y. W8ToB $ %igt x y + (<) = \x y. W8ToB $ %ilt x y + +instance Ord Word8 + (>) = \x y. W8ToB $ %igt x y + (<) = \x y. W8ToB $ %ilt x y + +instance Ord Unit + (>) = \x y. False + (<) = \x y. False + +instance [Eq a, Eq b] Eq (a & b) + (==) = \(x1,x2) (y1,y2). x1 == y1 && x2 == y2 + +instance [Ord a, Ord b] Ord (a & b) + (>) = \(x1,x2) (y1,y2). x1 > y1 || (x1 == y1 && x2 > y2) + (<) = \(x1,x2) (y1,y2). x1 < y1 || (x1 == y1 && x2 < y2) -- TODO: accumulate using the True/&& monoid -@instance -def tabEq [Eq a] : Eq (n=>a) = MkEq $ - \xs ys. +instance [Eq a] Eq (n=>a) + (==) = \xs ys. numDifferent : Float = yieldAccum \ref. for i. ref += (IToF (BToI (xs.i /= ys.i))) @@ -341,7 +365,7 @@ def tabEq [Eq a] : Eq (n=>a) = MkEq $ '## Transcencendental functions -interface Floating a:Type where +interface Floating a exp : a -> a exp2 : a -> a log : a -> a @@ -375,45 +399,45 @@ def float64_cosh (x:Float64) : Float64 = %fdiv ((%exp x) + (%exp (%fsub (FToF64 def float64_tanh (x:Float64) : Float64 = %fdiv (%fsub (%exp x) (%exp (%fsub (FToF64 0.0) x))) ((%exp x) + (%exp (%fsub (FToF64 0.0) x))) -instance float64Floating : Floating Float64 where - exp = \x:Float64. %exp x - exp2 = \x:Float64. %exp2 x - log = \x:Float64. %log x - log2 = \x:Float64. %log2 x - log10 = \x:Float64. %log10 x - log1p = \x:Float64. %log1p x - sin = \x:Float64. %sin x - cos = \x:Float64. %cos x - tan = \x:Float64. %tan x +instance Floating Float64 + exp = \x. %exp x + exp2 = \x. %exp2 x + log = \x. %log x + log2 = \x. %log2 x + log10 = \x. %log10 x + log1p = \x. %log1p x + sin = \x. %sin x + cos = \x. %cos x + tan = \x. %tan x sinh = float64_sinh cosh = float64_cosh tanh = float64_tanh - floor = \x:Float64. %floor x - ceil = \x:Float64. %ceil x - round = \x:Float64. %round x - sqrt = \x:Float64. %sqrt x - pow = \x:Float64 y:Float64. %fpow x y - lgamma = \x:Float64. %lgamma x - -instance float32Floating : Floating Float32 where - exp = \x:Float32. %exp x - exp2 = \x:Float32. %exp2 x - log = \x:Float32. %log x - log2 = \x:Float32. %log2 x - log10 = \x:Float32. %log10 x - log1p = \x:Float32. %log1p x - sin = \x:Float32. %sin x - cos = \x:Float32. %cos x - tan = \x:Float32. %tan x + floor = \x. %floor x + ceil = \x. %ceil x + round = \x. %round x + sqrt = \x. %sqrt x + pow = \x y. %fpow x y + lgamma = \x. %lgamma x + +instance Floating Float32 + exp = \x. %exp x + exp2 = \x. %exp2 x + log = \x. %log x + log2 = \x. %log2 x + log10 = \x. %log10 x + log1p = \x. %log1p x + sin = \x. %sin x + cos = \x. %cos x + tan = \x. %tan x sinh = float32_sinh cosh = float32_cosh tanh = float32_tanh - floor = \x:Float32. %floor x - ceil = \x:Float32. %ceil x - round = \x:Float32. %round x - sqrt = \x:Float32. %sqrt x - pow = \x:Float32 y:Float32. %fpow x y - lgamma = \x:Float32. %lgamma x + floor = \x. %floor x + ceil = \x. %ceil x + round = \x. %round x + sqrt = \x. %sqrt x + pow = \x y. %fpow x y + lgamma = \x. %lgamma x '## Index set utilities @@ -425,90 +449,66 @@ def unsafeFromOrdinal (n : Type) (i : Int) : n = %unsafeFromOrdinal n i def iota (n:Type) : n=>Int = view i. ordinal i -- TODO: we want Eq and Ord for all index sets, not just `Fin n` -@instance -def finEq (n:Int) ?-> : Eq (Fin n) = MkEq \x y. ordinal x == ordinal y +instance (n:Int) ?-> Eq (Fin n) + (==) = \x y. ordinal x == ordinal y -@instance -def finOrd (n:Int) ?-> : Ord (Fin n) = - MkOrd finEq (\x y. ordinal x > ordinal y) (\x y. ordinal x < ordinal y) +instance (n:Int) ?-> Ord (Fin n) + (>) = \x y. ordinal x > ordinal y + (<) = \x y. ordinal x < ordinal y '## Raw pointer operations -data Ptr a:Type = MkPtr RawPtr +data Ptr a = MkPtr RawPtr -- Is there a better way to select the right instance for `storageSize`?? -data TypeVehicle a:Type = MkTypeVehicle +data TypeVehicle a = MkTypeVehicle def typeVehicle (a:Type) : TypeVehicle a = MkTypeVehicle -interface Storable a:Type where +interface Storable a store : Ptr a -> a -> {State World} Unit load : Ptr a -> {State World} a - storageSize : TypeVehicle a -> Int - --- TODO: we can't inline these into the instance definitions until we change --- type inference to push types down into record constructors or allow `def` in --- instance definitions. -def word8Store ((MkPtr ptr): Ptr Word8) (x:Word8) : {State World} Unit = %ptrStore ptr x -def word8Load ((MkPtr ptr): Ptr Word8) : {State World} Word8 = %ptrLoad ptr - -instance word8Storable : Storable Word8 where - store = word8Store - load = word8Load - storageSize = const 1 - --- TODO: there's a bug preventing us inlining these definitions into the instance -def int32Store ((MkPtr ptr): Ptr Int32) (x:Int32) : {State World} Unit = - %ptrStore (internalCast %Int32Ptr ptr) x -def int32Load ((MkPtr ptr): Ptr Int32) : {State World} Int32 = - %ptrLoad (internalCast %Int32Ptr ptr) - -instance int32Storable : Storable Int32 where - store = int32Store - load = int32Load - storageSize = const 4 + storageSize_ : TypeVehicle a -> Int + +def storageSize (a:Type) -> (d:Storable a) ?=> : Int = + tv : TypeVehicle a = MkTypeVehicle + storageSize_ tv + +instance Storable Word8 + store = \(MkPtr ptr) x. %ptrStore ptr x + load = \(MkPtr ptr) . %ptrLoad ptr + storageSize_ = const 1 + +instance Storable Int32 + store = \(MkPtr ptr) x. %ptrStore (internalCast %Int32Ptr ptr) x + load = \(MkPtr ptr) . %ptrLoad (internalCast %Int32Ptr ptr) + storageSize_ = const 4 def unpackPairPtr [Storable a, Storable b] (pairPtr: Ptr (a & b)) : (Ptr a & Ptr b) = (MkPtr rawPtrX) = pairPtr - rawPtrY = %ptrOffset rawPtrX (storageSize (typeVehicle a)) + rawPtrY = %ptrOffset rawPtrX (storageSize a) (MkPtr rawPtrX, MkPtr rawPtrY) -def pairStore [Storable a, Storable b] - (pairPtr:Ptr (a & b)) ((x, y):(a & b)) : {State World} Unit = - (xPtr, yPtr) = unpackPairPtr pairPtr - store xPtr x - store yPtr y - -def pairLoad [Storable a, Storable b] - (pairPtr:Ptr (a & b)) : {State World} (a & b) = - (xPtr, yPtr) = unpackPairPtr pairPtr - (load xPtr, load yPtr) - -def pairStorageSize [Storable a, Storable b] - (_:TypeVehicle (a & b)) : Int = - storageSize (typeVehicle a) + storageSize (typeVehicle b) - -instance pairStorable : (Storable a) ?=> (Storable b) ?=> Storable (a & b) where - store = pairStore - load = pairLoad - storageSize = pairStorageSize - -def ptrPtrStore ((MkPtr ptr): Ptr (Ptr a)) (x:(Ptr a)) : {State World} Unit = - (MkPtr x') = x - %ptrStore (internalCast %PtrPtr ptr) x' - -def ptrPtrLoad ((MkPtr ptr): Ptr (Ptr a)) : {State World} (Ptr a) = - MkPtr $ %ptrLoad (internalCast %PtrPtr ptr) - -instance ptrStorable : Storable (Ptr a) where - store = ptrPtrStore - load = ptrPtrLoad - storageSize = const 8 -- TODO: something more portable? +instance [Storable a, Storable b] Storable (a & b) + store = \pairPtr (x, y). + (xPtr, yPtr) = unpackPairPtr pairPtr + store xPtr x + store yPtr y + load = \pairPtr. + (xPtr, yPtr) = unpackPairPtr pairPtr + (load xPtr, load yPtr) + storageSize_ = \_. + storageSize a + storageSize b + +instance Storable (Ptr a) + store = \(MkPtr ptr) (MkPtr x). %ptrStore (internalCast %PtrPtr ptr) x + load = \(MkPtr ptr) . MkPtr $ %ptrLoad (internalCast %PtrPtr ptr) + storageSize_ = const 8 -- TODO: something more portable? -- TODO: Storable instances for other types def malloc [Storable a] (n:Int) : {State World} (Ptr a) = - numBytes = storageSize (typeVehicle a) * n + numBytes = storageSize a * n MkPtr $ %alloc numBytes def free (ptr:Ptr a) : {State World} Unit = @@ -517,7 +517,7 @@ def free (ptr:Ptr a) : {State World} Unit = def (+>>) [Storable a] (ptr:Ptr a) (i:Int) : Ptr a = (MkPtr ptr') = ptr - i' = i * storageSize (typeVehicle a) + i' = i * storageSize a MkPtr $ %ptrOffset ptr' i' -- TODO: generalize these brackets to allow other effects @@ -601,7 +601,7 @@ def vdot (x:n=>Float) (y:n=>Float) : Float = fsum view i. x.i * y.i def dot [VSpace v] (s:n=>Float) (vs:n=>v) : v = sum for j. s.j .* vs.j -- matmul. Better symbol to use? `@`? -(**) : (l=>m=>Float) -> (m=>n=>Float) -> (l=>n=>Float) = \x y. +(**) : (l=>m=>Float) -> (m=>n=>Float) -> (l=>n=>Float) = \x y. for i k. fsum view j. x.i.j * y.j.k (**.) : (n=>m=>Float) -> (m=>Float) -> (n=>Float) = \mat v. for i. vdot mat.i v @@ -671,33 +671,33 @@ def deriv (f:Float->Float) (x:Float) : Float = jvp f x 1.0 def derivRev (f:Float->Float) (x:Float) : Float = snd (vjp f x) 1.0 -interface HasAllClose a:Type where +interface HasAllClose a allclose : a -> a -> a -> a -> Bool -interface HasDefaultTolerance a:Type where +interface HasDefaultTolerance a atol : a rtol : a def (~~) [HasAllClose a, HasDefaultTolerance a] : a -> a -> Bool = allclose atol rtol -instance allCloseF32 : HasAllClose Float32 where +instance HasAllClose Float32 allclose = \atol rtol x y. abs (x - y) <= (atol + rtol * abs y) -instance allCloseF64 : HasAllClose Float64 where +instance HasAllClose Float64 allclose = \atol rtol x y. abs (x - y) <= (atol + rtol * abs y) -instance defaultToleranceF32 : HasDefaultTolerance Float32 where +instance HasDefaultTolerance Float32 atol = FToF32 0.00001 rtol = FToF32 0.0001 -instance defaultToleranceF64 : HasDefaultTolerance Float64 where +instance HasDefaultTolerance Float64 atol = FToF64 0.00000001 rtol = FToF64 0.00001 -instance allCloseTable : HasAllClose t ?=> HasDefaultTolerance t ?=> HasAllClose (n=>t) where +instance [HasAllClose t, HasDefaultTolerance t] HasAllClose (n=>t) allclose = \atol rtol a b. all for i:n. (a.i ~~ b.i) -instance defaultToleranceTable : (HasDefaultTolerance t) ?=> HasDefaultTolerance (n=>t) where +instance [HasDefaultTolerance t] HasDefaultTolerance (n=>t) atol = for i. atol rtol = for i. rtol @@ -776,7 +776,7 @@ def tile1 (fTile : (t:(Tile n l) -> {|eff} m=>l=>a)) '## Monoid typeclass -interface Monoid a:Type where +interface Monoid a mempty : a mcombine : a -> a -> a -- can't use `<>` just for parser reasons? @@ -784,7 +784,7 @@ def (<>) [Monoid a] : a -> a -> a = mcombine '## Length-erased lists -data List a:Type = +data List a = AsList n:Int foo:(Fin n => a) def unsafeCastTable (m:Type) (xs:n=>a) : m=>a = @@ -794,7 +794,7 @@ def toList (xs:n=>a) : List a = n' = size n AsList _ $ unsafeCastTable (Fin n') xs -instance monoidList : Monoid (List a) where +instance Monoid (List a) mempty = AsList _ [] mcombine = \x y. (AsList nx xs) = x @@ -808,7 +808,7 @@ instance monoidList : Monoid (List a) where '## Isomorphisms -data Iso a:Type b:Type = MkIso { fwd: a -> b & bwd: b -> a } +data Iso a b = MkIso { fwd: a -> b & bwd: b -> a } def appIso (iso: Iso a b) (x:a) : b = (MkIso {fwd, bwd}) = iso @@ -890,7 +890,7 @@ def sliceFields (iso: Iso ({|} | a) (b | c)) (tab: a=>v) : b=>v = -- TODO: should we be able to use `Ref World Int` instead of `Ptr Int`? -- TODO: would be nice to be able to use records here -data DynBuffer a:Type = MkDynBuffer (Ptr (Int & Int & Ptr a)) -- size, max size, buf ptr +data DynBuffer a = MkDynBuffer (Ptr (Int & Int & Ptr a)) -- size, max size, buf ptr def withDynamicBuffer [Storable a] (action: DynBuffer a -> {State World} b) : {State World} b = @@ -945,29 +945,29 @@ def stringFromCharPtr (n:Int) (ptr:Ptr Char) : {State World} String = -- TODO. This is ASCII code point. It really should be Int32 for Unicode codepoint def codepoint (c:Char) : Int = W8ToI c -interface Show a:Type where +interface Show a show : a -> String -instance showString : Show String where +instance Show String show = id -instance showInt32 : Show Int32 where - show = \x: Int32. unsafeIO do +instance Show Int32 + show = \x. unsafeIO do (n, ptr) = %ffi showInt32 (Int32 & RawPtr) x stringFromCharPtr n $ MkPtr ptr -instance showInt64 : Show Int64 where - show = \x: Int64. unsafeIO do +instance Show Int64 + show = \x. unsafeIO do (n, ptr) = %ffi showInt64 (Int32 & RawPtr) x stringFromCharPtr n $ MkPtr ptr -instance showFloat32 : Show Float32 where - show = \x: Float32.unsafeIO do +instance Show Float32 + show = \x. unsafeIO do (n, ptr) = %ffi showFloat32 (Int32 & RawPtr) x stringFromCharPtr n $ MkPtr ptr -instance showFloat64 : Show Float64 where - show = \x: Float64.unsafeIO do +instance Show Float64 + show = \x. unsafeIO do (n, ptr) = %ffi showFloat64 (Int32 & RawPtr) x stringFromCharPtr n $ MkPtr ptr @@ -1053,7 +1053,7 @@ def while (eff:Effects) ?-> (body: Unit -> {|eff} Bool) : {|eff} Unit = body' : Unit -> {|eff} Word8 = \_. BToW8 $ body () %while body' -data IterResult a:Type = +data IterResult a = Continue Done a @@ -1225,19 +1225,19 @@ def randIdx (k:Key) : n = 'Type class for generating example values -interface Arbitrary a:Type where +interface Arbitrary a arb : Key -> a -instance float32Arb : Arbitrary Float32 where +instance Arbitrary Float32 arb = randn -instance in32Arb : Arbitrary Int32 where +instance Arbitrary Int32 arb = \key. FToI $ randn key * 5.0 -instance tabArb : Arbitrary a ?=> Arbitrary (n=>a) where +instance [Arbitrary a] Arbitrary (n=>a) arb = \key. for i. arb $ ixkey key i -instance finArb : n:Int ?-> Arbitrary (Fin n) where +instance (n:Int) ?-> Arbitrary (Fin n) arb = randIdx 'Control flow @@ -1331,28 +1331,28 @@ def atan (x:Float) : Float = atan2 x 1.0 data Complex = MkComplex Float Float -- real, imaginary -instance allCloseComplex : HasAllClose Complex where +instance HasAllClose Complex allclose = \atol rtol (MkComplex a b) (MkComplex c d). (a ~~ c) && (b ~~ d) -instance defaultToleranceComplex : HasDefaultTolerance Complex where +instance HasDefaultTolerance Complex atol = MkComplex atol atol rtol = MkComplex rtol rtol -@instance ComplexEq : Eq Complex = - MkEq \(MkComplex a b) (MkComplex c d). (a == c) && (b == d) +instance Eq Complex + (==) = \(MkComplex a b) (MkComplex c d). (a == c) && (b == d) -instance ComplexAdd : Add Complex where +instance Add Complex add = \(MkComplex a b) (MkComplex c d). MkComplex (a + c) (b + d) sub = \(MkComplex a b) (MkComplex c d). MkComplex (a - c) (b - d) zero = MkComplex 0.0 0.0 -instance ComplexMul : Mul Complex where +instance Mul Complex mul = \(MkComplex a b) (MkComplex c d). MkComplex (a * c - b * d) (a * d + b * c) one = MkComplex 1.0 0.0 -@instance complexVS : VSpace Complex = - MkVSpace ComplexAdd \a:Float (MkComplex c d):Complex. MkComplex (a * c) (a * d) +instance VSpace Complex + scaleVec = \a:Float (MkComplex c d):Complex. MkComplex (a * c) (a * d) -- Todo: Hook up to (/) operator. Might require two-parameter VSpace. def complex_division (MkComplex a b:Complex) (MkComplex c d:Complex): Complex = @@ -1391,7 +1391,7 @@ def complex_tanh (MkComplex a b:Complex) : Complex = den = MkComplex (cosh a * cos b) (sinh a * sin b) complex_division num den -instance ComplexFractional : Fractional Complex where +instance Fractional Complex divide = complex_division def complex_floor (MkComplex re im:Complex) : Complex = @@ -1424,7 +1424,7 @@ def complex_log1p (x:Complex) : Complex = True -> complex_log u False -> divide ((complex_log u) * x) x -instance complexFloating : Floating Complex where +instance Floating Complex exp = complex_exp exp2 = complex_exp2 log = complex_log diff --git a/src/lib/Embed.hs b/src/lib/Embed.hs index 731cfce66..705d1c50a 100644 --- a/src/lib/Embed.hs +++ b/src/lib/Embed.hs @@ -17,7 +17,7 @@ module Embed (emit, emitTo, emitAnn, emitOp, buildDepEffLam, buildLamAux, buildP app, add, mul, sub, neg, div', iadd, imul, isub, idiv, ilt, ieq, - fpow, flog, fLitLike, + fpow, flog, fLitLike, recGet, buildImplicitNaryLam, select, substEmbed, substEmbedR, emitUnpack, getUnpacked, fromPair, getFst, getSnd, getFstRef, getSndRef, naryApp, appReduce, appTryReduce, buildAbs, @@ -25,7 +25,7 @@ module Embed (emit, emitTo, emitAnn, emitOp, buildDepEffLam, buildLamAux, buildP emitBlock, unzipTab, isSingletonType, emitDecl, withNameHint, singletonTypeVal, scopedDecls, embedScoped, extendScope, checkEmbed, embedExtend, unpackConsList, emitRunWriter, applyPreludeFunction, - emitRunState, emitMaybeCase, emitWhile, + emitRunState, emitMaybeCase, emitWhile, buildDataDef, emitRunReader, tabGet, SubstEmbedT, SubstEmbed, runSubstEmbedT, traverseAtom, ptrOffset, ptrLoad, unsafePtrLoad, evalBlockE, substTraversalDef, @@ -42,6 +42,8 @@ import Control.Monad.Writer hiding (Alt) import Control.Monad.Identity import Control.Monad.State.Strict import Data.Foldable (toList) +import Data.List (elemIndex) +import Data.Maybe (fromJust) import Data.String (fromString) import Data.Tuple (swap) import GHC.Stack @@ -188,6 +190,28 @@ buildNAbsAux bs body = do return (fmap Bind vs, result) return (Abs bs' $ wrapDecls decls ans, aux) +buildDataDef :: MonadEmbed m + => Name -> Nest Binder -> ([Atom] -> m [DataConDef]) -> m DataDef +buildDataDef tyConName paramBinders body = do + ((paramBinders', dataDefs), _) <- scopedDecls $ do + vs <- freshNestedBinders paramBinders + result <- body $ map Var $ toList vs + return (fmap Bind vs, result) + return $ DataDef tyConName paramBinders' dataDefs + +buildImplicitNaryLam :: MonadEmbed m => (Nest Binder) -> ([Atom] -> m Atom) -> m Atom +buildImplicitNaryLam Empty body = body [] +buildImplicitNaryLam (Nest b bs) body = + buildLam b ImplicitArrow $ \x -> do + bs' <- substEmbed (b@>x) bs + buildImplicitNaryLam bs' $ \xs -> body $ x:xs + +recGet :: Label -> Atom -> Atom +recGet l x = do + let (RecordTy (Ext r _)) = getType x + let i = fromJust $ elemIndex l $ map fst $ toList $ reflectLabels r + getProjection [i] x + buildScoped :: MonadEmbed m => m Atom -> m Block buildScoped m = do (ans, decls) <- scopedDecls m diff --git a/src/lib/Env.hs b/src/lib/Env.hs index bfb2dd93e..456c613ab 100644 --- a/src/lib/Env.hs +++ b/src/lib/Env.hs @@ -39,6 +39,7 @@ data NameSpace = | InferenceName | SumName | FFIName + | TypeClassGenName -- names generated for type class dictionaries | AbstractedPtrName -- used in `abstractPtrLiterals` in Imp lowering | TopFunctionName -- top-level Imp functions | AllocPtrName -- used for constructing dests in Imp lowering @@ -163,6 +164,7 @@ env ! v = case envLookup env v of isGlobal :: VarP ann -> Bool isGlobal (GlobalName _ :> _) = True isGlobal (GlobalArrayName _ :> _) = True +isGlobal (Name TypeClassGenName _ _ :> _) = True isGlobal _ = False isGlobalBinder :: BinderP ann -> Bool diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index 4ad9aa0f8..aa3c94663 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -517,6 +517,7 @@ toImpHof env (maybeDest, hof) = do translateBlock env (maybeDest, body) Linearize _ -> error "Unexpected Linearize" Transpose _ -> error "Unexpected Transpose" + CatchException _ -> error "Unexpected CatchException" data LaunchInfo = LaunchInfo { numWorkgroups :: IExpr, workgroupSize :: IExpr } data ThreadInfo = ThreadInfo { tid :: IExpr, wid :: IExpr, threadRange :: Type } diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index ecfd2e962..5a74edbfd 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -321,8 +321,7 @@ unpackTopPat letAnn pat expr = do bindings <- bindPat pat atom void $ flip traverseNames bindings $ \name val -> do let name' = asGlobal name - scope <- getScope - when (name' `isin` scope) $ throw RepeatedVarErr $ pprint $ name' + checkNotInScope name' emitTo name' letAnn $ Atom val inferUDecl :: Bool -> UDecl -> UInferM SubstEnv @@ -343,29 +342,116 @@ inferUDecl topLevel (ULet letAnn (p, ann) rhs) = do else bindPat p val inferUDecl True (UData tc dcs) = do (tc', paramBs) <- inferUConDef tc - scope <- getScope - when (tc' `isin` scope) $ throw RepeatedVarErr $ pprint $ getName tc' - let paramVars = map (\(Bind v) -> v) $ toList paramBs -- TODO: refresh things properly - (dcs', _) <- embedScoped $ - extendR (newEnv paramBs (map Var paramVars)) $ do - extendScope (foldMap boundVars paramBs) - mapM inferUConDef dcs - let dataDef = DataDef tc' paramBs $ map (uncurry DataConDef) dcs' - let tyConTy = getType $ TypeCon dataDef [] - extendScope $ tc' @> (tyConTy, DataBoundTypeCon dataDef) - forM_ (zip [0..] dcs') $ \(i, (dc,_)) -> do - -- Retrieving scope at every step to avoid duplicate constructor names - scope' <- getScope - when (dc `isin` scope') $ throw RepeatedVarErr $ pprint $ getName dc - let ty = getType $ DataCon dataDef [] i [] - extendScope $ dc @> (ty, DataBoundDataCon dataDef i) + dataDef <- buildDataDef tc' paramBs $ \params -> do + extendR (newEnv paramBs params) $ forM dcs $ \dc -> + uncurry DataConDef <$> inferUConDef dc + checkDataDefShadows dataDef + emitConstructors dataDef + return mempty +inferUDecl True (UInterface superclasses tc methods) = do + (tc', paramBs) <- inferUConDef tc + dataDef <- buildDataDef tc' paramBs $ \params -> do + extendR (newEnv paramBs params) $ do + conName <- freshClassGenName + superclasses' <- mkLabeledItems <$> mapM mkSuperclass superclasses + methods' <- mkLabeledItems <$> mapM mkMethod methods + return $ ClassDictDef conName superclasses' methods' + checkDataDefShadows dataDef + emitConstructors dataDef + emitSuperclassGetters dataDef + emitMethodGetters dataDef return mempty -inferUDecl False (UData _ _) = error "data definitions should be top-level" +inferUDecl True (UInstance instanceTy methods) = do + ty <- checkUType instanceTy + instanceDict <- checkInstance ty methods + let instanceName = Name TypeClassGenName "instance" 0 + void $ emitTo instanceName InstanceLet $ Atom instanceDict + return mempty +inferUDecl False (UData _ _ ) = error "data definitions should be top-level" +inferUDecl False (UInterface _ _ _) = error "interface definitions should be top-level" +inferUDecl False (UInstance _ _ ) = error "instance definitions should be top-level" + +freshClassGenName :: MonadEmbed m => m Name +freshClassGenName = do + scope <- getScope + let v' = genFresh (Name TypeClassGenName "classgen" 0) scope + embedExtend $ asFst $ v' @> (UnitTy, UnknownBinder) + return v' + +mkMethod :: UAnnBinder -> UInferM (Label, Type) +mkMethod (Ignore _) = error "Methods must have names" +mkMethod (Bind (v:>ty)) = do + ty' <- checkUType ty + return (nameToLabel v, ty') + +mkSuperclass :: UType -> UInferM (Label, Type) +mkSuperclass ty = do + ty' <- checkUType ty + -- TODO: think about the scope of these names + l <- freshClassGenName + return (nameToLabel l, ty') + +-- TODO: just make Name and Label the same thing +nameToLabel :: Name -> Label +nameToLabel = pprint + +mkLabeledItems :: [(Label, a)] -> LabeledItems a +mkLabeledItems items = foldMap (uncurry labeledSingleton) items + +emitConstructors :: DataDef -> UInferM () +emitConstructors def@(DataDef tyConName _ dataConDefs) = do + let tyConTy = getType $ TypeCon def [] + checkNotInScope tyConName + extendScope $ tyConName @> (tyConTy, DataBoundTypeCon def) + forM_ (zip [0..] dataConDefs) $ \(i, DataConDef dataConName _) -> do + let dataConTy = getType $ DataCon def [] i [] + checkNotInScope dataConName + extendScope $ dataConName @> (dataConTy, DataBoundDataCon def i) + +emitMethodGetters :: DataDef -> UInferM () +emitMethodGetters def@(DataDef _ paramBs (ClassDictDef _ _ methodTys)) = do + forM_ (getLabels methodTys) $ \l -> do + f <- buildImplicitNaryLam paramBs $ \params -> do + buildLam (Bind ("d":> TypeCon def params)) ClassArrow $ \dict -> do + return $ recGet l $ getProjection [1] dict + let methodName = GlobalName $ fromString l + checkNotInScope methodName + emitTo methodName PlainLet $ Atom f +emitMethodGetters (DataDef _ _ _) = error "Not a class dictionary" + +emitSuperclassGetters :: MonadEmbed m => DataDef -> m () +emitSuperclassGetters def@(DataDef _ paramBs (ClassDictDef _ superclassTys _)) = do + forM_ (getLabels superclassTys) $ \l -> do + f <- buildImplicitNaryLam paramBs $ \params -> do + buildLam (Bind ("d":> TypeCon def params)) PureArrow $ \dict -> do + return $ recGet l $ getProjection [0] dict + getterName <- freshClassGenName + emitTo getterName SuperclassLet $ Atom f +emitSuperclassGetter (DataDef _ _ _) = error "Not a class dictionary" + +checkNotInScope :: Name -> UInferM () +checkNotInScope v = do + scope <- getScope + when (v `isin` scope) $ throw RepeatedVarErr $ pprint v + +checkDataDefShadows :: DataDef -> UInferM () +checkDataDefShadows (DataDef tc _ dataCons) = do + checkShadows $ tc:dcs + where dcs = [dc | DataConDef dc _ <- dataCons] + +checkShadows :: [Name] -> UInferM () +checkShadows vs = do + mapM_ checkNotInScope vs + case repeated vs of + [] -> return () + (v:_) -> throw RepeatedVarErr $ pprint v inferUConDef :: UConDef -> UInferM (Name, Nest Binder) inferUConDef (UConDef v bs) = do (bs', _) <- embedScoped $ checkNestedBinders bs - return (asGlobal v, bs') + let v' = asGlobal v + checkNotInScope v' + return (v', bs') checkNestedBinders :: Nest UAnnBinder -> UInferM (Nest Binder) checkNestedBinders Empty = return Empty @@ -393,6 +479,37 @@ checkULam (p, ann) body piTy = do $ \x@(Var v) -> checkLeaks [v] $ withBindPat p x $ checkSigma body Suggest $ snd $ applyAbs piTy x +checkInstance :: Type -> [(UVar, UExpr)] -> UInferM Atom +checkInstance ty methods = case ty of + TypeCon def@(DataDef className _ _) params -> do + case applyDataDefParams def params of + ClassDictDef _ superclassTys methodTys -> do + methods' <- liftM mkLabeledItems $ forM methods $ \((v:>()), rhs) -> do + let v' = nameToLabel v + case lookupLabel methodTys v' of + Nothing -> throw TypeErr (pprint v ++ " is not a method of " ++ pprint className) + Just methodTy -> do + rhs' <- checkSigma rhs Suggest methodTy + return (v', rhs') + let superclassHoles = fmap (Con . ClassDictHole Nothing) superclassTys + forM_ (reflectLabels methods') $ \(l,i) -> + when (i > 0) $ throw TypeErr $ "Duplicate method: " ++ pprint l + forM_ (reflectLabels methodTys) $ \(l,_) -> + case lookupLabel methods' l of + Nothing -> throw TypeErr $ "Missing method: " ++ pprint l + Just _ -> return () + return $ ClassDictCon def params superclassHoles methods' + _ -> throw TypeErr $ "Not a valid instance: " ++ pprint ty + Pi (Abs b (arrow, bodyTy)) -> do + case arrow of + ImplicitArrow -> return () + ClassArrow -> return () + _ -> throw TypeErr $ "Not a valid arrow for an instance: " ++ pprint arrow + buildLam b arrow $ \x@(Var v) -> do + bodyTy' <- substEmbed (b@>x) bodyTy + checkLeaks [v] $ extendR (b@>x) $ checkInstance bodyTy' methods + _ -> throw TypeErr $ "Not a valid instance type: " ++ pprint ty + checkUEffRow :: EffectRow -> UInferM EffectRow checkUEffRow (EffectRow effs t) = do effs' <- liftM S.fromList $ mapM checkUEff $ toList effs diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index 976d86dba..9757b9fea 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -21,7 +21,6 @@ import Data.Foldable (toList) import qualified Data.List.NonEmpty as NE import qualified Data.Map.Strict as M import qualified Data.ByteString.Lazy.Char8 as B -import Data.Maybe (fromMaybe) import Data.String (fromString) import Data.Text.Prettyprint.Doc.Render.Text import Data.Text.Prettyprint.Doc @@ -33,7 +32,6 @@ import Numeric import Env import Syntax -import Util (enumerate) -- Specifies what kinds of operations are allowed to be printed at this point. -- Printing at AppPrec level means that applications can be printed @@ -364,7 +362,7 @@ instance PrettyPrec Atom where "DataConRef" <+> p params <+> p args BoxedRef b ptr size body -> atPrec AppPrec $ "Box" <+> p b <+> "<-" <+> p ptr <+> "[" <> p size <> "]" <+> hardline <> "in" <+> p body - ProjectElt idxs x -> prettyProjection idxs x + ProjectElt idxs x -> atPrec LowestPrec $ "project" <+> p idxs <+> p x instance Pretty DataConRefBinding where pretty = prettyFromPrettyPrec instance PrettyPrec DataConRefBinding where @@ -376,45 +374,6 @@ fromInfix t = do (t'', ')') <- unsnoc t' return t'' -prettyProjection :: NE.NonEmpty Int -> Var -> DocPrec ann -prettyProjection idxs (name :> ty) = prettyPrec uproj where - -- Builds a source expression that performs the given projection. - uproj = UApp (PlainArrow ()) (nosrc ulam) (nosrc uvar) - ulam = ULam (upat, Nothing) (PlainArrow ()) (nosrc $ UVar $ target :> ()) - uvar = UVar $ name :> () - (_, upat, target) = buildProj idxs - - buildProj :: NE.NonEmpty Int -> (Type, UPat, Name) - buildProj (i NE.:| is) = let - -- Lazy Haskell trick: refer to `target` even though this function is - -- responsible for setting it! - (ty', pat', eltName) = case NE.nonEmpty is of - Just is' -> let (x, y, z) = buildProj is' in (x, y, Just z) - Nothing -> (ty, nosrc $ UPatBinder $ Bind $ target :> (), Nothing) - in case ty' of - TypeCon def params -> let - [DataConDef conName bs] = applyDataDefParams def params - b = toList bs !! i - pats = (\(j,_)-> if i == j then pat' else uignore) <$> enumerate bs - hint = case b of - Bind (n :> _) -> n - Ignore _ -> Name SourceName "elt" 0 - in ( binderAnn b, nosrc $ UPatCon conName pats, fromMaybe hint eltName) - RecordTy (NoExt types) -> let - ty'' = toList types !! i - pats = (\(j,_)-> if i == j then pat' else uignore) <$> enumerate types - (fieldName, _) = toList (reflectLabels types) !! i - hint = Name SourceName (fromString fieldName) 0 - in (ty'', nosrc $ UPatRecord $ NoExt pats, fromMaybe hint eltName) - PairTy x _ | i == 0 -> - (x, nosrc $ UPatPair pat' uignore, fromMaybe "a" eltName) - PairTy _ y | i == 1 -> - (y, nosrc $ UPatPair uignore pat', fromMaybe "b" eltName) - _ -> error "Bad projection" - - nosrc = WithSrc Nothing - uignore = nosrc $ UPatBinder $ Ignore () - prettyExtLabeledItems :: (PrettyPrec a, PrettyPrec b) => ExtLabeledItems a b -> Doc ann -> Doc ann -> DocPrec ann prettyExtLabeledItems (Ext (LabeledItems row) rest) separator bindwith = @@ -629,6 +588,10 @@ instance Pretty UDecl where align $ prettyUBinder b <+> "=" <> (nest 2 $ group $ line <> pLowest rhs) pretty (UData tyCon dataCons) = "data" <+> p tyCon <+> "where" <> nest 2 (hardline <> prettyLines dataCons) + pretty (UInterface cs def methods) = + "interface" <+> p cs <+> p def <> hardline <> prettyLines methods + pretty (UInstance ty methods) = + "instance" <+> p ty <> hardline <> prettyLines methods instance Pretty UConDef where pretty (UConDef con bs) = p con <+> spaced bs diff --git a/src/lib/Parser.hs b/src/lib/Parser.hs index 6eab3f678..62dc532f7 100644 --- a/src/lib/Parser.hs +++ b/src/lib/Parser.hs @@ -12,11 +12,11 @@ import Control.Monad import Control.Monad.Combinators.Expr import Control.Monad.Reader import Text.Megaparsec hiding (Label, State) -import Text.Megaparsec.Char hiding (space) +import Text.Megaparsec.Char hiding (space, eol) +import qualified Text.Megaparsec.Char as MC import Data.Char (isLower) import Data.Functor import Data.List.NonEmpty (NonEmpty (..)) -import qualified Data.Map.Strict as M import Data.Void import qualified Data.Set as S import Data.String (fromString) @@ -88,7 +88,7 @@ logLevel :: Parser LogLevel logLevel = do void $ try $ lexeme $ char '%' >> string "passes" passes <- many passName - void eol + eol case passes of [] -> return $ LogAll _ -> return $ LogPasses passes @@ -96,14 +96,14 @@ logLevel = do logTime :: Parser LogLevel logTime = do void $ try $ lexeme $ char '%' >> string "time" - void eol + eol return PrintEvalTime logBench :: Parser LogLevel logBench = do void $ try $ lexeme $ char '%' >> string "bench" benchName <- stringLiteral - void eol + eol return $ PrintBench benchName passName :: Parser PassName @@ -116,13 +116,15 @@ sourceBlock' :: Parser SourceBlock' sourceBlock' = proseBlock <|> topLevelCommand - <|> fmap (declsToModule . (:[])) (topDecl <* eolf) - <|> fmap (declsToModule . (:[])) (interfaceInstance <* eolf) - <|> fmap declsToModule (interfaceDef <* eolf) - <|> fmap (Command (EvalExpr Printed) . exprAsModule) (expr <* eol) + <|> liftM declToModule (topDecl <* eolf) + <|> liftM declToModule (instanceDef <* eolf) + <|> liftM declToModule (interfaceDef <* eolf) + <|> liftM (Command (EvalExpr Printed) . exprAsModule) (expr <* eol) <|> hidden (some eol >> return EmptyLines) <|> hidden (sc >> eol >> return CommentLine) - where declsToModule = RunModule . UModule . toNest + where + declsToModule = RunModule . UModule . toNest + declToModule = declsToModule . (:[]) proseBlock :: Parser SourceBlock' proseBlock = label "prose block" $ char '\'' >> fmap (ProseBlock . fst) (withSource consumeTillBreak) @@ -151,7 +153,7 @@ exprAsModule :: UExpr -> (Name, UModule) exprAsModule e = (asGlobal v, UModule (toNest [d])) where v = mkName "_ans_" - d = ULet PlainLet (WithSrc (srcPos e) (UPatBinder (Bind (v:>()))), Nothing) e + d = ULet PlainLet (WithSrc (srcPos e) (nameToPat v), Nothing) e -- === uexpr === @@ -206,8 +208,7 @@ charExpr :: Char -> UExpr' charExpr c = UPrimExpr $ ConExpr $ Lit $ Word8Lit $ fromIntegral $ fromEnum c uVarOcc :: Parser UExpr -uVarOcc = withSrc $ try $ (UVar . (:>())) <$> (occName <* notFollowedBy (sym ":")) - where occName = upperName <|> lowerName <|> symName +uVarOcc = withSrc $ try $ (UVar . (:>())) <$> (anyName <* notFollowedBy (sym ":")) uHole :: Parser UExpr uHole = withSrc $ underscore $> UHole @@ -222,9 +223,9 @@ topDecl = dataDef <|> topLet topLet :: Parser UDecl topLet = do - lAnn <- (char '@' >> letAnnStr <* (void eol <|> sc)) <|> return PlainLet - ~(ULet _ (p, ann) rhs, pos) <- withPos decl - let (ann', rhs') = addImplicitImplicitArgs pos ann rhs + lAnn <- (char '@' >> letAnnStr <* (eol <|> sc)) <|> return PlainLet + ~(ULet _ (p, ann) rhs) <- decl + let (ann', rhs') = addImplicitImplicitArgs ann rhs return $ ULet lAnn (p, ann') rhs' -- Given a type signature, find all "implicit implicit args": lower-case @@ -273,77 +274,36 @@ findImplicitImplicitArgNames typ = filter isLowerCaseName $ envNames $ UIntLit _ -> mempty UFloatLit _ -> mempty -addImplicitImplicitArgs :: SrcPos -> Maybe UType -> UExpr -> (Maybe UType, UExpr) -addImplicitImplicitArgs _ Nothing e = (Nothing, e) -addImplicitImplicitArgs sourcePos (Just typ) ex = - let (ty', e') = foldr (addImplicitArg sourcePos) (typ, ex) implicitVars +addImplicitImplicitArgs :: Maybe UType -> UExpr -> (Maybe UType, UExpr) +addImplicitImplicitArgs Nothing e = (Nothing, e) +addImplicitImplicitArgs (Just typ) ex = + let (ty', e') = foldr addImplicitArg (typ, ex) implicitVars in (Just ty', e') where implicitVars = findImplicitImplicitArgNames typ - addImplicitArg :: SrcPos -> Name -> (UType, UExpr) -> (UType, UExpr) - addImplicitArg pos v (ty, e) = - ( WithSrc (Just pos) $ UPi (Just uPat, uTyKind) ImplicitArrow ty - , WithSrc (Just pos) $ ULam (uPat, Just uTyKind) ImplicitArrow e) + addImplicitArg :: Name -> (UType, UExpr) -> (UType, UExpr) + addImplicitArg v (ty, e) = + ( ns $ UPi (Just uPat, uTyKind) ImplicitArrow ty + , ns $ ULam (uPat, Just uTyKind) ImplicitArrow e) where - uPat = WithSrc (Just pos) $ UPatBinder $ Bind $ v:>() + uPat = ns $ nameToPat v k = if v == mkName "eff" then EffectRowKind else TypeKind - uTyKind = WithSrc (Just pos) $ UPrimExpr $ TCExpr k + uTyKind = ns $ UPrimExpr $ TCExpr k + +superclassConstraints :: Parser [UType] +superclassConstraints = optionalMonoid $ brackets $ uType `sepBy` sym "," -interfaceDef :: Parser [UDecl] +interfaceDef :: Parser UDecl interfaceDef = do keyWord InterfaceKW - (tyCon, pos) <- withPos tyConDef - keyWord WhereKW - recordFieldsWithSrc <- withSrc $ interfaceRecordFields ":" - let (UConDef interfaceName uAnnBinderNest) = tyCon - record = URecordTy . NoExt <$> recordFieldsWithSrc - consName = mkInterfaceConsName interfaceName - varNames = fmap (\(Bind v) -> varName v) uAnnBinderNest - (WithSrc _ recordFields) = recordFieldsWithSrc - funDefs = mkFunDefs (pos, varNames, interfaceName) recordFields - return $ UData tyCon [UConDef consName (toNest [Ignore record])] : funDefs - where - -- From an interface - -- interface I a:Type b:Type where - -- f : a -> b - -- mkFunDefs generates the equivalent of the following function definition: - -- def f (instance# : I a b) ?=> : a -> b = - -- (I# {f=f,...}) = instance# - -- f - -- where I# is an automatically generated constructor of I. - mkFunDefs - :: (SrcPos, Nest Name, Name) -> LabeledItems UExpr -> [UDecl] - mkFunDefs meta (LabeledItems items) = - fmap (\(name, ty :| []) -> mkOneFunDef meta (name, ty)) $ M.toList items - mkOneFunDef :: (SrcPos, Nest Name, Name) -> (Label, UExpr) -> UDecl - mkOneFunDef (pos, typeVarNames, interfaceName) (fLabel, fType) = - ULet PlainLet (p, ann') rhs' - where - uAnnPat = ( Just $ WithSrc (Just pos) $ UPatBinder $ Bind $ instanceName :> () - , foldl mkUApp (var interfaceName) typeVarNames) - p = patb fLabel - ann = Just $ ns $ UPi uAnnPat ClassArrow fType - - mkUApp func typeVarName = - ns $ UApp (PlainArrow ()) func (var typeVarName) - recordStr = "recordVar" - recordPat = ns $ UPatRecord $ Ext (labeledSingleton fLabel (patb - fLabel)) $ Just underscorePat - conPat = ns $ UPatCon (mkInterfaceConsName interfaceName) - $ toNest [patb recordStr] - - let1 = ULet PlainLet (conPat, Nothing) $ var instanceName - let2 = ULet PlainLet (recordPat, Nothing) $ var $ mkName recordStr - body = ns $ UDecl let1 (ns $ UDecl let2 (var (mkName fLabel))) - rhs = ns $ ULam (patb instanceStr, Nothing) ClassArrow body - (ann', rhs') = addImplicitImplicitArgs pos ann rhs - - ns = WithSrc Nothing - patb s = ns $ UPatBinder $ Bind $ mkName s :> () - instanceStr = mkNoShadowingStr "instance" - instanceName = mkName instanceStr - var name = ns $ UVar $ name :> () + superclasses <- superclassConstraints + tyCon <- tyConDef + methods <- onePerLine $ do + v <- anyName + ty <- annot uType + return $ Bind $ v:>ty + return $ UInterface superclasses tyCon methods dataDef :: Parser UDecl dataDef = do @@ -353,9 +313,15 @@ dataDef = do dataCons <- onePerLine dataConDef return $ UData tyCon dataCons --- TODO: default to `Type` if unannoted tyConDef :: Parser UConDef -tyConDef = UConDef <$> (upperName <|> symName) <*> manyNested namedBinder +tyConDef = do + con <- upperName <|> symName + bs <- manyNested $ label "type constructor parameter" $ do + v <- lowerName + ty <- annot containedExpr <|> return tyKind + return $ Bind $ v :> ty + return $ UConDef con bs + where tyKind = ns $ UPrimExpr $ TCExpr TypeKind -- TODO: dependent types dataConDef :: Parser UConDef @@ -370,52 +336,32 @@ decl = do rhs <- sym "=" >> blockOrExpr return $ lhs rhs -interfaceInstance :: Parser UDecl -interfaceInstance = do +instanceDef :: Parser UDecl +instanceDef = do keyWord InstanceKW - (p, pos) <- withPos letPat - ann <- annot uType - case mkConstructorNameVar ann of - Left err -> fail err - Right constructorNameVar -> do - keyWord WhereKW - record <- withSrc $ (URecord . NoExt) <$> interfaceRecordFields "=" - let constructorCall = constructorNameVar `mkApp` record - (ann', rhs') = addImplicitImplicitArgs pos (Just ann) constructorCall - return $ ULet InstanceLet (p, ann') rhs' + explicitArgs <- many defArg + constraints <- classConstraints + classTy <- uType + let ty = buildPiType explicitArgs Pure $ + foldr addClassConstraint classTy constraints + let ty' = foldr addImplicitArg ty $ findImplicitImplicitArgNames ty + methods <- onePerLine instanceMethod + return $ UInstance ty' methods where - -- Here, we are traversing the type annotation to retrieve the name of - -- the interface and generate its corresponding constructor. A valid type - -- annotation for an instance is composed of: - -- 1) implicit/class arguments - -- 2) a function whose name is the name of the interface applied to 0 or - -- more arguments - mkConstructorNameVar ann = - stripArrows ann >>= stripAppliedArgs >>= buildConstructor - - stripArrows (WithSrc _ (UPi _ arr typ)) - | arr `elem` [ClassArrow, ImplicitArrow] = stripArrows typ - | otherwise = Left ("Met invalid arrow '" ++ pprint arr ++ "' in type " ++ - "annotation of instance. Only class arrows and " ++ - "implicit arrows are allowed.") - stripArrows ann = Right ann - - stripAppliedArgs ann - | (WithSrc _ (UApp _ func _)) <- ann = stripAppliedArgs func - | otherwise = Right ann - - buildConstructor (WithSrc _ (UVar v)) = - Right $ (var . nameToStr . mkInterfaceConsName . varName) v - buildConstructor _ = Left ("Could not extract interface name from type " ++ - "annotation.") - var s = noSrc $ UVar $ mkName s :> () - -interfaceRecordFields :: String -> Parser (LabeledItems UExpr) -interfaceRecordFields bindwith = - fuse <$> onePerLine (do l <- fieldLabel - e <- symbol bindwith *> expr - return $ labeledSingleton l e) - where fuse = foldr (<>) NoLabeledItems + addClassConstraint :: UType -> UType -> UType + addClassConstraint c ty = ns $ UPi (Nothing, c) ClassArrow ty + + addImplicitArg :: Name -> UType -> UType + addImplicitArg v ty = + ns $ UPi (Just (ns $ nameToPat v), uTyKind) ImplicitArrow ty + where uTyKind = ns $ UPrimExpr $ TCExpr TypeKind + +instanceMethod :: Parser (UVar, UExpr) +instanceMethod = do + v <- anyName + sym "=" + rhs <- blockOrExpr + return (v:>(), rhs) simpleLet :: Parser (UExpr -> UDecl) simpleLet = label "let binding" $ do @@ -424,14 +370,14 @@ simpleLet = label "let binding" $ do return $ ULet PlainLet (p, ann) letPat :: Parser UPat -letPat = nameAsPat $ upperName <|> lowerName <|> symName +letPat = withSrc $ nameToPat <$> anyName funDefLet :: Parser (UExpr -> UDecl) funDefLet = label "function definition" $ mayBreak $ do keyWord DefKW v <- letPat - cs <- defClassConstraints - argBinders <- many arg + cs <- classConstraints + argBinders <- many defArg (eff, ty) <- label "result type annotation" $ annot effectiveType when (null argBinders && eff /= Pure) $ fail "Nullary def can't have effects" let bs = map classAsBinder cs ++ argBinders @@ -441,22 +387,17 @@ funDefLet = label "function definition" $ mayBreak $ do return $ \body -> ULet PlainLet letBinder (buildLam lamBinders body) where classAsBinder :: UType -> (UPat, UType, UArrow) - classAsBinder ty = (underscorePat, ty, ClassArrow) + classAsBinder ty = (ns underscorePat, ty, ClassArrow) - arg :: Parser (UPat, UType, UArrow) - arg = label "def arg" $ do - (p, ty) <-parens ((,) <$> pat <*> annot uType) - arr <- arrow (return ()) <|> return (PlainArrow ()) - return (p, ty, arr) +defArg :: Parser (UPat, UType, UArrow) +defArg = label "def arg" $ do + (p, ty) <-parens ((,) <$> pat <*> annot uType) + arr <- arrow (return ()) <|> return (PlainArrow ()) + return (p, ty, arr) -defClassConstraints :: Parser [UType] -defClassConstraints = - (brackets $ mayNotPair $ uType `sepBy` sym ",") - <|> return [] - "class constraints" - -nameAsPat :: Parser Name -> Parser UPat -nameAsPat p = withSrc $ (UPatBinder . Bind . (:>())) <$> p +classConstraints :: Parser [UType] +classConstraints = label "class constraints" $ + optionalMonoid $ brackets $ mayNotPair $ uType `sepBy` sym "," buildPiType :: [(UPat, UType, UArrow)] -> EffectRow -> UType -> UType buildPiType [] Pure ty = ty @@ -531,18 +472,21 @@ uForExpr = do <|> (keyWord Rof_KW $> (Rev, True )) e <- buildFor pos dir <$> (some patAnn <* argTerm) <*> blockOrExpr if trailingUnit - then return $ noSrc $ UDecl (ULet PlainLet (underscorePat, Nothing) e) $ - noSrc unitExpr + then return $ ns $ UDecl (ULet PlainLet (ns underscorePat, Nothing) e) $ + ns unitExpr else return e -underscorePat :: UPat -underscorePat = noSrc $ UPatBinder $ Ignore () +underscorePat :: UPat' +underscorePat = UPatBinder $ Ignore () + +nameToPat :: Name -> UPat' +nameToPat v = UPatBinder (Bind (v:>())) unitExpr :: UExpr' unitExpr = UPrimExpr $ ConExpr UnitCon -noSrc :: a -> WithSrc a -noSrc = WithSrc Nothing +ns :: a -> WithSrc a +ns = WithSrc Nothing blockOrExpr :: Parser UExpr blockOrExpr = block <|> expr @@ -570,7 +514,7 @@ wrapUStatements statements = case statements of (s, pos):rest -> WithSrc (Just pos) $ case s of Left d -> UDecl d $ wrapUStatements rest Right e -> UDecl d $ wrapUStatements rest - where d = ULet PlainLet (underscorePat, Nothing) e + where d = ULet PlainLet (ns underscorePat, Nothing) e [] -> error "Shouldn't be reachable" uStatement :: Parser UStatement @@ -584,16 +528,17 @@ uPiType = withSrc $ UPi <$> piBinderPat <*> arrow effects <*> uType b <- annBinder return $ case b of Bind (n:>a@(WithSrc pos _)) -> - (Just $ WithSrc pos $ UPatBinder $ Bind $ n:>(), a) + (Just $ WithSrc pos $ nameToPat n, a) Ignore a -> (Nothing, a) annBinder :: Parser UAnnBinder annBinder = try $ namedBinder <|> anonBinder namedBinder :: Parser UAnnBinder -namedBinder = label "named annoted binder" $ lowerName - >>= \v -> sym ":" >> containedExpr - >>= \ty -> return $ Bind (v:>ty) +namedBinder = label "named annoted binder" $ do + v <- lowerName + ty <- annot containedExpr + return $ Bind (v:>ty) anonBinder :: Parser UAnnBinder anonBinder = @@ -622,7 +567,7 @@ ifExpr = withSrc $ do e <- expr (alt1, maybeAlt2) <- oneLineThenElse <|> blockThenElse let alt2 = case maybeAlt2 of - Nothing -> noSrc unitExpr + Nothing -> ns unitExpr Just alt -> alt return $ UCase e [ UAlt (globalEnumPat "True" ) alt1 @@ -647,7 +592,7 @@ blockThenElse = withIndent $ mayNotBreak $ do return (alt1, alt2) globalEnumPat :: Tag -> UPat -globalEnumPat s = noSrc $ UPatCon (GlobalName s) Empty +globalEnumPat s = ns $ UPatCon (GlobalName s) Empty onePerLine :: Parser a -> Parser [a] onePerLine p = liftM (:[]) p @@ -667,8 +612,8 @@ leafPat = <|> (variantPat `fallBackTo` recordPat) <|> brackets (UPatTable <$> leafPat `sepBy` sym ",") ) - where pun pos l = WithSrc (Just pos) $ UPatBinder $ Bind (mkName l:>()) - def pos = WithSrc (Just pos) $ UPatBinder $ Ignore () + where pun pos l = WithSrc (Just pos) $ nameToPat $ mkName l + def pos = WithSrc (Just pos) $ underscorePat variantPat = parseVariant leafPat UPatVariant UPatVariantLift recordPat = UPatRecord <$> parseLabeledItems "," "=" leafPat (Just pun) (Just def) @@ -741,9 +686,8 @@ uIsoSugar = withSrc (char '#' *> options) where <|> char '?' *> (variantFieldIso <$> fieldLabel) <|> char '&' *> (recordZipIso <$> fieldLabel) <|> char '|' *> (variantZipIso <$> fieldLabel) - ns = WithSrc Nothing var s = ns $ UVar $ mkName s :> () - patb s = ns $ UPatBinder $ Bind $ mkName s :> () + patb s = ns $ nameToPat $ mkName s plain = PlainArrow () lam p b = ns $ ULam (p, Nothing) plain b recordFieldIso field = @@ -1020,19 +964,6 @@ inpostfix' p op = Postfix $ do mkName :: String -> Name mkName s = Name SourceName (fromString s) 0 -nameToStr :: Name -> String -nameToStr = tagToStr . nameTag - --- This function is used to generate a string that is guaranteed to never shadow --- any user-defined name, as "#" is an invalid identifier character in normal --- source code. -mkNoShadowingStr :: String -> String -mkNoShadowingStr = (++ "#") - -mkInterfaceConsName :: Name -> Name -mkInterfaceConsName = - GlobalName . fromString . mkNoShadowingStr . nameToStr - -- === lexemes === -- These `Lexer` actions must be non-overlapping and never consume input on failure @@ -1054,6 +985,9 @@ lowerName = liftM mkName $ label "lower-case name" $ lexeme $ anyCaseName :: Lexer Name anyCaseName = lowerName <|> upperName +anyName :: Lexer Name +anyName = lowerName <|> upperName <|> symName + checkNotKeyword :: Parser String -> Parser String checkNotKeyword p = try $ do s <- p @@ -1202,6 +1136,9 @@ mayPair p = local (\ctx -> ctx { canPair = True }) p mayNotPair :: Parser a -> Parser a mayNotPair p = local (\ctx -> ctx { canPair = False }) p +optionalMonoid :: Monoid a => Parser a -> Parser a +optionalMonoid p = p <|> return mempty + nameString :: Parser String nameString = lexeme . try $ (:) <$> lowerChar <*> many alphaNumChar @@ -1244,7 +1181,7 @@ withPos p = do nextLine :: Parser () nextLine = do - void eol + eol n <- asks curIndent void $ mayNotBreak $ many $ try (sc >> eol) void $ replicateM n (char ' ') @@ -1261,8 +1198,11 @@ withIndent p = do indent <- liftM length $ some (char ' ') local (\ctx -> ctx { curIndent = curIndent ctx + indent }) $ p +eol :: Parser () +eol = void MC.eol + eolf :: Parser () -eolf = void eol <|> eof +eolf = eol <|> eof failIf :: Bool -> String -> Parser () failIf True s = fail s diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index 277d94693..98a303617 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -13,6 +13,7 @@ {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE Rank2Types #-} +{-# LANGUAGE LambdaCase #-} module Syntax ( Type, Kind, BaseType (..), ScalarBaseType (..), @@ -29,8 +30,9 @@ module Syntax ( IExpr (..), IVal, ImpInstr (..), Backend (..), Device (..), IPrimOp, IVar, IBinder, IType, SetVal (..), MonMap (..), LitProg, IFunType (..), IFunVar, CallingConvention (..), IsCUDARequired (..), - UAlt (..), AltP, Alt, Label, LabeledItems (..), labeledSingleton, - reflectLabels, withLabels, ExtLabeledItems (..), prefixExtLabeledItems, + UAlt (..), AltP, Alt, Label, LabeledItems (..), labeledSingleton, lookupLabel, + reflectLabels, withLabels, ExtLabeledItems (..), + prefixExtLabeledItems, getLabels, IScope, BinderInfo (..), Bindings, CUDAKernel (..), BenchStats, SrcCtx, Result (..), Output (..), OutFormat (..), Err (..), ErrType (..), Except, throw, throwIf, modifyErr, addContext, @@ -61,7 +63,9 @@ module Syntax ( pattern TabTy, pattern TabTyAbs, pattern TabVal, pattern TabValA, pattern Pure, pattern BinaryFunTy, pattern BinaryFunVal, pattern Unlabeled, pattern NoExt, pattern LabeledRowKind, - pattern NoLabeledItems, pattern InternalSingletonLabel, pattern EffKind) + pattern NoLabeledItems, pattern InternalSingletonLabel, pattern EffKind, + pattern NestOne, pattern NewTypeCon, pattern BinderAnn, + pattern ClassDictDef, pattern ClassDictCon) where import qualified Data.Map.Strict as M @@ -187,10 +191,18 @@ reflectLabels :: LabeledItems a -> LabeledItems (Label, Int) reflectLabels (LabeledItems items) = LabeledItems $ flip M.mapWithKey items $ \k xs -> fmap (\(i,_) -> (k,i)) (enumerate xs) +getLabels :: LabeledItems a -> [Label] +getLabels labeledItems = map fst $ toList $ reflectLabels labeledItems + withLabels :: LabeledItems a -> LabeledItems (Label, Int, a) withLabels (LabeledItems items) = LabeledItems $ flip M.mapWithKey items $ \k xs -> fmap (\(i,a) -> (k,i,a)) (enumerate xs) +lookupLabel :: LabeledItems a -> Label -> Maybe a +lookupLabel (LabeledItems items) l = case M.lookup l items of + Nothing -> Nothing + Just (x NE.:| _) -> Just x + instance Semigroup (LabeledItems a) where LabeledItems items <> LabeledItems items' = LabeledItems $ M.unionWith (<>) items items' @@ -237,6 +249,8 @@ data UExpr' = UVar UVar data UConDef = UConDef Name (Nest UAnnBinder) deriving (Show, Generic) data UDecl = ULet LetAnn UPatAnn UExpr | UData UConDef [UConDef] + | UInterface [UType] UConDef [UAnnBinder] + | UInstance UType [(UVar, UExpr)] deriving (Show, Generic) type UType = UExpr @@ -784,11 +798,15 @@ instance BindsUVars UPat' where instance HasUVars UDecl where freeUVars (ULet _ p expr) = freeUVars p <> freeUVars expr freeUVars (UData (UConDef _ bs) dataCons) = freeUVars $ Abs bs dataCons + freeUVars (UInterface _ _ _) = mempty -- TODO + freeUVars (UInstance _ _) = mempty -- TODO instance BindsUVars UDecl where boundUVars decl = case decl of - ULet _ (p,_) _ -> boundUVars p - UData tyCon dataCons -> boundUVars tyCon <> foldMap boundUVars dataCons + ULet _ (p,_) _ -> boundUVars p + UData tyCon dataCons -> boundUVars tyCon <> foldMap boundUVars dataCons + UInterface _ _ _ -> mempty + UInstance _ _ -> mempty instance HasUVars UModule where freeUVars (UModule decls) = freeUVars decls @@ -1005,8 +1023,9 @@ applyNaryAbs (Abs (Nest b bs) body) (x:xs) = applyNaryAbs ab xs applyNaryAbs _ _ = error "wrong number of arguments" applyDataDefParams :: DataDef -> [Type] -> [DataConDef] -applyDataDefParams (DataDef _ paramBs cons) params = - applyNaryAbs (Abs paramBs cons) params +applyDataDefParams (DataDef _ bs cons) params + | length params == length (toList bs) = applyNaryAbs (Abs bs cons) params + | otherwise = error $ "Wrong number of parameters: " ++ show (length params) makeAbs :: HasVars a => Binder -> a -> Abs Binder a makeAbs b body | b `isin` freeVars body = Abs b body @@ -1510,6 +1529,30 @@ pattern NothingAtom ty = DataCon MaybeDataDef [ty] 0 [] pattern JustAtom :: Type -> Atom -> Atom pattern JustAtom ty x = DataCon MaybeDataDef [ty] 1 [x] +pattern NestOne :: a -> Nest a +pattern NestOne x = Nest x Empty + +pattern BinderAnn :: a -> BinderP a +pattern BinderAnn x <- ((\case Ignore ann -> ann + Bind (_:>ann) -> ann) -> x) + where BinderAnn x = Ignore x + +pattern NewTypeCon :: Name -> Type -> [DataConDef] +pattern NewTypeCon con ty <- [DataConDef con (NestOne (BinderAnn ty))] + where NewTypeCon con ty = [DataConDef con (NestOne (Ignore ty))] + +pattern ClassDictDef :: Name + -> LabeledItems Type -> LabeledItems Type -> [DataConDef] +pattern ClassDictDef conName superclasses methods = + [DataConDef conName + (Nest (Ignore (RecordTy (NoExt superclasses))) + (Nest (Ignore (RecordTy (NoExt methods))) Empty))] + +pattern ClassDictCon :: DataDef -> [Type] + -> LabeledItems Atom -> LabeledItems Atom -> Atom +pattern ClassDictCon def params superclasses methods = + DataCon def params 0 [Record superclasses, Record methods] + -- TODO: Enable once https://gitlab.haskell.org//ghc/ghc/issues/13363 is fixed... -- {-# COMPLETE TypeVar, ArrowType, TabTy, Forall, TypeAlias, Effect, NoAnn, TC #-} diff --git a/src/lib/Type.hs b/src/lib/Type.hs index 4dd84a817..39af39287 100644 --- a/src/lib/Type.hs +++ b/src/lib/Type.hs @@ -92,7 +92,7 @@ checkBindings env ir bs = void $ runTypeCheck (CheckWith (env <> bs, Pure)) $ mapM_ (checkBinding ir) $ envPairs bs checkBinding :: IRVariant -> (Name, (Type, BinderInfo)) -> TypeM () -checkBinding ir (GlobalName v, b@(ty, info)) = +checkBinding ir (v, b@(ty, info)) | isGlobal (v:>()) = addContext ("binding: " ++ pprint (v, b)) $ do ty |: TyKind when (ir >= Evaluated && not (all isGlobal (envAsVars $ freeVars b))) $ @@ -165,8 +165,8 @@ instance HasType Atom where withBinder b $ typeCheck body ProjectElt (i NE.:| is) v -> do ty <- typeCheck $ case NE.nonEmpty is of - Nothing -> Var v - Just is' -> ProjectElt is' v + Nothing -> Var v + Just is' -> ProjectElt is' v case ty of TypeCon def params -> do [DataConDef _ bs'] <- return $ applyDataDefParams def params @@ -184,7 +184,8 @@ instance HasType Atom where PairTy x _ | i == 0 -> return x PairTy _ y | i == 1 -> return y Var _ -> throw CompilerErr $ "Tried to project value of unreduced type " <> pprint ty - _ -> throw TypeErr $ "Only single-member ADTs and record types can be projected. Got " <> pprint ty + _ -> throw TypeErr $ + "Only single-member ADTs and record types can be projected. Got " <> pprint ty <> " " <> pprint v checkDataConRefBindings :: Nest Binder -> Nest DataConRefBinding -> TypeM () diff --git a/tests/adt-tests.dx b/tests/adt-tests.dx index 1d2d2306e..97ad29dc7 100644 --- a/tests/adt-tests.dx +++ b/tests/adt-tests.dx @@ -216,7 +216,7 @@ def catLists (xs:List a) (ys:List a) : List a = def listToTable ((AsList n xs): List a) : (Fin n)=>a = xs :t listToTable -> ((a:Type) ?-> (pat:(List a)) -> (Fin ((\((AsList n _)). n) pat)) => a) +> ((a:Type) ?-> (pat:(List a)) -> (Fin (project [0] pat:(List a))) => a) :p l = AsList _ [1, 2, 3] @@ -228,7 +228,7 @@ def listToTable2 (l: List a) : (Fin (listLength l))=>a = xs :t listToTable2 -> ((a:Type) ?-> (l:(List a)) -> (Fin ((\((AsList n _)). n) l)) => a) +> ((a:Type) ?-> (l:(List a)) -> (Fin (project [0] l:(List a))) => a) :p l = AsList _ [1, 2, 3] @@ -258,7 +258,7 @@ def graphToAdjacencyMatrix ((MkGraph n nodes m edges):Graph a) : n=>n=>Bool = :t graphToAdjacencyMatrix > ((a:Type) > ?-> (pat:(Graph a)) -> -> ((\((MkGraph n _ _ _)). n) pat) => ((\((MkGraph n _ _ _)). n) pat) => Bool) +> -> (project [0] pat:(Graph a)) => (project [0] pat:(Graph a)) => Bool) :p g : Graph Int = MkGraph (Fin 3) [5, 6, 7] (Fin 4) [(0@_, 1@_), (0@_, 2@_), (2@_, 0@_), (1@_, 1@_)] @@ -269,15 +269,15 @@ def graphToAdjacencyMatrix ((MkGraph n nodes m edges):Graph a) : n=>n=>Bool = def pairUnpack ((v, _):(Int & Float)) : Int = v :p pairUnpack -> \pat:(Int32 & Float32). (\(a, _). a) pat +> \pat:(Int32 & Float32). project [0] pat:(Int32 & Float32) def adtUnpack ((MkMyPair v _):MyPair Int Float) : Int = v :p adtUnpack -> \pat:(MyPair Int32 Float32). (\((MkMyPair elt _)). elt) pat +> \pat:(MyPair Int32 Float32). project [0] pat:(MyPair Int32 Float32) def recordUnpack ({a=v, b=_}:{a:Int & b:Float}) : Int = v :p recordUnpack -> \pat:{a: Int32 & b: Float32}. (\{a = a, b = _}. a) pat +> \pat:{a: Int32 & b: Float32}. project [0] pat:{a: Int32 & b: Float32} def nestedUnpack (x:MyPair Int (MyPair (MyIntish & Int) Int)) : Int = (MkMyPair _ (MkMyPair (MkIntish y, _) _)) = x @@ -285,7 +285,7 @@ def nestedUnpack (x:MyPair Int (MyPair (MyIntish & Int) Int)) : Int = :p nestedUnpack > \x:(MyPair Int32 (MyPair (MyIntish & Int32) Int32)). -> (\((MkIntish (((MkMyPair ((MkMyPair _ elt)) _)), _))). elt) x +> project [0, 0, 0, 1] x:(MyPair Int32 (MyPair (MyIntish & Int32) Int32)) :p nestedUnpack (MkMyPair 3 (MkMyPair (MkIntish 4, 5) 6)) > 4 diff --git a/tests/io-tests.dx b/tests/io-tests.dx index 736c07ff4..853ee0027 100644 --- a/tests/io-tests.dx +++ b/tests/io-tests.dx @@ -38,7 +38,7 @@ unsafeIO \(). > 9 is odd > [(), (), (), (), (), (), (), (), (), ()] -:p storageSize (typeVehicle Int) +:p storageSize Int > 4 :p unsafeIO \(). diff --git a/tests/typeclass-tests.dx b/tests/typeclass-tests.dx index 5061d6ece..7970fbfcd 100644 --- a/tests/typeclass-tests.dx +++ b/tests/typeclass-tests.dx @@ -1,38 +1,43 @@ -interface InterfaceTest1 a:Type where + + +interface InterfaceTest1 a InterfaceTest1 : a > Error: variable already defined: InterfaceTest1 -interface InterfaceTest2 typeName:Type where - typeName : typeName -> typeName - -interface InterfaceTest3 _:Type where - foo : Int32 +interface InterfaceTest3 a + foo : a -> Int + foo : a -> Int +> Error: variable already defined: foo -> Parse error:8:26: -> | -> 8 | interface InterfaceTest3 _:Type where -> | ^^^^^ -> unexpected "_:Typ" -> expecting "where" or named annoted binder -interface InterfaceTest4 where +interface InterfaceTest4 a foo : Int + bar : a -> Int + +instance InterfaceTest4 Float + foo = 1 + bar = \_. 1 + foo = 1 +> Type error:Duplicate method: foo + +instance InterfaceTest4 Float + foo = 1 +> Type error:Missing method: bar + +instance InterfaceTest4 Float + baz = 1 +> Type error:baz is not a method of InterfaceTest4 -instance instanceTest4 : InterfaceTest4 where +instance InterfaceTest4 Float foo = 1 + bar = \_. 'x' +> Type error: +> Expected: Int32 +> Actual: Word8 +> +> bar = \_. 'x' +> ^^^ -instance instanceTest4 : InterfaceTest4 x -> InterfaceTest4 (n=>a) where +instance InterfaceTest4 Float foo = 1 + bar = \_. 1 -> Parse error:23:68: -> | -> 23 | instance instanceTest4 : InterfaceTest4 x -> InterfaceTest4 (n=>a) where -> | ^ -> Met invalid arrow '->' in type annotation of instance. Only class arrows and implicit arrows are allowed. -instance instanceTest5 : (..i) where - bar = bar - -> Parse error:31:32: -> | -> 31 | instance instanceTest5 : (..i) where -> | ^ -> Could not extract interface name from type annotation. From 567faa24a4664959dc5715a4c84393b1a0c60e93 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Mon, 4 Jan 2021 14:17:43 -0500 Subject: [PATCH 073/105] Enable BlockArguments by default (we use the bracketing pattern a lot) --- dex.cabal | 8 +++-- src/dex.hs | 4 +-- src/lib/Autodiff.hs | 50 ++++++++++++++--------------- src/lib/Cat.hs | 6 ++-- src/lib/Embed.hs | 34 ++++++++++---------- src/lib/Imp.hs | 60 +++++++++++++++++------------------ src/lib/Inference.hs | 56 ++++++++++++++++---------------- src/lib/JIT.hs | 16 +++++----- src/lib/LLVM/JIT.hs | 6 ++-- src/lib/LLVM/Shims.hs | 10 +++--- src/lib/LLVMExec.hs | 72 +++++++++++++++++++++--------------------- src/lib/Logging.hs | 6 ++-- src/lib/Parallelize.hs | 18 +++++------ src/lib/Parser.hs | 10 +++--- src/lib/Serialize.hs | 4 +-- src/lib/Simplify.hs | 46 +++++++++++++-------------- src/lib/Syntax.hs | 16 ++++------ src/lib/TopLevel.hs | 8 ++--- src/lib/Type.hs | 44 +++++++++++++------------- 19 files changed, 237 insertions(+), 237 deletions(-) diff --git a/dex.cabal b/dex.cabal index a4452d2cb..926c5314e 100644 --- a/dex.cabal +++ b/dex.cabal @@ -61,7 +61,8 @@ library cxx-sources: src/lib/dexrt.cpp cxx-options: -std=c++11 -fPIC default-extensions: CPP, DeriveTraversable, TypeApplications, OverloadedStrings, - TupleSections, ScopedTypeVariables, LambdaCase, PatternSynonyms + TupleSections, ScopedTypeVariables, LambdaCase, PatternSynonyms, + BlockArguments pkgconfig-depends: libpng if flag(cuda) include-dirs: /usr/local/cuda/include @@ -82,7 +83,7 @@ executable dex build-depends: dex-resources default-language: Haskell2010 hs-source-dirs: src - default-extensions: CPP, LambdaCase + default-extensions: CPP, LambdaCase, BlockArguments ghc-options: -threaded if flag(optimized) ghc-options: -O3 @@ -101,7 +102,8 @@ foreign-library Dex cc-options: -std=c11 -fPIC ghc-options: -Wall -fPIC -optP-Wno-nonportable-include-path default-language: Haskell2010 - default-extensions: TypeApplications, ScopedTypeVariables, LambdaCase + default-extensions: TypeApplications, ScopedTypeVariables, LambdaCase, + BlockArguments if flag(optimized) ghc-options: -O3 else diff --git a/src/dex.hs b/src/dex.hs index 7de5696fa..cfaf7b000 100644 --- a/src/dex.hs +++ b/src/dex.hs @@ -127,7 +127,7 @@ printLitProg TextDoc prog = do isatty <- queryTerminal stdOutput putStr $ foldMap (uncurry (printLitBlock isatty)) prog printLitProg JSONDoc prog = - forM_ prog $ \(_, result) -> case toJSONStr result of + forM_ prog \(_, result) -> case toJSONStr result of "{}" -> return () s -> putStrLn s @@ -164,7 +164,7 @@ parseMode = subparser $ objectFileInfo = argument str (metavar "OBJFILE" <> help "Output path (.o file)") optionList :: [(String, a)] -> ReadM a -optionList opts = eitherReader $ \s -> case lookup s opts of +optionList opts = eitherReader \s -> case lookup s opts of Just x -> Right x Nothing -> Left $ "Bad option. Expected one of: " ++ show (map fst opts) diff --git a/src/lib/Autodiff.hs b/src/lib/Autodiff.hs index d48837dfe..8a6dbd964 100644 --- a/src/lib/Autodiff.hs +++ b/src/lib/Autodiff.hs @@ -44,10 +44,10 @@ newtype LinA a = LinA { runLinA :: PrimalM (a, TangentM a) } linearize :: Scope -> Atom -> Atom linearize scope ~(Lam (Abs b (_, block))) = fst $ flip runEmbed scope $ do - buildLam b PureArrow $ \x@(Var v) -> do + buildLam b PureArrow \x@(Var v) -> do (y, yt) <- flip runReaderT (DerivWrt (varAsEnv v) [] mempty) $ runLinA $ linearizeBlock (b@>x) block -- TODO: check linearity - fLin <- buildLam (fmap tangentType b) LinArrow $ \xt -> runReaderT yt $ TangentEnv (v @> xt) [] mempty + fLin <- buildLam (fmap tangentType b) LinArrow \xt -> runReaderT yt $ TangentEnv (v @> xt) [] mempty fLinChecked <- checkEmbed fLin return $ PairVal y fLinChecked @@ -109,7 +109,7 @@ linearizeExpr env expr = case expr of return (ans, applyLinToTangents linLam) where linearizeInactiveAlt (Abs bs body) = do - buildNAbs bs $ \xs -> tangentFunAsLambda $ linearizeBlock (env <> newEnv bs xs) body + buildNAbs bs \xs -> tangentFunAsLambda $ linearizeBlock (env <> newEnv bs xs) body _ -> LinA $ do expr' <- substEmbed env expr runLinA $ case expr' of @@ -255,10 +255,10 @@ linearizeHof :: SubstEnv -> Hof -> LinA Atom linearizeHof env hof = case hof of For ~(RegularFor d) ~(LamVal i body) -> LinA $ do i' <- mapM (substEmbed env) i - (ansWithLinTab, vi'') <- buildForAux d i' $ \i''@(Var vi'') -> + (ansWithLinTab, vi'') <- buildForAux d i' \i''@(Var vi'') -> (,vi'') <$> (willRemat vi'' $ tangentFunAsLambda $ linearizeBlock (env <> i@>i'') body) (ans, linTab) <- unzipTab ansWithLinTab - return (ans, buildFor d i' $ \i'' -> provideRemat vi'' i'' $ app linTab i'' >>= applyLinToTangents) + return (ans, buildFor d i' \i'' -> provideRemat vi'' i'' $ app linTab i'' >>= applyLinToTangents) Tile _ _ _ -> notImplemented RunWriter lam -> linearizeEff Nothing lam True (const RunWriter) (emitRunWriter "r") Writer RunReader val lam -> linearizeEff (Just val) lam False RunReader (emitRunReader "r") Reader @@ -266,7 +266,7 @@ linearizeHof env hof = case hof of RunIO ~(Lam (Abs _ (arrow, body))) -> LinA $ do arrow' <- substEmbed env arrow -- TODO: consider the possibility of other effects here besides IO - lam <- buildLam (Ignore UnitTy) arrow' $ \_ -> + lam <- buildLam (Ignore UnitTy) arrow' \_ -> tangentFunAsLambda $ linearizeBlock env body result <- emit $ Hof $ RunIO lam (ans, linLam) <- fromPair result @@ -299,18 +299,18 @@ linearizeHof env hof = case hof of let (BinaryFunTy _ b _ _) = getType lam' let RefTy _ wTy = binderType b return $ emitter $ tangentType wTy - valEmitter $ \ref'@(Var (_:> RefTy (Var (h:>_)) _)) -> do + valEmitter \ref'@(Var (_:> RefTy (Var (h:>_)) _)) -> do extendTangentEnv (ref @> ref') [h] $ applyLinToTangents linBody return (ans, lin) linearizeEffectFun :: RWS -> Atom -> PrimalM (Atom, Var) linearizeEffectFun rws ~(BinaryFunVal h ref eff body) = do h' <- mapM (substEmbed env) h - buildLamAux h' (const $ return PureArrow) $ \h''@(Var hVar) -> do + buildLamAux h' (const $ return PureArrow) \h''@(Var hVar) -> do let env' = env <> h@>h'' eff' <- substEmbed env' eff ref' <- mapM (substEmbed env') ref - buildLamAux ref' (const $ return $ PlainArrow eff') $ \ref''@(Var refVar) -> + buildLamAux ref' (const $ return $ PlainArrow eff') \ref''@(Var refVar) -> extendWrt [refVar] [RWSEffect rws (varName hVar)] $ (,refVar) <$> (tangentFunAsLambda $ linearizeBlock (env' <> ref@>ref'') body) @@ -341,7 +341,7 @@ linearizeAtom atom = case atom of Con con -> linearizePrimCon con Lam (Abs i (TabArrow, body)) -> LinA $ do wrt <- ask - return (atom, buildLam i TabArrow $ \i' -> + return (atom, buildLam i TabArrow \i' -> rematPrimal wrt $ linearizeBlock (i@>i') body) DataCon _ _ _ _ -> notImplemented -- Need to synthesize or look up a tangent ADT Record elems -> Record <$> traverse linearizeAtom elems @@ -394,7 +394,7 @@ addTangent x y = case getType x of RecordTy (NoExt tys) -> do elems <- bindM2 (zipWithT addTangent) (getUnpacked x) (getUnpacked y) return $ Record $ restructure elems tys - TabTy b _ -> buildFor Fwd b $ \i -> bindM2 addTangent (tabGet x i) (tabGet y i) + TabTy b _ -> buildFor Fwd b \i -> bindM2 addTangent (tabGet x i) (tabGet y i) TC con -> case con of BaseType (Scalar _) -> emitOp $ ScalarBinOp FAdd x y BaseType (Vector _) -> emitOp $ VectorBinOp FAdd x y @@ -422,8 +422,8 @@ tangentFunAsLambda m = do let hs = map (Bind . (:>TyKind) . effectRegion) effs let rematList = envAsVars remats liftM (PairVal ans) $ lift $ do - tanLam <- makeLambdas rematList $ \rematArgs -> - buildNestedLam PureArrow hs $ \hVals -> do + tanLam <- makeLambdas rematList \rematArgs -> + buildNestedLam PureArrow hs \hVals -> do let hVarNames = map (\(Var (v:>_)) -> v) hVals -- TODO: handle exception effect too let effs' = zipWith (\(RWSEffect rws _) v -> RWSEffect rws v) effs hVarNames @@ -431,8 +431,8 @@ tangentFunAsLambda m = do let regionMap = newEnv (map ((:>()) . effectRegion) effs) hVals -- TODO: Only bind tangents for free variables? let activeVarBinders = map (Bind . fmap (tangentRefRegion regionMap)) $ envAsVars activeVars - buildNestedLam PureArrow activeVarBinders $ \activeVarArgs -> - buildLam (Ignore UnitTy) (PlainArrow $ EffectRow (S.fromList effs') Nothing) $ \_ -> + buildNestedLam PureArrow activeVarBinders \activeVarArgs -> + buildLam (Ignore UnitTy) (PlainArrow $ EffectRow (S.fromList effs') Nothing) \_ -> runReaderT tanFun $ TangentEnv (newEnv (envNames activeVars) activeVarArgs) hVarNames (newEnv rematList $ fmap Var rematArgs) @@ -448,7 +448,7 @@ tangentFunAsLambda m = do return $ Lam $ makeAbs (Bind v) (PureArrow, block) makeLambdas [] f = f [] - makeLambdas (v:vs) f = makeLambda v $ \x -> makeLambdas vs $ \xs -> f (x:xs) + makeLambdas (v:vs) f = makeLambda v \x -> makeLambdas vs \xs -> f (x:xs) -- This doesn't work if we have references inside pairs, tables etc. -- TODO: something more general and cleaner. @@ -544,7 +544,7 @@ type TransposeM a = ReaderT TransposeEnv Embed a transpose :: Scope -> Atom -> Atom transpose scope ~(Lam (Abs b (_, block))) = fst $ flip runEmbed scope $ do - buildLam (Bind $ "ct" :> getType block) LinArrow $ \ct -> do + buildLam (Bind $ "ct" :> getType block) LinArrow \ct -> do snd <$> (flip runReaderT mempty $ withLinVar b $ transposeBlock block ct) transposeBlock :: Block -> Atom -> TransposeM () @@ -590,7 +590,7 @@ transposeExpr expr ct = case expr of void $ emit $ Case e' alts' UnitTy where transposeNonlinAlt (Abs bs body) = - buildNAbs bs $ \xs -> do + buildNAbs bs \xs -> do localNonlinSubst (newEnv bs xs) $ transposeBlock body ct return UnitVal @@ -619,7 +619,7 @@ transposeOp op ct = case op of MPut x -> do transposeAtom x =<< emitEff MGet void $ emitEff $ MPut $ zeroAt $ getType x - TabCon ~(TabTy b _) es -> forM_ (enumerate es) $ \(i, e) -> do + TabCon ~(TabTy b _) es -> forM_ (enumerate es) \(i, e) -> do transposeAtom e =<< tabGet ct =<< intToIndexE (binderType b) (IdxRepVal $ fromIntegral i) IndexRef _ _ -> notImplemented FstRef _ -> notImplemented @@ -675,24 +675,24 @@ linAtomRef a = error $ "Not a linear var: " ++ pprint a transposeHof :: Hof -> Atom -> TransposeM () transposeHof hof ct = case hof of For ~(RegularFor d) ~(Lam (Abs b (_, body))) -> - void $ buildFor (flipDir d) b $ \i -> do + void $ buildFor (flipDir d) b \i -> do ct' <- tabGet ct i localNonlinSubst (b@>i) $ transposeBlock body ct' return UnitVal where flipDir dir = case dir of Fwd -> Rev; Rev -> Fwd RunReader r ~(BinaryFunVal (Bind (h:>_)) b _ body) -> do - (_, ctr) <- (fromPair =<<) $ emitRunWriter "w" (getType r) $ \ref -> do + (_, ctr) <- (fromPair =<<) $ emitRunWriter "w" (getType r) \ref -> do localLinRegion h $ localLinRefSubst (b@>ref) $ transposeBlock body ct return UnitVal transposeAtom r ctr RunWriter ~(BinaryFunVal (Bind (h:>_)) b _ body) -> do (ctBody, ctEff) <- fromPair ct - void $ emitRunReader "r" ctEff $ \ref -> do + void $ emitRunReader "r" ctEff \ref -> do localLinRegion h $ localLinRefSubst (b@>ref) $ transposeBlock body ctBody return UnitVal RunState s ~(BinaryFunVal (Bind (h:>_)) b _ body) -> do (ctBody, ctState) <- fromPair ct - (_, cts) <- (fromPair =<<) $ emitRunState "s" ctState $ \ref -> do + (_, cts) <- (fromPair =<<) $ emitRunState "s" ctState \ref -> do localLinRegion h $ localLinRefSubst (b@>ref) $ transposeBlock body ctBody return UnitVal transposeAtom s cts @@ -715,7 +715,7 @@ transposeAtom atom ct = case atom of DataCon _ _ _ e -> void $ zipWithT transposeAtom e =<< getUnpacked ct Variant _ _ _ _ -> notImplemented TabVal b body -> - void $ buildFor Fwd b $ \i -> do + void $ buildFor Fwd b \i -> do ct' <- tabGet ct i localNonlinSubst (b@>i) $ transposeBlock body ct' return UnitVal @@ -787,7 +787,7 @@ withLinVar :: Binder -> TransposeM a -> TransposeM (a, Atom) withLinVar b body = case singletonTypeVal (binderType b) of Nothing -> flip evalStateT Nothing $ do - ans <- emitRunWriter "ref" (binderType b) $ \ref -> do + ans <- emitRunWriter "ref" (binderType b) \ref -> do lift (localLinRef (b@>Just ref) body) >>= put . Just >> return UnitVal (,) <$> (fromJust <$> get) <*> (getSnd ans) Just x -> (,x) <$> (localLinRef (b@>Nothing) body) -- optimization to avoid accumulating into unit diff --git a/src/lib/Cat.hs b/src/lib/Cat.hs index aa6d703fa..f120df661 100644 --- a/src/lib/Cat.hs +++ b/src/lib/Cat.hs @@ -50,7 +50,7 @@ instance (Monoid env, Monad m) => MonadCat env (CatT env m) where instance MonadCat env m => MonadCat env (StateT s m) where look = lift look extend x = lift $ extend x - scoped m = StateT $ \s -> do + scoped m = StateT \s -> do ((ans, s'), env) <- scoped $ runStateT m s return $ ((ans, env), s') @@ -145,7 +145,7 @@ catTraverse f inj xs env = runCatT (traverse (asCat f inj) xs) env catFoldM :: (Monoid env, Traversable t, Monad m) => (env -> a -> m env) -> env -> t a -> m env -catFoldM f env xs = liftM snd $ flip runCatT env $ forM_ xs $ \x -> do +catFoldM f env xs = liftM snd $ flip runCatT env $ forM_ xs \x -> do cur <- look new <- lift $ f cur x extend new @@ -156,7 +156,7 @@ catFold f env xs = runIdentity $ catFoldM (\e x -> Identity $ f e x) env xs catMapM :: (Monoid env, Traversable t, Monad m) => (env -> a -> m (b, env)) -> env -> t a -> m (t b, env) -catMapM f env xs = flip runCatT env $ forM xs $ \x -> do +catMapM f env xs = flip runCatT env $ forM xs \x -> do cur <- look (y, new) <- lift $ f cur x extend new diff --git a/src/lib/Embed.hs b/src/lib/Embed.hs index 705d1c50a..c46397d64 100644 --- a/src/lib/Embed.hs +++ b/src/lib/Embed.hs @@ -164,7 +164,7 @@ buildLam b arr body = buildDepEffLam b (const (return arr)) body buildDepEffLam :: MonadEmbed m => Binder -> (Atom -> m Arrow) -> (Atom -> m Atom) -> m Atom -buildDepEffLam b fArr fBody = liftM fst $ buildLamAux b fArr $ \x -> (,()) <$> fBody x +buildDepEffLam b fArr fBody = liftM fst $ buildLamAux b fArr \x -> (,()) <$> fBody x buildLamAux :: MonadEmbed m => Binder -> (Atom -> m Arrow) -> (Atom -> m (Atom, a)) -> m (Atom, a) @@ -180,7 +180,7 @@ buildLamAux b fArr fBody = do return (Lam $ makeAbs b' (arr, wrapDecls decls ans), aux) buildNAbs :: MonadEmbed m => Nest Binder -> ([Atom] -> m Atom) -> m Alt -buildNAbs bs body = liftM fst $ buildNAbsAux bs $ \xs -> (,()) <$> body xs +buildNAbs bs body = liftM fst $ buildNAbsAux bs \xs -> (,()) <$> body xs buildNAbsAux :: MonadEmbed m => Nest Binder -> ([Atom] -> m (Atom, a)) -> m (Alt, a) buildNAbsAux bs body = do @@ -202,9 +202,9 @@ buildDataDef tyConName paramBinders body = do buildImplicitNaryLam :: MonadEmbed m => (Nest Binder) -> ([Atom] -> m Atom) -> m Atom buildImplicitNaryLam Empty body = body [] buildImplicitNaryLam (Nest b bs) body = - buildLam b ImplicitArrow $ \x -> do + buildLam b ImplicitArrow \x -> do bs' <- substEmbed (b@>x) bs - buildImplicitNaryLam bs' $ \xs -> body $ x:xs + buildImplicitNaryLam bs' \xs -> body $ x:xs recGet :: Label -> Atom -> Atom recGet l x = do @@ -383,14 +383,14 @@ unpackConsList xs = case getType xs of emitWhile :: MonadEmbed m => m Atom -> m () emitWhile body = do eff <- getAllowedEffects - lam <- buildLam (Ignore UnitTy) (PlainArrow eff) $ \_ -> body + lam <- buildLam (Ignore UnitTy) (PlainArrow eff) \_ -> body void $ emit $ Hof $ While lam emitMaybeCase :: MonadEmbed m => Atom -> m Atom -> (Atom -> m Atom) -> m Atom emitMaybeCase scrut nothingCase justCase = do let (MaybeTy a) = getType scrut - nothingAlt <- buildNAbs Empty $ \[] -> nothingCase - justAlt <- buildNAbs (Nest (Bind ("x":>a)) Empty) $ \[x] -> justCase x + nothingAlt <- buildNAbs Empty \[] -> nothingCase + justAlt <- buildNAbs (Nest (Bind ("x":>a)) Empty) \[x] -> justCase x let (Abs _ nothingBody) = nothingAlt let resultTy = getType nothingBody emit $ Case scrut [nothingAlt, justAlt] resultTy @@ -410,7 +410,7 @@ emitRunState v x0 body = do mkBinaryEffFun :: MonadEmbed m => RWS -> Name -> Type -> (Atom -> m Atom) -> m Atom mkBinaryEffFun rws v ty body = do eff <- getAllowedEffects - buildLam (Bind ("h":>TyKind)) PureArrow $ \r@(Var (rName:>_)) -> do + buildLam (Bind ("h":>TyKind)) PureArrow \r@(Var (rName:>_)) -> do let arr = PlainArrow $ extendEffect (RWSEffect rws rName) eff buildLam (Bind (v:> RefTy r ty)) arr body @@ -434,16 +434,16 @@ buildFor = buildForAnn . RegularFor buildNestedLam :: MonadEmbed m => Arrow -> [Binder] -> ([Atom] -> m Atom) -> m Atom buildNestedLam _ [] f = f [] buildNestedLam arr (b:bs) f = - buildLam b arr $ \x -> buildNestedLam arr bs $ \xs -> f (x:xs) + buildLam b arr \x -> buildNestedLam arr bs \xs -> f (x:xs) tabGet :: MonadEmbed m => Atom -> Atom -> m Atom tabGet x i = emit $ App x i unzipTab :: MonadEmbed m => Atom -> m (Atom, Atom) unzipTab tab = do - fsts <- buildLam (Bind ("i":>binderType v)) TabArrow $ \i -> + fsts <- buildLam (Bind ("i":>binderType v)) TabArrow \i -> liftM fst $ app tab i >>= fromPair - snds <- buildLam (Bind ("i":>binderType v)) TabArrow $ \i -> + snds <- buildLam (Bind ("i":>binderType v)) TabArrow \i -> liftM snd $ app tab i >>= fromPair return (fsts, snds) where TabTy v _ = getType tab @@ -509,9 +509,9 @@ instance Monad m => MonadEmbed (EmbedT m) where instance MonadEmbed m => MonadEmbed (ReaderT r m) where embedLook = lift embedLook embedExtend x = lift $ embedExtend x - embedScoped m = ReaderT $ \r -> embedScoped $ runReaderT m r + embedScoped m = ReaderT \r -> embedScoped $ runReaderT m r embedAsk = lift embedAsk - embedLocal v m = ReaderT $ \r -> embedLocal v $ runReaderT m r + embedLocal v m = ReaderT \r -> embedLocal v $ runReaderT m r instance MonadEmbed m => MonadEmbed (StateT s m) where embedLook = lift embedLook @@ -710,7 +710,7 @@ traverseExpr def@(_, _, fAtom) expr = case expr of where traverseAlt (Abs bs body) = do bs' <- mapM (mapM fAtom) bs - buildNAbs bs' $ \xs -> extendR (newEnv bs' xs) $ evalBlockE def body + buildNAbs bs' \xs -> extendR (newEnv bs' xs) $ evalBlockE def body traverseAtom :: forall m . (MonadEmbed m, MonadReader SubstEnv m) => TraversalDef m -> Atom -> m Atom @@ -747,7 +747,7 @@ traverseAtom def@(_, _, fAtom) atom = case atom of BoxedRef b ptr size body -> do ptr' <- fAtom ptr size' <- buildScoped $ evalBlockE def size - Abs b' (decls, body') <- buildAbs b $ \x -> + Abs b' (decls, body') <- buildAbs b \x -> extendR (b@>x) $ evalBlockE def (Block Empty $ Atom body) case decls of Empty -> return $ BoxedRef b' ptr' size' body' @@ -765,7 +765,7 @@ traverseAtom def@(_, _, fAtom) atom = case atom of traverseAAlt (Abs bs a) = do bs' <- mapM (mapM fAtom) bs - (Abs bs'' b) <- buildNAbs bs' $ \xs -> extendR (newEnv bs' xs) $ fAtom a + (Abs bs'' b) <- buildNAbs bs' \xs -> extendR (newEnv bs' xs) $ fAtom a case b of Block Empty (Atom r) -> return $ Abs bs'' r _ -> error "ACase alternative traversal has emitted decls or exprs!" @@ -842,7 +842,7 @@ indexToIntE idx = case getType idx of (offsets, _) <- scanM (\sz prev -> (prev,) <$> iadd sz prev) sizes (IdxRepVal 0) -- Build and apply a case expression alts <- flip mapM (zip (toList offsets) (toList types)) $ - \(offset, subty) -> buildNAbs (toNest [Ignore subty]) $ \[subix] -> do + \(offset, subty) -> buildNAbs (toNest [Ignore subty]) \[subix] -> do i <- indexToIntE subix iadd offset i emit $ Case idx alts IdxRepTy diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index aa3c94663..deae7ab50 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -80,7 +80,7 @@ toImpModule :: TopEnv -> Backend -> CallingConvention -> Name -> (ImpFunction, ImpModule, AtomRecon) toImpModule env backend cc entryName argBinders maybeDest block = do let standaloneFunctions = - for (requiredFunctions env block) $ \(v, f) -> + for (requiredFunctions env block) \(v, f) -> runImpM initCtx inVarScope $ toImpStandalone v f runImpM initCtx inVarScope $ do (reconAtom, impBlock) <- scopedBlock $ translateTopLevel env (maybeDest, block) @@ -98,7 +98,7 @@ toImpModule env backend cc entryName argBinders maybeDest block = do requiredFunctions :: HasVars a => Scope -> a -> [(Name, Atom)] requiredFunctions scope expr = - flip foldMap (transitiveClosure getParents immediateParents) $ \fname -> + flip foldMap (transitiveClosure getParents immediateParents) \fname -> case scope ! fname of (_, LetBound _ (Atom f)) -> [(fname, f)] (_, LamBound _) -> [] @@ -142,7 +142,7 @@ toImpStandalone fname ~(LamVal b body) = do impBlock <- scopedErrBlock $ do arg <- destToAtom argDest void $ translateBlock (b@>arg) (Just outDest, body) - let bs = for ptrSizes $ \(Bind (v:>BaseTy ty), _) -> Bind $ v:>ty + let bs = for ptrSizes \(Bind (v:>BaseTy ty), _) -> Bind $ v:>ty let fTy = IFunType CEntryFun (map binderAnn bs) (impBlockType impBlock) return $ ImpFunction (fname:>fTy) bs impBlock @@ -150,7 +150,7 @@ translateBlock :: SubstEnv -> WithDest Block -> ImpM Atom translateBlock env destBlock = do let (decls, result, copies) = splitDest destBlock env' <- (env<>) <$> catFoldM translateDecl env decls - forM_ copies $ \(dest, atom) -> copyAtom dest =<< impSubst env' atom + forM_ copies \(dest, atom) -> copyAtom dest =<< impSubst env' atom translateExpr env' result translateDecl :: SubstEnv -> WithDest Decl -> ImpM SubstEnv @@ -239,7 +239,7 @@ toImpOp :: WithDest (PrimOp Atom) -> ImpM Atom toImpOp (maybeDest, op) = case op of TabCon (TabTy b _) rows -> do dest <- allocDest maybeDest resultTy - forM_ (zip [0..] rows) $ \(i, row) -> do + forM_ (zip [0..] rows) \(i, row) -> do ithDest <- destGet dest =<< intToIndex (binderType b) (IIdxRepVal i) copyAtom ithDest row destToAtom dest @@ -358,7 +358,7 @@ toImpHof env (maybeDest, hof) = do Select (toScalarAtom isLast) (toScalarAtom elemsUntilEnd) (toScalarAtom usualChunkSize)) - emitLoop "li" Fwd (fromScalarAtom threadChunkSize) $ \li -> do + emitLoop "li" Fwd (fromScalarAtom threadChunkSize) \li -> do i <- li `iaddI` chunkStart let idx = Con $ ParIndexCon idxTy $ toScalarAtom i ithDest <- destGet dest idx @@ -381,7 +381,7 @@ toImpHof env (maybeDest, hof) = do _ -> do n <- indexSetSize idxTy dest <- allocDest maybeDest resultTy - emitLoop (binderNameHint b) d n $ \i -> do + emitLoop (binderNameHint b) d n \i -> do idx <- intToIndex idxTy i ithDest <- destGet dest idx void $ translateBlock (env <> b @> idx) (Just ithDest, body) @@ -389,13 +389,13 @@ toImpHof env (maybeDest, hof) = do For ParallelFor ~fbody@(LamVal b _) -> do idxTy <- impSubst env $ binderType b dest <- allocDest maybeDest resultTy - buildKernel idxTy $ \LaunchInfo{..} buildBody -> do - liftM (,()) $ buildBody $ \ThreadInfo{..} -> do + buildKernel idxTy \LaunchInfo{..} buildBody -> do + liftM (,()) $ buildBody \ThreadInfo{..} -> do let threadBody = fst $ flip runSubstEmbed (freeVars fbody) $ - buildLam (Bind $ "hwidx" :> threadRange) PureArrow $ \hwidx -> + buildLam (Bind $ "hwidx" :> threadRange) PureArrow \hwidx -> appReduce fbody =<< (emitOp $ Inject hwidx) let threadDest = Con $ TabRef $ fst $ flip runSubstEmbed (freeVars dest) $ - buildLam (Bind $ "hwidx" :> threadRange) TabArrow $ \hwidx -> + buildLam (Bind $ "hwidx" :> threadRange) TabArrow \hwidx -> indexDest dest =<< (emitOp $ Inject hwidx) void $ toImpHof env (Just threadDest, For (RegularFor Fwd) threadBody) destToAtom dest @@ -407,12 +407,12 @@ toImpHof env (maybeDest, hof) = do nTiles <- n `idivI` tileLen epilogueOff <- nTiles `imulI` tileLen nEpilogue <- n `isubI` epilogueOff - emitLoop (binderNameHint tb) Fwd nTiles $ \iTile -> do + emitLoop (binderNameHint tb) Fwd nTiles \iTile -> do tileOffset <- toScalarAtom <$> iTile `imulI` tileLen let tileAtom = Con $ IndexSliceVal idxTy tileIdxTy tileOffset tileDest <- fromEmbed $ sliceDestDim d dest tileOffset tileIdxTy void $ translateBlock (env <> tb @> tileAtom) (Just tileDest, tBody) - emitLoop (binderNameHint sb) Fwd nEpilogue $ \iEpi -> do + emitLoop (binderNameHint sb) Fwd nEpilogue \iEpi -> do i <- iEpi `iaddI` epilogueOff idx <- intToIndex idxTy i sDest <- fromEmbed $ indexDestDim d dest idx @@ -422,16 +422,16 @@ toImpHof env (maybeDest, hof) = do idxTy <- impSubst env idxTy' (mappingDest, finalAccDest) <- destPairUnpack <$> allocDest maybeDest resultTy let PairTy _ accType = resultTy - (numTileWorkgroups, wgResArr, widIdxTy) <- buildKernel idxTy $ \LaunchInfo{..} buildBody -> do + (numTileWorkgroups, wgResArr, widIdxTy) <- buildKernel idxTy \LaunchInfo{..} buildBody -> do let widIdxTy = Fin $ toScalarAtom numWorkgroups let tidIdxTy = Fin $ toScalarAtom workgroupSize wgResArr <- alloc $ TabTy (Ignore widIdxTy) accType thrAccArr <- alloc $ TabTy (Ignore widIdxTy) $ TabTy (Ignore tidIdxTy) accType - mappingKernelBody <- buildBody $ \ThreadInfo{..} -> do + mappingKernelBody <- buildBody \ThreadInfo{..} -> do let TC (ParIndexRange _ gtid nthr) = threadRange let scope = freeVars mappingDest let tileDest = Con $ TabRef $ fst $ flip runSubstEmbed scope $ do - buildLam (Bind $ "hwidx":>threadRange) TabArrow $ \hwidx -> do + buildLam (Bind $ "hwidx":>threadRange) TabArrow \hwidx -> do indexDest mappingDest =<< (emitOp $ Inject hwidx) wgAccs <- destGet thrAccArr =<< intToIndex widIdxTy wid thrAcc <- destGet wgAccs =<< intToIndex tidIdxTy tid @@ -443,12 +443,12 @@ toImpHof env (maybeDest, hof) = do -- TODO: Skip the reduction kernel if unnecessary? -- TODO: Reduce sequentially in the CPU backend? -- TODO: Actually we only need the previous-power-of-2 many threads - buildKernel widIdxTy $ \LaunchInfo{..} buildBody -> do + buildKernel widIdxTy \LaunchInfo{..} buildBody -> do -- We only do a one-level reduciton in the workgroup, so it is correct -- only if the end up scheduling a single workgroup. moreThanOneGroup <- (IIdxRepVal 1) `iltI` numWorkgroups guardBlock moreThanOneGroup $ emitStatement IThrowError - redKernelBody <- buildBody $ \ThreadInfo{..} -> + redKernelBody <- buildBody \ThreadInfo{..} -> workgroupReduce tid finalAccDest wgResArr numTileWorkgroups return (redKernelBody, ()) PairVal <$> destToAtom mappingDest <*> destToAtom finalAccDest @@ -548,7 +548,7 @@ buildKernel idxTy f = do LLVMCUDA -> (CUDAKernelLaunch, GPU) LLVMMC -> (MCThreadLaunch , CPU) backend -> error $ "Shouldn't be launching kernels from " ++ show backend - ((kernelBody, aux), env) <- scoped $ f LaunchInfo{..} $ \mkBody -> + ((kernelBody, aux), env) <- scoped $ f LaunchInfo{..} \mkBody -> withDevice dev $ withLevel ThreadLevel $ scopedErrBlock $ do gtid <- iaddI tid =<< imulI wid wsz let threadRange = TC $ ParIndexRange idxTy (toScalarAtom gtid) (toScalarAtom nthr) @@ -581,7 +581,7 @@ type DestM = ReaderT DestEnv (CatT (Env (Type, Block)) Embed) makeDest :: AllocInfo -> Type -> Embed ([(Binder, Atom)], Dest) makeDest allocInfo ty = do (dest, ptrs) <- flip runCatT mempty $ flip runReaderT env $ makeDestRec ty - ptrs' <- forM (envPairs ptrs) $ \(v, (ptrTy, numel)) -> do + ptrs' <- forM (envPairs ptrs) \(v, (ptrTy, numel)) -> do numel' <- emitBlock numel return (Bind (v:>ptrTy), numel') return (ptrs', dest) @@ -598,7 +598,7 @@ makeDestRec ty = case ty of makeDestRec ty makeBoxes (envPairs ptrs) dest else do - lam <- buildLam (Bind ("i":> binderAnn b)) TabArrow $ \(Var i) -> do + lam <- buildLam (Bind ("i":> binderAnn b)) TabArrow \(Var i) -> do bodyTy' <- substEmbed (b@>Var i) bodyTy withEnclosingIdxs (Bind i) $ makeDestRec bodyTy' return $ Con $ TabRef lam @@ -614,7 +614,7 @@ makeDestRec ty = case ty of "Dependent data constructors only allowed for single-constructor types" tag <- rec TagRepTy let dcs' = applyDataDefParams def params - contents <- forM dcs' $ \(DataConDef _ bs) -> forM (toList bs) (rec . binderType) + contents <- forM dcs' \(DataConDef _ bs) -> forM (toList bs) (rec . binderType) return $ Con $ ConRef $ SumAsProd ty tag contents RecordTy (NoExt types) -> (Con . RecordRef) <$> forM types rec VariantTy (NoExt types) -> do @@ -720,7 +720,7 @@ loadDest (DataConRef def params bs) = do loadDest (Con dest) = do case dest of BaseTypeRef ptr -> unsafePtrLoad ptr - TabRef (TabVal b body) -> buildLam b TabArrow $ \i -> do + TabRef (TabVal b body) -> buildLam b TabArrow \i -> do body' <- substEmbed (b@>i) body result <- emitBlock body' loadDest result @@ -744,7 +744,7 @@ loadDataConArgs (Nest (DataConRefBinding b ref) rest) = do indexDestDim :: MonadEmbed m => Int->Dest -> Atom -> m Dest indexDestDim 0 dest i = indexDest dest i -indexDestDim d dest i = buildFor Fwd (Bind ("i":>idxTy)) $ \j -> do +indexDestDim d dest i = buildFor Fwd (Bind ("i":>idxTy)) \j -> do dest' <- indexDest dest j indexDestDim (d-1) dest' i where @@ -757,7 +757,7 @@ indexDest dest _ = error $ pprint dest sliceDestDim :: MonadEmbed m => Int -> Dest -> Atom -> Type -> m Dest sliceDestDim 0 dest i sliceIdxTy = sliceDest dest i sliceIdxTy -sliceDestDim d dest i sliceIdxTy = buildFor Fwd (Bind ("i":>idxTy)) $ \j -> do +sliceDestDim d dest i sliceIdxTy = buildFor Fwd (Bind ("i":>idxTy)) \j -> do dest' <- indexDest dest j sliceDestDim (d-1) dest' i sliceIdxTy where @@ -766,7 +766,7 @@ sliceDestDim d dest i sliceIdxTy = buildFor Fwd (Bind ("i":>idxTy)) $ \j -> do sliceDest :: MonadEmbed m => Dest -> Atom -> Type -> m Dest sliceDest ~(Con (TabRef tab@(TabVal b _))) i sliceIdxTy = (Con . TabRef) <$> do - buildFor Fwd (Bind ("j" :> sliceIdxTy)) $ \j -> do + buildFor Fwd (Bind ("j" :> sliceIdxTy)) \j -> do j' <- indexToIntE j ioff <- iadd j' i vidx <- intToIndexE (binderType b) ioff @@ -790,7 +790,7 @@ makeAllocDestWithPtrs allocTy ty = do backend <- asks impBackend curDev <- asks curDevice (ptrsSizes, dest) <- fromEmbed $ makeDest (backend, curDev, allocTy) ty - (env, ptrs) <- flip foldMapM ptrsSizes $ \(Bind (ptr:>PtrTy ptrTy), size) -> do + (env, ptrs) <- flip foldMapM ptrsSizes \(Bind (ptr:>PtrTy ptrTy), size) -> do ptr' <- emitAlloc ptrTy $ fromScalarAtom size case ptrTy of (Heap _, _) | allocTy == Managed -> extendAlloc ptr' @@ -811,7 +811,7 @@ splitDest (maybeDest, (Block decls ans)) = do let closureCopies = fmap (\(n, (d, t)) -> (d, Var $ n :> t)) (envPairs $ varDests `envDiff` foldMap letBoundVars decls) - let destDecls = flip fmap (toList decls) $ \d -> case d of + let destDecls = flip fmap (toList decls) \d -> case d of Let _ b _ -> (fst <$> varDests `envLookup` b, d) (destDecls, (Nothing, ans), gatherCopies ++ closureCopies) _ -> (fmap (Nothing,) $ toList decls, (maybeDest, ans), []) @@ -939,7 +939,7 @@ zipTabDestAtom f ~dest@(Con (TabRef (TabVal b _))) ~src@(TabVal b' _) = do error $ "Mismatched dimensions: " <> pprint b <> " != " <> pprint b' let idxTy = binderType b n <- indexSetSize idxTy - emitLoop "i" Fwd n $ \i -> do + emitLoop "i" Fwd n \i -> do idx <- intToIndex idxTy i destIndexed <- destGet dest idx srcIndexed <- translateExpr mempty (Nothing, App src idx) @@ -1094,7 +1094,7 @@ restructureScalarOrPairTypeRec ty _ = error $ "Not a scalar or pair: " ++ pprint emitMultiReturnInstr :: ImpInstr -> ImpM [IExpr] emitMultiReturnInstr instr = do - vs <- forM (impInstrTypes instr) $ \ty -> freshVar ("v":>ty) + vs <- forM (impInstrTypes instr) \ty -> freshVar ("v":>ty) emitImpDecl $ ImpLet (map Bind vs) instr return (map IVar vs) diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index 5a74edbfd..73d673554 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -50,7 +50,7 @@ inferModule :: TopEnv -> UModule -> Except Module inferModule scope (UModule decls) = do ((), (bindings, decls')) <- runUInferM mempty scope $ mapM_ (inferUDecl True) decls - let bindings' = envFilter bindings $ \(_, b) -> case b of + let bindings' = envFilter bindings \(_, b) -> case b of DataBoundTypeCon _ -> True DataBoundDataCon _ _ -> True _ -> False @@ -68,7 +68,7 @@ checkSigma expr reqCon sTy = case sTy of WithSrc _ (ULam b arrow' body) | arrow' == void arrow -> checkULam b body piTy _ -> do - buildLam (Bind ("a":> absArgType piTy)) arrow $ \x@(Var v) -> + buildLam (Bind ("a":> absArgType piTy)) arrow \x@(Var v) -> checkLeaks [v] $ checkSigma expr reqCon $ snd $ applyAbs piTy x _ -> checkOrInferRho expr (reqCon sTy) @@ -157,7 +157,7 @@ checkOrInferRho (WithSrc pos expr) reqTy = do -- TODO: check leaks kind' <- checkUType kind piTy <- case pat of - Just pat' -> withNameHint ("pat" :: Name) $ buildPi b $ \x -> + Just pat' -> withNameHint ("pat" :: Name) $ buildPi b \x -> withBindPat pat' x $ (,) <$> mapM checkUEffRow arr <*> checkUType ty where b = case pat' of -- Note: The binder name becomes part of the type, so we @@ -182,7 +182,7 @@ checkOrInferRho (WithSrc pos expr) reqTy = do case scrutTy' of TypeCon def params -> do let conDefs = applyDataDefParams def params - altsSorted <- forM (enumerate conDefs) $ \(i, DataConDef _ bs) -> do + altsSorted <- forM (enumerate conDefs) \(i, DataConDef _ bs) -> do case lookup (ConAlt i) alts' of Nothing -> return $ Abs (fmap (Ignore . binderType) bs) $ Block Empty $ Op $ ThrowError reqTy' @@ -256,7 +256,7 @@ checkOrInferRho (WithSrc pos expr) reqTy = do val' <- checkSigma val reqCon ty' matchRequirement val' UPrimExpr prim -> do - prim' <- forM prim $ \e -> do + prim' <- forM prim \e -> do e' <- inferRho e scope <- getScope return $ typeReduceAtom scope e' @@ -319,7 +319,7 @@ unpackTopPat :: LetAnn -> UPat -> Expr -> UInferM () unpackTopPat letAnn pat expr = do atom <- emit expr bindings <- bindPat pat atom - void $ flip traverseNames bindings $ \name val -> do + void $ flip traverseNames bindings \name val -> do let name' = asGlobal name checkNotInScope name' emitTo name' letAnn $ Atom val @@ -342,15 +342,15 @@ inferUDecl topLevel (ULet letAnn (p, ann) rhs) = do else bindPat p val inferUDecl True (UData tc dcs) = do (tc', paramBs) <- inferUConDef tc - dataDef <- buildDataDef tc' paramBs $ \params -> do - extendR (newEnv paramBs params) $ forM dcs $ \dc -> + dataDef <- buildDataDef tc' paramBs \params -> do + extendR (newEnv paramBs params) $ forM dcs \dc -> uncurry DataConDef <$> inferUConDef dc checkDataDefShadows dataDef emitConstructors dataDef return mempty inferUDecl True (UInterface superclasses tc methods) = do (tc', paramBs) <- inferUConDef tc - dataDef <- buildDataDef tc' paramBs $ \params -> do + dataDef <- buildDataDef tc' paramBs \params -> do extendR (newEnv paramBs params) $ do conName <- freshClassGenName superclasses' <- mkLabeledItems <$> mapM mkSuperclass superclasses @@ -403,16 +403,16 @@ emitConstructors def@(DataDef tyConName _ dataConDefs) = do let tyConTy = getType $ TypeCon def [] checkNotInScope tyConName extendScope $ tyConName @> (tyConTy, DataBoundTypeCon def) - forM_ (zip [0..] dataConDefs) $ \(i, DataConDef dataConName _) -> do + forM_ (zip [0..] dataConDefs) \(i, DataConDef dataConName _) -> do let dataConTy = getType $ DataCon def [] i [] checkNotInScope dataConName extendScope $ dataConName @> (dataConTy, DataBoundDataCon def i) emitMethodGetters :: DataDef -> UInferM () emitMethodGetters def@(DataDef _ paramBs (ClassDictDef _ _ methodTys)) = do - forM_ (getLabels methodTys) $ \l -> do - f <- buildImplicitNaryLam paramBs $ \params -> do - buildLam (Bind ("d":> TypeCon def params)) ClassArrow $ \dict -> do + forM_ (getLabels methodTys) \l -> do + f <- buildImplicitNaryLam paramBs \params -> do + buildLam (Bind ("d":> TypeCon def params)) ClassArrow \dict -> do return $ recGet l $ getProjection [1] dict let methodName = GlobalName $ fromString l checkNotInScope methodName @@ -421,9 +421,9 @@ emitMethodGetters (DataDef _ _ _) = error "Not a class dictionary" emitSuperclassGetters :: MonadEmbed m => DataDef -> m () emitSuperclassGetters def@(DataDef _ paramBs (ClassDictDef _ superclassTys _)) = do - forM_ (getLabels superclassTys) $ \l -> do - f <- buildImplicitNaryLam paramBs $ \params -> do - buildLam (Bind ("d":> TypeCon def params)) PureArrow $ \dict -> do + forM_ (getLabels superclassTys) \l -> do + f <- buildImplicitNaryLam paramBs \params -> do + buildLam (Bind ("d":> TypeCon def params)) PureArrow \dict -> do return $ recGet l $ getProjection [0] dict getterName <- freshClassGenName emitTo getterName SuperclassLet $ Atom f @@ -468,7 +468,7 @@ inferULam (p, ann) arr body = do argTy <- checkAnn ann -- TODO: worry about binder appearing in arrow? buildLam (Bind $ patNameHint p :> argTy) arr - $ \x@(Var v) -> checkLeaks [v] $ withBindPat p x $ inferSigma body + \x@(Var v) -> checkLeaks [v] $ withBindPat p x $ inferSigma body checkULam :: UPatAnn -> UExpr -> PiType -> UInferM Atom checkULam (p, ann) body piTy = do @@ -476,7 +476,7 @@ checkULam (p, ann) body piTy = do checkAnn ann >>= constrainEq argTy buildDepEffLam (Bind $ patNameHint p :> argTy) ( \x -> return $ fst $ applyAbs piTy x) - $ \x@(Var v) -> checkLeaks [v] $ withBindPat p x $ + \x@(Var v) -> checkLeaks [v] $ withBindPat p x $ checkSigma body Suggest $ snd $ applyAbs piTy x checkInstance :: Type -> [(UVar, UExpr)] -> UInferM Atom @@ -484,7 +484,7 @@ checkInstance ty methods = case ty of TypeCon def@(DataDef className _ _) params -> do case applyDataDefParams def params of ClassDictDef _ superclassTys methodTys -> do - methods' <- liftM mkLabeledItems $ forM methods $ \((v:>()), rhs) -> do + methods' <- liftM mkLabeledItems $ forM methods \((v:>()), rhs) -> do let v' = nameToLabel v case lookupLabel methodTys v' of Nothing -> throw TypeErr (pprint v ++ " is not a method of " ++ pprint className) @@ -492,9 +492,9 @@ checkInstance ty methods = case ty of rhs' <- checkSigma rhs Suggest methodTy return (v', rhs') let superclassHoles = fmap (Con . ClassDictHole Nothing) superclassTys - forM_ (reflectLabels methods') $ \(l,i) -> + forM_ (reflectLabels methods') \(l,i) -> when (i > 0) $ throw TypeErr $ "Duplicate method: " ++ pprint l - forM_ (reflectLabels methodTys) $ \(l,_) -> + forM_ (reflectLabels methodTys) \(l,_) -> case lookupLabel methods' l of Nothing -> throw TypeErr $ "Missing method: " ++ pprint l Just _ -> return () @@ -505,7 +505,7 @@ checkInstance ty methods = case ty of ImplicitArrow -> return () ClassArrow -> return () _ -> throw TypeErr $ "Not a valid arrow for an instance: " ++ pprint arrow - buildLam b arrow $ \x@(Var v) -> do + buildLam b arrow \x@(Var v) -> do bodyTy' <- substEmbed (b@>x) bodyTy checkLeaks [v] $ extendR (b@>x) $ checkInstance bodyTy' methods _ -> throw TypeErr $ "Not a valid instance type: " ++ pprint ty @@ -513,7 +513,7 @@ checkInstance ty methods = case ty of checkUEffRow :: EffectRow -> UInferM EffectRow checkUEffRow (EffectRow effs t) = do effs' <- liftM S.fromList $ mapM checkUEff $ toList effs - t' <- forM t $ \tv -> lookupVarName EffKind tv + t' <- forM t \tv -> lookupVarName EffKind tv return $ EffectRow effs' t' where lookupVarName :: Type -> Name -> UInferM Name @@ -540,7 +540,7 @@ checkCaseAlt reqTy scrutineeTy (UAlt pat body) = do (conIdx, patTys) <- checkCasePat pat scrutineeTy let (subPats, subPatTys) = unzip patTys let bs = zipWith (\p ty -> Bind $ patNameHint p :> ty) subPats subPatTys - alt <- buildNAbs (toNest bs) $ \xs -> + alt <- buildNAbs (toNest bs) \xs -> withBindPats (zip subPats xs) $ checkRho body reqTy return (conIdx, alt) @@ -658,7 +658,7 @@ bindPat' (WithSrc pos pat) val = addSrcContext pos $ case pat of throw TypeErr $ "Incorrect length of table pattern: table index set has " <> pprint (length idxs) <> " elements but there are " <> pprint (length ps) <> " patterns." - flip foldMapM (zip ps idxs) $ \(p, i) -> do + flip foldMapM (zip ps idxs) \(p, i) -> do v <- lift $ emitZonked $ App val i bindPat' p v @@ -883,7 +883,7 @@ runSolverT m = liftM fst $ flip runCatT mempty $ do applyDefaults :: MonadCat SolverEnv m => m () applyDefaults = do vs <- looks unsolved - forM_ (envPairs vs) $ \(v, k) -> case k of + forM_ (envPairs vs) \(v, k) -> case k of EffKind -> addSub v $ Eff Pure _ -> return () where addSub v ty = extend $ SolverEnv mempty (v@>ty) @@ -907,8 +907,8 @@ checkLeaks tvs m = do unless (null $ resultTypeLeaks) $ throw TypeErr $ "Leaked local variable `" ++ pprint (head resultTypeLeaks) ++ "` in result type " ++ pprint (getType ans) - forM_ (solverSub env) $ \ty -> - forM_ tvs $ \tv -> + forM_ (solverSub env) \ty -> + forM_ tvs \tv -> throwIf (tv `occursIn` ty) TypeErr $ "Leaked type variable: " ++ pprint tv extend env return ans diff --git a/src/lib/JIT.hs b/src/lib/JIT.hs index a9d374269..d26ad3a4a 100644 --- a/src/lib/JIT.hs +++ b/src/lib/JIT.hs @@ -94,11 +94,11 @@ compileFunction logger fun@(ImpFunction f bs body) = case cc of (argPtrParam , argPtrOperand ) <- freshParamOpPair attrs $ hostPtrTy i64 (resultPtrParam, resultPtrOperand) <- freshParamOpPair attrs $ hostPtrTy i64 initializeOutputStream streamFDOperand - argOperands <- forM (zip [0..] argTys) $ \(i, ty) -> + argOperands <- forM (zip [0..] argTys) \(i, ty) -> gep argPtrOperand (i64Lit i) >>= castLPtr (scalarTy ty) >>= load when (toBool requiresCUDA) ensureHasCUDAContext results <- extendOperands (newEnv bs argOperands) $ compileBlock body - forM_ (zip [0..] results) $ \(i, x) -> + forM_ (zip [0..] results) \(i, x) -> gep resultPtrOperand (i64Lit i) >>= castLPtr (L.typeOf x) >>= flip store x mainFun <- makeFunction (asLLVMName name) [streamFDParam, argPtrParam, resultPtrParam] (Just $ i64Lit 0) @@ -607,7 +607,7 @@ compileExpr expr = case expr of packArgs :: [Operand] -> Compile Operand packArgs elems = do arr <- alloca (length elems) hostVoidp - forM_ (zip [0..] elems) $ \(i, e) -> do + forM_ (zip [0..] elems) \(i, e) -> do eptr <- alloca 1 $ L.typeOf e store eptr e earr <- gep arr $ i32Lit i @@ -616,7 +616,7 @@ packArgs elems = do unpackArgs :: Operand -> [L.Type] -> Compile [Operand] unpackArgs argArrayPtr types = - forM (zip [0..] types) $ \(i, ty) -> do + forM (zip [0..] types) \(i, ty) -> do argVoidPtr <- gep argArrayPtr $ i64Lit i argPtr <- castLPtr (hostPtrTy ty) argVoidPtr load =<< load argPtr @@ -624,7 +624,7 @@ unpackArgs argArrayPtr types = makeMultiResultAlloc :: [L.Type] -> Compile Operand makeMultiResultAlloc tys = do resultsPtr <- alloca (length tys) hostVoidp - forM_ (zip [0..] tys) $ \(i, ty) -> do + forM_ (zip [0..] tys) \(i, ty) -> do ptr <- alloca 1 ty >>= castVoidPtr resultsPtrOffset <- gep resultsPtr $ i32Lit i store resultsPtrOffset ptr @@ -632,7 +632,7 @@ makeMultiResultAlloc tys = do loadMultiResultAlloc :: [L.Type] -> Operand -> Compile [Operand] loadMultiResultAlloc tys ptr = - forM (zip [0..] tys) $ \(i, ty) -> + forM (zip [0..] tys) \(i, ty) -> gep ptr (i32Lit i) >>= load >>= castLPtr ty >>= load runMCKernel :: ExternFunSpec @@ -894,7 +894,7 @@ runCompile dev m = evalState (runReaderT m env) initState initState = CompileState [] [] [] "start_block" mempty mempty mempty extendOperands :: OperandEnv -> Compile a -> Compile a -extendOperands openv = local $ \env -> env { operandEnv = (operandEnv env) <> openv } +extendOperands openv = local \env -> env { operandEnv = (operandEnv env) <> openv } lookupImpVar :: IVar -> Compile Operand lookupImpVar v = asks ((! v) . operandEnv) @@ -912,7 +912,7 @@ freshName :: Name -> Compile L.Name freshName v = do used <- gets usedNames let v' = genFresh v used - modify $ \s -> s { usedNames = used <> v' @> () } + modify \s -> s { usedNames = used <> v' @> () } return $ nameToLName v' where nameToLName :: Name -> L.Name diff --git a/src/lib/LLVM/JIT.hs b/src/lib/LLVM/JIT.hs index e10228a4c..c73b396be 100644 --- a/src/lib/LLVM/JIT.hs +++ b/src/lib/LLVM/JIT.hs @@ -88,14 +88,14 @@ compileModule moduleJIT@JIT{..} ast compilationPipeline = do resolver <- newSymbolResolver execSession (makeResolver compileLayer) modifyIORef resolvers (M.insert moduleKey resolver) OrcJIT.addModule compileLayer moduleKey llvmModule - moduleDtors <- forM dtorNames $ \dtorName -> do + moduleDtors <- forM dtorNames \dtorName -> do dtorSymbol <- OrcJIT.mangleSymbol compileLayer (fromString dtorName) Right (OrcJIT.JITSymbol dtorAddr _) <- OrcJIT.findSymbol compileLayer dtorSymbol False return $ castPtrToFunPtr $ wordPtrToPtr dtorAddr return NativeModule{..} where makeResolver :: OrcJIT.IRCompileLayer OrcJIT.ObjectLinkingLayer -> OrcJIT.SymbolResolver - makeResolver cl = OrcJIT.SymbolResolver $ \sym -> do + makeResolver cl = OrcJIT.SymbolResolver \sym -> do rsym <- OrcJIT.findSymbol cl sym False -- We look up functions like malloc in the current process -- TODO: Use JITDylibs to avoid inlining addresses as constants: @@ -116,7 +116,7 @@ compileModule moduleJIT@JIT{..} ast compilationPipeline = do -- Unfortunately the JIT layers we use here don't handle the destructors properly, -- so we have to find and call them ourselves. dtorNames = do - let dtorStructs = flip foldMap (LLVM.AST.moduleDefinitions ast) $ \case + let dtorStructs = flip foldMap (LLVM.AST.moduleDefinitions ast) \case LLVM.AST.GlobalDefinition LLVM.AST.GlobalVariable{ name="llvm.global_dtors", diff --git a/src/lib/LLVM/Shims.hs b/src/lib/LLVM/Shims.hs index 860b5540a..e509ac12d 100644 --- a/src/lib/LLVM/Shims.hs +++ b/src/lib/LLVM/Shims.hs @@ -35,7 +35,7 @@ data SymbolResolver = SymbolResolver (FunPtr FFIResolver) (Ptr OrcJIT.FFI.Symbol -- | Create a `FFI.SymbolResolver` that can be used with the JIT. newSymbolResolver :: OrcJIT.ExecutionSession -> OrcJIT.SymbolResolver -> IO SymbolResolver newSymbolResolver (OrcJIT.ExecutionSession session) (OrcJIT.SymbolResolver resolverFn) = do - ffiResolverPtr <- wrapFFIResolver $ \sym res -> do + ffiResolverPtr <- wrapFFIResolver \sym res -> do f <- encodeM =<< resolverFn =<< decodeM sym f res lambdaResolver <- OrcJIT.FFI.createLambdaResolver session ffiResolverPtr @@ -60,10 +60,10 @@ newTargetMachine :: Target.Target newTargetMachine (Target.Target targetFFI) triple cpu features (Target.TargetOptions targetOptFFI) relocModel codeModel cgoLevel = do - SBS.useAsCString triple $ \tripleFFI -> do - BS.useAsCString cpu $ \cpuFFI -> do + SBS.useAsCString triple \tripleFFI -> do + BS.useAsCString cpu \cpuFFI -> do let featuresStr = BS.intercalate "," $ fmap encodeFeature $ M.toList features - BS.useAsCString featuresStr $ \featuresFFI -> do + BS.useAsCString featuresStr \featuresFFI -> do relocModelFFI <- encodeM relocModel codeModelFFI <- encodeM codeModel cgoLevelFFI <- encodeM cgoLevel @@ -79,7 +79,7 @@ newHostTargetMachine relocModel codeModel cgoLevel = do (target, _) <- Target.lookupTarget Nothing triple cpu <- Target.getHostCPUName features <- Target.getHostCPUFeatures - Target.withTargetOptions $ \targetOptions -> + Target.withTargetOptions \targetOptions -> newTargetMachine target triple cpu features targetOptions relocModel codeModel cgoLevel disposeTargetMachine :: Target.TargetMachine -> IO () diff --git a/src/lib/LLVMExec.hs b/src/lib/LLVMExec.hs index b2957cf5b..f4435dd19 100644 --- a/src/lib/LLVMExec.hs +++ b/src/lib/LLVMExec.hs @@ -72,9 +72,9 @@ type DexExitCode = Int compileAndEval :: Logger [Output] -> L.Module -> String -> [LitVal] -> [BaseType] -> IO [LitVal] compileAndEval logger ast fname args resultTypes = do - withPipeToLogger logger $ \fd -> - allocaBytes (length args * cellSize) $ \argsPtr -> - allocaBytes (length resultTypes * cellSize) $ \resultPtr -> do + withPipeToLogger logger \fd -> + allocaBytes (length args * cellSize) \argsPtr -> + allocaBytes (length resultTypes * cellSize) \resultPtr -> do storeLitVals argsPtr args evalTime <- compileOneOff logger ast fname $ checkedCallFunPtr fd argsPtr resultPtr @@ -84,11 +84,11 @@ compileAndEval logger ast fname args resultTypes = do compileAndBench :: Bool -> Logger [Output] -> L.Module -> String -> [LitVal] -> [BaseType] -> IO [LitVal] compileAndBench shouldSyncCUDA logger ast fname args resultTypes = do - withPipeToLogger logger $ \fd -> - allocaBytes (length args * cellSize) $ \argsPtr -> - allocaBytes (length resultTypes * cellSize) $ \resultPtr -> do + withPipeToLogger logger \fd -> + allocaBytes (length args * cellSize) \argsPtr -> + allocaBytes (length resultTypes * cellSize) \resultPtr -> do storeLitVals argsPtr args - compileOneOff logger ast fname $ \fPtr -> do + compileOneOff logger ast fname \fPtr -> do ((avgTime, benchRuns, results), totalTime) <- measureSeconds $ do -- First warmup iteration, which we also use to get the results void $ checkedCallFunPtr fd argsPtr resultPtr fPtr @@ -112,7 +112,7 @@ compileAndBench shouldSyncCUDA logger ast fname args resultTypes = do withPipeToLogger :: Logger [Output] -> (FD -> IO a) -> IO a withPipeToLogger logger writeAction = do result <- snd <$> withPipe - (\h -> readStream h $ \s -> logThis logger [TextOut s]) + (\h -> readStream h \s -> logThis logger [TextOut s]) (\h -> handleToFd h >>= writeAction) case result of Left e -> E.throw e @@ -129,9 +129,9 @@ checkedCallFunPtr fd argsPtr resultPtr fPtr = do compileOneOff :: Logger [Output] -> L.Module -> String -> (DexExecutable -> IO a) -> IO a compileOneOff logger ast name f = do - withHostTargetMachine $ \tm -> - withJIT tm $ \jit -> - withNativeModule jit ast (standardCompilationPipeline logger [name] tm) $ \compiled -> + withHostTargetMachine \tm -> + withJIT tm \jit -> + withNativeModule jit ast (standardCompilationPipeline logger [name] tm) \compiled -> f =<< getFunctionPtr compiled name standardCompilationPipeline :: Logger [Output] -> [String] -> T.TargetMachine -> Mod.Module -> IO () @@ -151,12 +151,12 @@ standardCompilationPipeline logger exports tm m = do -- Each module comes with a list of exported functions exportObjectFile :: FilePath -> [(L.Module, [String])] -> IO () exportObjectFile objFile modules = do - withContext $ \c -> do - withHostTargetMachine $ \tm -> - withBrackets (fmap (toLLVM c) modules) $ \mods -> do - Mod.withModuleFromAST c L.defaultModule $ \exportMod -> do + withContext \c -> do + withHostTargetMachine \tm -> + withBrackets (fmap (toLLVM c) modules) \mods -> do + Mod.withModuleFromAST c L.defaultModule \exportMod -> do void $ foldM linkModules exportMod mods - execLogger Nothing $ \logger -> + execLogger Nothing \logger -> standardCompilationPipeline logger allExports tm exportMod Mod.writeObjectToFile tm (Mod.File objFile) exportMod where @@ -164,14 +164,14 @@ exportObjectFile objFile modules = do toLLVM :: Context -> (L.Module, [String]) -> (Mod.Module -> IO a) -> IO a toLLVM c (ast, exports) cont = do - Mod.withModuleFromAST c ast $ \m -> internalize exports m >> cont m + Mod.withModuleFromAST c ast \m -> internalize exports m >> cont m linkModules a b = a <$ Mod.linkModules a b withBrackets :: [(a -> IO b) -> IO b] -> ([a] -> IO b) -> IO b withBrackets brackets f = go brackets [] where - go (h:t) args = h $ \arg -> go t (arg:args) + go (h:t) args = h \arg -> go t (arg:args) go [] args = f args @@ -179,12 +179,12 @@ exportObjectFile objFile modules = do runDefaultPasses :: T.TargetMachine -> Mod.Module -> IO () runDefaultPasses t m = do - P.withPassManager defaultPasses $ \pm -> void $ P.runPassManager pm m + P.withPassManager defaultPasses \pm -> void $ P.runPassManager pm m -- We are highly dependent on LLVM when it comes to some optimizations such as -- turning a sequence of scalar stores into a vector store, so we execute some -- extra passes to make sure they get simplified correctly. runPasses extraPasses (Just t) m - P.withPassManager defaultPasses $ \pm -> void $ P.runPassManager pm m + P.withPassManager defaultPasses \pm -> void $ P.runPassManager pm m where defaultPasses = P.defaultCuratedPassSetSpec {P.optLevel = Just 3} extraPasses = [ P.SuperwordLevelParallelismVectorize @@ -196,7 +196,7 @@ runPasses passes mt m = do Just t -> Just <$> T.getTargetMachineDataLayout t Nothing -> return Nothing let passSpec = P.PassSetSpec passes dl Nothing mt - P.withPassManager passSpec $ \pm -> void $ P.runPassManager pm m + P.withPassManager passSpec \pm -> void $ P.runPassManager pm m internalize :: [String] -> Mod.Module -> IO () internalize names m = runPasses [P.InternalizeFunctions names, P.GlobalDeadCodeElimination] Nothing m @@ -219,7 +219,7 @@ withHostTargetMachine f = withGPUTargetMachine :: B.ByteString -> (T.TargetMachine -> IO a) -> IO a withGPUTargetMachine computeCapability next = do (tripleTarget, _) <- T.lookupTarget Nothing ptxTargetTriple - T.withTargetOptions $ \topt -> + T.withTargetOptions \topt -> T.withTargetMachine tripleTarget ptxTargetTriple @@ -241,8 +241,8 @@ showAsm :: T.TargetMachine -> Mod.Module -> IO String showAsm t m' = do ctx <- Mod.moduleContext m' -- Uncomment this to dump assembly to a file that can be linked to a C benchmark suite: - -- withModuleClone ctx m' $ \m -> Mod.writeObjectToFile t (Mod.File "asm.o") m - withModuleClone ctx m' $ \m -> unpack <$> Mod.moduleTargetAssembly t m + -- withModuleClone ctx m' \m -> Mod.writeObjectToFile t (Mod.File "asm.o") m + withModuleClone ctx m' \m -> unpack <$> Mod.moduleTargetAssembly t m withModuleClone :: Context -> Mod.Module -> (Mod.Module -> IO a) -> IO a withModuleClone ctx m f = do @@ -291,8 +291,8 @@ ptrArray p = map (\i -> p `plusPtr` (i * cellSize)) [0..] {-# NOINLINE dexrtAST #-} dexrtAST :: L.Module dexrtAST = unsafePerformIO $ do - withContext $ \ctx -> do - Mod.withModuleFromBitcode ctx (("dexrt.c" :: String), dexrtBC) $ \m -> + withContext \ctx -> do + Mod.withModuleFromBitcode ctx (("dexrt.c" :: String), dexrtBC) \m -> stripFunctionAnnotations <$> Mod.moduleAST m where -- We strip the function annotations for dexrt functions, because clang @@ -313,7 +313,7 @@ linkDexrt m = do targetTriple <- Mod.getTargetTriple =<< Mod.readModule m let dexrtTargetAST = dexrtAST { L.moduleDataLayout = dataLayout , L.moduleTargetTriple = targetTriple } - Mod.withModuleFromAST ctx dexrtTargetAST $ \dexrtm -> do + Mod.withModuleFromAST ctx dexrtTargetAST \dexrtm -> do Mod.linkModules m dexrtm runPasses [P.AlwaysInline True] Nothing m @@ -325,21 +325,21 @@ data LLVMKernel = LLVMKernel L.Module compileCUDAKernel :: Logger [Output] -> LLVMKernel -> IO CUDAKernel compileCUDAKernel logger (LLVMKernel ast) = do T.initializeAllTargets - withContext $ \ctx -> - Mod.withModuleFromAST ctx ast $ \m -> do - withGPUTargetMachine (pack arch) $ \tm -> do + withContext \ctx -> + Mod.withModuleFromAST ctx ast \m -> do + withGPUTargetMachine (pack arch) \tm -> do linkLibdevice m standardCompilationPipeline logger ["kernel"] tm m ptx <- Mod.moduleTargetAssembly tm m usePTXAS <- maybe False (=="1") <$> lookupEnv "DEX_USE_PTXAS" if usePTXAS then do - withSystemTempFile "kernel.ptx" $ \ptxPath ptxH -> do + withSystemTempFile "kernel.ptx" \ptxPath ptxH -> do B.hPut ptxH ptx hClose ptxH - withSystemTempFile "kernel.sass" $ \sassPath sassH -> do + withSystemTempFile "kernel.sass" \sassPath sassH -> do let cmd = proc ptxasPath [ptxPath, "-o", sassPath, "-arch=" ++ arch, "-O3"] - withCreateProcess cmd $ \_ _ _ ptxas -> do + withCreateProcess cmd \_ _ _ ptxas -> do code <- waitForProcess ptxas case code of ExitSuccess -> return () @@ -354,7 +354,7 @@ compileCUDAKernel logger (LLVMKernel ast) = do {-# NOINLINE libdevice #-} libdevice :: L.Module libdevice = unsafePerformIO $ do - withContext $ \ctx -> do + withContext \ctx -> do let libdeviceDirectory = "/usr/local/cuda/nvvm/libdevice" [libdeviceFileName] <- listDirectory libdeviceDirectory let libdevicePath = libdeviceDirectory ++ "/" ++ libdeviceFileName @@ -367,8 +367,8 @@ libdevice = unsafePerformIO $ do linkLibdevice :: Mod.Module -> IO () linkLibdevice m = do ctx <- Mod.moduleContext m - Mod.withModuleFromAST ctx zeroNVVMReflect $ \reflectm -> - Mod.withModuleFromAST ctx libdevice $ \ldm -> do + Mod.withModuleFromAST ctx zeroNVVMReflect \reflectm -> + Mod.withModuleFromAST ctx libdevice \ldm -> do Mod.linkModules m ldm Mod.linkModules m reflectm runPasses [P.AlwaysInline True] Nothing m diff --git a/src/lib/Logging.hs b/src/lib/Logging.hs index 1ee82ccdc..37d40fd8a 100644 --- a/src/lib/Logging.hs +++ b/src/lib/Logging.hs @@ -20,7 +20,7 @@ data Logger l = Logger (MVar l) (Maybe Handle) runLogger :: (Monoid l, MonadIO m) => Maybe FilePath -> (Logger l -> m a) -> m (a, l) runLogger maybePath m = do log <- liftIO $ newMVar mempty - logFile <- liftIO $ forM maybePath $ \path -> openFile path WriteMode + logFile <- liftIO $ forM maybePath \path -> openFile path WriteMode ans <- m $ Logger log logFile logged <- liftIO $ readMVar log return (ans, logged) @@ -30,10 +30,10 @@ execLogger maybePath m = fst <$> runLogger maybePath m logThis :: (Pretty l, Monoid l, MonadIO m) => Logger l -> l -> m () logThis (Logger log maybeLogHandle) x = liftIO $ do - forM_ maybeLogHandle $ \h -> do + forM_ maybeLogHandle \h -> do hPutStrLn h $ pprint x hFlush h - modifyMVar_ log $ \cur -> return (cur <> x) + modifyMVar_ log \cur -> return (cur <> x) readLog :: MonadIO m => Logger l -> m l readLog (Logger log _) = liftIO $ readMVar log diff --git a/src/lib/Parallelize.hs b/src/lib/Parallelize.hs index 86aef81d3..e11842020 100644 --- a/src/lib/Parallelize.hs +++ b/src/lib/Parallelize.hs @@ -79,9 +79,9 @@ parallelTraverseExpr expr = case expr of False -> nothingSpecial Hof (RunWriter (BinaryFunVal h b _ body)) -> do ~(RefTy _ accTy) <- traverseAtom substTraversalDef $ binderType b - liftM Atom $ emitRunWriter (binderNameHint b) accTy $ \ref@(Var refVar) -> do + liftM Atom $ emitRunWriter (binderNameHint b) accTy \ref@(Var refVar) -> do let RefTy h' _ = varType refVar - modify $ \accEnv -> accEnv { activeAccs = activeAccs accEnv <> b @> refVar } + modify \accEnv -> accEnv { activeAccs = activeAccs accEnv <> b @> refVar } extendR (h @> h' <> b @> ref) $ evalBlockE parallelTrav body -- TODO: Do some alias analysis. This is not fundamentally hard, but it is a little annoying. -- We would have to track not only the base references, but also all the aliases, along @@ -95,7 +95,7 @@ parallelTraverseExpr expr = case expr of where nothingSpecial = traverseExpr parallelTrav expr disallowRef ~(Var refVar) = - modify $ \accEnv -> accEnv { activeAccs = activeAccs accEnv `envDiff` (refVar @> ()) } + modify \accEnv -> accEnv { activeAccs = activeAccs accEnv `envDiff` (refVar @> ()) } parallelizableEffect :: Env () -> Effect -> Bool parallelizableEffect allowedRegions effect = case effect of @@ -203,7 +203,7 @@ emitLoops buildPureLoop (ABlock decls result) = do let buildBody pari = do is <- unpackConsList pari extendR (newEnv lbs is) $ do - ctxEnv <- flip traverseNames dapps $ \_ (arr, idx) -> + ctxEnv <- flip traverseNames dapps \_ (arr, idx) -> -- XXX: arr is namespaced in the new program foldM appTryReduce arr =<< substEmbedR idx extendR ctxEnv $ evalBlockE appReduceTraversalDef $ Block decls $ Atom result @@ -211,18 +211,18 @@ emitLoops buildPureLoop (ABlock decls result) = do True -> buildPureLoop (Bind $ "pari" :> iterTy) buildBody False -> do body <- do - buildLam (Bind $ "gtid" :> IdxRepTy) PureArrow $ \gtid -> do - buildLam (Bind $ "nthr" :> IdxRepTy) PureArrow $ \nthr -> do + buildLam (Bind $ "gtid" :> IdxRepTy) PureArrow \gtid -> do + buildLam (Bind $ "nthr" :> IdxRepTy) PureArrow \nthr -> do let threadRange = TC $ ParIndexRange iterTy gtid nthr let accTys = mkConsListTy $ fmap (derefType . varType) newRefs - emitRunWriter "refsList" accTys $ \localRefsList -> do + emitRunWriter "refsList" accTys \localRefsList -> do localRefs <- unpackRefConsList localRefsList - buildFor Fwd (Bind $ "tidx" :> threadRange) $ \tidx -> do + buildFor Fwd (Bind $ "tidx" :> threadRange) \tidx -> do pari <- emitOp $ Inject tidx extendR (newEnv oldRefNames localRefs) $ buildBody pari (ans, updateList) <- fromPair =<< (emit $ Hof $ PTileReduce iterTy body) updates <- unpackConsList updateList - forM_ (zip newRefs updates) $ \(ref, update) -> + forM_ (zip newRefs updates) \(ref, update) -> emitOp $ PrimEffect (Var ref) $ MTell update return ans where diff --git a/src/lib/Parser.hs b/src/lib/Parser.hs index 62dc532f7..a167af975 100644 --- a/src/lib/Parser.hs +++ b/src/lib/Parser.hs @@ -383,8 +383,8 @@ funDefLet = label "function definition" $ mayBreak $ do let bs = map classAsBinder cs ++ argBinders let funTy = buildPiType bs eff ty let letBinder = (v, Just funTy) - let lamBinders = flip map bs $ \(p,_, arr) -> ((p,Nothing), arr) - return $ \body -> ULet PlainLet letBinder (buildLam lamBinders body) + let lamBinders = flip map bs \(p,_, arr) -> ((p,Nothing), arr) + return \body -> ULet PlainLet letBinder (buildLam lamBinders body) where classAsBinder :: UType -> (UPat, UType, UArrow) classAsBinder ty = (ns underscorePat, ty, ClassArrow) @@ -892,7 +892,7 @@ prefixNegOp :: Operator Parser UExpr prefixNegOp = Prefix $ label "negation" $ do ((), pos) <- withPos $ sym "-" let f = WithSrc (Just pos) "neg" - return $ \case + return \case -- Special case: negate literals directly WithSrc litpos (IntLitExpr i) -> WithSrc (joinPos (Just pos) litpos) (IntLitExpr (-i)) @@ -914,7 +914,7 @@ infixArrow :: Parser (UType -> UType -> UType) infixArrow = do notFollowedBy (sym "=>") -- table arrows have special fixity (arr, pos) <- withPos $ arrow effects - return $ \a b -> WithSrc (Just pos) $ UPi (Nothing, a) arr b + return \a b -> WithSrc (Just pos) $ UPi (Nothing, a) arr b mkArrow :: Arrow -> UExpr -> UExpr -> UExpr mkArrow arr a b = joinSrc a b $ UPi (Nothing, a) arr b @@ -959,7 +959,7 @@ inpostfix' :: Parser a -> Parser (a -> Maybe a -> a) -> Operator Parser a inpostfix' p op = Postfix $ do f <- op rest <- optional p - return $ \x -> f x rest + return \x -> f x rest mkName :: String -> Name mkName s = Name SourceName (fromString s) 0 diff --git a/src/lib/Serialize.hs b/src/lib/Serialize.hs index 11b552be6..602fdeed2 100644 --- a/src/lib/Serialize.hs +++ b/src/lib/Serialize.hs @@ -31,7 +31,7 @@ getDexString :: Val -> IO String getDexString (DataCon _ _ 0 [_, xs]) = do let (TabTy b _) = getType xs idxs <- indices $ getType b - forM idxs $ \i -> do + forM idxs \i -> do ~(Con (Lit (Word8Lit c))) <- evalBlock mempty (Block Empty (App xs i)) return $ toEnum $ fromIntegral c getDexString x = error $ "Not a string: " ++ pprint x @@ -49,7 +49,7 @@ prettyVal val = case val of _ -> "@" <> pretty idxSet -- Otherwise, show explicit index set -- Pretty-print elements. idxs <- indices idxSet - elems <- forM idxs $ \idx -> do + elems <- forM idxs \idx -> do atom <- evalBlock mempty $ snd $ applyAbs abs idx case atom of Con (Lit (Word8Lit c)) -> diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index 9f3b256da..d1c03e4fd 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -61,7 +61,7 @@ hoistDepDataCons scope (Module Simp decls bindings) = where (bindings', (_, decls')) = flip runEmbed scope $ do mapM_ emitDecl decls - forM bindings $ \(ty, info) -> case info of + forM bindings \(ty, info) -> case info of LetBound ann x | isData ty -> do x' <- emit x return (ty, LetBound ann $ Atom x') _ -> return (ty, info) @@ -89,7 +89,7 @@ simplifyDecl (Let ann b expr) = do simplifyStandalone :: Expr -> SimplifyM Atom simplifyStandalone (Atom (LamVal b body)) = do b' <- mapM substEmbedR b - buildLam b' PureArrow $ \x -> + buildLam b' PureArrow \x -> extendR (b@>x) $ simplifyBlock body simplifyStandalone block = error $ "@noinline decorator applied to non-function" ++ pprint block @@ -139,9 +139,9 @@ simplifyAtom atom = case atom of case simplifyCase e' alts of Just (env, result) -> extendR env $ simplifyAtom result Nothing -> do - alts' <- forM alts $ \(Abs bs a) -> do + alts' <- forM alts \(Abs bs a) -> do bs' <- mapM (mapM substEmbedR) bs - (Abs bs'' b) <- buildNAbs bs' $ \xs -> extendR (newEnv bs' xs) $ simplifyAtom a + (Abs bs'' b) <- buildNAbs bs' \xs -> extendR (newEnv bs' xs) $ simplifyAtom a case b of Block Empty (Atom r) -> return $ Abs bs'' r _ -> error $ "Nontrivial block in ACase simplification" @@ -192,7 +192,7 @@ simplifyLams numArgs lam = do Left res -> (res, Nothing) Right (dat, (ctx, recon), atomf) -> ( mkConsList $ (toList dat) ++ (toList ctx) - , Just $ \vals -> do + , Just \vals -> do (datEls', ctxEls') <- splitAt (length dat) <$> unpackConsList vals let dat' = restructure datEls' dat let ctx' = restructure ctxEls' ctx @@ -200,7 +200,7 @@ simplifyLams numArgs lam = do ) go n scope ~(Block Empty (Atom (Lam (Abs b (arr, body))))) = do b' <- mapM substEmbedR b - buildLamAux b' (\x -> extendR (b@>x) $ substEmbedR arr) $ \x@(Var v) -> do + buildLamAux b' (\x -> extendR (b@>x) $ substEmbedR arr) \x@(Var v) -> do let scope' = scope <> v @> (varType v, LamBound (void arr)) extendR (b@>x) $ go (n-1) scope' body @@ -278,7 +278,7 @@ separateDataComponent localVars v = do True -> nubCtx t False -> h : (nubCtx t) result = nubCtx $ toList ll - inv ctx' result' = for ll $ \x -> case elemIndex x (toList ctx) of + inv ctx' result' = for ll \x -> case elemIndex x (toList ctx) of Just i -> (toList ctx') !! i Nothing -> result' !! (fromJust $ elemIndex x result) @@ -299,7 +299,7 @@ simplifyExpr expr = case expr of case all isCurriedFun alts of True -> return $ ACase e (fmap appAlt alts) rty' False -> do - let alts' = for alts $ \(Abs bs a) -> Abs bs $ Block Empty (App a x') + let alts' = for alts \(Abs bs a) -> Abs bs $ Block Empty (App a x') dropSub $ simplifyExpr $ Case e alts' rty' where isCurriedFun alt = case alt of @@ -321,16 +321,16 @@ simplifyExpr expr = case expr of Nothing -> do if isData resultTy' then do - alts' <- forM alts $ \(Abs bs body) -> do + alts' <- forM alts \(Abs bs body) -> do bs' <- mapM (mapM substEmbedR) bs - buildNAbs bs' $ \xs -> extendR (newEnv bs' xs) $ simplifyBlock body + buildNAbs bs' \xs -> extendR (newEnv bs' xs) $ simplifyBlock body emit $ Case e' alts' resultTy' else do -- Construct the blocks of new cases. The results will only get replaced -- later, once we learn the closures of the non-data component of each case. - (alts', facs) <- liftM unzip $ forM alts $ \(Abs bs body) -> do + (alts', facs) <- liftM unzip $ forM alts \(Abs bs body) -> do bs' <- mapM (mapM substEmbedR) bs - buildNAbsAux bs' $ \xs -> do + buildNAbsAux bs' \xs -> do ~(Right fac@(dat, (ctx, _), _)) <- extendR (newEnv bs' xs) $ defunBlock (boundVars bs') body -- NB: The return value here doesn't really matter as we're going to replace it afterwards. return (mkConsList $ toList dat ++ toList ctx, fac) @@ -361,9 +361,9 @@ simplifyExpr expr = case expr of -- a single output. This can probably be made quite a bit faster. -- NB: All the non-data trees have the same structure, so we pick an arbitrary one. nondatTree <- (\(_, (ctx, rec), _) -> rec dat ctx) $ head facs - nondat <- forM (enumerate nondatTree) $ \(i, _) -> do - aalts <- forM facs $ \(_, (ctx, rec), _) -> do - Abs bs' b <- buildNAbs (toNest $ toList $ fmap (Ignore . getType) ctx) $ \ctxVals -> + nondat <- forM (enumerate nondatTree) \(i, _) -> do + aalts <- forM facs \(_, (ctx, rec), _) -> do + Abs bs' b <- buildNAbs (toNest $ toList $ fmap (Ignore . getType) ctx) \ctxVals -> ((!! i) . toList) <$> rec dat (restructure ctxVals ctx) case b of Block Empty (Atom r) -> return $ Abs bs' r @@ -441,7 +441,7 @@ simplifyHof hof = case hof of ans <- emit $ Hof $ For d lam' case recon of Nothing -> return ans - Just f -> buildLam i TabArrow $ \i' -> app ans i' >>= f + Just f -> buildLam i TabArrow \i' -> app ans i' >>= f Tile d fT fS -> do ~(fT', Nothing) <- simplifyLam fT ~(fS', Nothing) <- simplifyLam fS @@ -495,7 +495,7 @@ exceptToMaybeBlock (Block (Nest (Let _ b expr) decls) result) = do JustAtom _ x -> extendR (b@>x) $ exceptToMaybeBlock $ Block decls result NothingAtom _ -> return $ NothingAtom a _ -> do - emitMaybeCase maybeResult (return $ NothingAtom a) $ \x -> do + emitMaybeCase maybeResult (return $ NothingAtom a) \x -> do extendR (b@>x) $ exceptToMaybeBlock $ Block decls result exceptToMaybeExpr :: Expr -> SubstEmbed Atom @@ -505,27 +505,27 @@ exceptToMaybeExpr expr = do Case e alts resultTy -> do e' <- substEmbedR e resultTy' <- substEmbedR $ MaybeTy resultTy - alts' <- forM alts $ \(Abs bs body) -> do + alts' <- forM alts \(Abs bs body) -> do bs' <- substEmbedR bs - buildNAbs bs' $ \xs -> extendR (newEnv bs' xs) $ exceptToMaybeBlock body + buildNAbs bs' \xs -> extendR (newEnv bs' xs) $ exceptToMaybeBlock body emit $ Case e' alts' resultTy' Atom x -> substEmbedR $ JustAtom (getType x) x Op (ThrowException _) -> return $ NothingAtom a Hof (For ann ~(Lam (Abs b (_, body)))) -> do b' <- substEmbedR b - maybes <- buildForAnn ann b' $ \i -> extendR (b@>i) $ exceptToMaybeBlock body + maybes <- buildForAnn ann b' \i -> extendR (b@>i) $ exceptToMaybeBlock body catMaybesE maybes Hof (RunState s lam) -> do s' <- substEmbedR s let BinaryFunVal _ b _ body = lam - result <- emitRunState "ref" s' $ \ref -> + result <- emitRunState "ref" s' \ref -> extendR (b@>ref) $ exceptToMaybeBlock body (maybeAns, newState) <- fromPair result - emitMaybeCase maybeAns (return $ NothingAtom a) $ \ans -> + emitMaybeCase maybeAns (return $ NothingAtom a) \ans -> return $ JustAtom a $ PairVal ans newState Hof (While ~(Lam (Abs _ (_, body)))) -> do eff <- getAllowedEffects - lam <- buildLam (Ignore UnitTy) (PlainArrow eff) $ \_ -> + lam <- buildLam (Ignore UnitTy) (PlainArrow eff) \_ -> exceptToMaybeBlock body runMaybeWhile lam _ | not (hasExceptions expr) -> do diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index 98a303617..0fc83b9a2 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -10,10 +10,8 @@ {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE StrictData #-} {-# LANGUAGE DeriveFunctor #-} -{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE Rank2Types #-} -{-# LANGUAGE LambdaCase #-} module Syntax ( Type, Kind, BaseType (..), ScalarBaseType (..), @@ -189,14 +187,14 @@ labeledSingleton label value = LabeledItems $ M.singleton label (value NE.:|[]) reflectLabels :: LabeledItems a -> LabeledItems (Label, Int) reflectLabels (LabeledItems items) = LabeledItems $ - flip M.mapWithKey items $ \k xs -> fmap (\(i,_) -> (k,i)) (enumerate xs) + flip M.mapWithKey items \k xs -> fmap (\(i,_) -> (k,i)) (enumerate xs) getLabels :: LabeledItems a -> [Label] getLabels labeledItems = map fst $ toList $ reflectLabels labeledItems withLabels :: LabeledItems a -> LabeledItems (Label, Int, a) withLabels (LabeledItems items) = LabeledItems $ - flip M.mapWithKey items $ \k xs -> fmap (\(i,a) -> (k,i,a)) (enumerate xs) + flip M.mapWithKey items \k xs -> fmap (\(i,a) -> (k,i,a)) (enumerate xs) lookupLabel :: LabeledItems a -> Label -> Maybe a lookupLabel (LabeledItems items) l = case M.lookup l items of @@ -684,10 +682,10 @@ throwIf True e s = throw e s throwIf False _ _ = return () modifyErr :: MonadError e m => m a -> (e -> e) -> m a -modifyErr m f = catchError m $ \e -> throwError (f e) +modifyErr m f = catchError m \e -> throwError (f e) addContext :: MonadError Err m => String -> m a -> m a -addContext s m = modifyErr m $ \(Err e p s') -> Err e p (s' ++ "\n" ++ s) +addContext s m = modifyErr m \(Err e p s') -> Err e p (s' ++ "\n" ++ s) addSrcContext :: MonadError Err m => SrcCtx -> m a -> m a addSrcContext ctx m = modifyErr m updateErr @@ -698,9 +696,9 @@ addSrcContext ctx m = modifyErr m updateErr catchIOExcept :: (MonadIO m , MonadError Err m) => IO a -> m a catchIOExcept m = (liftIO >=> liftEither) $ (liftM Right m) `catches` - [ Handler $ \(e::Err) -> return $ Left e - , Handler $ \(e::IOError) -> return $ Left $ Err DataIOErr Nothing $ show e - , Handler $ \(e::SomeException) -> return $ Left $ Err CompilerErr Nothing $ show e + [ Handler \(e::Err) -> return $ Left e + , Handler \(e::IOError) -> return $ Left $ Err DataIOErr Nothing $ show e + , Handler \(e::SomeException) -> return $ Left $ Err CompilerErr Nothing $ show e ] liftEitherIO :: (Exception e, MonadIO m) => Either e a -> m a diff --git a/src/lib/TopLevel.hs b/src/lib/TopLevel.hs index d7269312d..9399ecfdf 100644 --- a/src/lib/TopLevel.hs +++ b/src/lib/TopLevel.hs @@ -78,7 +78,7 @@ evalSourceBlock opts env block = do Right env' -> return (env' , Result outs' (Right ())) runTopPassM :: Bool -> EvalConfig -> TopPassM a -> IO (Except a, [Output]) -runTopPassM bench opts m = runLogger (logFile opts) $ \logger -> +runTopPassM bench opts m = runLogger (logFile opts) \logger -> runExceptT $ catchIOExcept $ runReaderT m $ TopPassEnv logger bench opts evalSourceBlockM :: TopEnv -> SourceBlock -> TopPassM TopEnv @@ -97,7 +97,7 @@ evalSourceBlockM env block = case sbContents block of logTop $ HtmlOut s ExportFun name -> do f <- evalUModuleVal env v m - void $ traverseLiterals f $ \val -> case val of + void $ traverseLiterals f \val -> case val of PtrLit _ _ -> liftEitherIO $ throw CompilerErr $ "Can't export functions with captured pointers (not implemented)." _ -> return $ Con $ Lit val @@ -119,7 +119,7 @@ processLogs :: LogLevel -> [Output] -> [Output] processLogs logLevel logs = case logLevel of LogAll -> logs LogNothing -> [] - LogPasses passes -> flip filter logs $ \l -> case l of + LogPasses passes -> flip filter logs \l -> case l of PassInfo pass _ | pass `elem` passes -> True | otherwise -> False _ -> False @@ -249,7 +249,7 @@ logTop x = do abstractPtrLiterals :: Block -> ([IBinder], [LitVal], Block) abstractPtrLiterals block = flip evalState mempty $ do - block' <- traverseLiterals block $ \val -> case val of + block' <- traverseLiterals block \val -> case val of PtrLit ty ptr -> do ptrName <- gets $ M.lookup (ty, ptr) . fst case ptrName of diff --git a/src/lib/Type.hs b/src/lib/Type.hs index 39af39287..0e700202d 100644 --- a/src/lib/Type.hs +++ b/src/lib/Type.hs @@ -74,7 +74,7 @@ instance Checkable Module where checkValid m@(Module ir decls bindings) = addContext ("Checking module:\n" ++ pprint m) $ asCompilerErr $ do let env = freeVars m - forM_ (envNames env) $ \v -> when (not $ isGlobal $ v:>()) $ + forM_ (envNames env) \v -> when (not $ isGlobal $ v:>()) $ throw CompilerErr $ "Non-global free variable in module: " ++ pprint v addContext "Checking IR variant" $ checkModuleVariant m addContext "Checking body types" $ do @@ -152,7 +152,7 @@ instance HasType Atom where ACase e alts resultTy -> checkCase e alts resultTy DataConRef ~def@(DataDef _ paramBs [DataConDef _ argBs]) params args -> do checkEq (length paramBs) (length params) - forM_ (zip (toList paramBs) (toList params)) $ \(b, param) -> + forM_ (zip (toList paramBs) (toList params)) \(b, param) -> param |: binderAnn b let argBs' = applyNaryAbs (Abs paramBs argBs) params checkDataConRefBindings argBs' args @@ -203,7 +203,7 @@ typeCheckVar v@(name:>annTy) = do annTy |: TyKind when (annTy == EffKind) $ throw CompilerErr "Effect variables should only occur in effect rows" - checkWithEnv $ \(env, _) -> case envLookup env v of + checkWithEnv \(env, _) -> case envLookup env v of Nothing -> throw CompilerErr $ "Lookup failed: " ++ pprint v Just (ty, _) -> assertEq annTy ty $ "Annotation on var: " ++ pprint name return annTy @@ -227,19 +227,19 @@ instance HasType Expr where checkCase :: HasType b => Atom -> [AltP b] -> Type -> TypeM Type checkCase e alts resultTy = do - checkWithEnv $ \_ -> do + checkWithEnv \_ -> do ety <- typeCheck e case ety of TypeCon def params -> do let cons = applyDataDefParams def params checkEq (length cons) (length alts) - forM_ (zip cons alts) $ \((DataConDef _ bs'), (Abs bs body)) -> do + forM_ (zip cons alts) \((DataConDef _ bs'), (Abs bs body)) -> do checkEq bs' bs resultTy' <- flip (foldr withBinder) bs $ typeCheck body checkEq resultTy resultTy' VariantTy (NoExt types) -> do checkEq (length types) (length alts) - forM_ (zip (toList types) alts) $ \(ty, (Abs bs body)) -> do + forM_ (zip (toList types) alts) \(ty, (Abs bs body)) -> do [b] <- pure $ toList bs checkEq (getType b) ty resultTy' <- flip (foldr withBinder) bs $ typeCheck body @@ -319,7 +319,7 @@ instance HasType Block where instance HasType Binder where typeCheck b = do - checkWithEnv $ \(env, _) -> checkNoShadow env b + checkWithEnv \(env, _) -> checkNoShadow env b let ty = binderType b ty |: TyKind return ty @@ -344,7 +344,7 @@ infixr 7 |: checkEq reqTy ty checkEq :: (Show a, Pretty a, Eq a) => a -> a -> TypeM () -checkEq reqTy ty = checkWithEnv $ \_ -> assertEq reqTy ty "" +checkEq reqTy ty = checkWithEnv \_ -> assertEq reqTy ty "" withBinder :: Binder -> TypeM a -> TypeM a withBinder b m = typeCheck b >> extendTypeEnv (boundVars b) m @@ -407,7 +407,7 @@ instance CoreVariant Expr where Hof e -> checkVariant e >> forM_ e checkVariant Case e alts _ -> do checkVariant e - forM_ alts $ \(Abs _ body) -> checkVariant body + forM_ alts \(Abs _ body) -> checkVariant body instance CoreVariant Decl where -- let annotation restrictions? @@ -470,7 +470,7 @@ goneBy ir = do when (curIR >= ir) $ throw IRVariantErr $ "shouldn't appear after " ++ show ir addExpr :: (Pretty e, MonadError Err m) => e -> m a -> m a -addExpr x m = modifyErr m $ \e -> case e of +addExpr x m = modifyErr m \e -> case e of Err IRVariantErr ctx s -> Err CompilerErr ctx (s ++ ": " ++ pprint x) _ -> e @@ -478,11 +478,11 @@ addExpr x m = modifyErr m $ \e -> case e of checkEffRow :: EffectRow -> TypeM () checkEffRow (EffectRow effs effTail) = do - forM_ effs $ \eff -> case eff of + forM_ effs \eff -> case eff of RWSEffect _ v -> Var (v:>TyKind) |: TyKind ExceptionEffect -> return () - forM_ effTail $ \v -> do - checkWithEnv $ \(env, _) -> case envLookup env (v:>()) of + forM_ effTail \v -> do + checkWithEnv \(env, _) -> case envLookup env (v:>()) of Nothing -> throw CompilerErr $ "Lookup failed: " ++ pprint v Just (ty, _) -> assertEq EffKind ty "Effect var" @@ -490,7 +490,7 @@ declareEff :: Effect -> TypeM () declareEff eff = declareEffs $ oneEffect eff declareEffs :: EffectRow -> TypeM () -declareEffs effs = checkWithEnv $ \(_, allowedEffects) -> +declareEffs effs = checkWithEnv \(_, allowedEffects) -> checkExtends allowedEffects effs checkExtends :: MonadError Err m => EffectRow -> EffectRow -> m () @@ -499,7 +499,7 @@ checkExtends allowed (EffectRow effs effTail) = do case effTail of Just _ -> assertEq allowedEffTail effTail "" Nothing -> return () - forM_ effs $ \eff -> unless (eff `elem` allowedEffs) $ + forM_ effs \eff -> unless (eff `elem` allowedEffs) $ throw CompilerErr $ "Unexpected effect: " ++ pprint eff ++ "\nAllowed: " ++ pprint allowed @@ -517,8 +517,8 @@ ioEffect = RWSEffect State theWorld checkLabeledRow :: ExtLabeledItems Type Name -> TypeM () checkLabeledRow (Ext items rest) = do mapM_ (|: TyKind) items - forM_ rest $ \v -> do - checkWithEnv $ \(env, _) -> case envLookup env (v:>()) of + forM_ rest \v -> do + checkWithEnv \(env, _) -> case envLookup env (v:>()) of Nothing -> throw CompilerErr $ "Lookup failed: " ++ pprint v Just (ty, _) -> assertEq LabeledRowKind ty "Labeled row var" @@ -528,7 +528,7 @@ labeledRowDifference :: ExtLabeledItems Type Name labeledRowDifference (Ext (LabeledItems items) rest) (Ext (LabeledItems subitems) subrest) = do -- Check types in the right. - _ <- flip M.traverseWithKey subitems $ \label subtypes -> + _ <- flip M.traverseWithKey subitems \label subtypes -> case M.lookup label items of Just types -> assertEq subtypes (NE.fromList $ NE.take (length subtypes) types) $ @@ -556,7 +556,7 @@ checkWithEnv check = do CheckWith env -> check env updateTypeEnv :: (TypeEnv -> TypeEnv) -> TypeM a -> TypeM a -updateTypeEnv f m = flip local m $ fmap $ \(env, eff) -> (f env, eff) +updateTypeEnv f m = flip local m $ fmap \(env, eff) -> (f env, eff) extendTypeEnv :: TypeEnv -> TypeM a -> TypeM a extendTypeEnv new m = updateTypeEnv (<> new) m @@ -568,7 +568,7 @@ extendAllowedEffect :: Effect -> TypeM () -> TypeM () extendAllowedEffect eff m = updateAllowedEff (extendEffect eff) m updateAllowedEff :: (EffectRow -> EffectRow) -> TypeM a -> TypeM a -updateAllowedEff f m = flip local m $ fmap $ \(env, eff) -> (env, f eff) +updateAllowedEff f m = flip local m $ fmap \(env, eff) -> (env, f eff) withAllowedEff :: EffectRow -> TypeM a -> TypeM a withAllowedEff eff m = updateAllowedEff (const eff) m @@ -687,7 +687,7 @@ typeCheckOp op = case op of ToOrdinal i -> typeCheck i $> IdxRepTy IdxSetSize i -> typeCheck i $> IdxRepTy FFICall _ ansTy args -> do - forM_ args $ \arg -> do + forM_ args \arg -> do argTy <- typeCheck arg case argTy of BaseTy _ -> return () @@ -815,7 +815,7 @@ typeCheckOp op = case op of t |: TyKind x |: Word8Ty (TypeCon (DataDef _ _ dataConDefs) _) <- return t - forM_ dataConDefs $ \(DataConDef _ binders) -> + forM_ dataConDefs \(DataConDef _ binders) -> assertEq binders Empty "Not an enum" return t From 68cefc8cfc5389b67c55c185baea12f9478c2034 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Mon, 4 Jan 2021 15:06:12 -0500 Subject: [PATCH 074/105] Infer types of implicit implicit arguments. I guess that makes them implicitly typed implicit implicit arguments. --- examples/mcmc.dx | 3 +- examples/particle-swarm-optimizer.dx | 1 - lib/prelude.dx | 19 ++++------- src/lib/Inference.hs | 21 ++++++------ src/lib/PPrint.hs | 8 +---- src/lib/Parser.hs | 50 ++++++++++++---------------- src/lib/Syntax.hs | 10 +++--- tests/type-tests.dx | 6 ++-- 8 files changed, 49 insertions(+), 69 deletions(-) diff --git a/examples/mcmc.dx b/examples/mcmc.dx index a3bcbd314..d205cdbd7 100644 --- a/examples/mcmc.dx +++ b/examples/mcmc.dx @@ -28,8 +28,7 @@ def propose accept = logDensity proposal > (logDensity cur + log (rand k)) select accept proposal cur -def meanAndCovariance (n:Type) ?-> (d:Type) ?-> - (xs:n=>d=>Float) : (d=>Float & d=>d=>Float) = +def meanAndCovariance (xs:n=>d=>Float) : (d=>Float & d=>d=>Float) = xsMean : d=>Float = (for i. sum for j. xs.j.i) / IToF (size n) xsCov : d=>d=>Float = (for i i'. sum for j. (xs.j.i' - xsMean.i') * diff --git a/examples/particle-swarm-optimizer.dx b/examples/particle-swarm-optimizer.dx index 21b0cab5d..58227779e 100644 --- a/examples/particle-swarm-optimizer.dx +++ b/examples/particle-swarm-optimizer.dx @@ -57,7 +57,6 @@ We have **arguments**: ' **Returns**: the optimal point found with-in the bounds on the input domain of `f`. def optimize - (d:Type) ?-> (np':Int) -- number of particles (niter:Int) -- number of iterations (key:Key) -- random seed diff --git a/lib/prelude.dx b/lib/prelude.dx index 1831ef83e..4918585c0 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -233,7 +233,6 @@ def fstRef (ref: Ref h (a & b)) : Ref h a = %fstRef ref def sndRef (ref: Ref h (a & b)) : Ref h b = %sndRef ref def runReader - (eff:Effects) ?-> (init:r) (action: (h:Type ?-> Ref h r -> {Read h|eff} a)) : {|eff} a = @@ -241,27 +240,23 @@ def runReader %runReader init explicitAction def withReader - (eff:Effects) ?-> (init:r) (action: (h:Type ?-> Ref h r -> {Read h|eff} a)) : {|eff} a = runReader init action def runAccum - (eff:Effects) ?-> (action: (h:Type ?-> Ref h w -> {Accum h|eff} a)) : {|eff} (a & w) = def explicitAction (h':Type) (ref:Ref h' w) : {Accum h'|eff} a = action ref %runWriter explicitAction def yieldAccum - (eff:Effects) ?-> (action: (h:Type ?-> Ref h w -> {Accum h|eff} a)) : {|eff} w = snd $ runAccum action def runState - (eff:Effects) ?-> (init:s) (action: h:Type ?-> Ref h s -> {State h |eff} a) : {|eff} (a & s) = @@ -269,13 +264,11 @@ def runState %runState init explicitAction def withState - (eff:Effects) ?-> (init:s) (action: h:Type ?-> Ref h s -> {State h |eff} a) : {|eff} a = fst $ runState init action def yieldState - (eff:Effects) ?-> (init:s) (action: h:Type ?-> Ref h s -> {State h |eff} a) : {|eff} s = snd $ runState init action @@ -449,10 +442,10 @@ def unsafeFromOrdinal (n : Type) (i : Int) : n = %unsafeFromOrdinal n i def iota (n:Type) : n=>Int = view i. ordinal i -- TODO: we want Eq and Ord for all index sets, not just `Fin n` -instance (n:Int) ?-> Eq (Fin n) +instance Eq (Fin n) (==) = \x y. ordinal x == ordinal y -instance (n:Int) ?-> Ord (Fin n) +instance Ord (Fin n) (>) = \x y. ordinal x > ordinal y (<) = \x y. ordinal x < ordinal y @@ -625,7 +618,7 @@ def newKey (x:Int) : Key = hash (IToI64 0) x def many (f:Key->a) (k:Key) (i:n) : a = f (hash k (ordinal i)) def ixkey (k:Key) (i:n) : Key = hash k (ordinal i) def ixkey2 (k:Key) (i:n) (j:m) : Key = hash (hash k (ordinal i)) (ordinal j) -def splitKey (n:Int) ?-> (k:Key) : Fin n => Key = for i. ixkey k i +def splitKey (k:Key) : Fin n => Key = for i. ixkey k i def rand (k:Key) : Float = unsafeIO do F64ToF $ %ffi randunif Float64 k def randVec (n:Int) (f: Key -> a) (k: Key) : Fin n => a = for i:(Fin n). f (ixkey k i) @@ -1036,7 +1029,7 @@ def fopen (path:String) (mode:StreamMode) : {State World} (Stream mode) = withCString modeStr \(MkCString modePtr). MkStream $ %ffi fopen RawPtr pathPtr modePtr -def fclose (mode:StreamMode) ?-> (stream:Stream mode) : {State World} Unit = +def fclose (stream:Stream mode) : {State World} Unit = (MkStream stream') = stream %ffi fclose Int64 stream' () @@ -1049,7 +1042,7 @@ def fwrite (stream:Stream WriteMode) (s:String) : {State World} Unit = %ffi fflush Int64 stream' () -def while (eff:Effects) ?-> (body: Unit -> {|eff} Bool) : {|eff} Unit = +def while (body: Unit -> {|eff} Bool) : {|eff} Unit = body' : Unit -> {|eff} Word8 = \_. BToW8 $ body () %while body' @@ -1237,7 +1230,7 @@ instance Arbitrary Int32 instance [Arbitrary a] Arbitrary (n=>a) arb = \key. for i. arb $ ixkey key i -instance (n:Int) ?-> Arbitrary (Fin n) +instance Arbitrary (Fin n) arb = randIdx 'Control flow diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index 73d673554..263955d12 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -152,20 +152,20 @@ checkOrInferRho (WithSrc pos expr) reqTy = do addEffects $ arrowEff arr' appVal <- emitZonked $ App fVal xVal' instantiateSigma appVal >>= matchRequirement - UPi (pat, kind) arr ty -> do + UPi (pat, ann) arr ty -> do -- TODO: make sure there's no effect if it's an implicit or table arrow -- TODO: check leaks - kind' <- checkUType kind + ann' <- checkAnn ann piTy <- case pat of - Just pat' -> withNameHint ("pat" :: Name) $ buildPi b \x -> - withBindPat pat' x $ (,) <$> mapM checkUEffRow arr <*> checkUType ty - where b = case pat' of + UnderscoreUPat -> buildPi (Ignore ann') $ const $ + (,) <$> mapM checkUEffRow arr <*> checkUType ty + _ -> withNameHint ("pat" :: Name) $ buildPi b \x -> + withBindPat pat x $ (,) <$> mapM checkUEffRow arr <*> checkUType ty + where b = case pat of -- Note: The binder name becomes part of the type, so we -- need to keep the same name used in the pattern. - WithSrc _ (UPatBinder (Bind (v:>()))) -> Bind (v:>kind') - _ -> Ignore kind' - Nothing -> buildPi (Ignore kind') $ const $ - (,) <$> mapM checkUEffRow arr <*> checkUType ty + WithSrc _ (UPatBinder (Bind (v:>()))) -> Bind (v:>ann') + _ -> Ignore ann' matchRequirement piTy UDecl decl body -> do env <- inferUDecl False decl @@ -526,7 +526,8 @@ checkUEffRow (EffectRow effs t) = do checkUEff :: Effect -> UInferM Effect checkUEff eff = case eff of RWSEffect rws region -> do - (Var (v:>TyKind)) <- lookupSourceVar (region:>()) + (Var (v:>ty)) <- lookupSourceVar (region:>()) + constrainEq TyKind ty return $ RWSEffect rws v ExceptionEffect -> return ExceptionEffect diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index 9757b9fea..aa4883ffb 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -549,7 +549,7 @@ instance PrettyPrec UExpr' where where kw = case dir of Fwd -> "for" Rev -> "rof" UPi binder arr ty -> atPrec LowestPrec $ - prettyUPiBinder binder <+> pretty arr <+> pLowest ty + prettyUBinder binder <+> pretty arr <+> pLowest ty UDecl decl body -> atPrec LowestPrec $ align $ p decl <> hardline <> pLowest body UHole -> atPrec ArgPrec "_" @@ -614,12 +614,6 @@ prettyUBinder (pat, ann) = p pat <> annDoc where Just ty -> ":" <> pApp ty Nothing -> mempty -prettyUPiBinder :: UPiPatAnn -> Doc ann -prettyUPiBinder (pat, ann) = patDoc <> p ann where - patDoc = case pat of - Just pat' -> pApp pat' <> ":" - Nothing -> mempty - spaced :: (Foldable f, Pretty a) => f a -> Doc ann spaced xs = hsep $ map p $ toList xs diff --git a/src/lib/Parser.hs b/src/lib/Parser.hs index a167af975..8472d5d59 100644 --- a/src/lib/Parser.hs +++ b/src/lib/Parser.hs @@ -251,7 +251,7 @@ findImplicitImplicitArgNames typ = filter isLowerCaseName $ envNames $ -- recursive steps UVar _ -> mempty UPi (p, ann) _ ty -> - findVarsInAppLHS ann <> (findVarsInAppLHS ty `envDiff` boundUVars p) + foldMap findVarsInAppLHS ann <> (findVarsInAppLHS ty `envDiff` boundUVars p) UApp _ f x -> findVarsInAppLHS f <> findVarsInAppLHS x ULam (p, ann) _ x -> foldMap findVarsInAppLHS ann <> (findVarsInAppLHS x `envDiff` boundUVars p) @@ -284,12 +284,9 @@ addImplicitImplicitArgs (Just typ) ex = addImplicitArg :: Name -> (UType, UExpr) -> (UType, UExpr) addImplicitArg v (ty, e) = - ( ns $ UPi (Just uPat, uTyKind) ImplicitArrow ty - , ns $ ULam (uPat, Just uTyKind) ImplicitArrow e) - where - uPat = ns $ nameToPat v - k = if v == mkName "eff" then EffectRowKind else TypeKind - uTyKind = ns $ UPrimExpr $ TCExpr k + ( ns $ UPi (uPat, Nothing) ImplicitArrow ty + , ns $ ULam (uPat, Nothing) ImplicitArrow e) + where uPat = ns $ nameToPat v superclassConstraints :: Parser [UType] superclassConstraints = optionalMonoid $ brackets $ uType `sepBy` sym "," @@ -349,12 +346,11 @@ instanceDef = do return $ UInstance ty' methods where addClassConstraint :: UType -> UType -> UType - addClassConstraint c ty = ns $ UPi (Nothing, c) ClassArrow ty + addClassConstraint c ty = ns $ UPi (UnderscoreUPat, Just c) ClassArrow ty addImplicitArg :: Name -> UType -> UType addImplicitArg v ty = - ns $ UPi (Just (ns $ nameToPat v), uTyKind) ImplicitArrow ty - where uTyKind = ns $ UPrimExpr $ TCExpr TypeKind + ns $ UPi (ns $ nameToPat v, Nothing) ImplicitArrow ty instanceMethod :: Parser (UVar, UExpr) instanceMethod = do @@ -386,26 +382,25 @@ funDefLet = label "function definition" $ mayBreak $ do let lamBinders = flip map bs \(p,_, arr) -> ((p,Nothing), arr) return \body -> ULet PlainLet letBinder (buildLam lamBinders body) where - classAsBinder :: UType -> (UPat, UType, UArrow) - classAsBinder ty = (ns underscorePat, ty, ClassArrow) + classAsBinder :: UType -> (UPat, Maybe UType, UArrow) + classAsBinder ty = (UnderscoreUPat, Just ty, ClassArrow) -defArg :: Parser (UPat, UType, UArrow) +defArg :: Parser (UPat, Maybe UType, UArrow) defArg = label "def arg" $ do (p, ty) <-parens ((,) <$> pat <*> annot uType) arr <- arrow (return ()) <|> return (PlainArrow ()) - return (p, ty, arr) + return (p, Just ty, arr) classConstraints :: Parser [UType] classConstraints = label "class constraints" $ optionalMonoid $ brackets $ mayNotPair $ uType `sepBy` sym "," -buildPiType :: [(UPat, UType, UArrow)] -> EffectRow -> UType -> UType +buildPiType :: [(UPat, Maybe UType, UArrow)] -> EffectRow -> UType -> UType buildPiType [] Pure ty = ty buildPiType [] _ _ = error "shouldn't be possible" -buildPiType ((p, patTy, arr):bs) eff resTy = WithSrc pos $ case bs of - [] -> UPi (Just p, patTy) (fmap (const eff ) arr) resTy - _ -> UPi (Just p, patTy) (fmap (const Pure) arr) $ buildPiType bs eff resTy - where WithSrc pos _ = patTy +buildPiType ((p, patTy, arr):bs) eff resTy = ns case bs of + [] -> UPi (p, patTy) (fmap (const eff ) arr) resTy + _ -> UPi (p, patTy) (fmap (const Pure) arr) $ buildPiType bs eff resTy effectiveType :: Parser (EffectRow, UType) effectiveType = (,) <$> effects <*> uType @@ -472,13 +467,10 @@ uForExpr = do <|> (keyWord Rof_KW $> (Rev, True )) e <- buildFor pos dir <$> (some patAnn <* argTerm) <*> blockOrExpr if trailingUnit - then return $ ns $ UDecl (ULet PlainLet (ns underscorePat, Nothing) e) $ + then return $ ns $ UDecl (ULet PlainLet (UnderscoreUPat, Nothing) e) $ ns unitExpr else return e -underscorePat :: UPat' -underscorePat = UPatBinder $ Ignore () - nameToPat :: Name -> UPat' nameToPat v = UPatBinder (Bind (v:>())) @@ -514,7 +506,7 @@ wrapUStatements statements = case statements of (s, pos):rest -> WithSrc (Just pos) $ case s of Left d -> UDecl d $ wrapUStatements rest Right e -> UDecl d $ wrapUStatements rest - where d = ULet PlainLet (ns underscorePat, Nothing) e + where d = ULet PlainLet (UnderscoreUPat, Nothing) e [] -> error "Shouldn't be reachable" uStatement :: Parser UStatement @@ -528,8 +520,8 @@ uPiType = withSrc $ UPi <$> piBinderPat <*> arrow effects <*> uType b <- annBinder return $ case b of Bind (n:>a@(WithSrc pos _)) -> - (Just $ WithSrc pos $ nameToPat n, a) - Ignore a -> (Nothing, a) + (WithSrc pos $ nameToPat n, Just a) + Ignore a -> (UnderscoreUPat, Just a) annBinder :: Parser UAnnBinder annBinder = try $ namedBinder <|> anonBinder @@ -613,7 +605,7 @@ leafPat = <|> brackets (UPatTable <$> leafPat `sepBy` sym ",") ) where pun pos l = WithSrc (Just pos) $ nameToPat $ mkName l - def pos = WithSrc (Just pos) $ underscorePat + def pos = WithSrc (Just pos) $ UPatBinder (Ignore ()) variantPat = parseVariant leafPat UPatVariant UPatVariantLift recordPat = UPatRecord <$> parseLabeledItems "," "=" leafPat (Just pun) (Just def) @@ -914,10 +906,10 @@ infixArrow :: Parser (UType -> UType -> UType) infixArrow = do notFollowedBy (sym "=>") -- table arrows have special fixity (arr, pos) <- withPos $ arrow effects - return \a b -> WithSrc (Just pos) $ UPi (Nothing, a) arr b + return \a b -> WithSrc (Just pos) $ UPi (UnderscoreUPat, Just a) arr b mkArrow :: Arrow -> UExpr -> UExpr -> UExpr -mkArrow arr a b = joinSrc a b $ UPi (Nothing, a) arr b +mkArrow arr a b = joinSrc a b $ UPi (UnderscoreUPat, Just a) arr b withSrc :: Parser a -> Parser (WithSrc a) withSrc p = do diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index 0fc83b9a2..ee0978d15 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -39,7 +39,7 @@ module Syntax ( freeVars, freeUVars, Subst, HasVars, BindsVars, Ptr, PtrType, AddressSpace (..), showPrimName, strToPrimName, primNameToStr, monMapSingle, monMapLookup, Direction (..), Limit (..), - UExpr, UExpr' (..), UType, UPatAnn, UPiPatAnn, UAnnBinder, UVar, + UExpr, UExpr' (..), UType, UPatAnn, UAnnBinder, UVar, UPat, UPat' (..), UModule (..), UDecl (..), UArrow, arrowEff, DataDef (..), DataConDef (..), UConDef (..), Nest (..), toNest, subst, deShadow, scopelessSubst, absArgType, applyAbs, makeAbs, @@ -63,7 +63,7 @@ module Syntax ( pattern Unlabeled, pattern NoExt, pattern LabeledRowKind, pattern NoLabeledItems, pattern InternalSingletonLabel, pattern EffKind, pattern NestOne, pattern NewTypeCon, pattern BinderAnn, - pattern ClassDictDef, pattern ClassDictCon) + pattern ClassDictDef, pattern ClassDictCon, pattern UnderscoreUPat) where import qualified Data.Map.Strict as M @@ -225,7 +225,7 @@ prefixExtLabeledItems items (Ext items' rest) = Ext (items <> items') rest type UExpr = WithSrc UExpr' data UExpr' = UVar UVar | ULam UPatAnn UArrow UExpr - | UPi UPiPatAnn Arrow UType + | UPi UPatAnn Arrow UType | UApp UArrow UExpr UExpr | UDecl UDecl UExpr | UFor Direction UPatAnn UExpr @@ -257,7 +257,6 @@ type UVar = VarP () type UBinder = BinderP () type UPatAnn = (UPat, Maybe UType) -type UPiPatAnn = (Maybe UPat, UType) type UAnnBinder = BinderP UType data UAlt = UAlt UPat UExpr deriving (Show, Generic) @@ -285,6 +284,9 @@ srcPos (WithSrc pos _) = pos instance IsString UExpr' where fromString s = UVar $ Name SourceName (fromString s) 0 :> () +pattern UnderscoreUPat :: UPat +pattern UnderscoreUPat = WithSrc Nothing (UPatBinder (Ignore ())) + -- === primitive constructors and operators === data PrimExpr e = diff --git a/tests/type-tests.dx b/tests/type-tests.dx index b9f1b1f0f..2261e6878 100644 --- a/tests/type-tests.dx +++ b/tests/type-tests.dx @@ -158,11 +158,11 @@ MyPair : Type -> Type = -- TODO: put source annotation on effect for a better message here fEff : Unit -> {| a} a = todo > Type error: -> Expected: EffKind -> Actual: Type +> Expected: Type +> Actual: EffKind > > fEff : Unit -> {| a} a = todo -> ^^^^^^^^^ +> ^^ :p for i:(Fin 7). sum for j:(Fin unboundName). 1.0 From dba12d031f99452c35ed8cd7d303294f2ffa87ec Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Tue, 5 Jan 2021 14:47:27 -0500 Subject: [PATCH 075/105] Make fixes suggested in review. --- src/lib/Embed.hs | 6 +++--- src/lib/Inference.hs | 10 +++++----- src/lib/PPrint.hs | 43 ++++++++++++++++++++++++++++++++++++++++++- src/lib/Syntax.hs | 20 ++++++++++---------- tests/adt-tests.dx | 14 +++++++------- 5 files changed, 67 insertions(+), 26 deletions(-) diff --git a/src/lib/Embed.hs b/src/lib/Embed.hs index c46397d64..330daa991 100644 --- a/src/lib/Embed.hs +++ b/src/lib/Embed.hs @@ -17,7 +17,7 @@ module Embed (emit, emitTo, emitAnn, emitOp, buildDepEffLam, buildLamAux, buildP app, add, mul, sub, neg, div', iadd, imul, isub, idiv, ilt, ieq, - fpow, flog, fLitLike, recGet, buildImplicitNaryLam, + fpow, flog, fLitLike, recGetHead, buildImplicitNaryLam, select, substEmbed, substEmbedR, emitUnpack, getUnpacked, fromPair, getFst, getSnd, getFstRef, getSndRef, naryApp, appReduce, appTryReduce, buildAbs, @@ -206,8 +206,8 @@ buildImplicitNaryLam (Nest b bs) body = bs' <- substEmbed (b@>x) bs buildImplicitNaryLam bs' \xs -> body $ x:xs -recGet :: Label -> Atom -> Atom -recGet l x = do +recGetHead :: Label -> Atom -> Atom +recGetHead l x = do let (RecordTy (Ext r _)) = getType x let i = fromJust $ elemIndex l $ map fst $ toList $ reflectLabels r getProjection [i] x diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index 263955d12..06cb409af 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -413,7 +413,7 @@ emitMethodGetters def@(DataDef _ paramBs (ClassDictDef _ _ methodTys)) = do forM_ (getLabels methodTys) \l -> do f <- buildImplicitNaryLam paramBs \params -> do buildLam (Bind ("d":> TypeCon def params)) ClassArrow \dict -> do - return $ recGet l $ getProjection [1] dict + return $ recGetHead l $ getProjection [1] dict let methodName = GlobalName $ fromString l checkNotInScope methodName emitTo methodName PlainLet $ Atom f @@ -424,10 +424,10 @@ emitSuperclassGetters def@(DataDef _ paramBs (ClassDictDef _ superclassTys _)) = forM_ (getLabels superclassTys) \l -> do f <- buildImplicitNaryLam paramBs \params -> do buildLam (Bind ("d":> TypeCon def params)) PureArrow \dict -> do - return $ recGet l $ getProjection [0] dict + return $ recGetHead l $ getProjection [0] dict getterName <- freshClassGenName emitTo getterName SuperclassLet $ Atom f -emitSuperclassGetter (DataDef _ _ _) = error "Not a class dictionary" +emitSuperclassGetters (DataDef _ _ _) = error "Not a class dictionary" checkNotInScope :: Name -> UInferM () checkNotInScope v = do @@ -486,7 +486,7 @@ checkInstance ty methods = case ty of ClassDictDef _ superclassTys methodTys -> do methods' <- liftM mkLabeledItems $ forM methods \((v:>()), rhs) -> do let v' = nameToLabel v - case lookupLabel methodTys v' of + case lookupLabelHead methodTys v' of Nothing -> throw TypeErr (pprint v ++ " is not a method of " ++ pprint className) Just methodTy -> do rhs' <- checkSigma rhs Suggest methodTy @@ -495,7 +495,7 @@ checkInstance ty methods = case ty of forM_ (reflectLabels methods') \(l,i) -> when (i > 0) $ throw TypeErr $ "Duplicate method: " ++ pprint l forM_ (reflectLabels methodTys) \(l,_) -> - case lookupLabel methods' l of + case lookupLabelHead methods' l of Nothing -> throw TypeErr $ "Missing method: " ++ pprint l Just _ -> return () return $ ClassDictCon def params superclassHoles methods' diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index aa4883ffb..600d42b38 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -21,6 +21,7 @@ import Data.Foldable (toList) import qualified Data.List.NonEmpty as NE import qualified Data.Map.Strict as M import qualified Data.ByteString.Lazy.Char8 as B +import Data.Maybe (fromMaybe) import Data.String (fromString) import Data.Text.Prettyprint.Doc.Render.Text import Data.Text.Prettyprint.Doc @@ -32,6 +33,7 @@ import Numeric import Env import Syntax +import Util (enumerate) -- Specifies what kinds of operations are allowed to be printed at this point. -- Printing at AppPrec level means that applications can be printed @@ -362,7 +364,7 @@ instance PrettyPrec Atom where "DataConRef" <+> p params <+> p args BoxedRef b ptr size body -> atPrec AppPrec $ "Box" <+> p b <+> "<-" <+> p ptr <+> "[" <> p size <> "]" <+> hardline <> "in" <+> p body - ProjectElt idxs x -> atPrec LowestPrec $ "project" <+> p idxs <+> p x + ProjectElt idxs x -> prettyProjection idxs x instance Pretty DataConRefBinding where pretty = prettyFromPrettyPrec instance PrettyPrec DataConRefBinding where @@ -374,6 +376,45 @@ fromInfix t = do (t'', ')') <- unsnoc t' return t'' +prettyProjection :: NE.NonEmpty Int -> Var -> DocPrec ann +prettyProjection idxs (name :> ty) = prettyPrec uproj where + -- Builds a source expression that performs the given projection. + uproj = UApp (PlainArrow ()) (nosrc ulam) (nosrc uvar) + ulam = ULam (upat, Nothing) (PlainArrow ()) (nosrc $ UVar $ target :> ()) + uvar = UVar $ name :> () + (_, upat, target) = buildProj idxs + + buildProj :: NE.NonEmpty Int -> (Type, UPat, Name) + buildProj (i NE.:| is) = let + -- Lazy Haskell trick: refer to `target` even though this function is + -- responsible for setting it! + (ty', pat', eltName) = case NE.nonEmpty is of + Just is' -> let (x, y, z) = buildProj is' in (x, y, Just z) + Nothing -> (ty, nosrc $ UPatBinder $ Bind $ target :> (), Nothing) + in case ty' of + TypeCon def params -> let + [DataConDef conName bs] = applyDataDefParams def params + b = toList bs !! i + pats = (\(j,_)-> if i == j then pat' else uignore) <$> enumerate bs + hint = case b of + Bind (n :> _) -> n + Ignore _ -> Name SourceName "elt" 0 + in ( binderAnn b, nosrc $ UPatCon conName pats, fromMaybe hint eltName) + RecordTy (NoExt types) -> let + ty'' = toList types !! i + pats = (\(j,_)-> if i == j then pat' else uignore) <$> enumerate types + (fieldName, _) = toList (reflectLabels types) !! i + hint = Name SourceName (fromString fieldName) 0 + in (ty'', nosrc $ UPatRecord $ NoExt pats, fromMaybe hint eltName) + PairTy x _ | i == 0 -> + (x, nosrc $ UPatPair pat' uignore, fromMaybe "a" eltName) + PairTy _ y | i == 1 -> + (y, nosrc $ UPatPair uignore pat', fromMaybe "b" eltName) + _ -> error "Bad projection" + + nosrc = WithSrc Nothing + uignore = nosrc $ UPatBinder $ Ignore () + prettyExtLabeledItems :: (PrettyPrec a, PrettyPrec b) => ExtLabeledItems a b -> Doc ann -> Doc ann -> DocPrec ann prettyExtLabeledItems (Ext (LabeledItems row) rest) separator bindwith = diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index ee0978d15..1f99e550f 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -28,8 +28,8 @@ module Syntax ( IExpr (..), IVal, ImpInstr (..), Backend (..), Device (..), IPrimOp, IVar, IBinder, IType, SetVal (..), MonMap (..), LitProg, IFunType (..), IFunVar, CallingConvention (..), IsCUDARequired (..), - UAlt (..), AltP, Alt, Label, LabeledItems (..), labeledSingleton, lookupLabel, - reflectLabels, withLabels, ExtLabeledItems (..), + UAlt (..), AltP, Alt, Label, LabeledItems (..), labeledSingleton, + lookupLabelHead, reflectLabels, withLabels, ExtLabeledItems (..), prefixExtLabeledItems, getLabels, IScope, BinderInfo (..), Bindings, CUDAKernel (..), BenchStats, SrcCtx, Result (..), Output (..), OutFormat (..), @@ -196,8 +196,8 @@ withLabels :: LabeledItems a -> LabeledItems (Label, Int, a) withLabels (LabeledItems items) = LabeledItems $ flip M.mapWithKey items \k xs -> fmap (\(i,a) -> (k,i,a)) (enumerate xs) -lookupLabel :: LabeledItems a -> Label -> Maybe a -lookupLabel (LabeledItems items) l = case M.lookup l items of +lookupLabelHead :: LabeledItems a -> Label -> Maybe a +lookupLabelHead (LabeledItems items) l = case M.lookup l items of Nothing -> Nothing Just (x NE.:| _) -> Just x @@ -798,8 +798,9 @@ instance BindsUVars UPat' where instance HasUVars UDecl where freeUVars (ULet _ p expr) = freeUVars p <> freeUVars expr freeUVars (UData (UConDef _ bs) dataCons) = freeUVars $ Abs bs dataCons - freeUVars (UInterface _ _ _) = mempty -- TODO - freeUVars (UInstance _ _) = mempty -- TODO + freeUVars (UInterface superclasses tc methods) = + freeUVars $ Abs tc (superclasses, methods) + freeUVars (UInstance ty methods) = mempty -- TODO instance BindsUVars UDecl where boundUVars decl = case decl of @@ -1538,15 +1539,14 @@ pattern BinderAnn x <- ((\case Ignore ann -> ann where BinderAnn x = Ignore x pattern NewTypeCon :: Name -> Type -> [DataConDef] -pattern NewTypeCon con ty <- [DataConDef con (NestOne (BinderAnn ty))] - where NewTypeCon con ty = [DataConDef con (NestOne (Ignore ty))] +pattern NewTypeCon con ty = [DataConDef con (NestOne (BinderAnn ty))] pattern ClassDictDef :: Name -> LabeledItems Type -> LabeledItems Type -> [DataConDef] pattern ClassDictDef conName superclasses methods = [DataConDef conName - (Nest (Ignore (RecordTy (NoExt superclasses))) - (Nest (Ignore (RecordTy (NoExt methods))) Empty))] + (Nest (BinderAnn (RecordTy (NoExt superclasses))) + (Nest (BinderAnn (RecordTy (NoExt methods))) Empty))] pattern ClassDictCon :: DataDef -> [Type] -> LabeledItems Atom -> LabeledItems Atom -> Atom diff --git a/tests/adt-tests.dx b/tests/adt-tests.dx index 97ad29dc7..1d2d2306e 100644 --- a/tests/adt-tests.dx +++ b/tests/adt-tests.dx @@ -216,7 +216,7 @@ def catLists (xs:List a) (ys:List a) : List a = def listToTable ((AsList n xs): List a) : (Fin n)=>a = xs :t listToTable -> ((a:Type) ?-> (pat:(List a)) -> (Fin (project [0] pat:(List a))) => a) +> ((a:Type) ?-> (pat:(List a)) -> (Fin ((\((AsList n _)). n) pat)) => a) :p l = AsList _ [1, 2, 3] @@ -228,7 +228,7 @@ def listToTable2 (l: List a) : (Fin (listLength l))=>a = xs :t listToTable2 -> ((a:Type) ?-> (l:(List a)) -> (Fin (project [0] l:(List a))) => a) +> ((a:Type) ?-> (l:(List a)) -> (Fin ((\((AsList n _)). n) l)) => a) :p l = AsList _ [1, 2, 3] @@ -258,7 +258,7 @@ def graphToAdjacencyMatrix ((MkGraph n nodes m edges):Graph a) : n=>n=>Bool = :t graphToAdjacencyMatrix > ((a:Type) > ?-> (pat:(Graph a)) -> -> (project [0] pat:(Graph a)) => (project [0] pat:(Graph a)) => Bool) +> -> ((\((MkGraph n _ _ _)). n) pat) => ((\((MkGraph n _ _ _)). n) pat) => Bool) :p g : Graph Int = MkGraph (Fin 3) [5, 6, 7] (Fin 4) [(0@_, 1@_), (0@_, 2@_), (2@_, 0@_), (1@_, 1@_)] @@ -269,15 +269,15 @@ def graphToAdjacencyMatrix ((MkGraph n nodes m edges):Graph a) : n=>n=>Bool = def pairUnpack ((v, _):(Int & Float)) : Int = v :p pairUnpack -> \pat:(Int32 & Float32). project [0] pat:(Int32 & Float32) +> \pat:(Int32 & Float32). (\(a, _). a) pat def adtUnpack ((MkMyPair v _):MyPair Int Float) : Int = v :p adtUnpack -> \pat:(MyPair Int32 Float32). project [0] pat:(MyPair Int32 Float32) +> \pat:(MyPair Int32 Float32). (\((MkMyPair elt _)). elt) pat def recordUnpack ({a=v, b=_}:{a:Int & b:Float}) : Int = v :p recordUnpack -> \pat:{a: Int32 & b: Float32}. project [0] pat:{a: Int32 & b: Float32} +> \pat:{a: Int32 & b: Float32}. (\{a = a, b = _}. a) pat def nestedUnpack (x:MyPair Int (MyPair (MyIntish & Int) Int)) : Int = (MkMyPair _ (MkMyPair (MkIntish y, _) _)) = x @@ -285,7 +285,7 @@ def nestedUnpack (x:MyPair Int (MyPair (MyIntish & Int) Int)) : Int = :p nestedUnpack > \x:(MyPair Int32 (MyPair (MyIntish & Int32) Int32)). -> project [0, 0, 0, 1] x:(MyPair Int32 (MyPair (MyIntish & Int32) Int32)) +> (\((MkIntish (((MkMyPair ((MkMyPair _ elt)) _)), _))). elt) x :p nestedUnpack (MkMyPair 3 (MkMyPair (MkIntish 4, 5) 6)) > 4 From 8283de6af03aeb5ca14a0ac458a291fe1b46961f Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Tue, 5 Jan 2021 16:50:08 -0500 Subject: [PATCH 076/105] Make scoping of binders in instance declarations more explicit. And fix type inference to handle them properly. --- src/lib/Inference.hs | 82 +++++++++++++++++++++++++------------------- src/lib/PPrint.hs | 7 ++-- src/lib/Parser.hs | 36 +++++++++---------- src/lib/Syntax.hs | 27 ++++++++++----- 4 files changed, 89 insertions(+), 63 deletions(-) diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index 06cb409af..84c11c9d2 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -361,15 +361,14 @@ inferUDecl True (UInterface superclasses tc methods) = do emitSuperclassGetters dataDef emitMethodGetters dataDef return mempty -inferUDecl True (UInstance instanceTy methods) = do - ty <- checkUType instanceTy - instanceDict <- checkInstance ty methods +inferUDecl True (UInstance argBinders instanceTy methods) = do + instanceDict <- checkInstance argBinders instanceTy methods let instanceName = Name TypeClassGenName "instance" 0 void $ emitTo instanceName InstanceLet $ Atom instanceDict return mempty inferUDecl False (UData _ _ ) = error "data definitions should be top-level" inferUDecl False (UInterface _ _ _) = error "interface definitions should be top-level" -inferUDecl False (UInstance _ _ ) = error "instance definitions should be top-level" +inferUDecl False (UInstance _ _ _) = error "instance definitions should be top-level" freshClassGenName :: MonadEmbed m => m Name freshClassGenName = do @@ -479,36 +478,46 @@ checkULam (p, ann) body piTy = do \x@(Var v) -> checkLeaks [v] $ withBindPat p x $ checkSigma body Suggest $ snd $ applyAbs piTy x -checkInstance :: Type -> [(UVar, UExpr)] -> UInferM Atom -checkInstance ty methods = case ty of - TypeCon def@(DataDef className _ _) params -> do - case applyDataDefParams def params of - ClassDictDef _ superclassTys methodTys -> do - methods' <- liftM mkLabeledItems $ forM methods \((v:>()), rhs) -> do - let v' = nameToLabel v - case lookupLabelHead methodTys v' of - Nothing -> throw TypeErr (pprint v ++ " is not a method of " ++ pprint className) - Just methodTy -> do - rhs' <- checkSigma rhs Suggest methodTy - return (v', rhs') - let superclassHoles = fmap (Con . ClassDictHole Nothing) superclassTys - forM_ (reflectLabels methods') \(l,i) -> - when (i > 0) $ throw TypeErr $ "Duplicate method: " ++ pprint l - forM_ (reflectLabels methodTys) \(l,_) -> - case lookupLabelHead methods' l of - Nothing -> throw TypeErr $ "Missing method: " ++ pprint l - Just _ -> return () - return $ ClassDictCon def params superclassHoles methods' - _ -> throw TypeErr $ "Not a valid instance: " ++ pprint ty - Pi (Abs b (arrow, bodyTy)) -> do - case arrow of - ImplicitArrow -> return () - ClassArrow -> return () - _ -> throw TypeErr $ "Not a valid arrow for an instance: " ++ pprint arrow - buildLam b arrow \x@(Var v) -> do - bodyTy' <- substEmbed (b@>x) bodyTy - checkLeaks [v] $ extendR (b@>x) $ checkInstance bodyTy' methods - _ -> throw TypeErr $ "Not a valid instance type: " ++ pprint ty +checkInstance :: Nest UPatAnnArrow -> UType -> [UMethodDef] -> UInferM Atom +checkInstance Empty ty methods = do + ty' <- checkUType ty + case ty' of + TypeCon def@(DataDef className _ _) params -> + case applyDataDefParams def params of + ClassDictDef _ superclassTys methodTys -> do + let superclassHoles = fmap (Con . ClassDictHole Nothing) superclassTys + methods' <- checkMethodDefs className methodTys methods + return $ ClassDictCon def params superclassHoles methods' + _ -> throw TypeErr $ "Not a valid instance type: " ++ pprint ty + _ -> throw TypeErr $ "Not a valid instance type: " ++ pprint ty +checkInstance (Nest ((p, ann), arrow) rest) ty methods = do + case arrow of + ImplicitArrow -> return () + ClassArrow -> return () + _ -> throw TypeErr $ "Not a valid arrow for an instance: " ++ pprint arrow + argTy <- checkAnn ann + buildLam (Bind $ patNameHint p :> argTy) (fromUArrow arrow) \x@(Var v) -> + checkLeaks [v] $ withBindPat p x $ checkInstance rest ty methods + + +checkMethodDefs :: Name -> LabeledItems Type -> [UMethodDef] + -> UInferM (LabeledItems Atom) +checkMethodDefs className methodTys methods = do + methods' <- liftM mkLabeledItems $ forM methods \(UMethodDef (v:>()) rhs) -> do + let v' = nameToLabel v + case lookupLabelHead methodTys v' of + Nothing -> throw TypeErr $ + pprint v ++ " is not a method of " ++ pprint className + Just methodTy -> do + rhs' <- checkSigma rhs Suggest methodTy + return (v', rhs') + forM_ (reflectLabels methods') \(l,i) -> + when (i > 0) $ throw TypeErr $ "Duplicate method: " ++ pprint l + forM_ (reflectLabels methodTys) \(l,_) -> + case lookupLabelHead methods' l of + Nothing -> throw TypeErr $ "Missing method: " ++ pprint l + Just _ -> return () + return methods' checkUEffRow :: EffectRow -> UInferM EffectRow checkUEffRow (EffectRow effs t) = do @@ -732,13 +741,16 @@ inferTabCon xs reqTy = do return (tabTy, xs') emitZonked $ Op $ TabCon tabTy xs' +fromUArrow :: UArrow -> Arrow +fromUArrow arr = fmap (const Pure) arr + -- Bool flag is just to tweak the reported error message fromPiType :: Bool -> UArrow -> Type -> UInferM PiType fromPiType _ _ (Pi piTy) = return piTy -- TODO: check arrow fromPiType expectPi arr ty = do a <- freshType TyKind b <- freshType TyKind - let piTy = Abs (Ignore a) (fmap (const Pure) arr, b) + let piTy = Abs (Ignore a) (fromUArrow arr, b) if expectPi then constrainEq (Pi piTy) ty else constrainEq ty (Pi piTy) return piTy diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index 600d42b38..eecfe641e 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -631,8 +631,11 @@ instance Pretty UDecl where "data" <+> p tyCon <+> "where" <> nest 2 (hardline <> prettyLines dataCons) pretty (UInterface cs def methods) = "interface" <+> p cs <+> p def <> hardline <> prettyLines methods - pretty (UInstance ty methods) = - "instance" <+> p ty <> hardline <> prettyLines methods + pretty (UInstance bs ty methods) = + "instance" <+> p bs <+> p ty <> hardline <> prettyLines methods + +instance Pretty UMethodDef where + pretty (UMethodDef b rhs) = p b <+> "=" <+> p rhs instance Pretty UConDef where pretty (UConDef con bs) = p con <+> spaced bs diff --git a/src/lib/Parser.hs b/src/lib/Parser.hs index 8472d5d59..b6f9d3fd0 100644 --- a/src/lib/Parser.hs +++ b/src/lib/Parser.hs @@ -259,7 +259,7 @@ findImplicitImplicitArgNames typ = filter isLowerCaseName $ envNames $ UFor _ _ _ -> error "Unexpected for in type annotation" UHole -> mempty UTypeAnn v ty -> findVarsInAppLHS v <> findVarsInAppLHS ty - UTabCon _ -> mempty + UTabCon _ -> error "Unexpected table constructor in type annotation" UIndexRange low high -> foldMap findVarsInAppLHS low <> foldMap findVarsInAppLHS high UPrimExpr prim -> foldMap findVarsInAppLHS prim @@ -339,25 +339,25 @@ instanceDef = do explicitArgs <- many defArg constraints <- classConstraints classTy <- uType - let ty = buildPiType explicitArgs Pure $ - foldr addClassConstraint classTy constraints - let ty' = foldr addImplicitArg ty $ findImplicitImplicitArgNames ty + let implicitArgs = findImplicitImplicitArgNames $ + buildPiType explicitArgs Pure $ + foldr addClassConstraint classTy constraints + let argBinders = + [((ns (nameToPat v), Nothing), ImplicitArrow) | v <- implicitArgs] ++ + explicitArgs ++ + [((UnderscoreUPat, Just c) , ClassArrow ) | c <- constraints] methods <- onePerLine instanceMethod - return $ UInstance ty' methods + return $ UInstance (toNest argBinders) classTy methods where addClassConstraint :: UType -> UType -> UType addClassConstraint c ty = ns $ UPi (UnderscoreUPat, Just c) ClassArrow ty - addImplicitArg :: Name -> UType -> UType - addImplicitArg v ty = - ns $ UPi (ns $ nameToPat v, Nothing) ImplicitArrow ty - -instanceMethod :: Parser (UVar, UExpr) +instanceMethod :: Parser UMethodDef instanceMethod = do v <- anyName sym "=" rhs <- blockOrExpr - return (v:>(), rhs) + return $ UMethodDef (v:>()) rhs simpleLet :: Parser (UExpr -> UDecl) simpleLet = label "let binding" $ do @@ -379,26 +379,26 @@ funDefLet = label "function definition" $ mayBreak $ do let bs = map classAsBinder cs ++ argBinders let funTy = buildPiType bs eff ty let letBinder = (v, Just funTy) - let lamBinders = flip map bs \(p,_, arr) -> ((p,Nothing), arr) + let lamBinders = flip map bs \((p,_), arr) -> ((p,Nothing), arr) return \body -> ULet PlainLet letBinder (buildLam lamBinders body) where - classAsBinder :: UType -> (UPat, Maybe UType, UArrow) - classAsBinder ty = (UnderscoreUPat, Just ty, ClassArrow) + classAsBinder :: UType -> UPatAnnArrow + classAsBinder ty = ((UnderscoreUPat, Just ty), ClassArrow) -defArg :: Parser (UPat, Maybe UType, UArrow) +defArg :: Parser UPatAnnArrow defArg = label "def arg" $ do (p, ty) <-parens ((,) <$> pat <*> annot uType) arr <- arrow (return ()) <|> return (PlainArrow ()) - return (p, Just ty, arr) + return ((p, Just ty), arr) classConstraints :: Parser [UType] classConstraints = label "class constraints" $ optionalMonoid $ brackets $ mayNotPair $ uType `sepBy` sym "," -buildPiType :: [(UPat, Maybe UType, UArrow)] -> EffectRow -> UType -> UType +buildPiType :: [UPatAnnArrow] -> EffectRow -> UType -> UType buildPiType [] Pure ty = ty buildPiType [] _ _ = error "shouldn't be possible" -buildPiType ((p, patTy, arr):bs) eff resTy = ns case bs of +buildPiType (((p, patTy), arr):bs) eff resTy = ns case bs of [] -> UPi (p, patTy) (fmap (const eff ) arr) resTy _ -> UPi (p, patTy) (fmap (const Pure) arr) $ buildPiType bs eff resTy diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index 1f99e550f..1d26b02c2 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -40,6 +40,7 @@ module Syntax ( AddressSpace (..), showPrimName, strToPrimName, primNameToStr, monMapSingle, monMapLookup, Direction (..), Limit (..), UExpr, UExpr' (..), UType, UPatAnn, UAnnBinder, UVar, + UMethodDef (..), UPatAnnArrow, UPat, UPat' (..), UModule (..), UDecl (..), UArrow, arrowEff, DataDef (..), DataConDef (..), UConDef (..), Nest (..), toNest, subst, deShadow, scopelessSubst, absArgType, applyAbs, makeAbs, @@ -245,18 +246,21 @@ data UExpr' = UVar UVar deriving (Show, Generic) data UConDef = UConDef Name (Nest UAnnBinder) deriving (Show, Generic) -data UDecl = ULet LetAnn UPatAnn UExpr - | UData UConDef [UConDef] - | UInterface [UType] UConDef [UAnnBinder] - | UInstance UType [(UVar, UExpr)] - deriving (Show, Generic) +data UDecl = + ULet LetAnn UPatAnn UExpr + | UData UConDef [UConDef] + | UInterface [UType] UConDef [UAnnBinder] -- superclasses, constructor, methods + | UInstance (Nest UPatAnnArrow) UType [UMethodDef] -- args, type, methods + deriving (Show, Generic) type UType = UExpr type UArrow = ArrowP () type UVar = VarP () type UBinder = BinderP () +data UMethodDef = UMethodDef UVar UExpr deriving (Show, Generic) -type UPatAnn = (UPat, Maybe UType) +type UPatAnn = (UPat, Maybe UType) +type UPatAnnArrow = (UPatAnn, UArrow) type UAnnBinder = BinderP UType data UAlt = UAlt UPat UExpr deriving (Show, Generic) @@ -800,14 +804,21 @@ instance HasUVars UDecl where freeUVars (UData (UConDef _ bs) dataCons) = freeUVars $ Abs bs dataCons freeUVars (UInterface superclasses tc methods) = freeUVars $ Abs tc (superclasses, methods) - freeUVars (UInstance ty methods) = mempty -- TODO + freeUVars (UInstance bsArrows ty methods) = freeUVars $ Abs bs (ty, methods) + where bs = fmap fst bsArrows + +instance HasUVars UMethodDef where + freeUVars (UMethodDef _ def) = freeUVars def + +instance BindsUVars UPatAnn where + boundUVars (p, _) = boundUVars p instance BindsUVars UDecl where boundUVars decl = case decl of ULet _ (p,_) _ -> boundUVars p UData tyCon dataCons -> boundUVars tyCon <> foldMap boundUVars dataCons UInterface _ _ _ -> mempty - UInstance _ _ -> mempty + UInstance _ _ _ -> mempty instance HasUVars UModule where freeUVars (UModule decls) = freeUVars decls From 6b73c503e906fe52dcb9bf9e0c3c9c16553148d9 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Tue, 5 Jan 2021 17:18:16 -0500 Subject: [PATCH 077/105] Add some comments suggested by Adam. --- lib/prelude.dx | 4 ++-- src/lib/Imp.hs | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/lib/prelude.dx b/lib/prelude.dx index 4918585c0..15874c491 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -513,8 +513,6 @@ def (+>>) [Storable a] (ptr:Ptr a) (i:Int) : Ptr a = i' = i * storageSize a MkPtr $ %ptrOffset ptr' i' --- TODO: generalize these brackets to allow other effects - -- TODO: consider making a Storable instance for tables instead def storeTab [Storable a] (ptr: Ptr a) (tab:n=>a) : {State World} Unit = for_ i. store (ptr +>> ordinal i) tab.i @@ -524,6 +522,8 @@ def memcpy [Storable a] (dest:Ptr a) (src:Ptr a) (n:Int) : {State World} Unit = i' = ordinal i store (dest +>> i') (load $ src +>> i') +-- TODO: generalize these brackets to allow other effects +-- TODO: make sure that freeing happens even if there are run-time errors def withAlloc [Storable a] (n:Int) (action: Ptr a -> {State World} b) : {State World} b = ptr = malloc n diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index deae7ab50..0ef075003 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -101,6 +101,8 @@ requiredFunctions scope expr = flip foldMap (transitiveClosure getParents immediateParents) \fname -> case scope ! fname of (_, LetBound _ (Atom f)) -> [(fname, f)] + -- we treat runtime-supplied global constants (e.g. the virtual stdout + -- channel) as lambda-bound. TODO: consider a new annotation. (_, LamBound _) -> [] _ -> error "Shouldn't have other free variables left" where @@ -119,6 +121,7 @@ translateTopLevel topEnv (maybeDest, block) = do Just dest -> return dest handleErrors $ void $ translateBlock mempty (Just outDest, block) resultAtom <- destToAtom outDest + -- Some names in topEnv refer to global constants, like the virtual stdout channel let vsOut = envAsVars $ freeVars resultAtom `envDiff` topEnv let reconAtom = Abs (toNest $ [Bind (v:>ty) | (v:>(ty, _)) <- vsOut]) resultAtom let resultIExprs = case maybeDest of From baa7589300eee31305a2acdcddd163bdbe68bdea Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Tue, 5 Jan 2021 20:15:08 -0500 Subject: [PATCH 078/105] Remove `Storable` instance for pairs due to alignment concerns. Rewrote `DynBuffer` to use three pointers instead of a pointer to a triple. We now hit #348 when we query env vars so I had to disable a test. --- lib/prelude.dx | 64 ++++++++++++++++++++++------------------------- tests/io-tests.dx | 11 +++----- 2 files changed, 33 insertions(+), 42 deletions(-) diff --git a/lib/prelude.dx b/lib/prelude.dx index 15874c491..cf457a12b 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -476,23 +476,6 @@ instance Storable Int32 load = \(MkPtr ptr) . %ptrLoad (internalCast %Int32Ptr ptr) storageSize_ = const 4 -def unpackPairPtr [Storable a, Storable b] - (pairPtr: Ptr (a & b)) : (Ptr a & Ptr b) = - (MkPtr rawPtrX) = pairPtr - rawPtrY = %ptrOffset rawPtrX (storageSize a) - (MkPtr rawPtrX, MkPtr rawPtrY) - -instance [Storable a, Storable b] Storable (a & b) - store = \pairPtr (x, y). - (xPtr, yPtr) = unpackPairPtr pairPtr - store xPtr x - store yPtr y - load = \pairPtr. - (xPtr, yPtr) = unpackPairPtr pairPtr - (load xPtr, load yPtr) - storageSize_ = \_. - storageSize a + storageSize b - instance Storable (Ptr a) store = \(MkPtr ptr) (MkPtr x). %ptrStore (internalCast %PtrPtr ptr) x load = \(MkPtr ptr) . MkPtr $ %ptrLoad (internalCast %PtrPtr ptr) @@ -883,45 +866,58 @@ def sliceFields (iso: Iso ({|} | a) (b | c)) (tab: a=>v) : b=>v = -- TODO: should we be able to use `Ref World Int` instead of `Ptr Int`? -- TODO: would be nice to be able to use records here -data DynBuffer a = MkDynBuffer (Ptr (Int & Int & Ptr a)) -- size, max size, buf ptr +data DynBuffer a = + MkDynBuffer { size : Ptr Int + & maxSize : Ptr Int + & buffer : Ptr (Ptr a) } def withDynamicBuffer [Storable a] (action: DynBuffer a -> {State World} b) : {State World} b = initMaxSize = 256 - withAlloc 1 \dbPtr. - bufPtr = malloc initMaxSize - store dbPtr (0, initMaxSize, bufPtr) - result = action $ MkDynBuffer dbPtr - (_, _, bufPtr') = load dbPtr - free bufPtr' + withAlloc 1 \sizePtr. withAlloc 1 \maxSizePtr. withAlloc 1 \bufferPtr. + store sizePtr 0 + store maxSizePtr initMaxSize + store bufferPtr $ malloc initMaxSize + result = action $ MkDynBuffer { size = sizePtr + , maxSize = maxSizePtr + , buffer = bufferPtr } + + free $ load bufferPtr result def maybeIncreaseBufferSize [Storable a] - (buf: DynBuffer a) (sizeDelta:Int) : {State World} Unit = - (MkDynBuffer dbPtr) = buf - (size, maxSize, bufPtr) = load dbPtr + ((MkDynBuffer db): DynBuffer a) (sizeDelta:Int) : {State World} Unit = + size = load $ getAt #size db + maxSize = load $ getAt #maxSize db + bufPtr = load $ getAt #buffer db newSize = sizeDelta + size if newSize > maxSize then -- TODO: maybe this should use integer arithmetic? newMaxSize = FToI $ pow 2.0 (ceil $ log2 $ IToF newSize) newBufPtr = malloc newMaxSize memcpy newBufPtr bufPtr size - store dbPtr (size, newMaxSize, newBufPtr) + free bufPtr + store (getAt #maxSize db) newMaxSize + store (getAt #buffer db) newBufPtr + +def addAtIntPtr (ptr: Ptr Int) (n:Int) : {State World} Unit = + store ptr (load ptr + n) def extendDynBuffer [Storable a] (buf: DynBuffer a) (new:List a) : {State World} Unit = (AsList n xs) = new maybeIncreaseBufferSize buf n - (MkDynBuffer dbPtr) = buf - (size, maxSize, bufPtr) = load dbPtr - newSize = n + size + (MkDynBuffer db) = buf + bufPtr = load $ getAt #buffer db + size = load $ getAt #size db storeTab (bufPtr +>> size) xs - store dbPtr (newSize, maxSize, bufPtr) + addAtIntPtr (getAt #size db) n def loadDynBuffer [Storable a] (buf: DynBuffer a) : {State World} (List a) = - (MkDynBuffer dbPtr) = buf - (size, _, bufPtr) = load dbPtr + (MkDynBuffer db) = buf + bufPtr = load $ getAt #buffer db + size = load $ getAt #size db AsList size $ tabFromPtr _ bufPtr def pushDynBuffer [Storable a] diff --git a/tests/io-tests.dx b/tests/io-tests.dx index 853ee0027..fdc59c995 100644 --- a/tests/io-tests.dx +++ b/tests/io-tests.dx @@ -41,12 +41,6 @@ unsafeIO \(). :p storageSize Int > 4 -:p unsafeIO \(). - withAlloc 1 \ptr:(Ptr (Int & Int)). - store ptr (4, 3) - load ptr -> (4, 3) - :p unsafeIO \(). withAlloc 1 \ptr:(Ptr Int). store ptr 3 @@ -73,8 +67,9 @@ unsafeIO \(). :p unsafeIO do getEnv "NOT_AN_ENV_VAR" > Nothing -:p unsafeIO do getEnv "DEX_TEST_MODE" -> (Just (AsList 1 "t")) +-- disabled because of bug #348 +-- :p unsafeIO do getEnv "DEX_TEST_MODE" +-- > (Just (AsList 1 "t")) :p dex_test_mode () > True From be2f5523aedfa2daadf38017460e93f0efdb43a5 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Tue, 5 Jan 2021 20:36:25 -0500 Subject: [PATCH 079/105] Make `IO` its own distinct effect instead of using `State World`. --- lib/prelude.dx | 73 ++++++++++++++++++++++---------------------- src/lib/Autodiff.hs | 1 + src/lib/Embed.hs | 2 +- src/lib/Inference.hs | 1 + src/lib/PPrint.hs | 1 + src/lib/Parser.hs | 8 +++-- src/lib/Syntax.hs | 16 +++++----- src/lib/Type.hs | 30 +++++++++--------- 8 files changed, 67 insertions(+), 65 deletions(-) diff --git a/lib/prelude.dx b/lib/prelude.dx index cf457a12b..82d8d77d4 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -273,7 +273,7 @@ def yieldState (action: h:Type ?-> Ref h s -> {State h |eff} a) : {|eff} s = snd $ runState init action -def unsafeIO (f: Unit -> {State World|eff} a) : {|eff} a = +def unsafeIO (f: Unit -> {IO|eff} a) : {|eff} a = %runIO f def unreachable (():Unit) : a = unsafeIO do @@ -458,8 +458,8 @@ data TypeVehicle a = MkTypeVehicle def typeVehicle (a:Type) : TypeVehicle a = MkTypeVehicle interface Storable a - store : Ptr a -> a -> {State World} Unit - load : Ptr a -> {State World} a + store : Ptr a -> a -> {IO} Unit + load : Ptr a -> {IO} a storageSize_ : TypeVehicle a -> Int def storageSize (a:Type) -> (d:Storable a) ?=> : Int = @@ -483,11 +483,11 @@ instance Storable (Ptr a) -- TODO: Storable instances for other types -def malloc [Storable a] (n:Int) : {State World} (Ptr a) = +def malloc [Storable a] (n:Int) : {IO} (Ptr a) = numBytes = storageSize a * n MkPtr $ %alloc numBytes -def free (ptr:Ptr a) : {State World} Unit = +def free (ptr:Ptr a) : {IO} Unit = (MkPtr ptr') = ptr %free ptr' @@ -497,10 +497,10 @@ def (+>>) [Storable a] (ptr:Ptr a) (i:Int) : Ptr a = MkPtr $ %ptrOffset ptr' i' -- TODO: consider making a Storable instance for tables instead -def storeTab [Storable a] (ptr: Ptr a) (tab:n=>a) : {State World} Unit = +def storeTab [Storable a] (ptr: Ptr a) (tab:n=>a) : {IO} Unit = for_ i. store (ptr +>> ordinal i) tab.i -def memcpy [Storable a] (dest:Ptr a) (src:Ptr a) (n:Int) : {State World} Unit = +def memcpy [Storable a] (dest:Ptr a) (src:Ptr a) (n:Int) : {IO} Unit = for_ i:(Fin n). i' = ordinal i store (dest +>> i') (load $ src +>> i') @@ -508,19 +508,19 @@ def memcpy [Storable a] (dest:Ptr a) (src:Ptr a) (n:Int) : {State World} Unit = -- TODO: generalize these brackets to allow other effects -- TODO: make sure that freeing happens even if there are run-time errors def withAlloc [Storable a] - (n:Int) (action: Ptr a -> {State World} b) : {State World} b = + (n:Int) (action: Ptr a -> {IO} b) : {IO} b = ptr = malloc n result = action ptr free ptr result def withTabPtr [Storable a] - (xs:n=>a) (action : Ptr a -> {State World} b) : {State World} b = + (xs:n=>a) (action : Ptr a -> {IO} b) : {IO} b = withAlloc (size n) \ptr. for i. store (ptr +>> ordinal i) xs.i action ptr -def tabFromPtr [Storable a] (n:Type) -> (ptr:Ptr a) : {State World} n=>a = +def tabFromPtr [Storable a] (n:Type) -> (ptr:Ptr a) : {IO} n=>a = for i. load $ ptr +>> ordinal i '## Miscellaneous common utilities @@ -864,7 +864,6 @@ def sliceFields (iso: Iso ({|} | a) (b | c)) (tab: a=>v) : b=>v = 'Dynamic buffer --- TODO: should we be able to use `Ref World Int` instead of `Ptr Int`? -- TODO: would be nice to be able to use records here data DynBuffer a = MkDynBuffer { size : Ptr Int @@ -872,7 +871,7 @@ data DynBuffer a = & buffer : Ptr (Ptr a) } def withDynamicBuffer [Storable a] - (action: DynBuffer a -> {State World} b) : {State World} b = + (action: DynBuffer a -> {IO} b) : {IO} b = initMaxSize = 256 withAlloc 1 \sizePtr. withAlloc 1 \maxSizePtr. withAlloc 1 \bufferPtr. store sizePtr 0 @@ -886,7 +885,7 @@ def withDynamicBuffer [Storable a] result def maybeIncreaseBufferSize [Storable a] - ((MkDynBuffer db): DynBuffer a) (sizeDelta:Int) : {State World} Unit = + ((MkDynBuffer db): DynBuffer a) (sizeDelta:Int) : {IO} Unit = size = load $ getAt #size db maxSize = load $ getAt #maxSize db bufPtr = load $ getAt #buffer db @@ -900,11 +899,11 @@ def maybeIncreaseBufferSize [Storable a] store (getAt #maxSize db) newMaxSize store (getAt #buffer db) newBufPtr -def addAtIntPtr (ptr: Ptr Int) (n:Int) : {State World} Unit = +def addAtIntPtr (ptr: Ptr Int) (n:Int) : {IO} Unit = store ptr (load ptr + n) def extendDynBuffer [Storable a] - (buf: DynBuffer a) (new:List a) : {State World} Unit = + (buf: DynBuffer a) (new:List a) : {IO} Unit = (AsList n xs) = new maybeIncreaseBufferSize buf n (MkDynBuffer db) = buf @@ -914,21 +913,21 @@ def extendDynBuffer [Storable a] addAtIntPtr (getAt #size db) n def loadDynBuffer [Storable a] - (buf: DynBuffer a) : {State World} (List a) = + (buf: DynBuffer a) : {IO} (List a) = (MkDynBuffer db) = buf bufPtr = load $ getAt #buffer db size = load $ getAt #size db AsList size $ tabFromPtr _ bufPtr def pushDynBuffer [Storable a] - (buf: DynBuffer a) (x:a) : {State World} Unit = + (buf: DynBuffer a) (x:a) : {IO} Unit = extendDynBuffer buf $ AsList _ [x] '## Strings and Characters String : Type = List Char -def stringFromCharPtr (n:Int) (ptr:Ptr Char) : {State World} String = +def stringFromCharPtr (n:Int) (ptr:Ptr Char) : {IO} String = AsList n $ tabFromPtr _ ptr -- TODO. This is ASCII code point. It really should be Int32 for Unicode codepoint @@ -1013,11 +1012,11 @@ data StreamMode = data Stream mode:StreamMode = MkStream RawPtr -- TODO: check the string contains no nulls -def withCString (s:String) (action: CString -> {State World} a) : {State World} a = +def withCString (s:String) (action: CString -> {IO} a) : {IO} a = (AsList n s') = s <> "\NUL" withTabPtr s' \(MkPtr ptr). action $ MkCString ptr -def fopen (path:String) (mode:StreamMode) : {State World} (Stream mode) = +def fopen (path:String) (mode:StreamMode) : {IO} (Stream mode) = modeStr = case mode of ReadMode -> "r" WriteMode -> "w" @@ -1025,12 +1024,12 @@ def fopen (path:String) (mode:StreamMode) : {State World} (Stream mode) = withCString modeStr \(MkCString modePtr). MkStream $ %ffi fopen RawPtr pathPtr modePtr -def fclose (stream:Stream mode) : {State World} Unit = +def fclose (stream:Stream mode) : {IO} Unit = (MkStream stream') = stream %ffi fclose Int64 stream' () -def fwrite (stream:Stream WriteMode) (s:String) : {State World} Unit = +def fwrite (stream:Stream WriteMode) (s:String) : {IO} Unit = (MkStream stream') = stream (AsList n s') = s withTabPtr s' \(MkPtr ptr). @@ -1086,7 +1085,7 @@ def boundedIter (maxIters:Int) (fallback:a) then Done fallback else body i -def fromCString (s:CString) : {State World} (Maybe String) = +def fromCString (s:CString) : {IO} (Maybe String) = case cStringPtr s of Nothing -> Nothing Just ptr -> @@ -1098,18 +1097,18 @@ def fromCString (s:CString) : {State World} (Maybe String) = pushDynBuffer buf c Continue -def getEnv (name:String) : {State World} Maybe String = +def getEnv (name:String) : {IO} Maybe String = withCString name \(MkCString ptr). fromCString $ MkCString $ %ffi getenv RawPtr ptr -def checkEnv (name:String) : {State World} Bool = +def checkEnv (name:String) : {IO} Bool = -- This should be just `isJust $ getEnv name` but that segfaults (only if the -- env var *is* defined), possibly related to bug #348. withCString name \(MkCString ptr). resultPtr = %ffi getenv RawPtr ptr not $ resultPtr == nullRawPtr -def fread (stream:Stream ReadMode) : {State World} String = +def fread (stream:Stream ReadMode) : {IO} String = (MkStream stream') = stream -- TODO: allow reading longer files! n = 4096 @@ -1124,50 +1123,50 @@ def fread (stream:Stream ReadMode) : {State World} String = else Done () loadDynBuffer buf -def deleteFile (f:FilePath) : {State World} Unit = +def deleteFile (f:FilePath) : {IO} Unit = withCString f \(MkCString ptr). %ffi remove Int64 ptr () def withFile (f:FilePath) (mode:StreamMode) - (action: Stream mode -> {State World} a) - : {State World} a = + (action: Stream mode -> {IO} a) + : {IO} a = stream = fopen f mode result = action stream fclose stream result -def writeFile (f:FilePath) (s:String) : {State World} Unit = +def writeFile (f:FilePath) (s:String) : {IO} Unit = withFile f WriteMode \stream. fwrite stream s -def readFile (f:FilePath) : {State World} String = +def readFile (f:FilePath) : {IO} String = withFile f ReadMode \stream. fread stream -def newTempFile (_:Unit) : {State World} FilePath = +def newTempFile (_:Unit) : {IO} FilePath = withCString "/tmp/dex-XXXXXX" \(MkCString ptr). fd = %ffi mkstemp Int32 ptr %ffi close Int32 fd stringFromCharPtr 15 (MkPtr ptr) -def withTempFile (action: FilePath -> {State World} a) : {State World} a = +def withTempFile (action: FilePath -> {IO} a) : {IO} a = tmpFile = newTempFile () result = action tmpFile deleteFile tmpFile result -def withTempFiles (action: (n=>FilePath) -> {State World} a) : {State World} a = +def withTempFiles (action: (n=>FilePath) -> {IO} a) : {IO} a = tmpFiles = for i. newTempFile () result = action tmpFiles for i. deleteFile tmpFiles.i result -def getOutputStream (_:Unit) : {State World} Stream WriteMode = +def getOutputStream (_:Unit) : {IO} Stream WriteMode = MkStream $ %ptrLoad OUT_STREAM_PTR -def print (s:String) : {State World} Unit = +def print (s:String) : {IO} Unit = fwrite (getOutputStream ()) (s <> "\n") -def shellOut (command:String) : {State World} String = +def shellOut (command:String) : {IO} String = modeStr = "r" withCString command \(MkCString commandPtr). withCString modeStr \(MkCString modePtr). diff --git a/src/lib/Autodiff.hs b/src/lib/Autodiff.hs index 8a6dbd964..e158a34cb 100644 --- a/src/lib/Autodiff.hs +++ b/src/lib/Autodiff.hs @@ -460,6 +460,7 @@ tangentFunAsLambda m = do effectRegion eff = case eff of RWSEffect _ h -> h ExceptionEffect -> error "TODO!" + IOEffect -> error "TODO!" -- Inverse of tangentFunAsLambda. Should be used inside a returned tangent action. applyLinToTangents :: Atom -> TangentM Atom diff --git a/src/lib/Embed.hs b/src/lib/Embed.hs index 330daa991..283aeeaaf 100644 --- a/src/lib/Embed.hs +++ b/src/lib/Embed.hs @@ -366,7 +366,7 @@ ptrOffset x i = emitOp $ PtrOffset x i unsafePtrLoad :: MonadEmbed m => Atom -> m Atom unsafePtrLoad x = emit $ Hof $ RunIO $ Lam $ Abs (Ignore UnitTy) $ - (PlainArrow (oneEffect ioEffect), Block Empty (Op (PtrLoad x))) + (PlainArrow (oneEffect IOEffect), Block Empty (Op (PtrLoad x))) ptrLoad :: MonadEmbed m => Atom -> m Atom ptrLoad x = emitOp $ PtrLoad x diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index 84c11c9d2..ed8f3c1d2 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -539,6 +539,7 @@ checkUEff eff = case eff of constrainEq TyKind ty return $ RWSEffect rws v ExceptionEffect -> return ExceptionEffect + IOEffect -> return IOEffect data CaseAltIndex = ConAlt Int | VariantAlt Label Int diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index eecfe641e..c09da8a2d 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -674,6 +674,7 @@ instance Pretty Effect where pretty eff = case eff of RWSEffect rws h -> p rws <+> p h ExceptionEffect -> "Except" + IOEffect -> "IO" instance Pretty RWS where pretty eff = case eff of diff --git a/src/lib/Parser.hs b/src/lib/Parser.hs index b6f9d3fd0..f7b35c7f8 100644 --- a/src/lib/Parser.hs +++ b/src/lib/Parser.hs @@ -416,7 +416,8 @@ effects = braces someEffects <|> return Pure effect :: Parser Effect effect = (RWSEffect <$> rwsName <*> anyCaseName) <|> (keyWord ExceptKW $> ExceptionEffect) - "effect (Accum h | Read h | State h | Except)" + <|> (keyWord IOKW $> IOEffect) + "effect (Accum h | Read h | State h | Except | IO)" rwsName :: Parser RWS rwsName = (keyWord WriteKW $> Writer) @@ -964,7 +965,7 @@ type Lexer = Parser data KeyWord = DefKW | ForKW | For_KW | RofKW | Rof_KW | CaseKW | OfKW | ReadKW | WriteKW | StateKW | DataKW | InterfaceKW | InstanceKW | WhereKW | IfKW | ThenKW | ElseKW | DoKW - | ExceptKW | ViewKW + | ExceptKW | IOKW | ViewKW upperName :: Lexer Name upperName = liftM mkName $ label "upper-case name" $ lexeme $ @@ -1004,6 +1005,7 @@ keyWord kw = lexeme $ try $ string s >> notFollowedBy nameTailChar WriteKW -> "Accum" StateKW -> "State" ExceptKW -> "Except" + IOKW -> "IO" DataKW -> "data" InterfaceKW -> "interface" InstanceKW -> "instance" @@ -1013,7 +1015,7 @@ keyWord kw = lexeme $ try $ string s >> notFollowedBy nameTailChar keyWordStrs :: [String] keyWordStrs = ["def", "for", "for_", "rof", "rof_", "case", "of", "llam", - "Read", "Write", "Accum", "Except", "data", "interface", + "Read", "Write", "Accum", "Except", "IO", "data", "interface", "instance", "where", "if", "then", "else", "do", "view"] fieldLabel :: Lexer Label diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index 1d26b02c2..d15a4f997 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -46,7 +46,7 @@ module Syntax ( subst, deShadow, scopelessSubst, absArgType, applyAbs, makeAbs, applyNaryAbs, applyDataDefParams, freshSkolemVar, IndexStructure, mkConsList, mkConsListTy, fromConsList, fromConsListTy, extendEffRow, - getProjection, theWorld, outputStreamPtrName, initTopEnv, + getProjection, outputStreamPtrName, initTopEnv, varType, binderType, isTabTy, LogLevel (..), IRVariant (..), applyIntBinOp, applyIntCmpOp, applyFloatBinOp, applyFloatUnOp, getIntLit, getFloatLit, sizeOf, vectorWidth, @@ -447,22 +447,19 @@ data EffectRow = EffectRow (S.Set Effect) (Maybe Name) deriving (Show, Eq, Generic) data RWS = Reader | Writer | State deriving (Show, Eq, Ord, Generic) -data Effect = RWSEffect RWS Name | ExceptionEffect deriving (Show, Eq, Ord, Generic) +data Effect = RWSEffect RWS Name | ExceptionEffect | IOEffect + deriving (Show, Eq, Ord, Generic) pattern Pure :: EffectRow pattern Pure <- ((\(EffectRow effs t) -> (S.null effs, t)) -> (True, Nothing)) where Pure = mempty -theWorld :: Name -theWorld = GlobalName "World" - outputStreamPtrName :: Name outputStreamPtrName = GlobalName "OUT_STREAM_PTR" initTopEnv :: TopEnv initTopEnv = fold [v @> (ty, LamBound ImplicitArrow) | (v, ty) <- - [ (theWorld , TyKind) - , (outputStreamPtrName , BaseTy $ hostPtrTy $ hostPtrTy $ Scalar Word8Type)]] + [(outputStreamPtrName , BaseTy $ hostPtrTy $ hostPtrTy $ Scalar Word8Type)]] hostPtrTy :: BaseType -> BaseType hostPtrTy ty = PtrType (Heap CPU, ty) @@ -846,7 +843,8 @@ instance HasUVars EffectRow where instance HasUVars Effect where freeUVars (RWSEffect _ h) = nameAsEnv h - freeUVars (ExceptionEffect) = mempty + freeUVars ExceptionEffect = mempty + freeUVars IOEffect = mempty instance HasUVars a => HasUVars (LabeledItems a) where freeUVars (LabeledItems items) = foldMap freeUVars items @@ -1175,10 +1173,12 @@ instance HasVars Effect where freeVars eff = case eff of RWSEffect _ v -> v@>(TyKind , UnknownBinder) ExceptionEffect -> mempty + IOEffect -> mempty instance Subst Effect where subst (env,_) eff = case eff of RWSEffect rws v -> RWSEffect rws (substName env v) ExceptionEffect -> ExceptionEffect + IOEffect -> IOEffect instance HasVars BinderInfo where freeVars binfo = case binfo of diff --git a/src/lib/Type.hs b/src/lib/Type.hs index 0e700202d..72a0bcc98 100644 --- a/src/lib/Type.hs +++ b/src/lib/Type.hs @@ -13,7 +13,7 @@ module Type ( isPure, functionEffs, exprEffs, blockEffs, extendEffect, isData, checkBinOp, checkUnOp, checkIntBaseType, checkFloatBaseType, withBinder, isDependent, checkExtends, indexSetConcreteSize, checkNoShadow, traceCheckM, traceCheck, projectLength, - typeReduceBlock, typeReduceAtom, typeReduceExpr, oneEffect, ioEffect) where + typeReduceBlock, typeReduceAtom, typeReduceExpr, oneEffect) where import Prelude hiding (pi) import Control.Monad @@ -278,11 +278,11 @@ exprEffs expr = case expr of MTell _ -> oneEffect (RWSEffect Writer h) where RefTy (Var (h:>_)) _ = getType ref ThrowException _ -> oneEffect ExceptionEffect - IOAlloc _ _ -> oneEffect ioEffect - IOFree _ -> oneEffect ioEffect - PtrLoad _ -> oneEffect ioEffect - PtrStore _ _ -> oneEffect ioEffect - FFICall _ _ _ -> oneEffect ioEffect + IOAlloc _ _ -> oneEffect IOEffect + IOFree _ -> oneEffect IOEffect + PtrLoad _ -> oneEffect IOEffect + PtrStore _ _ -> oneEffect IOEffect + FFICall _ _ _ -> oneEffect IOEffect _ -> Pure Hof hof -> case hof of For _ f -> functionEffs f @@ -295,7 +295,7 @@ exprEffs expr = case expr of RunState _ f -> handleRWSRunner State f PTileReduce _ _ -> mempty RunIO ~(Lam (Abs _ (PlainArrow (EffectRow effs t), _))) -> - EffectRow (S.delete ioEffect effs) t + EffectRow (S.delete IOEffect effs) t Case _ alts _ -> foldMap (\(Abs _ block) -> blockEffs block) alts where handleRWSRunner rws ~(BinaryFunVal (Bind (h:>_)) _ (EffectRow effs t) _) = @@ -481,6 +481,7 @@ checkEffRow (EffectRow effs effTail) = do forM_ effs \eff -> case eff of RWSEffect _ v -> Var (v:>TyKind) |: TyKind ExceptionEffect -> return () + IOEffect -> return () forM_ effTail \v -> do checkWithEnv \(env, _) -> case envLookup env (v:>()) of Nothing -> throw CompilerErr $ "Lookup failed: " ++ pprint v @@ -509,9 +510,6 @@ extendEffect eff (EffectRow effs t) = EffectRow (S.insert eff effs) t oneEffect :: Effect -> EffectRow oneEffect eff = EffectRow (S.singleton eff) Nothing -ioEffect :: Effect -ioEffect = RWSEffect State theWorld - -- === labeled row types === checkLabeledRow :: ExtLabeledItems Type Name -> TypeM () @@ -693,7 +691,7 @@ typeCheckOp op = case op of BaseTy _ -> return () _ -> throw TypeErr $ "All arguments of FFI calls have to be " ++ "fixed-width base types, but got: " ++ pprint argTy - declareEff ioEffect + declareEff IOEffect return ansTy Inject i -> do TC tc <- typeCheck i @@ -720,11 +718,11 @@ typeCheckOp op = case op of return $ RefTy h b IOAlloc t n -> do n |: IdxRepTy - declareEff ioEffect + declareEff IOEffect return $ PtrTy (Heap CPU, t) IOFree ptr -> do PtrTy _ <- typeCheck ptr - declareEff ioEffect + declareEff IOEffect return UnitTy PtrOffset arr off -> do PtrTy (a, b) <- typeCheck arr @@ -732,12 +730,12 @@ typeCheckOp op = case op of return $ PtrTy (a, b) PtrLoad ptr -> do PtrTy (_, t) <- typeCheck ptr - declareEff ioEffect + declareEff IOEffect return $ BaseTy t PtrStore ptr val -> do PtrTy (_, t) <- typeCheck ptr val |: BaseTy t - declareEff ioEffect + declareEff IOEffect return $ UnitTy SliceOffset s i -> do TC (IndexSlice n l) <- typeCheck s @@ -887,7 +885,7 @@ typeCheckHof hof = case hof of RunIO f -> do FunTy b eff resultTy <- typeCheck f checkEq (binderAnn b) UnitTy - extendAllowedEffect ioEffect $ declareEffs eff + extendAllowedEffect IOEffect $ declareEffs eff return resultTy CatchException f -> do FunTy b eff resultTy <- typeCheck f From 3d1146eb694bb14279a16594a97140196454b62c Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 6 Jan 2021 11:48:14 +0000 Subject: [PATCH 080/105] Continue inlining after application Turns out that the inlining pass had a pretty big gaping hole previously: the inlineable tables originate not only from the `for` loops, but also from applications with table lambdas! Previously we completely ignored the second case, potentially failing to fully reducing a table application, and blowing up the run-time complexity as in #346. --- src/lib/Embed.hs | 5 +++-- src/lib/Optimize.hs | 22 ++++++++++++++-------- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/src/lib/Embed.hs b/src/lib/Embed.hs index 283aeeaaf..f19303e32 100644 --- a/src/lib/Embed.hs +++ b/src/lib/Embed.hs @@ -27,9 +27,10 @@ module Embed (emit, emitTo, emitAnn, emitOp, buildDepEffLam, buildLamAux, buildP embedExtend, unpackConsList, emitRunWriter, applyPreludeFunction, emitRunState, emitMaybeCase, emitWhile, buildDataDef, emitRunReader, tabGet, SubstEmbedT, SubstEmbed, runSubstEmbedT, - traverseAtom, ptrOffset, ptrLoad, unsafePtrLoad, + ptrOffset, ptrLoad, unsafePtrLoad, evalBlockE, substTraversalDef, - TraversalDef, traverseDecls, traverseDecl, traverseBlock, traverseExpr, + TraversalDef, traverseDecls, traverseDecl, traverseDeclsOpen, + traverseBlock, traverseExpr, traverseAtom, clampPositive, buildNAbs, buildNAbsAux, buildNestedLam, zeroAt, transformModuleAsBlock, dropSub, appReduceTraversalDef, indexSetSizeE, indexToIntE, intToIndexE, freshVarE) where diff --git a/src/lib/Optimize.hs b/src/lib/Optimize.hs index 92158b51c..44022ba70 100644 --- a/src/lib/Optimize.hs +++ b/src/lib/Optimize.hs @@ -95,13 +95,22 @@ inlineModule m = transformModuleAsBlock inlineBlock (computeInlineHints m) inlineTraverseDecl :: Decl -> InlineM SubstEnv inlineTraverseDecl decl = case decl of - -- This is not a super safe condition for inlining, because it might still duplicate work - -- unexpectedly (consider an `arr` that's only used as `for i. 2.0 .* arr`). Still, this is not - -- the way arrays are usually used, so it should be good enough for now. In the future we should - -- strengthen this check to better ensure that each element of the array is used at most once. Let _ b@(BindWithHint CanInline _) expr@(Hof (For _ body)) | isPure expr -> do ~(LamVal ib block) <- traverseAtom inlineTraversalDef body return $ b @> TabVal ib block + -- If `f` turns out to be an inlined table lambda, we expand its block and + -- call ourselves recursively on the block's result expression. This makes + -- it possible for us to e.g. discover that the result is a `for` loop, and + -- match the case above, to continue the inlining process. + Let letAnn letBinder (App f' x') -> do + f <- traverseAtom inlineTraversalDef f' + x <- traverseAtom inlineTraversalDef x' + case f of + TabVal b (Block body result) -> do + dropSub $ extendR (b@>x) $ do + blockEnv <- traverseDeclsOpen substTraversalDef body + extendR blockEnv $ inlineTraverseDecl $ Let letAnn letBinder result + _ -> (letBinder@>) <$> emitTo (binderNameHint letBinder) letAnn (App f x) _ -> traverseDecl inlineTraversalDef decl -- TODO: This is a bit overeager. We should count under how many loops are we. @@ -113,12 +122,9 @@ inlineTraverseExpr expr = case expr of Hof (For d body) -> do newBody <- traverseAtom inlineTraversalDef body case newBody of - -- Trivial bodies -- XXX: The trivial body might be a table lambda, and those could technically -- get quite expensive. But I think this should never be the case in practice. - -- XXX: This doesn't always have to end up being beneficial. If the result is - -- significantly smaller than the intermediates it refers to, then this - -- optimization will waste a bunch of memory by keeping the large intermediates alive. + -- Trivial bodies LamVal ib block@(Block Empty (Atom _)) -> return $ Atom $ TabVal ib block -- Pure broadcasts LamVal ib@(Ignore _) block | blockEffs block == Pure -> do From 72f19dc11dd24f63b283fee7b24e742249417e77 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 6 Jan 2021 12:08:43 +0000 Subject: [PATCH 081/105] Fix the stupid zext bug Who knew that zero-extending two's complement integers could go badly. --- src/lib/JIT.hs | 2 +- tests/show-tests.dx | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/lib/JIT.hs b/src/lib/JIT.hs index d26ad3a4a..a2e724266 100644 --- a/src/lib/JIT.hs +++ b/src/lib/JIT.hs @@ -801,7 +801,7 @@ asIntWidth :: Operand -> L.Type -> Compile Operand asIntWidth op ~expTy@(L.IntegerType expWidth) = case compare expWidth opWidth of LT -> emitInstr expTy $ L.Trunc op expTy [] EQ -> return op - GT -> emitInstr expTy $ L.ZExt op expTy [] + GT -> emitInstr expTy $ L.SExt op expTy [] where ~(L.IntegerType opWidth) = L.typeOf op freshParamOpPair :: [L.ParameterAttribute] -> L.Type -> Compile (Parameter, Operand) diff --git a/tests/show-tests.dx b/tests/show-tests.dx index 251ba4fee..1d0d24098 100644 --- a/tests/show-tests.dx +++ b/tests/show-tests.dx @@ -20,10 +20,8 @@ :p show (IToI64 1234: Int64) > (AsList 4 "1234") --- FIXME(https://github.com/google-research/dex-lang/issues/317): --- Unexpected zext from type conversion of negative Int32 to Int64. :p show (IToI64 (-1234): Int64) -> (AsList 10 "4294966062") +> (AsList 5 "-1234") -- Float32 From de34e9f6414884e60a87140f91cdaf1c2c389099 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 6 Jan 2021 12:30:17 +0000 Subject: [PATCH 082/105] Harden CI checks We end up committing a whole bunch of warnings all the time, which create unnecessary noise for others. This turns them into errors, to make sure that none of those slip through the cracks. Also, any segmentation faults or aborts didn't cause the tests to fail in the past, which should hopefully be fixed now. --- .github/workflows/ci.yaml | 6 +++++- makefile | 9 ++++++--- misc/check-quine | 1 + src/lib/Autodiff.hs | 18 +++++++++++------- src/lib/Type.hs | 2 ++ 5 files changed, 25 insertions(+), 11 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index d75782fda..b513e61f0 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -51,8 +51,12 @@ jobs: run: rm -rf ~/.stack/setup-exe-cache if: runner.os == 'macOS' + - name: Build, treating warnings as errors + run: make build-ci + if: runner.os == 'Linux' + - name: Build - run: make + run: make build - name: Run tests run: make tests diff --git a/makefile b/makefile index 23afc3aa0..ed75f9f24 100644 --- a/makefile +++ b/makefile @@ -66,14 +66,17 @@ install: dexrt-llvm build-prof: dexrt-llvm $(STACK) build $(PROF) -dexrt-llvm: src/lib/dexrt.bc - # For some reason stack fails to detect modifications to foreign library files -build-python: build +build-python: dexrt-llvm $(STACK) build $(STACK_FLAGS) --force-dirty $(eval STACK_INSTALL_DIR=$(shell stack path --local-install-root)) cp $(STACK_INSTALL_DIR)/lib/libDex.so python/dex/ +build-ci: dexrt-llvm + $(STACK) build $(STACK_FLAGS) --force-dirty --ghc-options "-Werror -fforce-recomp" + +dexrt-llvm: src/lib/dexrt.bc + %.bc: %.cpp clang++ $(CXXFLAGS) -c -emit-llvm $^ -o $@ diff --git a/misc/check-quine b/misc/check-quine index c592d32fd..0bde916df 100755 --- a/misc/check-quine +++ b/misc/check-quine @@ -26,6 +26,7 @@ if ${@:2} $1 > $tmpout 2> $errout ; then misc/check-no-diff $1 $tmpout status=$? else + status=$? cat $tmpout fi diff --git a/src/lib/Autodiff.hs b/src/lib/Autodiff.hs index e158a34cb..1da5eac39 100644 --- a/src/lib/Autodiff.hs +++ b/src/lib/Autodiff.hs @@ -172,6 +172,7 @@ linearizeOp op = case op of VariantLift ts v -> (VariantLift ts <$> la v) `bindLin` emitOp VariantSplit ts v -> (VariantSplit ts <$> la v) `bindLin` emitOp FFICall _ _ _ -> error $ "Can't differentiate through an FFI call" + ThrowException _ -> notImplemented where emitDiscrete = if isTrivialForAD (Op op) then LinA $ withZeroTangent <$> emitOp op @@ -274,7 +275,8 @@ linearizeHof env hof = case hof of -- TODO: Consider providing an upper bound for the number of while iterations as a hint. -- In the current form the best we can do is try to use some dynamically growing lists, -- but that won't work on the GPU. - While _ -> notImplemented + While _ -> notImplemented + CatchException _ -> notImplemented Linearize _ -> error "Unexpected linearization" Transpose _ -> error "Unexpected transposition" PTileReduce _ _ -> error "Unexpected PTileReduce" @@ -649,6 +651,7 @@ transposeOp op ct = case op of FFICall _ _ _ -> notLinear DataConTag _ -> notLinear ToEnum _ _ -> notLinear + ThrowException _ -> notLinear where -- Both nonlinear operations and operations on discrete types, where linearity doesn't make sense notLinear = error $ "Can't transpose a non-linear operation: " ++ pprint op @@ -697,12 +700,13 @@ transposeHof hof ct = case hof of localLinRegion h $ localLinRefSubst (b@>ref) $ transposeBlock body ctBody return UnitVal transposeAtom s cts - RunIO _ -> error "Not implemented" - Tile _ _ _ -> notImplemented - While _ -> notImplemented - Linearize _ -> error "Unexpected linearization" - Transpose _ -> error "Unexpected transposition" - PTileReduce _ _ -> error "Unexpected PTileReduce" + Tile _ _ _ -> notImplemented + While _ -> notImplemented + RunIO _ -> notImplemented + CatchException _ -> notImplemented + Linearize _ -> error "Unexpected linearization" + Transpose _ -> error "Unexpected transposition" + PTileReduce _ _ -> error "Unexpected PTileReduce" transposeAtom :: Atom -> Atom -> TransposeM () transposeAtom atom ct = case atom of diff --git a/src/lib/Type.hs b/src/lib/Type.hs index 72a0bcc98..29248533a 100644 --- a/src/lib/Type.hs +++ b/src/lib/Type.hs @@ -296,6 +296,8 @@ exprEffs expr = case expr of PTileReduce _ _ -> mempty RunIO ~(Lam (Abs _ (PlainArrow (EffectRow effs t), _))) -> EffectRow (S.delete IOEffect effs) t + CatchException ~(Lam (Abs _ (PlainArrow (EffectRow effs t), _))) -> + EffectRow (S.delete ExceptionEffect effs) t Case _ alts _ -> foldMap (\(Abs _ block) -> blockEffs block) alts where handleRWSRunner rws ~(BinaryFunVal (Bind (h:>_)) _ (EffectRow effs t) _) = From 9e6e470254626bc2d4aaf9fe9fcff4b915817f93 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 6 Jan 2021 12:37:57 +0000 Subject: [PATCH 083/105] Revert #376 to fix segfaults in main Of course the segfaults are due to #348, but I'd like to keep main green. Let's resubmit this once the underlying issue is fixed. --- lib/diagram.dx | 46 +++++++++------------------------------------- 1 file changed, 9 insertions(+), 37 deletions(-) diff --git a/lib/diagram.dx b/lib/diagram.dx index a05fc1cb3..4e14b8a89 100644 --- a/lib/diagram.dx +++ b/lib/diagram.dx @@ -7,7 +7,6 @@ data Geom = Circle Float Rectangle Float Float -- width, height Line Point - Text String -- HTML color (no alpha) -- TODO: replace with `Fin 3 => Word8` when we fix #348 @@ -63,7 +62,6 @@ flipY : Diagram -> Diagram = Circle r -> Circle r Rectangle w h -> Rectangle w h Line (x, y) -> Line (x, -y) - Text x -> Text x def scale (s:Float) : (Diagram -> Diagram) = applyTransformation ( \(x,y). (s * x, s * y) ) \geom. case geom of @@ -71,7 +69,6 @@ def scale (s:Float) : (Diagram -> Diagram) = Circle r -> Circle (s * r) Rectangle w h -> Rectangle (s * w) (s * h) Line (x, y) -> Line (s * x, s * y) - Text x -> Text x def moveXY ((offX, offY) : Point) : (Diagram -> Diagram) = applyTransformation (\(x,y). (x + offX, y + offY) ) id @@ -83,7 +80,6 @@ def pointDiagram : Diagram = singletonDefault PointGeom def circle (r:Float) : Diagram = singletonDefault $ Circle r def rect (w:Float) (h:Float) : Diagram = singletonDefault $ Rectangle w h def line (p:Point) : Diagram = singletonDefault $ Line p -def text (x:String) : Diagram = singletonDefault $ Text x def updateGeom (update: GeomStyle -> GeomStyle) (d:Diagram) : Diagram = (MkDiagram (AsList _ objs)) = d @@ -146,14 +142,11 @@ def attrString (attr:GeomStyle) : String = <+> ("stroke-width" <=> (getAt #strokeWidth attr))) def renderGeom (attr:GeomStyle) ((x,y):Point) (geom:Geom) : String = - -- For things that are solid. SVG says they have fill=stroke. - solidAttr = setAt #fillColor (getAt #strokeColor attr) attr - groupEle = \attr. tagBracketsAttr "g" (attrString attr) case geom of PointGeom -> pointAttr = setAt #fillColor (getAt #strokeColor attr) attr - groupEle solidAttr $ selfClosingBrackets $ + groupEle pointAttr $ selfClosingBrackets $ ("circle" <+> "cx" <=> x <.> "cy" <=> y <.> @@ -171,14 +164,6 @@ def renderGeom (attr:GeomStyle) ((x,y):Point) (geom:Geom) : String = "height" <=> h <.> "x" <=> (x - (w/2.0)) <.> "y" <=> (y - (h/2.0))) - Text content -> - textEle = tagBracketsAttr "text" $ - ("x" <=> x <.> - "y" <=> y <.> - "text-anchor" <=> "middle" <.> -- horizontal center - "dominant-baseline" <=> "middle" -- vertical center - ) - groupEle solidAttr $ textEle content BoundingBox : Type = (Point & Point) @@ -203,24 +188,11 @@ def renderSVG (d:Diagram) (bounds:BoundingBox) : String = moveX : Float -> Diagram -> Diagram = \x. moveXY (x, 0.0) moveY : Float -> Diagram -> Diagram = \y. moveXY (0.0, y) -' A Demo showing all kind of features -``` -mydiagram : Diagram = - ( (circle 7.0 |> moveXY (20.0, 20.0) |> setFillColor blue |> setStrokeColor red) - <> (circle 5.0 |> moveXY (40.0, 40.0)) - <> (rect 10.0 20.0 |> moveXY (5.0, 10.0) |> setStrokeColor red) - <> (text "DexLang" |> moveXY (30.0, 10.0) |> setStrokeColor green) - <> (pointDiagram |> moveXY (15.0, 5.0) |> setStrokeColor red) - ) -:html renderSVG mydiagram ((0.0, 0.0), (100.0, 50.0)) -``` - -' Another demo that shows things are all center aligned: -``` -concentricDiagram : Diagram = ( - (rect 2.0 2.0 |> setFillColor red) - <> (circle 1.0 |> setFillColor blue) - <> (text "DexLang" |> setStrokeColor white) -) |> moveXY (5.0, 5.0) -:html renderSVG concentricDiagram ((0.0, 0.0), (10.0, 10.0)) -``` +-- mydiagram : Diagram = +-- ( (circle 7.0 |> moveXY (20.0, 20.0) |> setFillColor blue |> setStrokeColor red) +-- <> (circle 5.0 |> moveXY (40.0, 40.0)) +-- <> (rect 10.0 20.0 |> moveXY (5.0, 10.0) |> setStrokeColor red) +-- <> (pointDiagram |> moveXY (15.0, 5.0) |> setStrokeColor red) +-- ) + +-- :html renderSVG mydiagram ((0.0, 0.0), (100.0, 50.0)) From 43f7758a1aec346feb2ae524d9b43b6d8f616cc1 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Wed, 6 Jan 2021 23:31:49 -0500 Subject: [PATCH 084/105] [README] Fix typo. (#436) tematic -> thematic --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index a02161574..47e3e9f12 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ or these example programs: 🚨 **Dex is an experimental research project at an early stage of development. Expect monstrous bugs and razor-sharp edges!** -🤝 **Contributions welcome!** See our issue tracker for [good first issues](https://github.com/google-research/dex-lang/labels/good%20first%20issue), or browse by [tematic labels](https://github.com/google-research/dex-lang/labels). +🤝 **Contributions welcome!** See our issue tracker for [good first issues](https://github.com/google-research/dex-lang/labels/good%20first%20issue), or browse by [thematic labels](https://github.com/google-research/dex-lang/labels). ## Dependencies From 4002ba7970ddf24dc32c6c4a96cb7685046882d2 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Thu, 7 Jan 2021 03:05:51 -0500 Subject: [PATCH 085/105] Use llvm-hs/llvm-hs@llvm-9 in cabal.project. Previously, a downstream commit at apaszke/llvm-hs was referenced. That commit has now been merged into llvm-hs/llvm-hs@llvm-9. --- cabal.project | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cabal.project b/cabal.project index 7ab885353..33ff576b5 100644 --- a/cabal.project +++ b/cabal.project @@ -2,12 +2,12 @@ packages: dex.cabal source-repository-package type: git - location: https://github.com/apaszke/llvm-hs - tag: a9a74be1a7c15f3d21b2fffd35a425002ae7736f + location: https://github.com/llvm-hs/llvm-hs + tag: llvm-9 subdir: llvm-hs source-repository-package type: git - location: https://github.com/apaszke/llvm-hs - tag: a9a74be1a7c15f3d21b2fffd35a425002ae7736f + location: https://github.com/llvm-hs/llvm-hs + tag: llvm-9 subdir: llvm-hs-pure From e17e5842ef8704b58722348edd0b4a03a81447e0 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 6 Jan 2021 18:15:41 +0000 Subject: [PATCH 086/105] Basic JAX integration Make it possible to wrap Dex function atoms in JAX primitives. For now the only supported JAX functionality is abstract evaluation (shape inference). Next step will be to make it possible to JIT the functions as XLA custom calls. --- python/dex/__init__.py | 2 +- python/dex/interop/__init__.py | 5 +++ python/dex/interop/jax.py | 79 ++++++++++++++++++++++++++++++++++ python/dex/native_function.py | 2 +- python/tests/jax_test.py | 38 ++++++++++++++++ python/tests/jit_test.py | 1 - src/Dex/Foreign/Context.hs | 10 ++--- 7 files changed, 129 insertions(+), 8 deletions(-) create mode 100644 python/dex/interop/__init__.py create mode 100644 python/dex/interop/jax.py create mode 100644 python/tests/jax_test.py diff --git a/python/dex/__init__.py b/python/dex/__init__.py index 84da18d93..46ba9e088 100644 --- a/python/dex/__init__.py +++ b/python/dex/__init__.py @@ -56,7 +56,7 @@ def eval(expr: str, module=prelude, _env=None): class Atom: - __slots__ = ('_as_parameter_', 'module') + __slots__ = ('__weakref__', '_as_parameter_', 'module') def __init__(self, ptr, module): self._as_parameter_ = ptr diff --git a/python/dex/interop/__init__.py b/python/dex/interop/__init__.py new file mode 100644 index 000000000..6b607710e --- /dev/null +++ b/python/dex/interop/__init__.py @@ -0,0 +1,5 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd diff --git a/python/dex/interop/jax.py b/python/dex/interop/jax.py new file mode 100644 index 000000000..99f88aecd --- /dev/null +++ b/python/dex/interop/jax.py @@ -0,0 +1,79 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +from weakref import WeakKeyDictionary +from functools import partial +import numpy as np +import jax + +from .. import Atom +from ..native_function import ScalarType, RectContArrayType + +def primitive(f): + if not isinstance(f, Atom): + raise TypeError("DexPrimitive expects a function atom as an argument") + return partial(dex_call.bind, func_atom=f) + +compiler_cache = WeakKeyDictionary() +def get_compiled(func_atom): + compiled = compiler_cache.get(func_atom, None) + if compiled is None: + compiled = compiler_cache[func_atom] = func_atom.compile() + return compiled + + +dex_call = jax.core.Primitive('dex_call') + +@dex_call.def_impl +def dex_call_impl(*args, func_atom): + return get_compiled(func_atom)(*args) + +@dex_call.def_abstract_eval +def dex_call_abstract_eval(*args, func_atom): + # TODO: Make it possible to get the signature without compiling the function + native_func = get_compiled(func_atom) + arg_sig = native_func.explicit_argument_signature + res_sig = native_func.result_signature + if len(args) != len(arg_sig): + raise RuntimeError(f"Dex function expects {len(arg_sig)} arguments, but was given {len(args)}") + if not all(isinstance(arg, jax.core.ShapedArray) for arg in args): + raise RuntimeError("Cannot perform evaluation of Dex functions without known shapes") + # Check arguments and infer shape parameters + shape_vars = {} + for i, (arg, b) in enumerate(zip(args, arg_sig)): + expected_dtype = np.dtype(b.type.ctype) + if arg.dtype != expected_dtype: + raise RuntimeError(f"dtype mismatch in arg {i}: expected {expected_dtype}, got {arg.dtype}") + if isinstance(b.type, ScalarType): + expected_shape = () + elif isinstance(b.type, RectContArrayType): + expected_shape = b.type.shape + else: + raise AssertionError("Unhandled case!") + if len(arg.shape) != len(expected_shape): + raise RuntimeError(f"rank mismatch in arg {i}: expected {len(expected_shape)}, got {len(arg.shape)}") + inferred_shape = tuple( + size if isinstance(size, int) else shape_vars.setdefault(size, real_size) + for size, real_size in zip(expected_shape, arg.shape)) + if arg.shape != inferred_shape: + raise RuntimeError(f"shape mismatch in arg {i}: expected {inferred_shape}, got {arg.shape}") + # Infer result types + result_avals = [] + for b in res_sig: + dtype = np.dtype(b.type.ctype) + if isinstance(b.type, ScalarType): + shape = () + elif isinstance(b.type, RectContArrayType): + shape = tuple(shape_vars.get(size, size) for size in b.type.shape) + result_avals.append(jax.core.ShapedArray(shape, dtype)) + assert len(result_avals) == 1 # TODO: Make dex_call a multiple_results primitive + return result_avals[0] + +# TODO +# jax.interpreters.xla.backend_specific_translations['cpu'][self.primitive] = ... +# jax.interpreters.batching.primitive_batchers[self.primitive] = ... +# jax.interpreters.ad.primitive_jvps[self.primitive] = ... +# jax.interpreters.ad.primitive_transposes[self.primitive] = ... diff --git a/python/dex/native_function.py b/python/dex/native_function.py index 2277d6f8c..6008e34c4 100644 --- a/python/dex/native_function.py +++ b/python/dex/native_function.py @@ -57,7 +57,7 @@ def unsafe_array_ptr(self, array): def to_ctype(self, array, name_cvalue): if not isinstance(array, np.ndarray): - raise TypeError("Expected a NumPy ndarray for an array argument") + array = np.asarray(array) if array.ndim != len(self.shape): raise ValueError(f"Expected a {len(self.shape)}D array, got {array.ndim}D") expected_dtype = np.dtype(self.ctype) diff --git a/python/tests/jax_test.py b/python/tests/jax_test.py new file mode 100644 index 000000000..47f041c99 --- /dev/null +++ b/python/tests/jax_test.py @@ -0,0 +1,38 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +import unittest +import ctypes +import numpy as np +from textwrap import dedent + +# TODO: Write a proper setup.py instead of using this hack... +from pathlib import Path +import sys +sys.path.append(str(Path(__file__).parent.parent)) + +import jax +import jax.numpy as jnp + +import dex +from dex.interop.jax import primitive + +def test_impl_scalar(): + add_two = primitive(dex.eval(r'\x:Float. x + 2.0')) + x = jnp.zeros((), dtype=np.float32) + np.testing.assert_allclose(add_two(x), x + 2.0) + +def test_impl_array(): + add_two = primitive(dex.eval(r'\x:((Fin 10)=>Float). for i. x.i + 2.0')) + x = jnp.arange((10,), dtype=np.float32) + np.testing.assert_allclose(add_two(x), x + 2.0) + +def test_abstract_eval_simple(): + add_two = primitive(dex.eval(r'\x:((Fin 10)=>Float). for i. FToI $ x.i + 2.0')) + x = jax.ShapeDtypeStruct((10,), np.float32) + output_shape = jax.eval_shape(add_two, x) + assert output_shape.shape == (10,) + assert output_shape.dtype == np.int32 diff --git a/python/tests/jit_test.py b/python/tests/jit_test.py index a0594cd34..e299b845b 100644 --- a/python/tests/jit_test.py +++ b/python/tests/jit_test.py @@ -25,7 +25,6 @@ def check_atom(dex_atom, reference, args_iter): ran_any_iter = False for args in args_iter: ran_any_iter = True - print(args) np.testing.assert_allclose(compiled(*args), reference(*args), rtol=1e-4, atol=1e-6) assert ran_any_iter, "Empty argument iterator!" diff --git a/src/Dex/Foreign/Context.hs b/src/Dex/Foreign/Context.hs index 6b0ab72fe..7a0e3cbb1 100644 --- a/src/Dex/Foreign/Context.hs +++ b/src/Dex/Foreign/Context.hs @@ -46,10 +46,10 @@ dexCreateContext = do maybePreludeEnv <- evalPrelude evalConfig preludeSource case maybePreludeEnv of Right preludeEnv -> toStablePtr $ Context evalConfig preludeEnv - Left _ -> setError "Failed to initialize standard library" $> nullPtr + Left err -> nullPtr <$ setError ("Failed to initialize standard library: " ++ pprint err) where evalPrelude :: EvalConfig -> String -> IO (Either Err TopEnv) - evalPrelude opts contents = flip evalStateT mempty $ do + evalPrelude opts contents = flip evalStateT initTopEnv $ do results <- fmap snd <$> evalSource opts contents env <- get return $ env `unlessError` results @@ -83,11 +83,11 @@ dexInsert ctxPtr namePtr atomPtr = do dexEvalExpr :: Ptr Context -> CString -> IO (Ptr Atom) dexEvalExpr ctxPtr sourcePtr = do Context evalConfig env <- fromStablePtr ctxPtr - maybeExpr <- parseExpr <$> peekCString sourcePtr - case maybeExpr of + source <- peekCString sourcePtr + case parseExpr source of Right expr -> do let (v, m) = exprAsModule expr - let block = SourceBlock 0 0 LogNothing "" (RunModule m) Nothing + let block = SourceBlock 0 0 LogNothing source (RunModule m) Nothing (resultEnv, Result [] maybeErr) <- evalSourceBlock evalConfig env block case maybeErr of Right () -> do From 8282129e5e0377785660f964bf19d44dad36c4ee Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 7 Jan 2021 12:38:50 +0000 Subject: [PATCH 087/105] Add a CPU XLA translation rule for the dex_call primitive (JAX interop) This makes it possible to use Dex functions in self-contained XLA executables. Unfortunately, Python is still in the loop, as we use ctypes to generate a trampoline that adapts the Dex calling convention to the one expected by XLA. Still, I don't expect it to be too difficult to generate the right signature from Dex, so we should do that at some point in the future. --- python/dex/interop/jax.py | 83 +++++++++++++++++++++++++++++++++++---- python/tests/jax_test.py | 15 +++++++ 2 files changed, 90 insertions(+), 8 deletions(-) diff --git a/python/dex/interop/jax.py b/python/dex/interop/jax.py index 99f88aecd..df40bb50e 100644 --- a/python/dex/interop/jax.py +++ b/python/dex/interop/jax.py @@ -6,16 +6,21 @@ from weakref import WeakKeyDictionary from functools import partial +from itertools import count +import ctypes import numpy as np + import jax +from jax.lib import xla_client as xc +from jax.interpreters import xla from .. import Atom -from ..native_function import ScalarType, RectContArrayType +from ..native_function import IdxRepTy, ScalarType, RectContArrayType def primitive(f): if not isinstance(f, Atom): raise TypeError("DexPrimitive expects a function atom as an argument") - return partial(dex_call.bind, func_atom=f) + return partial(dex_call_p.bind, func_atom=f) compiler_cache = WeakKeyDictionary() def get_compiled(func_atom): @@ -25,14 +30,15 @@ def get_compiled(func_atom): return compiled -dex_call = jax.core.Primitive('dex_call') +dex_call_p = jax.core.Primitive('dex_call') -@dex_call.def_impl +@dex_call_p.def_impl def dex_call_impl(*args, func_atom): return get_compiled(func_atom)(*args) -@dex_call.def_abstract_eval -def dex_call_abstract_eval(*args, func_atom): +# === abstract evaluation / shape inference === + +def dex_call_abstract_eval_with_shape(*args, func_atom): # TODO: Make it possible to get the signature without compiling the function native_func = get_compiled(func_atom) arg_sig = native_func.explicit_argument_signature @@ -70,10 +76,71 @@ def dex_call_abstract_eval(*args, func_atom): shape = tuple(shape_vars.get(size, size) for size in b.type.shape) result_avals.append(jax.core.ShapedArray(shape, dtype)) assert len(result_avals) == 1 # TODO: Make dex_call a multiple_results primitive - return result_avals[0] + return result_avals[0], shape_vars + +@dex_call_p.def_abstract_eval +def dex_call_abstract_eval(*args, **kwargs): + return dex_call_abstract_eval_with_shape(*args, **kwargs)[0] + +# === xla translation === + +PyCapsule_Destructor = ctypes.CFUNCTYPE(None, ctypes.py_object) +PyCapsule_New = ctypes.pythonapi.PyCapsule_New +PyCapsule_New.restype = ctypes.py_object +PyCapsule_New.argtypes = (ctypes.c_void_p, ctypes.c_char_p, PyCapsule_Destructor) + +def make_custom_call_target(func_ptr): + return PyCapsule_New(func_ptr, b"xla._CUSTOM_CALL_TARGET", PyCapsule_Destructor(0)) + +# TODO: Better lifetime management. func_atoms will be quite often created on the fly +# at trace time when different transforms are applied, and I'm pretty sure that +# the XLA executables outlive jaxprs formed by tracing. +custom_call_id = count() +custom_call_cache = {} +def dex_call_cpu_translation(b, *args, func_atom): + xla_shapes = list(map(b.get_shape, args)) + result_aval, shape_vars = dex_call_abstract_eval_with_shape( + *(jax.core.ShapedArray(xshape.dimensions(), xshape.numpy_dtype()) + for xshape in xla_shapes), + func_atom=func_atom) + result_xshape = xc.Shape.array_shape(result_aval.dtype, result_aval.shape) + + custom_call = custom_call_cache.get(func_atom, None) + native = get_compiled(func_atom) + if custom_call is None: + assert len(args) == len(native.explicit_argument_signature) + assert 1 == len(native.result_signature) + custom_call_ctype = ctypes.CFUNCTYPE(None, + ctypes.c_void_p, + ctypes.POINTER(ctypes.c_void_p * len(args))) + @custom_call_ctype + def trampoline(result_ptr, arg_ptr_array): + name_to_cval = {name: IdxRepTy(value) for name, value in shape_vars.items()} + for binder, ptr in zip(native.explicit_argument_signature, arg_ptr_array.contents): + if isinstance(binder.type, ScalarType): + cval = ctypes.cast(ptr, ctypes.POINTER(binder.type.arg_ctype)).contents + elif isinstance(binder.type, RectContArrayType): + cval = ctypes.cast(ptr, binder.type.arg_ctype) + else: + raise AssertionError("Unexpected binder type") + name_to_cval[binder.name] = cval + result_binder = native.result_signature[0] + name_to_cval[result_binder.name] = ctypes.cast(result_ptr, result_binder.type.ref_ctype) + native.callable(*(name_to_cval[name] for name in native.ccall_signature)) + + trampoline_addr = ctypes.c_void_p.from_param(trampoline) + custom_call_name = f"dex_custom_call{next(custom_call_id)}".encode('ascii') + xc.register_custom_call_target(custom_call_name, + make_custom_call_target(trampoline_addr)) + custom_call_cache[func_atom] = (custom_call_name, trampoline) + # TODO: Unregister custom calls at some point? + else: + custom_call_name, *_ = custom_call + return xc.ops.CustomCall(b, custom_call_name, operands=args, shape=result_xshape) + +jax.interpreters.xla.backend_specific_translations['cpu'][dex_call_p] = dex_call_cpu_translation # TODO -# jax.interpreters.xla.backend_specific_translations['cpu'][self.primitive] = ... # jax.interpreters.batching.primitive_batchers[self.primitive] = ... # jax.interpreters.ad.primitive_jvps[self.primitive] = ... # jax.interpreters.ad.primitive_transposes[self.primitive] = ... diff --git a/python/tests/jax_test.py b/python/tests/jax_test.py index 47f041c99..02a4fbff8 100644 --- a/python/tests/jax_test.py +++ b/python/tests/jax_test.py @@ -36,3 +36,18 @@ def test_abstract_eval_simple(): output_shape = jax.eval_shape(add_two, x) assert output_shape.shape == (10,) assert output_shape.dtype == np.int32 + +def test_jit_scalar(): + add_two = primitive(dex.eval(r'\x:Float. x + 2.0')) + x = jnp.zeros((), dtype=np.float32) + np.testing.assert_allclose(jax.jit(add_two)(x), 2.0) + +def test_jit_array(): + add_two = primitive(dex.eval(r'\x:((Fin 10)=>Float). for i. FToI $ x.i + 2.0')) + x = jnp.zeros((10,), dtype=np.float32) + np.testing.assert_allclose(jax.jit(add_two)(x), (x + 2.0).astype(np.int32)) + +def test_jit_scale(): + scale = primitive(dex.eval(r'\x:((Fin 10)=>Float) y:Float. for i. x.i * y')) + x = jnp.arange((10,), dtype=np.float32) + np.testing.assert_allclose(scale(x, 5.0), x * 5.0) From 4036c60babfbdff0a4f5005230274324fc8940aa Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Thu, 7 Jan 2021 01:36:30 -0500 Subject: [PATCH 088/105] [NFC] Gardening. - Address some hlint warnings. - Add doc comments to `Dag` in LiveOutput.hs. - Improve dex help option printing. --- dex.cabal | 3 ++- src/dex.hs | 57 ++++++++++++++++++++++++------------------- src/lib/Actor.hs | 6 ++--- src/lib/Cat.hs | 2 +- src/lib/LiveOutput.hs | 19 +++++++++------ src/lib/Parser.hs | 4 +-- src/lib/TopLevel.hs | 7 +++--- 7 files changed, 55 insertions(+), 43 deletions(-) diff --git a/dex.cabal b/dex.cabal index 926c5314e..be781045d 100644 --- a/dex.cabal +++ b/dex.cabal @@ -78,7 +78,8 @@ executable dex main-is: dex.hs other-extensions: OverloadedStrings build-depends: dex, base, haskeline, prettyprinter, mtl, - optparse-applicative, unix, store, bytestring, directory + optparse-applicative, ansi-wl-pprint, + unix, store, bytestring, directory if os(darwin) build-depends: dex-resources default-language: Haskell2010 diff --git a/src/dex.hs b/src/dex.hs index cfaf7b000..2b63048db 100644 --- a/src/dex.hs +++ b/src/dex.hs @@ -11,9 +11,10 @@ import System.Exit import Control.Monad import Control.Monad.State.Strict import Options.Applicative +import Text.PrettyPrint.ANSI.Leijen (text, hardline) import System.Posix.Terminal (queryTerminal) import System.Posix.IO (stdOutput) -import System.Exit + import System.Directory import Data.List @@ -59,7 +60,7 @@ runMode evalMode preludeFile opts = do WebMode fname -> runWeb fname opts env WatchMode fname -> runTerminal fname opts env ExportMode dexPath objPath -> do - results <- fmap snd <$> (runEnv $ evalFile opts dexPath) + results <- fmap snd <$> runEnv (evalFile opts dexPath) let outputs = foldMap (\(Result outs _) -> outs) results let errors = foldMap (\case (Result _ (Left err)) -> [err]; _ -> []) results putStr $ foldMap (nonEmptyNewline . pprint) errors @@ -71,7 +72,7 @@ runMode evalMode preludeFile opts = do evalPrelude :: EvalConfig -> Maybe FilePath -> IO TopEnv evalPrelude opts fname = flip execStateT initTopEnv $ do source <- case fname of - Nothing -> return $ preludeSource + Nothing -> return preludeSource Just path -> liftIO $ readFile path result <- evalSource opts source void $ liftErrIO $ mapM (\(_, Result _ r) -> r) result @@ -96,7 +97,7 @@ dexCompletions (line, _) = do let startoflineKeywords = ["%bench \"", ":p", ":t", ":html", ":export"] let candidates = (if null rest then startoflineKeywords else []) ++ anywhereKeywords ++ varNames - let completions = map simpleCompletion $ filter ((reverse word) `isPrefixOf`) candidates + let completions = map simpleCompletion $ filter (reverse word `isPrefixOf`) candidates return (rest, completions) liftErrIO :: MonadIO m => Except a -> m a @@ -131,34 +132,39 @@ printLitProg JSONDoc prog = "{}" -> return () s -> putStrLn s +nonEmptyNewline :: String -> String nonEmptyNewline [] = [] nonEmptyNewline l = l ++ ['\n'] parseOpts :: ParserInfo CmdOpts parseOpts = simpleInfo $ CmdOpts <$> parseMode - <*> (optional $ strOption $ long "prelude" <> metavar "FILE" <> help "Prelude file") + <*> optional (strOption $ long "prelude" <> metavar "FILE" <> help "Prelude file") <*> parseEvalOpts +helpOption :: String -> String -> Mod f a +helpOption optionName options = + helpDoc (Just (text optionName <> hardline <> text options)) + parseMode :: Parser EvalMode parseMode = subparser $ - (command "repl" $ simpleInfo $ - ReplMode <$> (strOption $ long "prompt" <> value ">=> " - <> metavar "STRING" <> help "REPL prompt")) - <> (command "web" $ simpleInfo (WebMode <$> sourceFileInfo )) - <> (command "watch" $ simpleInfo (WatchMode <$> sourceFileInfo )) - <> (command "export" $ simpleInfo (ExportMode <$> sourceFileInfo <*> objectFileInfo)) - <> (command "script" $ simpleInfo (ScriptMode <$> sourceFileInfo - <*> (option - (optionList [ ("literate" , TextDoc) - , ("result-only", ResultOnly) - , ("HTML" , HTMLDoc) - , ("JSON" , JSONDoc)]) - (long "outfmt" <> value TextDoc - <> help "Output format (literate(default)|result-only|HTML|JSON")) - <*> flag HaltOnErr ContinueOnErr ( - long "allow-errors" - <> help "Evaluate programs containing non-fatal type errors"))) + command "repl" (simpleInfo + (ReplMode <$> strOption (long "prompt" <> value ">=> " + <> metavar "STRING" <> help "REPL prompt"))) + <> command "web" (simpleInfo (WebMode <$> sourceFileInfo)) + <> command "watch" (simpleInfo (WatchMode <$> sourceFileInfo)) + <> command "export" (simpleInfo (ExportMode <$> sourceFileInfo <*> objectFileInfo)) + <> command "script" (simpleInfo (ScriptMode <$> sourceFileInfo + <*> option + (optionList [ ("literate" , TextDoc) + , ("result-only", ResultOnly) + , ("HTML" , HTMLDoc) + , ("JSON" , JSONDoc)]) + (long "outfmt" <> value TextDoc <> + helpOption "Output format" "literate (default) | result-only | HTML | JSON") + <*> flag HaltOnErr ContinueOnErr ( + long "allow-errors" + <> help "Evaluate programs containing non-fatal type errors"))) where sourceFileInfo = argument str (metavar "FILE" <> help "Source program") objectFileInfo = argument str (metavar "OBJFILE" <> help "Output path (.o file)") @@ -170,13 +176,14 @@ optionList opts = eitherReader \s -> case lookup s opts of parseEvalOpts :: Parser EvalConfig parseEvalOpts = EvalConfig - <$> (option + <$> option (optionList [ ("LLVM", LLVM) , ("LLVM-CUDA", LLVMCUDA) , ("LLVM-MC", LLVMMC) , ("interp", Interp)]) - (long "backend" <> value LLVM <> help "Backend (LLVM(default)|LLVM-CUDA|interp)")) - <*> (optional $ strOption $ long "logto" + (long "backend" <> value LLVM <> + helpOption "Backend" "LLVM (default) | LLVM-CUDA | LLVM-MC | interp") + <*> optional (strOption $ long "logto" <> metavar "FILE" <> help "File to log to" <> showDefault) diff --git a/src/lib/Actor.hs b/src/lib/Actor.hs index f152b2ba6..fbf0cb6e8 100644 --- a/src/lib/Actor.hs +++ b/src/lib/Actor.hs @@ -43,7 +43,7 @@ runActor (Actor m) = do linksRef <- newIORef [] chan <- newBackChan tid <- myThreadId - let p = (Proc Trap tid (asErrPChan chan)) + let p = Proc Trap tid (asErrPChan chan) runReaderT m (ActorConfig p chan linksRef) subChan :: (a -> b) -> PChan b -> PChan a @@ -123,7 +123,7 @@ receive :: MonadActor msg m => m msg receive = receiveF Just newBackChan :: IO (BackChan a) -newBackChan = liftM2 BackChan (newIORef []) (newChan) +newBackChan = liftM2 BackChan (newIORef []) newChan readBackChan :: BackChan a -> IO a readBackChan (BackChan ptr chan) = do xs <- readIORef ptr @@ -173,6 +173,6 @@ logServer = flip evalStateT (mempty, []) $ forever $ do Push x -> do modify $ onFst (<> x) subscribers <- gets snd - mapM_ (flip send x) subscribers + mapM_ (`send` x) subscribers -- TODO: state machine? diff --git a/src/lib/Cat.hs b/src/lib/Cat.hs index f120df661..01aa6d062 100644 --- a/src/lib/Cat.hs +++ b/src/lib/Cat.hs @@ -41,7 +41,7 @@ instance (Monoid env, Monad m) => MonadCat env (CatT env m) where put (fullState <> x, localState <> x) scoped (CatT m) = CatT $ do originalState <- get - put $ (fst originalState, mempty) + put (fst originalState, mempty) ans <- m newLocalState <- gets snd put originalState diff --git a/src/lib/LiveOutput.hs b/src/lib/LiveOutput.hs index 640fcdb8b..83a302b43 100644 --- a/src/lib/LiveOutput.hs +++ b/src/lib/LiveOutput.hs @@ -92,7 +92,7 @@ sourceBlockToDag block = do -- TODO: Stop forcing dependencies on all preceding blocks. This will require -- an improvement of the analysis above, such that all blocks depend on those -- that contain interface instance definitions. - extend $ (foldMap ((@>n) . Bind) $ envAsVars $ boundUVars block, [n]) + extend (foldMap ((@>n) . Bind) $ envAsVars $ boundUVars block, [n]) case sbContents block of IncludeSourceFile _ -> extend $ asSnd [n] _ -> return () @@ -145,7 +145,7 @@ oneSourceBlock k b = RFragment mempty (M.singleton k b) mempty serveResults :: StreamingBody -> Application serveResults results request respond = do - putStrLn (show $ pathInfo request) + print (pathInfo request) case pathInfo request of [] -> respondWith "static/index.html" "text/html" ["style.css"] -> respondWith "static/style.css" "text/css" @@ -204,7 +204,7 @@ displayResultsTerm reqChan = c <- myChan send reqChan $ subChan Left c void $ spawn Trap $ monitorKeyboard $ subChan Right c - forever $ termDisplayLoop + forever termDisplayLoop termDisplayLoop :: TermDisplayM () termDisplayLoop = do @@ -232,7 +232,7 @@ cropTrailingLines n s = unlines $ reverse $ drop n $ reverse $ lines s renderResults :: RFragment -> Maybe String renderResults (RFragment NotSet _ _) = Nothing renderResults (RFragment (Set ids) blocks results) = - liftM fold $ flip mapM ids $ \i -> do + liftM fold $ forM ids $ \i -> do b <- M.lookup i blocks r <- M.lookup i results return $ printLitBlock True b r @@ -241,7 +241,7 @@ monitorKeyboard :: PChan KeyboardCommand -> Actor () () monitorKeyboard chan = do liftIO $ hSetBuffering stdin NoBuffering forever $ do - c <- liftIO $ getChar + c <- liftIO getChar case c of 'k' -> send chan ScrollUp 'j' -> send chan ScrollDown @@ -274,10 +274,15 @@ onmod fname action = do -- === DAG utils === +-- | A pair of an @a@ and a list of neighbor node ids. type Node a = (a, [NodeId]) -data Dag a = Dag (M.Map NodeId (Node a)) (M.Map (a, [NodeId]) NodeId) --- returns the addition only, not the new DAG +-- | A directed acyclic graph, represented as a bidirectional map from node ids +-- to nodes. +data Dag a = Dag (M.Map NodeId (Node a)) (M.Map (Node a) NodeId) + +-- | Adds a node to a DAG, if it does not already exist. +-- Returns the added node id and a DAG representing the added node. addToDag :: Ord a => Dag a -> Node a -> (NodeId, Dag a) addToDag (Dag _ m) node = case M.lookup node m of diff --git a/src/lib/Parser.hs b/src/lib/Parser.hs index f7b35c7f8..cb6574b8b 100644 --- a/src/lib/Parser.hs +++ b/src/lib/Parser.hs @@ -90,7 +90,7 @@ logLevel = do passes <- many passName eol case passes of - [] -> return $ LogAll + [] -> return LogAll _ -> return $ LogPasses passes logTime :: Parser LogLevel @@ -131,7 +131,7 @@ proseBlock = label "prose block" $ char '\'' >> fmap (ProseBlock . fst) (withSou topLevelCommand :: Parser SourceBlock' topLevelCommand = - (liftM IncludeSourceFile includeSourceFile) + liftM IncludeSourceFile includeSourceFile <|> explicitCommand "top-level command" diff --git a/src/lib/TopLevel.hs b/src/lib/TopLevel.hs index 9399ecfdf..63b203099 100644 --- a/src/lib/TopLevel.hs +++ b/src/lib/TopLevel.hs @@ -14,7 +14,6 @@ import Control.Monad.Reader import Control.Monad.Except hiding (Except) import Data.Text.Prettyprint.Doc import Data.String -import Data.Maybe import Data.List (partition) import qualified Data.Map.Strict as M @@ -119,7 +118,7 @@ processLogs :: LogLevel -> [Output] -> [Output] processLogs logLevel logs = case logLevel of LogAll -> logs LogNothing -> [] - LogPasses passes -> flip filter logs \l -> case l of + LogPasses passes -> flip filter logs \case PassInfo pass _ | pass `elem` passes -> True | otherwise -> False _ -> False @@ -135,7 +134,7 @@ timesFromLogs logs = (totalTime - totalEvalTime, singleEvalTime, benchStats) case [(t, stats) | EvalTime t stats <- logs] of [] -> (0.0 , 0.0, Nothing) [(t, stats)] -> (total, t , stats) - where total = fromMaybe t $ fmap snd stats + where total = maybe t snd stats _ -> error "Expect at most one result" totalTime = case [tTotal | TotalTime tTotal <- logs] of [] -> 0.0 @@ -221,7 +220,7 @@ evalBackend env block = do withCompileTime :: TopPassM a -> TopPassM a withCompileTime m = do - (ans, t) <- measureSeconds $ m + (ans, t) <- measureSeconds m logTop $ TotalTime t return ans From cf06f3d6426aba47df7c4977b130e0ac6251eb24 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Thu, 7 Jan 2021 02:26:02 -0500 Subject: [PATCH 089/105] Gardening. Rename `dex --backend interp` flag to `--backend interpreter`. `interpreter` is a noun and reads more naturally. --- src/dex.hs | 4 ++-- src/lib/Imp.hs | 8 ++++---- src/lib/Syntax.hs | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/dex.hs b/src/dex.hs index 2b63048db..72bc85447 100644 --- a/src/dex.hs +++ b/src/dex.hs @@ -180,9 +180,9 @@ parseEvalOpts = EvalConfig (optionList [ ("LLVM", LLVM) , ("LLVM-CUDA", LLVMCUDA) , ("LLVM-MC", LLVMMC) - , ("interp", Interp)]) + , ("interpreter", Interpreter)]) (long "backend" <> value LLVM <> - helpOption "Backend" "LLVM (default) | LLVM-CUDA | LLVM-MC | interp") + helpOption "Backend" "LLVM (default) | LLVM-CUDA | LLVM-MC | interpreter") <*> optional (strOption $ long "logto" <> metavar "FILE" <> help "File to log to" <> showDefault) diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index 0ef075003..0039d7baf 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -866,10 +866,10 @@ chooseAddrSpace (backend, curDev, allocTy) numel = case allocTy of else Heap mainDev | otherwise -> Heap mainDev where mainDev = case backend of - LLVM -> CPU - LLVMMC -> CPU - LLVMCUDA -> GPU - Interp -> error "Shouldn't be compiling with interpreter backend" + LLVM -> CPU + LLVMMC -> CPU + LLVMCUDA -> GPU + Interpreter -> error "Shouldn't be compiling with interpreter backend" isSmall :: Block -> Bool isSmall numel = case numel of diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index d15a4f997..4aa7d8d3c 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -559,7 +559,7 @@ data ImpInstr = IFor Direction IBinder Size ImpBlock | IPrimOp IPrimOp deriving (Show) -data Backend = LLVM | LLVMCUDA | LLVMMC | Interp deriving (Show, Eq) +data Backend = LLVM | LLVMCUDA | LLVMMC | Interpreter deriving (Show, Eq) newtype CUDAKernel = CUDAKernel B.ByteString deriving (Show) -- === base types === From 0451c7c83bc79d36b56528c109d5242c836a5b95 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Thu, 7 Jan 2021 09:11:53 -0500 Subject: [PATCH 090/105] Add .hlint.yaml to match repository coding style preferences. Ignore "Use fmap": some use sites (e.g. of `liftM`) are intentional and more readable. Consider adding more cases in the future, based on coding style preferences and noise level from hlint-aware IDEs: ``` - ignore: {name: "Use <$>"} - ignore: {name: "Eta reduce"} ``` --- .hlint.yaml | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 .hlint.yaml diff --git a/.hlint.yaml b/.hlint.yaml new file mode 100644 index 000000000..44155cbac --- /dev/null +++ b/.hlint.yaml @@ -0,0 +1,2 @@ +- arguments: [--color] +- ignore: {name: "Use fmap"} From 37e4f50f555a1966a48c0ed09449639848bb9cf3 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Thu, 7 Jan 2021 23:44:32 -0500 Subject: [PATCH 091/105] Make `--backend` and `--outfmt` options be consistently lowercase. It seems standard for CLIs to use lowercase option names. Changed option names: - `--backend llvm` (was LLVM) - `--backend llvm-cuda` (was LLVM-CUDA) - `--backend llvm-mc` (was LLVM-MC) - `--outfmt html` (was HTML) - `--outfmt json` (was JSON) --- benchmarks/dexbench.py | 6 +++--- makefile | 10 +++++----- src/dex.hs | 14 +++++++------- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/benchmarks/dexbench.py b/benchmarks/dexbench.py index b66634c7e..6dbf7391d 100644 --- a/benchmarks/dexbench.py +++ b/benchmarks/dexbench.py @@ -55,15 +55,15 @@ def restore_machine(): def run_benches(lang, backend): if lang == "dex": if backend == "CPU": - backend_args = ["--backend", "LLVM-MC"] + backend_args = ["--backend", "llvm-mc"] env = {} elif backend == "GPU": - backend_args = ["--backend", "LLVM-CUDA"] + backend_args = ["--backend", "llvm-cuda"] env = {"CUDA_LAUNCH_BLOCKING":"1"} else: raise Exception command = (["stack", "exec", "dex", "--"] + backend_args + - ["script", "--outfmt", "JSON", dex_microbench_file]) + ["script", "--outfmt", "json", dex_microbench_file]) elif lang == "jax": if backend == "CPU": env = {"CUDA_VISIBLE_DEVICES":""} diff --git a/makefile b/makefile index ed75f9f24..fdbb5bb3a 100644 --- a/makefile +++ b/makefile @@ -137,11 +137,11 @@ update-examples-%: examples/%.dx build run-gpu-tests: export DEX_ALLOC_CONTRACTIONS=0 run-gpu-tests: tests/gpu-tests.dx build - misc/check-quine $< $(dex) --backend LLVM-CUDA script --allow-errors + misc/check-quine $< $(dex) --backend llvm-cuda script --allow-errors update-gpu-tests: export DEX_ALLOW_CONTRACTIONS=0 update-gpu-tests: tests/gpu-tests.dx build - $(dex) --backend LLVM-CUDA script --allow-errors $< > $<.tmp + $(dex) --backend llvm-cuda script --allow-errors $< > $<.tmp mv $<.tmp $< uexpr-tests: @@ -175,15 +175,15 @@ docs: doc-prelude $(doc-example-names) $(doc-lib-names) $(slow-docs) doc-prelude: lib/prelude.dx mkdir -p doc - $(dex) --prelude /dev/null script lib/prelude.dx --outfmt HTML > doc/prelude.html + $(dex) --prelude /dev/null script lib/prelude.dx --outfmt html > doc/prelude.html doc/examples/%.html: examples/%.dx mkdir -p doc/examples - $(dex) script $^ --outfmt HTML > $@ + $(dex) script $^ --outfmt html > $@ doc/lib/%.html: lib/%.dx mkdir -p doc/lib - $(dex) script $^ --outfmt HTML > $@ + $(dex) script $^ --outfmt html > $@ clean: $(STACK) clean diff --git a/src/dex.hs b/src/dex.hs index 72bc85447..d6102f503 100644 --- a/src/dex.hs +++ b/src/dex.hs @@ -158,10 +158,10 @@ parseMode = subparser $ <*> option (optionList [ ("literate" , TextDoc) , ("result-only", ResultOnly) - , ("HTML" , HTMLDoc) - , ("JSON" , JSONDoc)]) + , ("html" , HTMLDoc) + , ("json" , JSONDoc)]) (long "outfmt" <> value TextDoc <> - helpOption "Output format" "literate (default) | result-only | HTML | JSON") + helpOption "Output format" "literate (default) | result-only | html | json") <*> flag HaltOnErr ContinueOnErr ( long "allow-errors" <> help "Evaluate programs containing non-fatal type errors"))) @@ -177,12 +177,12 @@ optionList opts = eitherReader \s -> case lookup s opts of parseEvalOpts :: Parser EvalConfig parseEvalOpts = EvalConfig <$> option - (optionList [ ("LLVM", LLVM) - , ("LLVM-CUDA", LLVMCUDA) - , ("LLVM-MC", LLVMMC) + (optionList [ ("llvm", LLVM) + , ("llvm-cuda", LLVMCUDA) + , ("llvm-mc", LLVMMC) , ("interpreter", Interpreter)]) (long "backend" <> value LLVM <> - helpOption "Backend" "LLVM (default) | LLVM-CUDA | LLVM-MC | interpreter") + helpOption "Backend" "llvm (default) | llvm-cuda | llvm-mc | interpreter") <*> optional (strOption $ long "logto" <> metavar "FILE" <> help "File to log to" <> showDefault) From 08db4cfbfa5905a68a80f7dbca791f72542b06c3 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Fri, 8 Jan 2021 13:11:26 -0500 Subject: [PATCH 092/105] Fix `stack build` on macOS. (#441) Disable `-Wnonportable-include-path` for executable dex in dex.cabal. I am not sure caused this issue to surface. The same issue previously occurred for other targets in dex.cabal. Workaround suggested by discussion: https://github.com/haskell/cabal/issues/4739#issuecomment-359209133 --- dex.cabal | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dex.cabal b/dex.cabal index be781045d..5a639a818 100644 --- a/dex.cabal +++ b/dex.cabal @@ -85,7 +85,7 @@ executable dex default-language: Haskell2010 hs-source-dirs: src default-extensions: CPP, LambdaCase, BlockArguments - ghc-options: -threaded + ghc-options: -threaded -optP-Wno-nonportable-include-path if flag(optimized) ghc-options: -O3 else From 1c0a4173496a1e49de295c1ca649044f0a1e0dc0 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Wed, 6 Jan 2021 17:14:25 -0500 Subject: [PATCH 093/105] Put `SumAsProd` accesses under switches to avoid illegal dereferencing. Fixes #348. --- src/lib/Imp.hs | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index 0039d7baf..daf131df5 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -692,6 +692,11 @@ copyAtom (Con dest) src = case (dest, src) of (ConRef (SumAsProd _ tag payload), DataCon _ _ con x) -> do copyAtom tag (TagRepVal $ fromIntegral con) zipWithM_ copyAtom (payload !! con) x + (ConRef (SumAsProd _ tagDest payloadDest), Con (SumAsProd _ tag payload)) -> do + copyAtom tagDest tag + unless (all null payload) $ -- optimization + emitSwitch (fromScalarAtom tag) $ + zipWith (zipWithM_ copyAtom) payloadDest payload (ConRef destCon, Con srcCon) -> zipWithRefConM copyAtom destCon srcCon (RecordRef refs, Record vals) | fmap (const ()) refs == fmap (const ()) vals -> do @@ -832,6 +837,8 @@ splitDest (maybeDest, (Block decls ans)) = do (_, Con (Lit _)) -> tell [(dest, result)] -- This is conservative, in case the type is dependent. We could do better. (DataConRef _ _ _, DataCon _ _ _ _) -> tell [(dest, result)] + -- This is conservative. Without it, we hit bugs like #348 + (Con (ConRef (SumAsProd _ _ _)), Con (SumAsProd _ _ _)) -> tell [(dest, result)] (Con (ConRef destCon), Con srcCon) -> zipWithRefConM gatherVarDests destCon srcCon (Con (RecordRef items), Record items') @@ -952,8 +959,6 @@ zipWithRefConM :: Monad m => (Dest -> Atom -> m ()) -> Con -> Con -> m () zipWithRefConM f destCon srcCon = case (destCon, srcCon) of (PairCon d1 d2, PairCon s1 s2) -> f d1 s1 >> f d2 s2 (UnitCon, UnitCon) -> return () - (SumAsProd _ tagRef xssRef, SumAsProd _ tag xss) -> - f tagRef tag >> zipWithM_ (zipWithM f) xssRef xss (IntRangeVal _ _ iRef, IntRangeVal _ _ i) -> f iRef i (IndexRangeVal _ _ _ iRef, IndexRangeVal _ _ _ i) -> f iRef i _ -> error $ "Unexpected ref/val " ++ pprint (destCon, srcCon) @@ -972,6 +977,10 @@ addToAtom dest src = case (dest, src) of updated <- emitInstr $ IPrimOp $ op FAdd cur x' storeAnywhere ptr' updated (Con (TabRef _), TabVal _ _) -> zipTabDestAtom addToAtom dest src + (Con (ConRef (SumAsProd _ _ payloadDest)), Con (SumAsProd _ tag payload)) -> do + unless (all null payload) $ -- optimization + emitSwitch (fromScalarAtom tag) $ + zipWith (zipWithM_ addToAtom) payloadDest payload (Con (ConRef destCon), Con srcCon) -> zipWithRefConM addToAtom destCon srcCon (Con (RecordRef dests), Record srcs) -> zipWithM_ addToAtom (toList dests) (toList srcs) From ae2492d680d6656464d1de8a890a8235d510526d Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Thu, 7 Jan 2021 17:04:18 -0500 Subject: [PATCH 094/105] Small changes that were blocked on #348. --- lib/diagram.dx | 21 +++++++++------------ lib/plot.dx | 3 +-- lib/prelude.dx | 6 +----- tests/io-tests.dx | 5 ++--- 4 files changed, 13 insertions(+), 22 deletions(-) diff --git a/lib/diagram.dx b/lib/diagram.dx index 4e14b8a89..d27bb5390 100644 --- a/lib/diagram.dx +++ b/lib/diagram.dx @@ -8,19 +8,17 @@ data Geom = Rectangle Float Float -- width, height Line Point --- HTML color (no alpha) --- TODO: replace with `Fin 3 => Word8` when we fix #348 -HtmlColor : Type = (Word8 & Word8 & Word8) +HtmlColor : Type = Fin 3 => Word8 def showHex (x:Int32) : String = unsafeIO \(). (n, ptr) = %ffi showHex (Int32 & RawPtr) x stringFromCharPtr n (MkPtr ptr) -black : HtmlColor = (IToW8 0, IToW8 0, IToW8 0) -white : HtmlColor = (IToW8 255, IToW8 255, IToW8 255) -red : HtmlColor = (IToW8 255, IToW8 0, IToW8 0) -green : HtmlColor = (IToW8 0, IToW8 255, IToW8 0) -blue : HtmlColor = (IToW8 0, IToW8 0, IToW8 255) +black : HtmlColor = [IToW8 0, IToW8 0, IToW8 0] +white : HtmlColor = [IToW8 255, IToW8 255, IToW8 255] +red : HtmlColor = [IToW8 255, IToW8 0, IToW8 0] +green : HtmlColor = [IToW8 0, IToW8 255, IToW8 0] +blue : HtmlColor = [IToW8 0, IToW8 0, IToW8 255] GeomStyle : Type = { fillColor : Maybe HtmlColor @@ -127,8 +125,7 @@ def (<=>) [Show b] (attr:String) (val:b) : String = attr <.> "=" <.> quote (show val) def htmlColor(cs:HtmlColor) : String = - (r, g, b) = cs - "#" <> (showHex $ W8ToI r) <> (showHex $ W8ToI g) <> (showHex $ W8ToI b) + "#" <> (concat $ for i. showHex (W8ToI cs.i)) def optionalHtmlColor(c: Maybe HtmlColor) : String = case c of @@ -137,8 +134,8 @@ def optionalHtmlColor(c: Maybe HtmlColor) : String = @noinline def attrString (attr:GeomStyle) : String = - ( -- "stroke" <=> (optionalHtmlColor $ getAt #strokeColor attr) - ("fill" <=> (optionalHtmlColor $ getAt #fillColor attr)) + ( ("stroke" <=> (optionalHtmlColor $ getAt #strokeColor attr)) + <+> ("fill" <=> (optionalHtmlColor $ getAt #fillColor attr)) <+> ("stroke-width" <=> (getAt #strokeWidth attr))) def renderGeom (attr:GeomStyle) ((x,y):Point) (geom:Geom) : String = diff --git a/lib/plot.dx b/lib/plot.dx index 4529435fb..56c36f647 100644 --- a/lib/plot.dx +++ b/lib/plot.dx @@ -54,8 +54,7 @@ def interpolate [VSpace a] (low:a) (high:a) (x:Float) : a = (x' .* low) + ((1.0 - x') .* high) def makeRGBColor (c : Color) : HtmlColor = - [r, g, b] = for i. IToW8 $ FToI $ floor (255.0 * c.i) - (r, g, b) + for i. IToW8 $ FToI $ floor (255.0 * c.i) def colorScale (x:Float) : HtmlColor = makeRGBColor $ interpolate lowColor highColor x diff --git a/lib/prelude.dx b/lib/prelude.dx index 82d8d77d4..596f9ad3f 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -1102,11 +1102,7 @@ def getEnv (name:String) : {IO} Maybe String = fromCString $ MkCString $ %ffi getenv RawPtr ptr def checkEnv (name:String) : {IO} Bool = - -- This should be just `isJust $ getEnv name` but that segfaults (only if the - -- env var *is* defined), possibly related to bug #348. - withCString name \(MkCString ptr). - resultPtr = %ffi getenv RawPtr ptr - not $ resultPtr == nullRawPtr + isJust $ getEnv name def fread (stream:Stream ReadMode) : {IO} String = (MkStream stream') = stream diff --git a/tests/io-tests.dx b/tests/io-tests.dx index fdc59c995..896e2d98c 100644 --- a/tests/io-tests.dx +++ b/tests/io-tests.dx @@ -67,9 +67,8 @@ unsafeIO \(). :p unsafeIO do getEnv "NOT_AN_ENV_VAR" > Nothing --- disabled because of bug #348 --- :p unsafeIO do getEnv "DEX_TEST_MODE" --- > (Just (AsList 1 "t")) +:p unsafeIO do getEnv "DEX_TEST_MODE" +> (Just (AsList 1 "t")) :p dex_test_mode () > True From e979cae84c9b0cd612bed1013cdecf71e8c0d917 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Thu, 7 Jan 2021 17:07:02 -0500 Subject: [PATCH 095/105] Fix a shadowing bug by making `emitBlock` freshen its binders. Also add tests for this bug and for #348. --- src/lib/Embed.hs | 15 +++++++++------ tests/adt-tests.dx | 29 +++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/src/lib/Embed.hs b/src/lib/Embed.hs index f19303e32..4fd4d6a27 100644 --- a/src/lib/Embed.hs +++ b/src/lib/Embed.hs @@ -109,13 +109,16 @@ emitOp op = emit $ Op op emitUnpack :: MonadEmbed m => Expr -> m [Atom] emitUnpack expr = getUnpacked =<< emit expr --- Assumes the decl binders are already fresh wrt current scope emitBlock :: MonadEmbed m => Block -> m Atom -emitBlock (Block decls result) = do - mapM_ emitDecl decls - case result of - Atom x -> return x - _ -> emit result +emitBlock block = emitBlockRec mempty block + +emitBlockRec :: MonadEmbed m => SubstEnv -> Block -> m Atom +emitBlockRec env (Block (Nest (Let ann b expr) decls) result) = do + expr' <- substEmbed env expr + x <- emitTo (binderNameHint b) ann expr' + emitBlockRec (env <> b@>x) $ Block decls result +emitBlockRec env (Block Empty (Atom atom)) = substEmbed env atom +emitBlockRec env (Block Empty expr) = substEmbed env expr >>= emit freshVarE :: MonadEmbed m => BinderInfo -> Binder -> m Var freshVarE bInfo b = do diff --git a/tests/adt-tests.dx b/tests/adt-tests.dx index 1d2d2306e..683a46228 100644 --- a/tests/adt-tests.dx +++ b/tests/adt-tests.dx @@ -289,3 +289,32 @@ def nestedUnpack (x:MyPair Int (MyPair (MyIntish & Int) Int)) : Int = :p nestedUnpack (MkMyPair 3 (MkMyPair (MkIntish 4, 5) 6)) > 4 + +data MySum = + Foo Float + Bar String + +-- bug #348 +:p + xs = for i:(Fin 3). + if ordinal i < 2 + then Foo 2.0 + else Foo 1.0 + (xs, xs) +> ([(Foo 2.), (Foo 2.), (Foo 1.)], [(Foo 2.), (Foo 2.), (Foo 1.)]) + +data MySum2 = + Foo2 + Bar2 (Fin 3 => Int) + +-- bug #348 +:p concat for i:(Fin 4). AsList _ [(Foo2, Foo2)] +> (AsList 4 [(Foo2, Foo2), (Foo2, Foo2), (Foo2, Foo2), (Foo2, Foo2)]) + +-- reproducer for a shadowing bug +:p concat $ for i:(Fin 2). toList [(Just [0,0,0], Just [0,0,0]), + (Just [0,0,0], Just [0,0,0])] +> (AsList 4 [ ((Just [0, 0, 0]), (Just [0, 0, 0])) +> , ((Just [0, 0, 0]), (Just [0, 0, 0])) +> , ((Just [0, 0, 0]), (Just [0, 0, 0])) +> , ((Just [0, 0, 0]), (Just [0, 0, 0])) ]) From d37b9cb04fe86d0723fe8bfbee979e51162cec06 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Thu, 7 Jan 2021 17:07:59 -0500 Subject: [PATCH 096/105] Strengthen the Imp type checks and add a debugging printf to JIT. The new Imp check would have caught the bug from the previous commit. --- src/lib/Imp.hs | 17 +++++++++++++---- src/lib/JIT.hs | 15 +++++++++++++++ 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index daf131df5..e34529f49 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -282,6 +282,7 @@ toImpOp (maybeDest, op) = case op of IOFree ptr -> do emitStatement $ Free $ fromScalarAtom ptr return UnitVal + PtrOffset arr (IdxRepVal 0) -> returnVal arr PtrOffset arr off -> do buf <- impOffset (fromScalarAtom arr) (fromScalarAtom off) returnVal $ toScalarAtom buf @@ -930,7 +931,8 @@ toScalarType b = BaseTy b fromEmbed :: Subst a => Embed a -> ImpM a fromEmbed m = do scope <- variableScope - let (ans, (_, decls)) = runEmbed m scope + let (ans, (scopeDelta, decls)) = runEmbed m scope + extend $ mempty { envScope = scopeDelta } env <- catFoldM translateDecl mempty $ fmap (Nothing,) decls impSubst env ans @@ -1203,7 +1205,7 @@ checkDecl decl@(ImpLet bs instr) = addContext ctx $ do instrTypeChecked :: ImpInstr -> ImpCheckM [IType] instrTypeChecked instr = case instr of IFor _ i size block -> do - checkInt size + checkIdxRep size checkBinder i assertEq (binderAnn i) (getIType size) $ "Mismatch between the loop iterator and upper bound type" [] <- withTypeEnv (i @> getIType size) $ checkBlock block @@ -1236,14 +1238,16 @@ instrTypeChecked instr = case instr of _ -> throw CompilerErr $ "Can't cast " ++ pprint st ++ " to " ++ pprint dt return dt - Alloc a ty _ -> (:[]) <$> do + + Alloc a ty n -> (:[]) <$> do + checkIdxRep n when (a /= Stack) assertHost return $ PtrType (a, ty) MemCopy dest src numel -> [] <$ do PtrType (_, destTy) <- checkIExpr dest PtrType (_, srcTy) <- checkIExpr src assertEq destTy srcTy "pointer type mismatch" - checkInt numel + checkIdxRep numel Store dest val -> [] <$ do PtrType (addr, ty) <- checkIExpr dest checkAddrAccessible addr @@ -1282,6 +1286,11 @@ checkIExpr expr = case expr of Nothing -> throw CompilerErr $ "Lookup failed: " ++ pprint v Just x -> return x +checkIdxRep :: IExpr -> ImpCheckM () +checkIdxRep expr = do + t <- checkIExpr expr + assertEq IIdxRepTy t $ "Not an index rep tye: " ++ pprint t + checkInt :: IExpr -> ImpCheckM () checkInt expr = do bt <- checkIExpr expr diff --git a/src/lib/JIT.hs b/src/lib/JIT.hs index a2e724266..860829bc3 100644 --- a/src/lib/JIT.hs +++ b/src/lib/JIT.hs @@ -583,6 +583,21 @@ _gpuDebugPrint i32Val = do genericPtrTy ty = L.PointerType ty $ L.AddrSpace 0 vprintfSpec = ExternFunSpec "vprintf" i32 [] [] [genericPtrTy i8, genericPtrTy i8] +-- Takes a single int64 payload. TODO: implement a varargs version +_debugPrintf :: String -> Operand -> Compile () +_debugPrintf fmtStr x = do + let chars = map (C.Int 8) $ map (fromIntegral . fromEnum) fmtStr ++ [0] + let formatStrArr = L.ConstantOperand $ C.Array i8 chars + formatStrPtr <- alloca (length chars) i8 + castLPtr (L.typeOf formatStrArr) formatStrPtr >>= (`store` formatStrArr) + void $ emitExternCall printfSpec [formatStrPtr, x] + where printfSpec = ExternFunSpec "printf" i32 [] [] [hostVoidp, i64] + +_debugPrintfPtr :: String -> Operand -> Compile () +_debugPrintfPtr s x = do + x' <- emitInstr i64 $ L.PtrToInt x i64 [] + _debugPrintf s x' + compileBlock :: ImpBlock -> Compile [Operand] compileBlock (ImpBlock Empty result) = traverse compileExpr result compileBlock (ImpBlock (Nest decl rest) result) = do From d3fd463b1873c1c1f72d77b46d912b5e19ea7f5c Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Thu, 7 Jan 2021 17:22:35 -0500 Subject: [PATCH 097/105] Fix conversion of Word8 to hex string for diagram colors. --- lib/diagram.dx | 4 ++-- src/lib/dexrt.cpp | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/diagram.dx b/lib/diagram.dx index d27bb5390..000943acc 100644 --- a/lib/diagram.dx +++ b/lib/diagram.dx @@ -10,7 +10,7 @@ data Geom = HtmlColor : Type = Fin 3 => Word8 -def showHex (x:Int32) : String = unsafeIO \(). +def showHex (x:Word8) : String = unsafeIO \(). (n, ptr) = %ffi showHex (Int32 & RawPtr) x stringFromCharPtr n (MkPtr ptr) @@ -125,7 +125,7 @@ def (<=>) [Show b] (attr:String) (val:b) : String = attr <.> "=" <.> quote (show val) def htmlColor(cs:HtmlColor) : String = - "#" <> (concat $ for i. showHex (W8ToI cs.i)) + "#" <> (concat $ for i. showHex cs.i) def optionalHtmlColor(c: Maybe HtmlColor) : String = case c of diff --git a/src/lib/dexrt.cpp b/src/lib/dexrt.cpp index 389455c40..be89d4028 100644 --- a/src/lib/dexrt.cpp +++ b/src/lib/dexrt.cpp @@ -148,9 +148,9 @@ double randunif(uint64_t keypair) { return out - 1; } -void showHex(char **resultPtr, int x) { +void showHex(char **resultPtr, char x) { auto p = reinterpret_cast(malloc_dex(100)); // TODO: something better! - auto n = sprintf(p, "%02x", x); + auto n = sprintf(p, "%02hhX", x); auto result1Ptr = reinterpret_cast(resultPtr[0]); auto result2Ptr = reinterpret_cast( resultPtr[1]); *result1Ptr = n; From 113bdf7470e6ab2d2b8ece446fd31875d3dace68 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Thu, 7 Jan 2021 17:36:50 -0500 Subject: [PATCH 098/105] Revert "Revert #376 to fix segfaults in main" This reverts commit 9e6e470254626bc2d4aaf9fe9fcff4b915817f93. --- lib/diagram.dx | 46 +++++++++++++++++++++++++++++++++++++--------- 1 file changed, 37 insertions(+), 9 deletions(-) diff --git a/lib/diagram.dx b/lib/diagram.dx index 000943acc..bbfff2ef4 100644 --- a/lib/diagram.dx +++ b/lib/diagram.dx @@ -7,6 +7,7 @@ data Geom = Circle Float Rectangle Float Float -- width, height Line Point + Text String HtmlColor : Type = Fin 3 => Word8 @@ -60,6 +61,7 @@ flipY : Diagram -> Diagram = Circle r -> Circle r Rectangle w h -> Rectangle w h Line (x, y) -> Line (x, -y) + Text x -> Text x def scale (s:Float) : (Diagram -> Diagram) = applyTransformation ( \(x,y). (s * x, s * y) ) \geom. case geom of @@ -67,6 +69,7 @@ def scale (s:Float) : (Diagram -> Diagram) = Circle r -> Circle (s * r) Rectangle w h -> Rectangle (s * w) (s * h) Line (x, y) -> Line (s * x, s * y) + Text x -> Text x def moveXY ((offX, offY) : Point) : (Diagram -> Diagram) = applyTransformation (\(x,y). (x + offX, y + offY) ) id @@ -78,6 +81,7 @@ def pointDiagram : Diagram = singletonDefault PointGeom def circle (r:Float) : Diagram = singletonDefault $ Circle r def rect (w:Float) (h:Float) : Diagram = singletonDefault $ Rectangle w h def line (p:Point) : Diagram = singletonDefault $ Line p +def text (x:String) : Diagram = singletonDefault $ Text x def updateGeom (update: GeomStyle -> GeomStyle) (d:Diagram) : Diagram = (MkDiagram (AsList _ objs)) = d @@ -139,11 +143,14 @@ def attrString (attr:GeomStyle) : String = <+> ("stroke-width" <=> (getAt #strokeWidth attr))) def renderGeom (attr:GeomStyle) ((x,y):Point) (geom:Geom) : String = + -- For things that are solid. SVG says they have fill=stroke. + solidAttr = setAt #fillColor (getAt #strokeColor attr) attr + groupEle = \attr. tagBracketsAttr "g" (attrString attr) case geom of PointGeom -> pointAttr = setAt #fillColor (getAt #strokeColor attr) attr - groupEle pointAttr $ selfClosingBrackets $ + groupEle solidAttr $ selfClosingBrackets $ ("circle" <+> "cx" <=> x <.> "cy" <=> y <.> @@ -161,6 +168,14 @@ def renderGeom (attr:GeomStyle) ((x,y):Point) (geom:Geom) : String = "height" <=> h <.> "x" <=> (x - (w/2.0)) <.> "y" <=> (y - (h/2.0))) + Text content -> + textEle = tagBracketsAttr "text" $ + ("x" <=> x <.> + "y" <=> y <.> + "text-anchor" <=> "middle" <.> -- horizontal center + "dominant-baseline" <=> "middle" -- vertical center + ) + groupEle solidAttr $ textEle content BoundingBox : Type = (Point & Point) @@ -185,11 +200,24 @@ def renderSVG (d:Diagram) (bounds:BoundingBox) : String = moveX : Float -> Diagram -> Diagram = \x. moveXY (x, 0.0) moveY : Float -> Diagram -> Diagram = \y. moveXY (0.0, y) --- mydiagram : Diagram = --- ( (circle 7.0 |> moveXY (20.0, 20.0) |> setFillColor blue |> setStrokeColor red) --- <> (circle 5.0 |> moveXY (40.0, 40.0)) --- <> (rect 10.0 20.0 |> moveXY (5.0, 10.0) |> setStrokeColor red) --- <> (pointDiagram |> moveXY (15.0, 5.0) |> setStrokeColor red) --- ) - --- :html renderSVG mydiagram ((0.0, 0.0), (100.0, 50.0)) +' A Demo showing all kind of features +``` +mydiagram : Diagram = + ( (circle 7.0 |> moveXY (20.0, 20.0) |> setFillColor blue |> setStrokeColor red) + <> (circle 5.0 |> moveXY (40.0, 40.0)) + <> (rect 10.0 20.0 |> moveXY (5.0, 10.0) |> setStrokeColor red) + <> (text "DexLang" |> moveXY (30.0, 10.0) |> setStrokeColor green) + <> (pointDiagram |> moveXY (15.0, 5.0) |> setStrokeColor red) + ) +:html renderSVG mydiagram ((0.0, 0.0), (100.0, 50.0)) +``` + +' Another demo that shows things are all center aligned: +``` +concentricDiagram : Diagram = ( + (rect 2.0 2.0 |> setFillColor red) + <> (circle 1.0 |> setFillColor blue) + <> (text "DexLang" |> setStrokeColor white) +) |> moveXY (5.0, 5.0) +:html renderSVG concentricDiagram ((0.0, 0.0), (10.0, 10.0)) +``` From f46109301c9148e0becd8d3a9108d4602c462dfc Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Fri, 8 Jan 2021 13:30:18 -0500 Subject: [PATCH 099/105] Changes suggested in review. --- src/lib/Imp.hs | 7 ++----- tests/adt-tests.dx | 2 +- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index e34529f49..c264c5b2b 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -839,13 +839,12 @@ splitDest (maybeDest, (Block decls ans)) = do -- This is conservative, in case the type is dependent. We could do better. (DataConRef _ _ _, DataCon _ _ _ _) -> tell [(dest, result)] -- This is conservative. Without it, we hit bugs like #348 - (Con (ConRef (SumAsProd _ _ _)), Con (SumAsProd _ _ _)) -> tell [(dest, result)] + (Con (ConRef (SumAsProd _ _ _)), _) -> tell [(dest, result)] (Con (ConRef destCon), Con srcCon) -> zipWithRefConM gatherVarDests destCon srcCon (Con (RecordRef items), Record items') | fmap (const ()) items == fmap (const ()) items' -> do zipWithM_ gatherVarDests (toList items) (toList items') - (Con (ConRef (SumAsProd _ _ _)), _) -> tell [(dest, result)] -- TODO (_, ProjectElt _ _) -> tell [(dest, result)] -- TODO: is this reasonable? _ -> unreachable where @@ -931,8 +930,7 @@ toScalarType b = BaseTy b fromEmbed :: Subst a => Embed a -> ImpM a fromEmbed m = do scope <- variableScope - let (ans, (scopeDelta, decls)) = runEmbed m scope - extend $ mempty { envScope = scopeDelta } + let (ans, (_, decls)) = runEmbed m scope env <- catFoldM translateDecl mempty $ fmap (Nothing,) decls impSubst env ans @@ -1238,7 +1236,6 @@ instrTypeChecked instr = case instr of _ -> throw CompilerErr $ "Can't cast " ++ pprint st ++ " to " ++ pprint dt return dt - Alloc a ty n -> (:[]) <$> do checkIdxRep n when (a /= Stack) assertHost diff --git a/tests/adt-tests.dx b/tests/adt-tests.dx index 683a46228..08a28c6d0 100644 --- a/tests/adt-tests.dx +++ b/tests/adt-tests.dx @@ -311,7 +311,7 @@ data MySum2 = :p concat for i:(Fin 4). AsList _ [(Foo2, Foo2)] > (AsList 4 [(Foo2, Foo2), (Foo2, Foo2), (Foo2, Foo2), (Foo2, Foo2)]) --- reproducer for a shadowing bug +-- reproducer for a shadowing bug (PR #440) :p concat $ for i:(Fin 2). toList [(Just [0,0,0], Just [0,0,0]), (Just [0,0,0], Just [0,0,0])] > (AsList 4 [ ((Just [0, 0, 0]), (Just [0, 0, 0])) From f713b3d2ceaae4438e745906bdb0714ddb006a31 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Fri, 8 Jan 2021 13:44:21 -0500 Subject: [PATCH 100/105] Add `make watch` command for automatic rebuilding. (#442) `make watch` invokes `stack build $(STACK_FLAGS) --file-watch`, which watches for source file changes and automatically rebuilds. --- makefile | 3 +++ 1 file changed, 3 insertions(+) diff --git a/makefile b/makefile index fdbb5bb3a..d1b74519d 100644 --- a/makefile +++ b/makefile @@ -60,6 +60,9 @@ tc: dexrt-llvm build: dexrt-llvm $(STACK) build $(STACK_FLAGS) +watch: dexrt-llvm + $(STACK) build $(STACK_FLAGS) --file-watch + install: dexrt-llvm $(STACK) install $(STACK_BIN_PATH) --flag dex:optimized $(STACK_FLAGS) From 0dab3fa21fe6211977477b78381a65612a447e90 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sat, 9 Jan 2021 00:01:41 -0500 Subject: [PATCH 101/105] Remove extra CSS bottom padding in HTML-rendered Dex. (#445) The bottom padding adds unnecessary empty bottom scrolling, which slightly hurts UX. --- static/style.css | 1 - 1 file changed, 1 deletion(-) diff --git a/static/style.css b/static/style.css index 77a7ce208..f978675d4 100644 --- a/static/style.css +++ b/static/style.css @@ -11,7 +11,6 @@ body { font-family: Helvetica, sans-serif; font-size: 100%; color: #333; - padding-bottom: 500px; } .cell { From c3f6a10b111bee0bb19cd7ece231e92f539928b2 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sat, 9 Jan 2021 00:02:58 -0500 Subject: [PATCH 102/105] Add LaTeX rendering for HTML-rendered Dex via KaTeX. (#444) This involves only web frontend changes via KaTeX CSS and JS: https://katex.org. I chose KaTeX over MathJax because KaTeX seems more modern, performant, and prettier. --- examples/latex.dx | 41 +++++++++++++++++++++++++++++++++++++++++ examples/pi.dx | 12 ++++++++++++ static/dynamic.js | 14 ++++++++++++++ static/index.html | 5 +++++ 4 files changed, 72 insertions(+) create mode 100644 examples/latex.dx diff --git a/examples/latex.dx b/examples/latex.dx new file mode 100644 index 000000000..fcb0dc07b --- /dev/null +++ b/examples/latex.dx @@ -0,0 +1,41 @@ +'# $\href{https://katex.org/}{\KaTeX}$ rendering examples + +'This document demonstrates $\KaTeX$ rendering in literate Dex programs. + +'## Random examples + +'$$\text{This is a multiline equation:} \\\\ \textbf{A}\textbf{x} = \textbf{b}$$ + +'$$f(\relax{x}) = \int_{-\infty}^\infty \hat{f}(\xi)\,e^{2 \pi i \xi x} \,d\xi$$ + +'## [Environments](https://katex.org/docs/supported.html#environments) + +'$$\begin{matrix} a & b \\\\ c & d \end{matrix}$$ + +'$$\begin{pmatrix} a & b \\\\ c & d \end{pmatrix}$$ + +'$$\begin{bmatrix} a & b \\\\ c & d \end{bmatrix}$$ + +'$$\def\arraystretch{1.5} \begin{array}{c:c:c} a & b & c \\\\ \hline d & e & f \\\\ \hdashline g & h & i \end{array}$$ + +'$$\begin{aligned} a&=b+c \\\\ d+e&=f \end{aligned}$$ + +'$$\begin{alignedat}{2} 10&x+ &3&y = 2 \\\\ 3&x+&13&y = 4 \end{alignedat}$$ + +'$$\begin{gathered} a=b \\\\ e=b+c \end{gathered}$$ + +'$$x = \begin{cases} a &\text{if } b \\\\ c &\text{if } d \end{cases}$$ + +'$$\begin{rcases} a &\text{if } b \\\\ c &\text{if } d \end{rcases} \Rightarrow \dots$$ + +'## [Layout annotation](https://katex.org/docs/supported.html#annotation) + +'$$\overbrace{a+b+c}^{\text{note}}$$ + +'$$\underbrace{a+b+c}_{\text{note}}$$ + +'$$\xcancel{\text{second-order array combinators}}$$ + +'## [Logic and Set Theory](https://katex.org/docs/supported.html#logic-and-set-theory) + +'$$\begin{aligned} \forall \\; & \texttt{\textbackslash forall} & \complement \\; & \texttt{\textbackslash complement} & \therefore \\; & \texttt{\textbackslash therefore} & \emptyset \\; & \texttt{\textbackslash emptyset} \\\\ \exists \\; & \texttt{\textbackslash exists} & \subset \\; & \texttt{\textbackslash subset} & \because \\; & \texttt{\textbackslash because} & \empty \\; & \texttt{\textbackslash empty} \\\\ \exist \\; & \texttt{\textbackslash exist} & \supset \\; & \texttt{\textbackslash supset} & \mapsto \\; & \texttt{\textbackslash mapsto} & \varnothing \\; & \texttt{\textbackslash varnothing} \\\\ \nexists \\; & \texttt{\textbackslash nexists} & \mid \\; & \texttt{\textbackslash mid} & \to \\; & \texttt{\textbackslash to} & \implies \\; & \texttt{\textbackslash implies} \\\\ \in \\; & \texttt{\textbackslash in} & \land \\; & \texttt{\textbackslash land} & \gets \\; & \texttt{\textbackslash gets} & \impliedby \\; & \texttt{\textbackslash impliedby} \\\\ \isin \\; & \texttt{\textbackslash isin} & \lor \\; & \texttt{\textbackslash lor} & \leftrightarrow \\; & \texttt{\textbackslash leftrightarrow} & \iff \\; & \texttt{\textbackslash iff} \\\\ \notin \\; & \texttt{\textbackslash notin} & \ni \\; & \texttt{\textbackslash ni} & \notni \\; & \texttt{\textbackslash notni} & \neg \\; & \texttt{\textbackslash neg} \\\\ \lnot \\; & \texttt{\textbackslash lnot} \\\\ \end{aligned}$$ diff --git a/examples/pi.dx b/examples/pi.dx index c8cc314e7..ef8175b34 100644 --- a/examples/pi.dx +++ b/examples/pi.dx @@ -1,5 +1,17 @@ '# Monte Carlo estimates of pi +'Consider the unit circle centered at the origin. + +'Consider the first quadrant: the unit circle quadrant and its $1 \times 1$ bounding unit square. + +'$$\text{Area of unit circle quadrant: } \\\\ A_{quadrant} = \frac{\pi r^2}{4} = \frac{\pi}{4}$$ + +'$$\text{Area of unit square: } \\\\ A_{square} = 1$$ + +'$$\text{Compute } \pi \text{ via ratios: } \\\\ \frac{A_{quadrant}}{A_{square}} = \frac{\pi}{4}, \\; \pi = 4 \thinspace \frac{A_{quadrant}}{A_{square}} $$ + +'To compute $\pi$, randomly sample points in the first quadrant unit square to estimate the $\frac{A_{quadrant}}{A_{square}}$ ratio. Then, multiply by $4$. + def estimatePiArea (key:Key) : Float = [k1, k2] = splitKey key x = rand k1 diff --git a/static/dynamic.js b/static/dynamic.js index 41f5a8634..6e6efcb13 100644 --- a/static/dynamic.js +++ b/static/dynamic.js @@ -4,6 +4,18 @@ // license that can be found in the LICENSE file or at // https://developers.google.com/open-source/licenses/bsd +var katexOptions = { + delimiters: [ + {left: "$$", right: "$$", display: true}, + {left: "\\[", right: "\\]", display: true}, + {left: "$", right: "$", display: false}, + {left: "\\(", right: "\\)", display: false} + ], + // Enable commands that load resources or change HTML attributes + // (e.g. hyperlinks): https://katex.org/docs/security.html. + trust: true +}; + var cells = {}; function append_contents(key, contents) { @@ -65,4 +77,6 @@ source.onmessage = function(event) { } Object.assign(cells, new_cells); } + // Render LaTeX equations via KaTeX. + renderMathInElement(body, katexOptions); }; diff --git a/static/index.html b/static/index.html index 5084094db..d1774f2ec 100644 --- a/static/index.html +++ b/static/index.html @@ -4,7 +4,12 @@ Dex Output + + + + + From f7459d64c187e436d421f5fbdba5bc60197203b4 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sat, 9 Jan 2021 01:07:47 -0500 Subject: [PATCH 103/105] Fix stack build warning for dex executable target. (#446) Remove `-threaded` option to fix warning: ``` Warning: 'ghc-options: -threaded' has no effect for libraries. It should only be used for executables. ``` --- dex.cabal | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dex.cabal b/dex.cabal index 5a639a818..8c4f5acd7 100644 --- a/dex.cabal +++ b/dex.cabal @@ -85,7 +85,7 @@ executable dex default-language: Haskell2010 hs-source-dirs: src default-extensions: CPP, LambdaCase, BlockArguments - ghc-options: -threaded -optP-Wno-nonportable-include-path + ghc-options: -optP-Wno-nonportable-include-path if flag(optimized) ghc-options: -O3 else From f521e306328f74cf4a36cb80c1ccd56f84f16119 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sat, 9 Jan 2021 03:38:23 -0500 Subject: [PATCH 104/105] Edit tutorial by @srush. - Rebase on top of main branch. - Some code snippets do not yet compile due to syntax changes. This should be straightforward to fix. - Stylistic edits: use consistent terminology and voice. - Remove some second person "you" references. - Use consistent Dex and programming languages terminology and spellings. - Consistently use "array" everywhere: no mentions of "table" - Consistently use "`for` comprehension" - Consistent formatting, punctuation, heading casing (first word uppercase). --- examples/tutorial.dx | 424 +++++++++++++++++++++---------------------- 1 file changed, 208 insertions(+), 216 deletions(-) diff --git a/examples/tutorial.dx b/examples/tutorial.dx index 06622370d..2e73be438 100644 --- a/examples/tutorial.dx +++ b/examples/tutorial.dx @@ -1,104 +1,100 @@ '# Dex Tutorial +' Dex is a functional, statically typed language for array processing. There are + many tools for array processing, from high-level libraries like NumPy and + MATLAB to low-level languages like CUDA. Dex gives you many of the safety and + simplicity benefits of high-level array processing languages, without + requiring that users give up low-level control. -'Dex is a functional, statically typed language for array processing. -There are many tools for array processing from high-level libraries -like NumPy / MATLAB to low-level languages like CUDA. Dex gives you -many benefit of the safety and simplicity benefits of high-level array -processing languages, without requiring that you give up low-level -control. +'## Array comprehensions -'## Array Comprehensions - - -' Before getting into the details of the language, let us begin with -the most useful component of dex, the `for` builder. The best analogy -for this construct is list comprehensions in Python. For instance, in -Python we might write a list comprehension like: +' Before getting into language details, let us begin with the most useful + component of Dex, the `for` builder. The best analogy for this construct is + list comprehensions in Python. For instance, in Python, we might write a + list comprehension like: ' `x = [[1 for j in range(10)] for i in range(5)]` -' In Dex, this construct would be written as, +' In Dex, this construct would be written as: x = for i:(Fin 5). for j:(Fin 10). 1 - -' Once we have an variable we can print it `:p` +' Once we have an variable, we can print it `:p` :p x -' More interestingly, we can also see its type with `:t`. This type -signature tells us that `x` is a two-dimensional array, with first -dimension of size 5 and the second of size 10. +' More interestingly, we can also see its type with `:t`. This type tells us + that `x` is a two-dimensional array, whose first dimension has size 5 and + second dimension has size 10. :t x -' Once we have an array we can use it in new comprehensions. For example, - if say we want to add `5` to each element of the array. In Python, - you might write this as, +' Once we have an array, we can use it in new comprehensions. For example, + let's try to add `5` to each array element. In Python, one might write this as: ' `y = [[x[i][j] for j in range(10)] for i in range(5)]` ' Dex can do something similar. The main superficial difference is the - indexing syntax which uses `.` instead of brackets. + array indexing syntax, which uses `array.i` instead of square brackets for + subscripting. -y = for i:(Fin 5). for j:(Fin 10). x.i.j + 5 +y = for i:(Fin 5). for j:(Fin 10). x.i.j + 5 :p y -' However, we can make this expression nicer. Because `x` has a known type +' However, we can make this expression nicer. Because `x` has a known array type and `i` and `j` index into that type, Dex can infer the range of the loop. - That means that we can safely remove `Fin` statements and get the same result. + That means that we can safely remove the explicit `Fin` type annotations and + get the same result. -y' = for i. for j. x.i.j + 5 +y' = for i. for j. x.i.j + 5 +' We can further reduce this array by applying array reduction functions like + `sum`: -' We can further reduce this array by applying array functions such as `sum`. +z = for i. sum x.i -z = for i. sum x.i +' This style of using `for` to construct type-inferred arrays is central to what + makes Dex powerful. Let's consider another example. This one produces a list of + length 50 in Python. -' This style of using the `for` construct to infer the loop range is - central to what makes Dex powerful. Let's consider another example. - This one produces a list of length 50 in Python. - ' `x = [1 for j in range(10) for i in range(5)]` -' The analogous array construct in Dex is written in - the following form. This produces a one dimension - array of 50 elements. +' The analogous array construct in Dex is written in the following form. It + produces a one-dimensional array of 50 elements. -x2 = for (i, j): (Fin 5 & Fin 10). 1 +x2 = for (i, j): (Fin 5 & Fin 10). 1 +' As before, we can implement "adding 5" to this array using a `for` constructor, + enumerating over each of its elements: -' As before, we can modify this array through another `for` constructor, - which enumerates over each element of `x2`. Or by applying a function. - +y2 = for i. x2.i + 5 -y2 = for i. x2.i + 5 +' And we can apply array functions to the array: :p sum x2 -' But things start to get interesting when we consider the type of this array. - Unlike the Python example that produces a list of length 50. The - Dex array maintains the index type of its construction. In particular - the type of `x2` remembers the original ranges. +' But things start to get interesting when we consider the type of the array. + Unlike the Python example, which produces a list of length 50, the Dex array + Jmaintains the index type of its construction. In particular, the type of the + array remembers the original ranges. :t x2 +'## Typed indexing -'## Typed Indexing - -' The use of typed indices lets you do some really neat things, but it - also breaks some things in counterintuitive ways. Dex use the `.` - syntax for indexing. Critically though, cannot simply index with a - raw integer. +' The use of typed indices lets you do really neat things, but it also breaks + other things in counterintuitive ways. Dex uses the `.` syntax for array + indexing. Critically though, one cannot simply index an array with an integer + literal. r = x.3 -' Instead you need to cast your integer into the index type of the current - shape. This is done with the `@` operator. (If it helps, you can think of `a.i` - as passing index `i` to array `a` the same way `f x` passes arg `x` to function - `f`.) +' Instead, it is necessary to cast the integer into the index type of the + current shape. This type annotation is done with the `@` operator. (If it + helps, you can think of array indexing as function application: `a.i` applies + array `a` with index `i` just like how `f x` applies function `f` with + argument `x`.) :t x @@ -108,63 +104,58 @@ row = x.(3 @ Fin 5) :t row.(5 @ Fin 10) -' This is bit verbose, but you rarely need to do it in practice. Most of the - time, you index with the `for` construct which is able to infer the right indices. +' This explicit annotation is a bit verbose, but it is rarely necessary in + practice. Most of the time, the `for` construct can infer index types. That's why we didn't have to use any `@` symbols when constructing `y2` above. -' Similarly you can't use indices as integers as you might be used to. You need to - cast them out explicitly. - +' Similarly, you cannot use indices as integers as you might be used to. It is + necessary to explicitly annotate index types. x4 = for i:(Fin 5). for j:(Fin 10). i + j +x4 = for i:(Fin 5). for j:(Fin 10). (ordinal i) + (ordinal j) -x4 = for i:(Fin 5). for j:(Fin 10). (ordinal i) + (ordinal j) - - -' As we have seen though, indices do not need to just be integers. We can index with - many different Dex type. For instance `x2` was indexed with a pair of integers (`&` means tuple) - so we need to build a tuple in order to index. +' As we have seen, indices are not limited to only integers. Many different Dex + types are valid index types. For example, we declared array `x2` as having a + pair of integers as its index type (`a & b` means tuple type), so indexing + into `x2` requires creating a tuple value (via `(x, y)`). :t x2 :t x2.(3@Fin 5, 5@Fin 10) -' A lot of algorithms in Dex come down to being able to pack and - unpack these indices. For example, we have seen that it is easy to - sum over one dimension of a 2D array. However, if we have a 1D - array indexed by a pair, we can easily turn it into a 2D array by - constructing it. +' Many algorithms in Dex come down to being able to pack and unpack these + indices. For example, we have seen that it is easy to sum over one dimension + of a 2-D array. However, if we have a 1D array indexed by a pair, we can + easily turn it into a 2D array using two `for` constructors. x3 = for i. for j. x2.(i, j) :t x3 -' Again we rely on type inference in order to avoid explicitly giving -the ranges. +' Again, we rely on type inference in order to avoid explicitly spelling the + ranges. -' ## Functions over Arrays +' ## Defining functions over arrays -' One use case of packing and unpacking array indices is that - it allows us to change the order of the axes. This is useful for - applying functions on arrays. +' One use case of packing and unpacking array indices is that it allows us to + change the order of the axes. This is useful for applying functions on arrays. -' For instance, we saw the `sum` function above which sums over an - axes. We can apply `sum` to `x2` to produce the sum over 50 elements. +' For instance, we saw the `sum` function above which sums over the first axis + of an array. We can apply `sum` to `x2` to produce the sum over 50 elements: :t x2 :p sum x2 -' Alternatively we can apply sum over `x3` to produce the sum over rows. +' Alternatively, we can apply sum over `x3` to produce the sum over rows: :t x3 :p sum x3 -' How do we sum over the columns? In systems like NumPy you would - do this by passing an axis argument to `sum`. Dex doesn't work this - way. To sum over columns, you need to move columns to the front - of the line. Luckily, you already know how to do this. - +' How do we sum over the columns of `x3`? In systems like NumPy, you would do + this by passing an axis argument to `sum`. Dex doesn't work this way. To sum + over columns, you need to move columns to the front of the line. Luckily, we + already know how to do this: using `for` constructors! :t x3 @@ -174,12 +165,11 @@ trans = for j. for i. x3.i.j :p sum trans -' The `sum` function seems to work independently of the index type of the - array. +' The `sum` function works independently of the index type of the array. -' Let's see how we can do this with our own functions. To define a function in - Dex we use the following syntax (there are other ways to do it, but this - one looks pretty close to Python.) +' Let's see how we can define our own array functions. Defining a function in + Dex uses the following syntax. (There are other ways to do it, but this one + looks closest to Python.) def add5 (x:Int32) : Int32 = x + 5 @@ -187,28 +177,25 @@ def add5 (x:Int32) : Int32 = x + 5 :t for i. add5 x2.i - -' We can also write functions with type variables over their inputs. For instance - we if we want to be able to `Add5` to any array. This function binds the type - variable `n` to the index type of the array. - +' We can also write functions with type variables over their inputs. For + instance, we may want to be able to write a function that applies "adds 5" + to arrays with _any_ index type. This is possible by declaring an `n => Int32` + array argument type: this declares the type variable `n` as the index type of + the array argument. def arrAdd5 (x : n => Int32) : n => Int32 = for i. x.i + 5 - -:t arrAdd5 x2 +:t arrAdd5 x2 -' But the function types can help you out even more. - For instance, because index types are sized, you - can use type inference to ensure the arguments passed in - are valid. +' But function types can help you out even more. For instance, since index types + are statically known, type checking can ensure that array arguments have valid + dimensions. This is "shape safety". -' For instance, let's say we want to add two array together. +' For instance, let's write a function adding two 2D arrays with the same shape: :t x :t y - def arrAdd (x : m => n => Int32) (y : m => n => Int32) : m => n => Int32 = for i. for j. x.i.j + y.i.j @@ -216,192 +203,207 @@ def arrAdd (x : m => n => Int32) (y : m => n => Int32) : m => n => Int32 = :t arrAdd x (trans y) -' Here the system type checked for us that they are the same size. +' The type system checked for us that input arrays indeed have the same shape. +'## Writing loops -'## Writing Loops +' Dex is a functional language - but when writing mathematical algorithms, + it is often convenient to temporarily put aside immutability and write + imperative code using mutation. -' Dex is a functional language, but when writing mathematical algorithm - it is often convenient to ignore that fact and write imperative code. +' For example, let's say we want to actually implement the `sum` function + ourselves by accumulating summed values in-place. In Python, implementing this + is not directly possible solely via list comprehensions, so we would write a + loop. -' For example, lets say we now want to actually write the `sum` function - ourselves by accumulating summed values. In Python, We can't do this directly - with list comprehensions, so we would write a loop. - ' `acc = 0` ' `for i in range(10):` -' `acc = acc + x[i]` +' `acc = acc + x[i]` -' Variables are immutable in Dex, so we cannot do this directly. But we can - write very similar code using the `state` effect. Here's what it looks like - with the corresponding Python code. - +' In Dex, values are immutable, so we cannot directly perform mutation. But Dex + includes algebraic effects, which are a purely-functional way to modeling + side-effects like mutation. We can write code that looks like mutation using + the `State` effect, which provides getter and setter functionality (via `get` + and `:=` assignment). Here's what it looks like: -def arrSum (x : a => Int32) : Int32 = +def arrSum (x : n => Int32) : Int32 = -- acc = 0 - initAcc = 0 + init = 0 -- (ignore for now) - snd $ withState initAcc $ \acc. + snd $ withState init $ \acc. -- for i in range for i. -- acc = acc + x[i] acc := (get acc) + x.i - -:p arrSum x2 +:p arrSum x2 -' So even though we are functional, the loop looks quite - similar to the imperative style. However there is one - line which is quite new and a bit scary. Let us look - into that line in a bit more detail. +' So, even though Dex is a functional language, it is possible to write loops + that look similar to ones that truly perform mutation. However, there is one + line which is quite new and a bit scary. Let us look into that line in a bit + more detail. -' First `$`. This symbol is used in Dex the same way it is - used in Haskell, but if you have haven't seen it before it - is a bit strange. It basically takes the place of parens `( )` - when you are too lazy to write them. For example, these two are the same: +' First: `$`. This symbol is used in Dex just like it is used in Haskell, but + if you haven't seen it before, it seems a bit strange. `$` is the function + application operator: it basically replaces of expression-grouping parentheses + `(f x)` when it is inconvenient to write them. For example, the following two + expressions are identical: :t arrSum (x2 + x2) :t arrSum $ x2 + x2 -' Next `\`. This symbol is the lambda operator in Dex. It makes a function - that you can use right away, and behaves like `lambda` in python. - Here the function takes an argument `acc` and returns the expression below (a `for` constructor). +' Next: `\`. This symbol is the lambda sigil in Dex. It is analogous to the + `lambda` keyword in Python, and starts the definition of a function value + (i.e. closure). In `arrSum` above: the lambda takes an argument named `acc` + and returns the body, which is the expression following the `.` (a `for` + constructor in this case). -' Finally, the function `snd` is from the prelude. It returns the second of a pair, nothing fancy. +' Finally, the function `snd` is from the Dex Prelude. It simply returns the + second element of a pair - there is also `fst` for extracting the first + element. :p fst (1, 2) :p snd (1, 2) - -' That leaves `withState`. This function allows you to introduce imperative variables into the computation. - It takes a intial values `initAcc` and a function of a reference to that value `\acc.` It then returns - a pair of the result of that function and the final value. Here's a simple example +' That leaves: `withState`. This function uses the `State` effect, enabling us + to introduce imperative variables into the computation. + `withState` takes an initial value `init` and a body function taking a + "mutable value" reference (`acc` here), and returns a pair of the body + function's result and the final value. Here's a simple example: :p withState 10 $ \ state. state := 30 20 -' The first element in the pair is the function return (`20`) and the second is the final value of the variable (`30`). - -' Finally this is a good point to talk a bit about some of the other operators in Dex. - Here we see two types of equal signs `=` and `:=`. The first is the `let` operator that makes an - immutable assignment. This one is built into the language and can be used anywhere you want. +' The first element of the returned pair is the body function's result (`20`). + The second element is the final value of the variable (`30`). +' Finally: this is a good point to talk a bit about some other operators in Dex. + In the examples above, we see two types of equal sign operators: `=` and `:=`. + The first is the `let` operator that creates an immutable assignment (a + "let-binding"). This one is built into the language and can be used anywhere. q = for i:(Fin 10). - -- Bind a temp variable for some reason + -- Bind a temporary variable `temp`, as an example. temp = (ordinal i) + 10 for j:(Fin 5). temp - -' The other is `:=` which can only be used inside of a `withState` block. It assigns - a value to a mutable reference. To read that value you need to use the `get` function. - or wait until the `withState` returns. +' The other is `:=`, which is an assignment operator that can only be used + when a `State` effect is available (e.g. inside of a body function passed to + `withState`. `ref := x` assigns the value `x` to the mutable reference `ref`. + Reading the value in `ref` is possible via the `get` function. or via using + the final result returned by `withState`. -'## Type Classes +'## Typeclasses -' Our arrSum function is pretty neat. It lets us work with any type index - to compute the sum. However, it annoyingly only works for integers. +' Our `arrSum` function is pretty neat. It lets us work with arrays with any + index type and computes the sum. However, `arrSum` explicitly takes only + integer arrays (of type `n => Int32`). :t arrSum -' If we apply it to floats we get the following error. +' If we try to apply `arrSum` to float arrays, we get the following error: arrSum for i : (Fin 5). 10.0 -' We can compare the type of our sum to the built-in Dex `sum`. +' We can compare the type of our `arrSum` function to the `sum` function found + in the Dex Prelude. :t sum -' It has another type variable `v` for the output. It also has the extra annotation - `(Add v) ?=>`. This is a constraint that tells us that `v` can be any type in the - `Add` type class. +' The Prelude-defined `sum` function also has an additional argument, spelled + like: `(Add v) ?=> ...`. This is a constraint telling us that the function + expects an `Add v` typeclass instance, where `v` is any type that implements + the `Add` typeclass. -' If we wanted to, we could look in the Dex prelude to see what this looks like. But we can - probably guess what it means. `v` needs to be something where `add` works on it. - We can do that! Let's define our own type class. +' We could look in the Dex Prelude to see exactly how `sum` is defined and what + `Add v` means. But we can guess what the `Add v` constraint means: `v` needs + to be a type that works with `add`. We can do that! Let's define our own + typeclass. interface MyAdd a:Type where myAdd : a -> a -> a myZero : a -' This tells us that to be in the `MyAdd` type class, a type `a` needs to have - a function `myAdd` and `myZero`. A type can then join the class like this. - +' This declares a typeclass (i.e. interface or trait) called `MyAdd` with some + typeclass methods (interface requirements). To implement the `MyAdd` + typeclass, a type `a` needs to define functions `myAdd` and `myZero` in a + "typeclass instance", like so: instance int32MyAdd : MyAdd Int32 where myAdd = \x y. x + y myZero = 0 instance float32MyAdd : MyAdd Float32 where - myAdd = \x y. (x + y) + myAdd = \x y. (x + y) myZero = 0.0 -' Once we have these two definitions, we can revisit our sum function. Here is how we modify - the type. +' Once we have these two instance definitions (`MyAdd Int32` and + `MyAdd Float32`), we can revisit our array sum function and add a typeclass + constraint: -def arrSum2 (_:MyAdd v) ?=> (x : a => v) : v = +def arrSumGeneric (_:MyAdd v) ?=> (x : a => v) : v = snd $ withState myZero $ \acc. for i. acc := myAdd (get acc) x.i -arrSum2 for i : (Fin 5). 10 -arrSum2 for i : (Fin 5). 10.0 +arrSumGeneric for i : (Fin 5). 10 +arrSumGeneric for i : (Fin 5). 10.0 -arrSum2 $ for i : (Fin 5). - for j : (Fin 10). 10.0 +arrSumGeneric $ for i : (Fin 5). + for j : (Fin 10). 10.0 -' So it works for ints and it works for floats. But it failed when we tried to - pass in a 2D array. What went wrong? The error tells us that it can't produce - a class dictionary for `MyAdd ((Fin 10) => Float32)`. This makes sense as - we have have not written one. We need to tell the system how to add columns. - -' If we want, we can take the type checker literally and make this instance :). +' This sum function works for any type that implements `MyAdd`, like `Int32` and + `Float32`. But it failed when we tried to pass in a 2D array. What went wrong? + The error tells us that the function could not find a `MyAdd` instance for + `MyAdd ((Fin 10) => Float32)`. This makes sense because we have have not + written one. We need to tell the system "how to add array columns". +' One option is to directly satisfy the type checker and provide a specific + `MyAdd ((Fin 10) => Float32)` instance: instance specMyAdd : MyAdd ((Fin 10) => Float32) where - myAdd = \x y. for i: (Fin 10). (x.i + y.i) + myAdd = \x y. for i: (Fin 10). (x.i + y.i) myZero = for i: (Fin 10). 0.0 -arrSum2 $ for i : (Fin 5). - for j : (Fin 10). 10.0 - +arrSumGeneric $ for i : (Fin 5). + for j : (Fin 10). 10.0 -' Or we can treat it a more generally and extend to all 1D arrays. +' To be more general, we can instead define a `MyAdd` instance for all array + types. This instance requires that the array element type `v` also has an + `MyAdd` instance; this requirement is represented as a `(MyAdd v) ?=> ...` + constraint. instance arrMyAdd : (MyAdd v) ?=> MyAdd (a => v) where - myAdd = \x y. for i. (myAdd x.i y.i) + myAdd = \x y. for i. (myAdd x.i y.i) myZero = for i. myZero -arrSum2 $ for i : (Fin 5). - for j : (Fin 9). 10.0 +arrSumGeneric $ for i : (Fin 5). + for j : (Fin 9). 10.0 +' This instance not only works for 2D arrays, but also 3D and higher-dimensional + arrays: -' This now works for 3D arrays too. - -arrSum2 $ for i : (Fin 5). +arrSumGeneric $ for i : (Fin 5). for j : (Fin 9). for k : (Fin 9). 10.0 - - -' ## Prelude Practice +' ## Learn the Prelude -' There are a bunch of goodies implemented in the prelude - that are worth knowing. It's good practice just to - infer what these functions do from their type. - -' Here are a couple that come up a lot. +' The Prelude contains many handy functions. Since Dex types contain so much + information, it is possible to infer what many of these functions do just by + reading and understanding their type. +' Here are a few used, commonly-used Prelude functions. ' * `select` for filtering -:t select +:t select select True 1 2 select False 1 2 @@ -413,23 +415,21 @@ select False 1 2 myzero1 : (Fin 20 & Fin 10) => Float32 = zero myzero2 : (Fin 20) => (Fin 10) => Float32 = zero -' * `zip` for creating tables of pairs +' * `zip` for creating arrays of pairs :t zip :t zip x x :t for i. zip x.i x.i - -' * `iota` for create aranges +' * `iota` for creating "aranges" :t iota :p (iota (Fin 10)) :p for i. for j. (iota (Fin 4 & Fin 4)).(i, j) - -' * Random numbers +' * Pseudorandom number generation :t newKey :t splitKey @@ -440,55 +440,47 @@ key = newKey 0 :p randn key1 -' * `randVec` creates a random vector +' * `randVec` for creating a vector of random numbers - -randv = randVec 20 randn key2 +randv = randVec 20 randn key2 :t randv randv2 = randVec 20 randInt key3 :t randv2 +'## Worked examples: Project Euler -'## Worked Examples: Project Euler - -' To demonstrate Dex in practice, here are some example - functions solving problems on https://projecteuler.net/ +' To demonstrate Dex in practice, below are some examples of solving problems + from [Project Euler](https://projecteuler.net). - def ignore (y:a) (x : Maybe a) : a = case x of Just x -> x Nothing -> y - + instance maybeAdd : (Add v) ?=> Add (Maybe v) where add = \x y. Just $ ignore zero x + ignore zero y sub = \x y. Just $ ignore zero x - ignore zero y zero = Just zero - ' ### Problem 1: Find the sum of all the multiples of 3 or 5 below 1000. - - prob1 = for i : (Fin 1000). i' = ordinal i case ((i' `mod` 3) == 0 || (i' `mod` 5) == 0) of True -> Just i' False -> Nothing - + :p fromJust $ sum prob1 ' ### Problem 2: By considering the terms in the Fibonacci sequence whose values do not exceed four million, find the sum of the even-valued terms. - ... - -- def maybeList (x : Maybe a) : List a = -- case x of -- Just a -> AsList 1 $ for i : (Fin 1). a -- Nothing -> mempty -- def remMaybe (x: n => Maybe a) : List a = --- concat $ for i. maybeList x.i +-- concat $ for i. maybeList x.i \ No newline at end of file From 1288790aac9760a14ed16c0948fb6b7ed98d7f81 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sat, 9 Jan 2021 09:25:58 -0500 Subject: [PATCH 105/105] Consistently spell "n-D" as "nD". --- examples/tutorial.dx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/tutorial.dx b/examples/tutorial.dx index 2e73be438..2c5b54042 100644 --- a/examples/tutorial.dx +++ b/examples/tutorial.dx @@ -126,7 +126,7 @@ x4 = for i:(Fin 5). for j:(Fin 10). (ordinal i) + (ordinal j) ' Many algorithms in Dex come down to being able to pack and unpack these indices. For example, we have seen that it is easy to sum over one dimension - of a 2-D array. However, if we have a 1D array indexed by a pair, we can + of a 2D array. However, if we have a 1D array indexed by a pair, we can easily turn it into a 2D array using two `for` constructors. x3 = for i. for j. x2.(i, j) @@ -483,4 +483,4 @@ prob1 = for i : (Fin 1000). -- Nothing -> mempty -- def remMaybe (x: n => Maybe a) : List a = --- concat $ for i. maybeList x.i \ No newline at end of file +-- concat $ for i. maybeList x.i