diff --git a/README.md b/README.md index 47e3e9f12..2bdca7943 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ or these example programs: * [Ray tracer](https://google-research.github.io/dex-lang/examples/raytrace.html) * [Estimating pi](https://google-research.github.io/dex-lang/examples/pi.html) * [Hamiltonian Monte Carlo](https://google-research.github.io/dex-lang/examples/mcmc.html) - * [ODE integrator](https://google-research.github.io/dex-lang/oexamples/de-integrator.html) + * [ODE integrator](https://google-research.github.io/dex-lang/examples/ode-integrator.html) * [Sierpinski triangle](https://google-research.github.io/dex-lang/examples/sierpinski.html) * [Basis function regression](https://google-research.github.io/dex-lang/examples/regression.html) * [Brownian bridge](https://google-research.github.io/dex-lang/examples/brownian_motion.html) @@ -56,10 +56,10 @@ development mode: ```console # Linux: -alias dex="stack exec dex --" +alias dex="stack exec dex -- --lib-path lib" # macOS: -alias dex="stack exec --stack-yaml=stack-macos.yaml dex --" +alias dex="stack exec --stack-yaml=stack-macos.yaml dex -- --lib-path lib" ``` ## Running diff --git a/dex.cabal b/dex.cabal index 8c4f5acd7..7a3ffbb4b 100644 --- a/dex.cabal +++ b/dex.cabal @@ -13,6 +13,8 @@ maintainer: dougalm@google.com license-file: LICENSE build-type: Simple +data-files: lib/*.dx + flag cuda description: Enables building with CUDA support default: False @@ -31,9 +33,9 @@ library dex-resources library exposed-modules: Env, Syntax, Type, Inference, JIT, LLVMExec, - Parser, Util, Imp, Imp.Embed, Imp.Optimize, + Parser, Util, Imp, Imp.Builder, Imp.Optimize, PPrint, Algebra, Parallelize, Optimize, Serialize - Actor, Cat, Embed, Export, + Actor, Builder, Cat, Export, RenderHtml, LiveOutput, Simplify, TopLevel, Autodiff, Interpreter, Logging, CUDA, LLVM.JIT, LLVM.Shims @@ -49,6 +51,7 @@ library store, -- Notebook support warp, wai, blaze-html, aeson, http-types, cmark, binary + other-modules: Paths_dex if !os(darwin) exposed-modules: Resources hs-source-dirs: src/resources diff --git a/examples/brownian_motion.dx b/examples/brownian_motion.dx index 9f9456291..3297bd243 100644 --- a/examples/brownian_motion.dx +++ b/examples/brownian_motion.dx @@ -1,5 +1,5 @@ -include "plot.dx" +import plot UnitInterval = Float diff --git a/examples/fluidsim.dx b/examples/fluidsim.dx index e1ea7e9ac..1fa0c6d57 100644 --- a/examples/fluidsim.dx +++ b/examples/fluidsim.dx @@ -2,7 +2,7 @@ Fluid simulation code based on [Real-Time Fluid Dynamics for Games](https://www.josstam.com/publications) by Jos Stam -include "plot.dx" +import plot def wrapidx (n:Type) -> (i:Int) : n = -- Index wrapping around at ends. diff --git a/examples/isomorphisms.dx b/examples/isomorphisms.dx index 9668eac42..102f644d0 100644 --- a/examples/isomorphisms.dx +++ b/examples/isomorphisms.dx @@ -46,7 +46,7 @@ that produce isos. We will start with the first two: :t #b : Iso {a:Int & b:Float & c:Unit} _ > (Iso {a: Int32 & b: Float32 & c: Unit} (Float32 & {a: Int32 & c: Unit})) > === parse === -> _ans_ = +> _ans_ = > MkIso {bwd = \(x, r). {b = x, ...r}, fwd = \{b = x, ...r}. (,) x r} > : Iso {a: Int & b: Float & c: Unit} _ @@ -54,7 +54,7 @@ that produce isos. We will start with the first two: :t #?b : Iso {a:Int | b:Float | c:Unit} _ > (Iso {a: Int32 | b: Float32 | c: Unit} (Float32 | {a: Int32 | c: Unit})) > === parse === -> _ans_ = +> _ans_ = > MkIso > { bwd = \v. case v > ((Left x)) -> {| b = x |} @@ -142,7 +142,7 @@ another. For instance: > ({ &} & {a: Int32 & b: Float32 & c: Unit}) > ({a: Int32} & {b: Float32 & c: Unit})) > === parse === -> _ans_ = +> _ans_ = > MkIso > { bwd = \({a = x, ...l}, {, ...r}). (,) {, ...l} {a = x, ...r} > , fwd = \({, ...l}, {a = x, ...r}). (,) {a = x, ...l} {, ...r}} @@ -212,7 +212,7 @@ zipper isomorphisms: > ({ |} | {a: Int32 | b: Float32 | c: Unit}) > ({a: Int32} | {b: Float32 | c: Unit})) > === parse === -> _ans_ = +> _ans_ = > MkIso > { bwd = \v. case v > ((Left w)) -> (case w diff --git a/examples/mandelbrot.dx b/examples/mandelbrot.dx index 468ef17dd..32d5471ee 100644 --- a/examples/mandelbrot.dx +++ b/examples/mandelbrot.dx @@ -1,6 +1,6 @@ '# Mandelbrot set -include "plot.dx" +import plot 'Escape time algorithm diff --git a/examples/mcmc.dx b/examples/mcmc.dx index d205cdbd7..6676c913c 100644 --- a/examples/mcmc.dx +++ b/examples/mcmc.dx @@ -2,7 +2,7 @@ '## General MCMC utilities -include "plot.dx" +import plot LogProb : Type = Float diff --git a/examples/ode-integrator.dx b/examples/ode-integrator.dx index 53e568d5a..297cd8500 100644 --- a/examples/ode-integrator.dx +++ b/examples/ode-integrator.dx @@ -4,7 +4,7 @@ This version is a port of the [Jax implementation](https://github.com/google/jax One difference is that it uses a lower-triangular matrix type for the Butcher tableau, and so avoids zero-padding everywhere. -include "plot.dx" +import plot Time = Float diff --git a/examples/raytrace.dx b/examples/raytrace.dx index 051722f56..b0a7e97f6 100644 --- a/examples/raytrace.dx +++ b/examples/raytrace.dx @@ -5,7 +5,7 @@ described [here](https://blog.evjang.com/2019/11/jaxpt.html). -include "plot.dx" +import plot ' ### Generic Helper Functions Some of these should probably go in prelude. @@ -212,7 +212,7 @@ def sampleLightRadiance (surfNor, surf) = osurf (rayPos, _) = inRay (MkScene objs) = scene - yieldAccum \radiance. + yieldAccum (AddMonoid Float) \radiance. for i. case objs.i of PassiveObject _ _ -> () Light lightPos hw _ -> @@ -227,7 +227,7 @@ def sampleLightRadiance def trace (params:Params) (scene:Scene n) (initRay:Ray) (k:Key) : Color = noFilter = [1.0, 1.0, 1.0] - yieldAccum \radiance. + yieldAccum (AddMonoid Float) \radiance. runState noFilter \filter. runState initRay \ray. boundedIter (getAt #maxBounces params) () \i. diff --git a/examples/regression.dx b/examples/regression.dx index ca8ce2731..65f9f080b 100644 --- a/examples/regression.dx +++ b/examples/regression.dx @@ -1,6 +1,6 @@ '# Basis function regression -include "plot.dx" +import plot -- Conjugate gradients solver def solve (mat:m=>m=>Float) (b:m=>Float) : m=>Float = diff --git a/examples/sierpinski.dx b/examples/sierpinski.dx index 64a8f8aea..f007355a3 100644 --- a/examples/sierpinski.dx +++ b/examples/sierpinski.dx @@ -1,6 +1,6 @@ '# Sierpinski triangle ("chaos game") -include "plot.dx" +import plot def update (points:n=>Point) (key:Key) ((x,y):Point) : Point = (x', y') = points.(randIdx key) diff --git a/examples/tiled-matmul.dx b/examples/tiled-matmul.dx index 7238d671e..677a80746 100644 --- a/examples/tiled-matmul.dx +++ b/examples/tiled-matmul.dx @@ -16,7 +16,7 @@ def matmul (k : Type) ?-> (n : Type) ?-> (m : Type) ?-> vectorTile = Fin VectorWidth colTile = (colVectors & vectorTile) (tile2d (\nt:(Tile n rowTile). \mt:(Tile m colTile). - ct = yieldAccum \acc. + ct = yieldAccum (AddMonoid Float) \acc. for l:k. for i:rowTile. ail = broadcastVector a.(nt +> i).l diff --git a/lib/diagram.dx b/lib/diagram.dx index bbfff2ef4..c97579700 100644 --- a/lib/diagram.dx +++ b/lib/diagram.dx @@ -1,5 +1,7 @@ '# Vector graphics library +import png + Point : Type = (Float & Float) data Geom = diff --git a/lib/plot.dx b/lib/plot.dx index 56c36f647..bb981a293 100644 --- a/lib/plot.dx +++ b/lib/plot.dx @@ -1,7 +1,7 @@ '# Plotting library -include "diagram.dx" -include "png.dx" +import diagram +import png data CompactSet a:Type = Interval a a diff --git a/lib/prelude.dx b/lib/prelude.dx index 596f9ad3f..2302d4fcf 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -2,7 +2,8 @@ 'Runs before every Dex program unless an alternative is provided with `--prelude`. -'## Wrappers around primitives +'## Essentials +### Primitive Types Unit = %UnitType Type = %TyKind @@ -23,6 +24,8 @@ RawPtr : Type = %Word8Ptr Int = Int32 Float = Float32 +'### Casting + def internalCast (b:Type) (x:a) : b = %cast b x def F64ToF (x : Float64) : Float = internalCast _ x @@ -40,6 +43,11 @@ def FToI (x:Float) : Int = internalCast _ x def I64ToRawPtr (x:Int64 ) : RawPtr = internalCast _ x def RawPtrToI64 (x:RawPtr) : Int64 = internalCast _ x +'### Basic Arithmetic +#### Add +Things that can be added. +This defines the `Add` [group](https://en.wikipedia.org/wiki/Group_(mathematics)) and its operators. + interface Add a add : a -> a -> a sub : a -> a -> a @@ -83,6 +91,10 @@ instance [Add a] Add (n=>a) sub = \xs ys. for i. xs.i - ys.i zero = for _. zero +'#### Mul +Things that can be multiplied. +This defines the `Mul` [Monoid](https://en.wikipedia.org/wiki/Monoid), and its operator. + interface Mul a mul : a -> a -> a one : a @@ -113,6 +125,12 @@ instance Mul Unit mul = \x y. () one = () +instance [Mul a] Mul (n=>a) + mul = \xs ys. for i. xs.i * ys.i + one = for _. one + +'#### Integral +Integer-like things. interface Integral a idiv : a->a->a @@ -130,6 +148,9 @@ instance Integral Word8 idiv = \x y. %idiv x y rem = \x y. %irem x y +'#### Fractional +Rational-like things. +Includes floating point and two field rational representations. interface Fractional a divide : a -> a -> a @@ -199,6 +220,9 @@ def not (x:Bool) : Bool = W8ToB $ %not x' '## Sum types +A [sum type, or tagged union](https://en.wikipedia.org/wiki/Tagged_union) can hold values from a fixed set of types, distinguished by tags. +For those familiar with the C language, they can be though of as a combination of an `enum` with a `union`. +Here we define several basic kinds, and some operators on them. data Maybe a = Nothing @@ -210,10 +234,18 @@ def isNothing (x:Maybe a) : Bool = case x of def isJust (x:Maybe a) : Bool = not $ isNothing x +def maybe (d: b) (f : (a -> b)) (x: Maybe a) : b = + case x of + Nothing -> d + Just x' -> f x' + data (|) a b = Left a Right b +'## More Boolean operations +TODO: move these with the others? + def select (p:Bool) (x:a) (y:a) : a = case p of True -> x False -> y @@ -221,13 +253,79 @@ def select (p:Bool) (x:a) (y:a) : a = case p of def BToI (x:Bool) : Int = W8ToI $ BToW8 x def BToF (x:Bool) : Float = IToF (BToI x) +'## Ordering +TODO: move this down to with `Ord`? + +data Ordering = + LT + EQ + GT + +def OToW8 (x : Ordering) : Word8 = %dataConTag x + +'### Monoid +A [monoid](https://en.wikipedia.org/wiki/Monoid) is a things that have an associative binary operator and an identity element. +This is a very useful and general calls of things. +It includes: + - Addition and Multiplication of Numbers + - Boolean Logic + - Concatenation of Lists (including strings) +Monoids support `fold` operations, and similar. + +interface Monoid a + mempty : a + mcombine : a -> a -> a -- can't use `<>` just for parser reasons? + +def (<>) [Monoid a] : a -> a -> a = mcombine + +instance [Monoid a] Monoid (n=>a) + mempty = for i. mempty + mcombine = \x y. for i. mcombine x.i y.i + +named-instance AndMonoid : Monoid Bool + mempty = True + mcombine = (&&) + +named-instance OrMonoid : Monoid Bool + mempty = False + mcombine = (||) + +def AddMonoid (a:Type) -> (_:Add a) ?=> : Monoid a = + A = a -- XXX: Typing `Monoid a` below would quantify it over a, which we don't want + named-instance result : Monoid A + mempty = zero + mcombine = add + result + +def MulMonoid (a:Type) -> (_:Mul a) ?=> : Monoid a = + A = a -- XXX: Typing `Monoid a` below would quantify it over a, which we don't want + named-instance result : Monoid A + mempty = one + mcombine = mul + result + '## Effects def Ref (r:Type) (a:Type) : Type = %Ref r a def get (ref:Ref h s) : {State h} s = %get ref def (:=) (ref:Ref h s) (x:s) : {State h} Unit = %put ref x + def ask (ref:Ref h r) : {Read h} r = %ask ref -def (+=) (ref:Ref h w) (x:w) : {Accum h} Unit = %tell ref x + +data AccumMonoid h w = UnsafeMkAccumMonoid (Monoid w) + +@instance +def tableAccumMonoid ((UnsafeMkAccumMonoid m):AccumMonoid h w) ?=> : AccumMonoid h (n=>w) = + %instance mHint = m + def liftTableMonoid (tm:Monoid (n=>w)) ?=> : Monoid (n=>w) = tm + UnsafeMkAccumMonoid liftTableMonoid + +def (+=) (am:AccumMonoid h w) ?=> (ref:Ref h w) (x:w) : {Accum h} Unit = + (UnsafeMkAccumMonoid m) = am + %instance mHint = m + updater = \v. mcombine v x + %mextend ref updater + def (!) (ref:Ref h (n=>a)) (i:n) : Ref h a = %indexRef ref i def fstRef (ref: Ref h (a & b)) : Ref h a = %fstRef ref def sndRef (ref: Ref h (a & b)) : Ref h b = %sndRef ref @@ -245,16 +343,29 @@ def withReader : {|eff} a = runReader init action +def MonoidLifter (b:Type) (w:Type) : Type = h:Type -> AccumMonoid h b ?=> AccumMonoid h w + def runAccum - (action: (h:Type ?-> Ref h w -> {Accum h|eff} a)) + (mlift:MonoidLifter b w) ?=> + (bm:Monoid b) + (action: (h:Type ?-> AccumMonoid h b ?=> Ref h w -> {Accum h|eff} a)) : {|eff} (a & w) = - def explicitAction (h':Type) (ref:Ref h' w) : {Accum h'|eff} a = action ref - %runWriter explicitAction + -- Normally, only the ?=> lambda binders participate in dictionary synthesis, + -- so we need to explicitly declare `m` as a hint. + %instance bmHint = bm + empty : b = mempty + combine : b -> b -> b = mcombine + def explicitAction (h':Type) (ref:Ref h' w) : {Accum h'|eff} a = + %instance accumBaseMonoidHint : AccumMonoid h' b = UnsafeMkAccumMonoid bm + action ref + %runWriter empty combine explicitAction def yieldAccum - (action: (h:Type ?-> Ref h w -> {Accum h|eff} a)) + (mlift:MonoidLifter b w) ?=> + (m:Monoid b) + (action: (h:Type ?-> AccumMonoid h b ?=> Ref h w -> {Accum h|eff} a)) : {|eff} w = - snd $ runAccum action + snd $ runAccum m action def runState (init:s) @@ -281,11 +392,24 @@ def unreachable (():Unit) : a = unsafeIO do '## Type classes +'### Eq and Ord + +'#### Eq +Equatable. +Things that we can tell if they are equal or not to other things. + interface Eq a (==) : a -> a -> Bool def (/=) [Eq a] (x:a) (y:a) : Bool = not $ x == y +'#### Ord +Orderable / Comparable. +Things that can be place in a total order. +i.e. things that can be compared to other things to find if larger, smaller or equal in value. + +'We take the standard false-hood and pretend that this applies to Floats, even though strictly speaking this not true as our floats follow [IEEE754](https://en.wikipedia.org/wiki/IEEE_754), and thus have `NaN < 1.0 == false` and `1.0 < NaN == false`. + interface [Eq a] Ord a (>) : a -> a -> Bool (<) : a -> a -> Bool @@ -348,15 +472,53 @@ instance [Ord a, Ord b] Ord (a & b) (>) = \(x1,x2) (y1,y2). x1 > y1 || (x1 == y1 && x2 > y2) (<) = \(x1,x2) (y1,y2). x1 < y1 || (x1 == y1 && x2 < y2) --- TODO: accumulate using the True/&& monoid +instance Eq Ordering + (==) = \x y. OToW8 x == OToW8 y + +def scan (init:a) (body:n->a->(a&b)) : (a & n=>b) = + swap $ runState init \s. for i. + c = get s + (c', y) = body i c + s := c' + y + +def fold (init:a) (body:(n->a->a)) : a = fst $ scan init \i x. (body i x, ()) + +def compare [Ord a] (x:a) (y:a) : Ordering = + if x < y + then LT + else if x == y + then EQ + else GT + +instance Monoid Ordering + mempty = EQ + mcombine = \x y. + case x of + LT -> LT + GT -> GT + EQ -> y + instance [Eq a] Eq (n=>a) (==) = \xs ys. - numDifferent : Float = - yieldAccum \ref. for i. - ref += (IToF (BToI (xs.i /= ys.i))) - numDifferent == 0.0 - -'## Transcencendental functions + yieldAccum AndMonoid \ref. + for i. ref += xs.i == ys.i + +instance [Ord a] Ord (n=>a) + (>) = \xs ys. + f: Ordering = + fold EQ $ \i c. c <> (compare xs.i ys.i) + f == GT + (<) = \xs ys. + f: Ordering = + fold EQ $ \i c. c <> (compare xs.i ys.i) + f == LT + +'## Elementary/Special Functions +This is more or less the standard [LibM fare](https://en.wikipedia.org/wiki/C_mathematical_functions). +Roughly it lines up with some definitions of the set of [Elementary](https://en.wikipedia.org/wiki/Elementary_function) and/or [Special](https://en.wikipedia.org/wiki/Special_functions). +In truth, nothing is elementary or special except that we humans have decided it is. +Many, but not all of these functions are [Transcendental](https://en.wikipedia.org/wiki/Transcendental_function). interface Floating a exp : a -> a @@ -537,16 +699,39 @@ def sq [Mul a] (x:a) : a = x * x def abs [Add a, Ord a] (x:a) : a = select (x > zero) x (zero - x) def mod (x:Int) (y:Int) : Int = rem (y + rem x y) y +'## Table Operations + +instance [Floating a] Floating (n=>a) + exp = map exp + exp2 = map exp2 + log = map log + log2 = map log2 + log10 = map log10 + log1p = map log1p + sin = map sin + cos = map cos + tan = map tan + sinh = map sinh + cosh = map cosh + tanh = map tanh + floor = map floor + ceil = map ceil + round = map round + sqrt = map sqrt + pow = \x y. for i. pow x.i y.i + lgamma = map lgamma + +'### Axis Restructuring + +def axis1 (x : a => b => c) : b => a => c = for j. for i. x.i.j +def axis2 (x : a => b => c => d) : c => a => b => d = for k. for i. for j. x.i.j.k + + def reindex (ixr: b -> a) (tab: a=>v) : b=>v = for i. tab.(ixr i) -def scan (init:a) (body:n->a->(a&b)) : (a & n=>b) = - swap $ runState init \s. for i. - c = get s - (c', y) = body i c - s := c' - y -def fold (init:a) (body:(n->a->a)) : a = fst $ scan init \i x. (body i x, ()) +'### Reductions + def reduce (identity:a) (combine:(a->a->a)) (xs:n=>a) : a = -- `combine` should be a commutative and associative, and form a -- commutative monoid with `identity` @@ -556,18 +741,22 @@ def reduce (identity:a) (combine:(a->a->a)) (xs:n=>a) : a = -- TODO: call this `scan` and call the current `scan` something else def scan' (init:a) (body:n->a->a) : n=>a = snd $ scan init \i x. dup (body i x) -- TODO: allow tables-via-lambda and get rid of this -def fsum (xs:n=>Float) : Float = yieldAccum \ref. for i. ref += xs i +def fsum (xs:n=>Float) : Float = yieldAccum (AddMonoid Float) \ref. for i. ref += xs i def sum [Add v] (xs:n=>v) : v = reduce zero (+) xs def prod [Mul v] (xs:n=>v) : v = reduce one (*) xs -def mean (xs:n=>Float) : Float = sum xs / IToF (size n) -def std (xs:n=>Float) : Float = sqrt $ mean (map sq xs) - sq (mean xs) +def mean [VSpace v] (xs:n=>v) : v = sum xs / IToF (size n) +def std [Mul v, VSpace v, Floating v] (xs:n=>v) : v = sqrt $ mean (map sq xs) - sq (mean xs) def any (xs:n=>Bool) : Bool = reduce False (||) xs def all (xs:n=>Bool) : Bool = reduce True (&&) xs +'### ApplyN + def applyN (n:Int) (x:a) (f:a -> a) : a = yieldState x \ref. for _:(Fin n). ref := f (get ref) +'### Linear Algebra + def linspace (n:Type) (low:Float) (high:Float) : n=>Float = dx = (high - low) / IToF (size n) for i:n. low + IToF (ordinal i) * dx @@ -590,6 +779,11 @@ def eye [Eq n] : n=>n=>Float = for i j. select (i == j) 1.0 0.0 '## Pseudorandom number generator utilities +Dex does not use a stateful random number generator. +Rather it uses what is known as a split-able random number generator, which is based on a hash function. +Dex's PRNG system is modelled directly after [JAX's](https://github.com/google/jax/blob/master/design_notes/prng.md), which is based on a well established but shockingly underused idea from the functional programming community: the splittable PRNG. It's a good idea for many reasons, but it's especially helpful in a parallel setting. If you want to read more, [Splittable pseudorandom number generators using cryptographic hashing](http://publications.lib.chalmers.se/records/fulltext/183348/local_183348.pdf) describes the splitting model itself and [D.E. Shaw Research's counter-based PRNG](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf) proposes the particular hash function we use. + +'### Key functions -- TODO: newtype Key = Int64 @@ -602,6 +796,11 @@ def many (f:Key->a) (k:Key) (i:n) : a = f (hash k (ordinal i)) def ixkey (k:Key) (i:n) : Key = hash k (ordinal i) def ixkey2 (k:Key) (i:n) (j:m) : Key = hash (hash k (ordinal i)) (ordinal j) def splitKey (k:Key) : Fin n => Key = for i. ixkey k i + +'### Sample Generators +These functions generate samples taken from, different distributions. +Such as `randMat` with samples from the distribution of floating point matrices where each element is taken from a i.i.d. uniform distribution. + def rand (k:Key) : Float = unsafeIO do F64ToF $ %ffi randunif Float64 k def randVec (n:Int) (f: Key -> a) (k: Key) : Fin n => a = for i:(Fin n). f (ixkey k i) @@ -623,6 +822,10 @@ def bern (p:Float) (k:Key) : Bool = rand k < p def randnVec (k:Key) : n=>Float = for i. randn (ixkey k i) +'## cumSum +TODO: Move this to be with reductions? +It's a kind of `scan`. + def cumSum (xs: n=>Float) : n=>Float = withState 0.0 \total. for i. @@ -632,6 +835,8 @@ def cumSum (xs: n=>Float) : n=>Float = '## Automatic differentiation +'### AD operations + -- TODO: add vector space constraints def linearize (f:a->b) (x:a) : (b & a --o b) = %linearize f x def jvp (f:a->b) (x:a) : a --o b = snd (linearize f x) @@ -647,6 +852,9 @@ def deriv (f:Float->Float) (x:Float) : Float = jvp f x 1.0 def derivRev (f:Float->Float) (x:Float) : Float = snd (vjp f x) 1.0 +'### Approximate Equality +TODO: move this outside the AD section to be with equality? + interface HasAllClose a allclose : a -> a -> a -> a -> Bool @@ -677,6 +885,7 @@ instance [HasDefaultTolerance t] HasDefaultTolerance (n=>t) atol = for i. atol rtol = for i. rtol +'### AD Checking tools def checkDerivBase (f:Float->Float) (x:Float) : Bool = eps = 0.01 @@ -750,18 +959,10 @@ def tile1 (fTile : (t:(Tile n l) -> {|eff} m=>l=>a)) -- arr.(t +> UNSAFEFromOrdinal idx 2) -- arr.(t +> UNSAFEFromOrdinal idx 3)) -'## Monoid typeclass - -interface Monoid a - mempty : a - mcombine : a -> a -> a -- can't use `<>` just for parser reasons? - -def (<>) [Monoid a] : a -> a -> a = mcombine - '## Length-erased lists data List a = - AsList n:Int foo:(Fin n => a) + AsList n:Int elements:(Fin n => a) def unsafeCastTable (m:Type) (xs:n=>a) : m=>a = for i. xs.(unsafeFromOrdinal _ (ordinal i)) @@ -805,16 +1006,18 @@ def (&>>) (iso1: Iso a b) (iso2: Iso b c) : Iso a c = def (<<&) (iso2: Iso b c) (iso1: Iso a b) : Iso a c = iso1 &>> iso2 --- Lens-like accessors --- (note: #foo is an Iso {foo: a & ...b} (a & {&...b})) +'### Lens-like accessors +note: `#foo is an Iso {foo: a & ...b} (a & {&...b}))` + def getAt (iso: Iso a (b & c)) : a -> b = fst <<< appIso iso def popAt (iso: Iso a (b & c)) : a -> c = snd <<< appIso iso def pushAt (iso: Iso a (b & c)) (x:b) (r:c) : a = revIso iso (x, r) def setAt (iso: Iso a (b & c)) (x:b) (r:a) : a = pushAt iso x $ popAt iso r --- Prism-like accessors --- (note: #?foo is an Iso {foo: a | ...b} (a | {|...b})) +'### Prism-like accessors +note: `#?foo is an Iso {foo: a | ...b} (a | {|...b}))` + def matchWith (iso: Iso a (b | c)) (x: a) : Maybe b = case appIso iso x of Left x -> Just x @@ -824,7 +1027,9 @@ def buildWith (iso: Iso a (b | c)) (x: b) : a = revIso iso $ Left x swapPairIso : Iso (a & b) (b & a) = MkIso {fwd = \(a, b). (b, a), bwd = \(b, a). (a, b)} --- Complement the focus of a lens-like isomorphism +'### Complement lens +Complement the focus of a lens-like isomorphism + exceptLens : Iso a (b & c) -> Iso a (c & b) = \iso. iso &>> swapPairIso swapEitherIso : Iso (a | b) (b | a) = @@ -836,33 +1041,40 @@ swapEitherIso : Iso (a | b) (b | a) = Right l -> Left l MkIso {fwd, bwd} --- Complement the focus of a prism-like isomorphism +'### Complement prism +Complement the focus of a prism-like isomorphism + exceptPrism : Iso a (b | c) -> Iso a (c | b) = \iso. iso &>> swapEitherIso -- Use a lens-like iso to split a 1d table into a 2d table def overLens (iso: Iso a (b & c)) (tab: a=>v) : (b=>c=>v) = for i j. tab.(revIso iso (i, j)) --- Zipper isomorphisms to easily specify many record/variant fields: --- #&foo is an Iso ({&...l} & {foo:a & ...r}) ({foo:a & ...l} & {&...r}) --- #|foo is an Iso ({|...l} | {foo:a | ...r}) ({foo:a | ...l} | {|...r}) +'### Zipper +Zipper isomorphisms to easily specify many record/variant fields: +``` +#&foo is an Iso ({&...l} & {foo:a & ...r}) ({foo:a & ...l} & {&...r}) +#|foo is an Iso ({|...l} | {foo:a | ...r}) ({foo:a | ...l} | {|...r}) +``` + +' Convert a record zipper isomorphism to a normal lens-like isomorphism +by using `splitR &>> iso` --- Convert a record zipper isomorphism to a normal lens-like isomorphism --- by using splitR &>> iso splitR : Iso a ({&} & a) = MkIso {fwd=\x. ({}, x), bwd=\({}, x). x} def overFields (iso: Iso ({&} & a) (b & c)) (tab: a=>v) : b=>c=>v = overLens (splitR &>> iso) tab --- Convert a variant zipper isomorphism to a normal prism-like isomorphism --- by using splitV &>> iso +'Convert a variant zipper isomorphism to a normal prism-like isomorphism +by using `splitV &>> iso` + splitV : Iso a ({|} | a) = MkIso {fwd=\x. Right x, bwd=\v. case v of Right x -> x} def sliceFields (iso: Iso ({|} | a) (b | c)) (tab: a=>v) : b=>v = reindex (buildWith $ splitV &>> iso) tab -'Dynamic buffer +'## Dynamic buffer -- TODO: would be nice to be able to use records here data DynBuffer a = @@ -933,6 +1145,13 @@ def stringFromCharPtr (n:Int) (ptr:Ptr Char) : {IO} String = -- TODO. This is ASCII code point. It really should be Int32 for Unicode codepoint def codepoint (c:Char) : Int = W8ToI c +'### Show interface +For things that can be shown. +`show` gives a string representation of its input. +No particular promises are made to exactly what that representation will contain. +In particular it is **not** promised to be parseable. +Nor does it promise a particular level of precision for numeric values. + interface Show a show : a -> String @@ -959,10 +1178,13 @@ instance Show Float64 (n, ptr) = %ffi showFloat64 (Int32 & RawPtr) x stringFromCharPtr n $ MkPtr ptr --- pipe-like reverse function application +'## pipe-like reverse function application +TODO: move this + def (|>) (x:a) (f: a -> b) : b = f x '## Floating-point helper functions +TODO: Move these to be with Elementary/Special functions. Or move those to be here. def sign (x:Float) : Float = case x > 0.0 of @@ -989,7 +1211,7 @@ def isnan (x:Float) : Bool = not (x >= x && x <= x) -- Todo: use IEEE-754R 5.11: Floating Point Comparison Relation cmpUnordered. def either_is_nan (x:Float) (y:Float) : Bool = (isnan x) || (isnan y) -'File system operations +'## File system operations FilePath : Type = String data CString = MkCString RawPtr @@ -1016,6 +1238,8 @@ def withCString (s:String) (action: CString -> {IO} a) : {IO} a = (AsList n s') = s <> "\NUL" withTabPtr s' \(MkPtr ptr). action $ MkCString ptr +'### Stream IO + def fopen (path:String) (mode:StreamMode) : {IO} (Stream mode) = modeStr = case mode of ReadMode -> "r" @@ -1037,6 +1261,9 @@ def fwrite (stream:Stream WriteMode) (s:String) : {IO} Unit = %ffi fflush Int64 stream' () +'### Iteration +TODO: move this out of the file-system section + def while (body: Unit -> {|eff} Bool) : {|eff} Unit = body' : Unit -> {|eff} Word8 = \_. BToW8 $ body () %while body' @@ -1085,6 +1312,8 @@ def boundedIter (maxIters:Int) (fallback:a) then Done fallback else body i +'### Environment Variables + def fromCString (s:CString) : {IO} (Maybe String) = case cStringPtr s of Nothing -> Nothing @@ -1104,6 +1333,8 @@ def getEnv (name:String) : {IO} Maybe String = def checkEnv (name:String) : {IO} Bool = isJust $ getEnv name +'### More Stream IO + def fread (stream:Stream ReadMode) : {IO} String = (MkStream stream') = stream -- TODO: allow reading longer files! @@ -1119,6 +1350,8 @@ def fread (stream:Stream ReadMode) : {IO} String = else Done () loadDynBuffer buf +'### File Operations + def deleteFile (f:FilePath) : {IO} Unit = withCString f \(MkCString ptr). %ffi remove Int64 ptr @@ -1138,6 +1371,9 @@ def writeFile (f:FilePath) (s:String) : {IO} Unit = def readFile (f:FilePath) : {IO} String = withFile f ReadMode \stream. fread stream + +'### Temporary Files + def newTempFile (_:Unit) : {IO} FilePath = withCString "/tmp/dex-XXXXXX" \(MkCString ptr). fd = %ffi mkstemp Int32 ptr @@ -1156,12 +1392,16 @@ def withTempFiles (action: (n=>FilePath) -> {IO} a) : {IO} a = for i. deleteFile tmpFiles.i result +'### Print + def getOutputStream (_:Unit) : {IO} Stream WriteMode = MkStream $ %ptrLoad OUT_STREAM_PTR def print (s:String) : {IO} Unit = fwrite (getOutputStream ()) (s <> "\n") +'### Shelling Out + def shellOut (command:String) : {IO} String = modeStr = "r" withCString command \(MkCString commandPtr). @@ -1169,7 +1409,12 @@ def shellOut (command:String) : {IO} String = pipe = MkStream %ffi popen RawPtr commandPtr modePtr fread pipe -'Partial functions +'## Partial functions +A partial function in this context is a function that can error. +i.e. a function that is not actually defined for all of its supposed domain. +Not to be confused with a partially applied function + +'### Error throwing def error (s:String) : a = unsafeIO do print s @@ -1177,6 +1422,8 @@ def error (s:String) : a = unsafeIO do def todo : a = error "TODO: implement it!" +'### Table operations + def fromOrdinal (n:Type) (i:Int) : n = case (0 <= i) && (i < size n) of True -> unsafeFromOrdinal _ i @@ -1203,11 +1450,13 @@ def tail (xs:n=>a) (start:Int) : List a = numElts = size n - start toList $ slice xs start (Fin numElts) +--TODO: move this to be with other random number functions def randIdx (k:Key) : n = unif = rand k fromOrdinal n $ FToI $ floor $ unif * IToF (size n) -'Type class for generating example values +'## Arbitrary +Type class for generating example values interface Arbitrary a arb : Key -> a @@ -1221,12 +1470,20 @@ instance Arbitrary Int32 instance [Arbitrary a] Arbitrary (n=>a) arb = \key. for i. arb $ ixkey key i +instance [Arbitrary a, Arbitrary b] Arbitrary (a & b) + arb = \key. + [k1, k2] = splitKey key + (arb k1, arb k2) + instance Arbitrary (Fin n) arb = randIdx -'Control flow +'## Ord on Arrays + +'### Searching + +'returns the highest index `i` such that `xs.i <= x` --- returns the highest index `i` such that `xs.i <= x` def searchSorted [Ord a] (xs:n=>a) (x:a) : Maybe n = if size n == 0 then Nothing @@ -1243,7 +1500,7 @@ def searchSorted [Ord a] (xs:n=>a) (x:a) : Maybe n = else low := centerIx Continue -'min / max etc +'### min / max etc def minBy [Ord o] (f:a->o) (x:a) (y:a) : a = select (f x < f y) x y def maxBy [Ord o] (f:a->o) (x:a) (y:a) : a = select (f x > f y) x y @@ -1259,17 +1516,27 @@ def maximumBy [Ord o] (f:a->o) (xs:n=>a) : a = def minimum [Ord o] (xs:n=>o) : o = minimumBy id xs def maximum [Ord o] (xs:n=>o) : o = maximumBy id xs -def argmin [Ord o] (xs:n=>o) : n = +'### argmin/argmax +TODO: put in same section as `searchsorted` + +def argscan (comp:o->o->Bool) (xs:n=>o) : n = zeroth = (0@_, xs.(0@_)) compare = \(idx1, x1) (idx2, x2). - select (x1 < x2) (idx1, x1) (idx2, x2) + select (comp x1 x2) (idx1, x1) (idx2, x2) zipped = for i. (i, xs.i) fst $ reduce zeroth compare zipped +def argmin [Ord o] (xs:n=>o) : n = argscan (<) xs +def argmax [Ord o] (xs:n=>o) : n = argscan (>) xs + +'### clip + def clip [Ord a] ((low,high):(a&a)) (x:a) : a = min high $ max low x '## Trigonometric functions +TODO: these should be with the other Elementary/Special Functions +### atan/atan2 def atan_inner (x:Float) : Float = -- From "Computing accurate Horner form approximations to @@ -1437,6 +1704,7 @@ def (.|.) (x:Byte) (y:Byte) : Byte = %or x y def (.&.) (x:Byte) (y:Byte) : Byte = %and x y '## Miscellaneous utilities +TODO: all of these should be in some other section def reverse (x:n=>a) : n=>a = s = size n @@ -1492,6 +1760,9 @@ def concat (lists:n=>(List a)) : List a = eltIdx := eltIdxVal + 1 xs.(eltIdxVal@_) + +'## Probability + def cumSumLow (xs: n=>Float) : n=>Float = withState 0.0 \total. for i. @@ -1520,8 +1791,6 @@ def categoricalBatch (logprobs: n=>Float) (key: Key) : m=>n = cdf = cdfForCategorical logprobs for i. categoricalFromCDF cdf $ ixkey key i -'Numerical utilities - def logsumexp (x: n=>Float) : Float = m = maximum x m + (log $ sum for i. exp (x.i - m)) @@ -1536,13 +1805,20 @@ def softmax (x: n=>Float) : n=>Float = s = sum e for i. e.i / s +'## Polynomials +TODO: Move this somewhere else + def evalpoly [VSpace v] (coefficients:n=>v) (x:Float) : v = -- Evaluate a polynomial at x. Same as Numpy's polyval. fold zero \i c. coefficients.i + x .* c +'## TestMode +TODO: move this to be in Testing Helpers + def dex_test_mode (():Unit) : Bool = unsafeIO do checkEnv "DEX_TEST_MODE" '## Exception effect +TODO: move `error` and `todo` to here. def catch (f:Unit -> {Except|eff} a) : {|eff} Maybe a = %catchException f @@ -1552,3 +1828,13 @@ def throw (_:Unit) : {Except} a = def assert (b:Bool) : {Except} Unit = if not b then throw () + + +'## Testing Helpers + +-- -- Reliably causes a segfault if pointers aren't initialized to zero. +-- -- TODO: add this test when we cache modules +-- justSomeDataToTestCaching = toList for i:(Fin 100). +-- if ordinal i == 0 +-- then Left (toList [1,2,3]) +-- else Right 1 diff --git a/makefile b/makefile index d1b74519d..13157df5e 100644 --- a/makefile +++ b/makefile @@ -72,7 +72,7 @@ build-prof: dexrt-llvm # For some reason stack fails to detect modifications to foreign library files build-python: dexrt-llvm $(STACK) build $(STACK_FLAGS) --force-dirty - $(eval STACK_INSTALL_DIR=$(shell stack path --local-install-root)) + $(eval STACK_INSTALL_DIR=$(shell stack $(STACK_FLAGS) path --local-install-root)) cp $(STACK_INSTALL_DIR)/lib/libDex.so python/dex/ build-ci: dexrt-llvm diff --git a/misc/dex.el b/misc/dex.el index c7bf86dad..6533fa456 100644 --- a/misc/dex.el +++ b/misc/dex.el @@ -10,7 +10,7 @@ ("^'\\(.\\|\n.\\)*\n\n" . font-lock-comment-face) ("\\w+:" . font-lock-comment-face) ("^:\\w*" . font-lock-preprocessor-face) - ("\\bdef\\b\\|\\bfor\\b\\|\\brof\\b\\|\\bcase\\b\\|\\bdata\\b\\|\\bwhere\\b\\|\\bof\\b\\|\\bif\\b\\|\\bthen\\b\\|\\belse\\b\\|\\binterface\\b\\|\\binstance\\b\\|\\bdo\\b\\|\\bview\\b" . + ("\\bdef\\b\\|\\bfor\\b\\|\\brof\\b\\|\\bcase\\b\\|\\bdata\\b\\|\\bwhere\\b\\|\\bof\\b\\|\\bif\\b\\|\\bthen\\b\\|\\belse\\b\\|\\binterface\\b\\|\\binstance\\b\\|\\bdo\\b\\|\\bview\\b\\|\\bimport\\b" . font-lock-keyword-face) ("--o" . font-lock-variable-name-face) ("[-.,!$^&*:~+/=<>|?\\\\]" . font-lock-variable-name-face) diff --git a/src/Dex/Foreign/Context.hs b/src/Dex/Foreign/Context.hs index 7a0e3cbb1..bb78f0f49 100644 --- a/src/Dex/Foreign/Context.hs +++ b/src/Dex/Foreign/Context.hs @@ -42,7 +42,7 @@ setError msg = withCStringLen msg $ \(ptr, len) -> dexCreateContext :: IO (Ptr Context) dexCreateContext = do - let evalConfig = EvalConfig LLVM Nothing + let evalConfig = EvalConfig LLVM Nothing Nothing maybePreludeEnv <- evalPrelude evalConfig preludeSource case maybePreludeEnv of Right preludeEnv -> toStablePtr $ Context evalConfig preludeEnv @@ -77,8 +77,8 @@ dexInsert ctxPtr namePtr atomPtr = do Context evalConfig env <- fromStablePtr ctxPtr name <- GlobalName . fromString <$> peekCString namePtr atom <- fromStablePtr atomPtr - let env' = env <> name @> (getType atom, LetBound PlainLet (Atom atom)) - toStablePtr $ Context evalConfig env' + let newBinding = name @> (getType atom, LetBound PlainLet (Atom atom)) + toStablePtr $ Context evalConfig $ env <> TopEnv newBinding mempty dexEvalExpr :: Ptr Context -> CString -> IO (Ptr Atom) dexEvalExpr ctxPtr sourcePtr = do @@ -88,10 +88,11 @@ dexEvalExpr ctxPtr sourcePtr = do Right expr -> do let (v, m) = exprAsModule expr let block = SourceBlock 0 0 LogNothing source (RunModule m) Nothing - (resultEnv, Result [] maybeErr) <- evalSourceBlock evalConfig env block + (resultEnv, Result [] maybeErr) <- + evalSourceBlock evalConfig env block case maybeErr of Right () -> do - let (_, LetBound _ (Atom atom)) = resultEnv ! v + let (_, LetBound _ (Atom atom)) = topBindings resultEnv ! v toStablePtr atom Left err -> setError (pprint err) $> nullPtr Left err -> setError (pprint err) $> nullPtr @@ -100,7 +101,7 @@ dexLookup :: Ptr Context -> CString -> IO (Ptr Atom) dexLookup ctxPtr namePtr = do Context _ env <- fromStablePtr ctxPtr name <- peekCString namePtr - case envLookup env (GlobalName $ fromString name) of + case envLookup (topBindings env) (GlobalName $ fromString name) of Just (_, LetBound _ (Atom atom)) -> toStablePtr atom Just _ -> setError "Looking up an expression" $> nullPtr Nothing -> setError "Unbound name" $> nullPtr diff --git a/src/Dex/Foreign/JIT.hs b/src/Dex/Foreign/JIT.hs index d40a4b4a0..362664e81 100644 --- a/src/Dex/Foreign/JIT.hs +++ b/src/Dex/Foreign/JIT.hs @@ -37,6 +37,7 @@ import LLVMExec import JIT import Syntax hiding (sizeOf) import Export +import TopLevel import Dex.Foreign.Util import Dex.Foreign.Context @@ -82,7 +83,8 @@ dexCompile jitPtr ctxPtr funcAtomPtr = do ForeignJIT{..} <- fromStablePtr jitPtr Context _ env <- fromStablePtr ctxPtr funcAtom <- fromStablePtr funcAtomPtr - let (impMod, nativeSignature) = prepareFunctionForExport env "userFunc" funcAtom + let (impMod, nativeSignature) = prepareFunctionForExport + (topBindings env) "userFunc" funcAtom nativeModule <- execLogger Nothing $ \logger -> do llvmAST <- impToLLVM logger impMod LLVM.JIT.compileModule jit llvmAST diff --git a/src/dex.hs b/src/dex.hs index d6102f503..9c0070186 100644 --- a/src/dex.hs +++ b/src/dex.hs @@ -45,7 +45,7 @@ runMode evalMode preludeFile opts = do key <- case preludeFile of Nothing -> return $ show curResourceVersion -- memoizeFileEval already checks compiler version Just path -> show <$> getModificationTime path - env <- cached "prelude" key $ evalPrelude opts preludeFile + env <- cachedWithSnapshot "prelude" key $ evalPrelude opts preludeFile let runEnv m = evalStateT m env case evalMode of ReplMode prompt -> do @@ -67,7 +67,7 @@ runMode evalMode preludeFile opts = do let exportedFuns = foldMap (\case (ExportedFun name f) -> [(name, f)]; _ -> []) outputs unless (backendName opts == LLVM) $ liftEitherIO $ throw CompilerErr "Export only supported with the LLVM CPU backend" - exportFunctions objPath exportedFuns env + exportFunctions objPath exportedFuns $ topBindings env evalPrelude :: EvalConfig -> Maybe FilePath -> IO TopEnv evalPrelude opts fname = flip execStateT initTopEnv $ do @@ -89,14 +89,12 @@ replLoop prompt opts = do dexCompletions :: CompletionFunc (StateT TopEnv IO) dexCompletions (line, _) = do env <- get - let varNames = map pprint $ envNames env + let varNames = map pprint $ envNames $ topBindings env -- note: line and thus word and rest have character order reversed let (word, rest) = break (== ' ') line - let anywhereKeywords = ["def", "for", "rof", "case", "data", "where", "of", "if", - "then", "else", "interface", "instance", "do", "view"] let startoflineKeywords = ["%bench \"", ":p", ":t", ":html", ":export"] let candidates = (if null rest then startoflineKeywords else []) ++ - anywhereKeywords ++ varNames + keyWordStrs ++ varNames let completions = map simpleCompletion $ filter (reverse word `isPrefixOf`) candidates return (rest, completions) @@ -183,6 +181,7 @@ parseEvalOpts = EvalConfig , ("interpreter", Interpreter)]) (long "backend" <> value LLVM <> helpOption "Backend" "llvm (default) | llvm-cuda | llvm-mc | interpreter") + <*> optional (strOption $ long "lib-path" <> metavar "PATH" <> help "Library path") <*> optional (strOption $ long "logto" <> metavar "FILE" <> help "File to log to" <> showDefault) diff --git a/src/lib/Algebra.hs b/src/lib/Algebra.hs index bf8a3a87a..41a1eb4b0 100644 --- a/src/lib/Algebra.hs +++ b/src/lib/Algebra.hs @@ -20,8 +20,8 @@ import Data.Coerce import Env import Syntax import PPrint -import Embed ( MonadEmbed, iadd, imul, idiv, clampPositive, ptrOffset - , indexToIntE, indexSetSizeE ) +import Builder ( MonadBuilder, iadd, imul, idiv, clampPositive, ptrOffset + , indexToIntE, indexSetSizeE ) -- MVar is like Var, but it additionally defines Ord. The invariant here is that the variables -- should never be shadowing, and so it is sufficient to only use the name for equality and @@ -50,7 +50,7 @@ type ClampPolynomial = PolynomialP ClampMonomial data SumPolynomial = SumPolynomial Polynomial Var deriving (Show, Eq) data SumClampPolynomial = SumClampPolynomial ClampPolynomial Var deriving (Show, Eq) -applyIdxs :: MonadEmbed m => Atom -> IndexStructure -> m Atom +applyIdxs :: MonadBuilder m => Atom -> IndexStructure -> m Atom applyIdxs ptr Empty = return ptr applyIdxs ptr idxs@(Nest ~(Bind i) rest) = do ordinal <- indexToIntE $ Var i @@ -58,10 +58,10 @@ applyIdxs ptr idxs@(Nest ~(Bind i) rest) = do ptr' <- ptrOffset ptr offset applyIdxs ptr' rest -offsetToE :: MonadEmbed m => IndexStructure -> Atom -> m Atom +offsetToE :: MonadBuilder m => IndexStructure -> Atom -> m Atom offsetToE idxs i = evalSumClampPolynomial (offsets idxs) i -elemCountE :: MonadEmbed m => IndexStructure -> m Atom +elemCountE :: MonadBuilder m => IndexStructure -> m Atom elemCountE idxs = case idxs of Empty -> return $ IdxRepVal 1 Nest b _ -> offsetToE idxs =<< indexSetSizeE (binderType b) @@ -124,12 +124,12 @@ toPolynomial atom = case atom of fromInt i = poly [((fromIntegral i) % 1, mono [])] unreachable = error $ "Unsupported or invalid atom in index set: " ++ pprint atom --- === Embedding === +-- === Building === -_evalClampPolynomial :: MonadEmbed m => ClampPolynomial -> m Atom +_evalClampPolynomial :: MonadBuilder m => ClampPolynomial -> m Atom _evalClampPolynomial cp = evalPolynomialP (evalClampMonomial Var) cp -evalSumClampPolynomial :: MonadEmbed m => SumClampPolynomial -> Atom -> m Atom +evalSumClampPolynomial :: MonadBuilder m => SumClampPolynomial -> Atom -> m Atom evalSumClampPolynomial (SumClampPolynomial cp summedVar) a = evalPolynomialP (evalClampMonomial varVal) cp where varVal v = if MVar v == sumVar summedVar then a else Var v @@ -139,7 +139,7 @@ evalSumClampPolynomial (SumClampPolynomial cp summedVar) a = -- coefficients. This is why we have to find the least common multiples and do the -- accumulation over numbers multiplied by that LCM. We essentially do fixed point -- fractional math here. -evalPolynomialP :: MonadEmbed m => (mono -> m Atom) -> PolynomialP mono -> m Atom +evalPolynomialP :: MonadBuilder m => (mono -> m Atom) -> PolynomialP mono -> m Atom evalPolynomialP evalMono p = do let constLCM = asAtom $ foldl lcm 1 $ fmap (denominator . snd) $ toList p monoAtoms <- flip traverse (toList p) $ \(m, c) -> do @@ -153,19 +153,19 @@ evalPolynomialP evalMono p = do -- because it might be causing overflows due to all arithmetic being shifted. asAtom = IdxRepVal . fromInteger -evalMonomial :: MonadEmbed m => (Var -> Atom) -> Monomial -> m Atom +evalMonomial :: MonadBuilder m => (Var -> Atom) -> Monomial -> m Atom evalMonomial varVal m = do varAtoms <- traverse (\(MVar v, e) -> ipow (varVal v) e) $ toList m foldM imul (IdxRepVal 1) varAtoms -evalClampMonomial :: MonadEmbed m => (Var -> Atom) -> ClampMonomial -> m Atom +evalClampMonomial :: MonadBuilder m => (Var -> Atom) -> ClampMonomial -> m Atom evalClampMonomial varVal (ClampMonomial clamps m) = do valuesToClamp <- traverse (evalPolynomialP (evalMonomial varVal) . coerce) clamps clampsProduct <- foldM imul (IdxRepVal 1) =<< traverse clampPositive valuesToClamp mval <- evalMonomial varVal m imul clampsProduct mval -ipow :: MonadEmbed m => Atom -> Int -> m Atom +ipow :: MonadBuilder m => Atom -> Int -> m Atom ipow x i = foldM imul (IdxRepVal 1) (replicate i x) -- === Polynomial math === diff --git a/src/lib/Autodiff.hs b/src/lib/Autodiff.hs index 1da5eac39..8783fa1b9 100644 --- a/src/lib/Autodiff.hs +++ b/src/lib/Autodiff.hs @@ -22,7 +22,7 @@ import Type import Env import Syntax import PPrint -import Embed +import Builder import Cat import Util (bindM2, zipWithT, enumerate, restructure) import GHC.Stack @@ -38,17 +38,17 @@ data DerivWrt = DerivWrt { activeVars :: Env Type -- arguments to the linearized function. data TangentEnv = TangentEnv { tangentVals :: SubstEnv, activeRefs :: [Name], rematVals :: SubstEnv } -type PrimalM = ReaderT DerivWrt Embed -type TangentM = ReaderT TangentEnv Embed +type PrimalM = ReaderT DerivWrt Builder +type TangentM = ReaderT TangentEnv Builder newtype LinA a = LinA { runLinA :: PrimalM (a, TangentM a) } linearize :: Scope -> Atom -> Atom -linearize scope ~(Lam (Abs b (_, block))) = fst $ flip runEmbed scope $ do +linearize scope ~(Lam (Abs b (_, block))) = fst $ flip runBuilder scope $ do buildLam b PureArrow \x@(Var v) -> do (y, yt) <- flip runReaderT (DerivWrt (varAsEnv v) [] mempty) $ runLinA $ linearizeBlock (b@>x) block -- TODO: check linearity fLin <- buildLam (fmap tangentType b) LinArrow \xt -> runReaderT yt $ TangentEnv (v @> xt) [] mempty - fLinChecked <- checkEmbed fLin + fLinChecked <- checkBuilder fLin return $ PairVal y fLinChecked linearizeBlock :: SubstEnv -> Block -> LinA Atom @@ -64,7 +64,7 @@ linearizeBlock env (Block decls result) = case decls of -- Technically, we could do this and later run the code through a simplification -- pass that would eliminate a bunch of multiplications with zeros, but this seems -- simpler to do for now. - freeAtoms <- traverse (substEmbed env . Var) $ bindingsAsVars $ freeVars expr + freeAtoms <- traverse (substBuilder env . Var) $ bindingsAsVars $ freeVars expr varsAreActive <- traverse isActive $ bindingsAsVars $ freeVars freeAtoms if any id varsAreActive then do @@ -90,7 +90,7 @@ linearizeBlock env (Block decls result) = case decls of let nontrivialTs = if vIsTrivial then [] else [t] extendTangentEnv (newEnv nontrivialVs nontrivialTs) [] bodyLin) else do - expr' <- substEmbed env expr + expr' <- substBuilder env expr x <- emit expr' runLinA $ linearizeBlock (env <> b @> x) body @@ -98,7 +98,7 @@ linearizeExpr :: SubstEnv -> Expr -> LinA Atom linearizeExpr env expr = case expr of Hof e -> linearizeHof env e Case e alts _ -> LinA $ do - e' <- substEmbed env e + e' <- substBuilder env e hasActiveScrutinee <- any id <$> (mapM isActive $ bindingsAsVars $ freeVars e') case hasActiveScrutinee of True -> notImplemented @@ -111,7 +111,7 @@ linearizeExpr env expr = case expr of linearizeInactiveAlt (Abs bs body) = do buildNAbs bs \xs -> tangentFunAsLambda $ linearizeBlock (env <> newEnv bs xs) body _ -> LinA $ do - expr' <- substEmbed env expr + expr' <- substBuilder env expr runLinA $ case expr' of App x i | isTabTy (getType x) -> liftA (flip App i) (linearizeAtom x) `bindLin` emit Op e -> linearizeOp e @@ -124,13 +124,26 @@ linearizeOp :: Op -> LinA Atom linearizeOp op = case op of ScalarUnOp uop x -> linearizeUnOp uop x ScalarBinOp bop x y -> linearizeBinOp bop x y + PrimEffect refArg (MExtend ~(LamVal b body)) -> LinA $ do + (primalRef, mkTangentRef) <- runLinA $ la refArg + (primalUpdate, mkTangentUpdate) <- + buildLamAux b (const $ return PureArrow) \x@(Var v) -> + extendWrt [v] [] $ runLinA $ linearizeBlock (b @> x) body + let LamVal (Bind primalStateVar) _ = primalUpdate + ans <- emitOp $ PrimEffect primalRef $ MExtend primalUpdate + return (ans, do + tangentRef <- mkTangentRef + -- TODO: Assert that tangent update doesn't close over anything? + tangentUpdate <- buildLam (Bind $ "t":>tangentType (varType primalStateVar)) PureArrow \tx -> + extendTangentEnv (primalStateVar @> tx) [] $ mkTangentUpdate + emitOp $ PrimEffect tangentRef $ MExtend tangentUpdate) PrimEffect refArg m -> liftA2 PrimEffect (la refArg) (case m of - MAsk -> pure MAsk - MTell x -> liftA MTell $ la x - MGet -> pure MGet - MPut x -> liftA MPut $ la x) `bindLin` emitOp + MAsk -> pure MAsk + MExtend _ -> error "Unhandled MExtend" + MGet -> pure MGet + MPut x -> liftA MPut $ la x) `bindLin` emitOp IndexRef ref i -> (IndexRef <$> la ref <*> pure i) `bindLin` emitOp FstRef ref -> (FstRef <$> la ref ) `bindLin` emitOp SndRef ref -> (SndRef <$> la ref ) `bindLin` emitOp @@ -181,7 +194,7 @@ linearizeOp op = case op of emitWithZero :: LinA Atom emitWithZero = LinA $ withZeroTangent <$> emitOp op -emitUnOp :: MonadEmbed m => UnOp -> Atom -> m Atom +emitUnOp :: MonadBuilder m => UnOp -> Atom -> m Atom emitUnOp op x = emitOp $ ScalarUnOp op x linearizeUnOp :: UnOp -> Atom -> LinA Atom @@ -255,17 +268,25 @@ linearizeBinOp op x' y' = LinA $ do linearizeHof :: SubstEnv -> Hof -> LinA Atom linearizeHof env hof = case hof of For ~(RegularFor d) ~(LamVal i body) -> LinA $ do - i' <- mapM (substEmbed env) i + i' <- mapM (substBuilder env) i (ansWithLinTab, vi'') <- buildForAux d i' \i''@(Var vi'') -> (,vi'') <$> (willRemat vi'' $ tangentFunAsLambda $ linearizeBlock (env <> i@>i'') body) (ans, linTab) <- unzipTab ansWithLinTab return (ans, buildFor d i' \i'' -> provideRemat vi'' i'' $ app linTab i'' >>= applyLinToTangents) Tile _ _ _ -> notImplemented - RunWriter lam -> linearizeEff Nothing lam True (const RunWriter) (emitRunWriter "r") Writer - RunReader val lam -> linearizeEff (Just val) lam False RunReader (emitRunReader "r") Reader - RunState val lam -> linearizeEff (Just val) lam True RunState (emitRunState "r") State + RunWriter bm ~lam@(BinaryFunVal _ refBinder _ _) -> LinA $ do + unless (checkZeroPlusFloatMonoid bm) $ + error "AD of Accum effect only supported when the base monoid is (0, +) on Float" + let RefTy _ accTy = binderType refBinder + linearizeEff lam Writer (RunWriter bm) (emitRunWriter "r" accTy bm) + RunReader val lam -> LinA $ do + (val', mkLinInit) <- runLinA <$> linearizeAtom =<< substBuilder env val + linearizeEff lam Reader (RunReader val') \f -> mkLinInit >>= emitRunReader "r" `flip` f + RunState val lam -> LinA $ do + (val', mkLinInit) <- runLinA <$> linearizeAtom =<< substBuilder env val + linearizeEff lam State (RunState val') \f -> mkLinInit >>= emitRunState "r" `flip` f RunIO ~(Lam (Abs _ (arrow, body))) -> LinA $ do - arrow' <- substEmbed env arrow + arrow' <- substBuilder env arrow -- TODO: consider the possibility of other effects here besides IO lam <- buildLam (Ignore UnitTy) arrow' \_ -> tangentFunAsLambda $ linearizeBlock env body @@ -279,39 +300,28 @@ linearizeHof env hof = case hof of CatchException _ -> notImplemented Linearize _ -> error "Unexpected linearization" Transpose _ -> error "Unexpected transposition" - PTileReduce _ _ -> error "Unexpected PTileReduce" + PTileReduce _ _ _ -> error "Unexpected PTileReduce" where - linearizeEff maybeInit lam hasResult hofMaker emitter eff = LinA $ do - (valHofMaker, maybeLinInit) <- case maybeInit of - Just val -> do - (val', linVal) <- runLinA <$> linearizeAtom =<< substEmbed env val - return (hofMaker val', Just linVal) - Nothing -> return (hofMaker undefined, Nothing) + linearizeEff lam eff primalHofCon tangentEmitter = do (lam', ref) <- linearizeEffectFun eff lam - (ans, linBody) <- case hasResult of - True -> do - (ansLin, w) <- fromPair =<< emit (Hof $ valHofMaker lam') + -- The reader effect doesn't return any additional values + (ans, linBody) <- case eff of + Reader -> fromPair =<< emit (Hof $ primalHofCon lam') + _ -> do + (ansLin, w) <- fromPair =<< emit (Hof $ primalHofCon lam') (ans, linBody) <- fromPair ansLin return (PairVal ans w, linBody) - False -> fromPair =<< emit (Hof $ valHofMaker lam') - let lin = do - valEmitter <- case maybeLinInit of - Just linVal -> emitter <$> linVal - Nothing -> do - let (BinaryFunTy _ b _ _) = getType lam' - let RefTy _ wTy = binderType b - return $ emitter $ tangentType wTy - valEmitter \ref'@(Var (_:> RefTy (Var (h:>_)) _)) -> do - extendTangentEnv (ref @> ref') [h] $ applyLinToTangents linBody + let lin = tangentEmitter \ref'@(Var (_:> RefTy (Var (h:>_)) _)) -> do + extendTangentEnv (ref @> ref') [h] $ applyLinToTangents linBody return (ans, lin) linearizeEffectFun :: RWS -> Atom -> PrimalM (Atom, Var) linearizeEffectFun rws ~(BinaryFunVal h ref eff body) = do - h' <- mapM (substEmbed env) h + h' <- mapM (substBuilder env) h buildLamAux h' (const $ return PureArrow) \h''@(Var hVar) -> do let env' = env <> h@>h'' - eff' <- substEmbed env' eff - ref' <- mapM (substEmbed env') ref + eff' <- substBuilder env' eff + ref' <- mapM (substBuilder env') ref buildLamAux ref' (const $ return $ PlainArrow eff') \ref''@(Var refVar) -> extendWrt [refVar] [RWSEffect rws (varName hVar)] $ (,refVar) <$> (tangentFunAsLambda $ linearizeBlock (env' <> ref@>ref'') body) @@ -391,24 +401,6 @@ tangentType ty = case ty of _ -> unsupported where unsupported = error $ "Can't differentiate wrt type " ++ pprint ty -addTangent :: MonadEmbed m => Atom -> Atom -> m Atom -addTangent x y = case getType x of - RecordTy (NoExt tys) -> do - elems <- bindM2 (zipWithT addTangent) (getUnpacked x) (getUnpacked y) - return $ Record $ restructure elems tys - TabTy b _ -> buildFor Fwd b \i -> bindM2 addTangent (tabGet x i) (tabGet y i) - TC con -> case con of - BaseType (Scalar _) -> emitOp $ ScalarBinOp FAdd x y - BaseType (Vector _) -> emitOp $ VectorBinOp FAdd x y - UnitType -> return UnitVal - PairType _ _ -> do - (xa, xb) <- fromPair x - (ya, yb) <- fromPair y - PairVal <$> addTangent xa ya <*> addTangent xb yb - _ -> notTangent - _ -> notTangent - where notTangent = error $ "Not a tangent type: " ++ pprint (getType x) - isTrivialForAD :: Expr -> Bool isTrivialForAD expr = isSingletonType tangentTy && exprEffs expr == mempty where tangentTy = tangentType $ getType expr @@ -445,7 +437,7 @@ tangentFunAsLambda m = do -- Like buildLam, but doesn't try to deshadow the binder. makeLambda v f = do block <- buildScoped $ do - embedExtend $ asFst $ v @> (varType v, LamBound (void PureArrow)) + builderExtend $ asFst $ v @> (varType v, LamBound (void PureArrow)) f v return $ Lam $ makeAbs (Bind v) (PureArrow, block) @@ -473,7 +465,7 @@ applyLinToTangents f = do let args = (toList rematVals) ++ hs' ++ tangents ++ [UnitVal] naryApp f args -bindLin :: LinA a -> (a -> Embed b) -> LinA b +bindLin :: LinA a -> (a -> Builder b) -> LinA b bindLin (LinA m) f = LinA $ do (e, t) <- m x <- lift $ f e @@ -543,10 +535,10 @@ instance Semigroup TransposeEnv where instance Monoid TransposeEnv where mempty = TransposeEnv mempty mempty mempty mempty -type TransposeM a = ReaderT TransposeEnv Embed a +type TransposeM a = ReaderT TransposeEnv Builder a transpose :: Scope -> Atom -> Atom -transpose scope ~(Lam (Abs b (_, block))) = fst $ flip runEmbed scope $ do +transpose scope ~(Lam (Abs b (_, block))) = fst $ flip runBuilder scope $ do buildLam (Bind $ "ct" :> getType block) LinArrow \ct -> do snd <$> (flip runReaderT mempty $ withLinVar b $ transposeBlock block ct) @@ -614,12 +606,16 @@ transposeOp op ct = case op of refArg' <- substTranspose linRefSubst refArg let emitEff = emitOp . PrimEffect refArg' case m of - MAsk -> void $ emitEff $ MTell ct - MTell x -> transposeAtom x =<< emitEff MAsk + MAsk -> void $ emitEff . MExtend =<< (updateAddAt ct) + -- XXX: This assumes that the update function uses a tangent (0, +) monoid, + -- which is why we can ignore the binder (we even can't; we only have a + -- reader reference!). This should have been checked in the transposeHof + -- rule for RunWriter. + MExtend ~(LamVal _ body) -> transposeBlock body =<< emitEff MAsk -- TODO: Do something more efficient for state. We should be able -- to do in-place addition, just like we do for the Writer effect. - MGet -> void $ emitEff . MPut =<< addTangent ct =<< emitEff MGet - MPut x -> do + MGet -> void $ emitEff . MPut =<< addTangent ct =<< emitEff MGet + MPut x -> do transposeAtom x =<< emitEff MGet void $ emitEff $ MPut $ zeroAt $ getType x TabCon ~(TabTy b _) es -> forM_ (enumerate es) \(i, e) -> do @@ -685,11 +681,16 @@ transposeHof hof ct = case hof of return UnitVal where flipDir dir = case dir of Fwd -> Rev; Rev -> Fwd RunReader r ~(BinaryFunVal (Bind (h:>_)) b _ body) -> do - (_, ctr) <- (fromPair =<<) $ emitRunWriter "w" (getType r) \ref -> do + let RefTy _ valTy = binderType b + let baseTy = getBaseMonoidType valTy + baseMonoid <- tangentBaseMonoidFor baseTy + (_, ctr) <- (fromPair =<<) $ emitRunWriter "w" valTy baseMonoid \ref -> do localLinRegion h $ localLinRefSubst (b@>ref) $ transposeBlock body ct return UnitVal transposeAtom r ctr - RunWriter ~(BinaryFunVal (Bind (h:>_)) b _ body) -> do + RunWriter bm ~(BinaryFunVal (Bind (h:>_)) b _ body) -> do + unless (checkZeroPlusFloatMonoid bm) $ + error "AD of Accum effect only supported when the base monoid is (0, +) on Float" (ctBody, ctEff) <- fromPair ct void $ emitRunReader "r" ctEff \ref -> do localLinRegion h $ localLinRefSubst (b@>ref) $ transposeBlock body ctBody @@ -706,7 +707,7 @@ transposeHof hof ct = case hof of CatchException _ -> notImplemented Linearize _ -> error "Unexpected linearization" Transpose _ -> error "Unexpected transposition" - PTileReduce _ _ -> error "Unexpected PTileReduce" + PTileReduce _ _ _ -> error "Unexpected PTileReduce" transposeAtom :: Atom -> Atom -> TransposeM () transposeAtom atom ct = case atom of @@ -776,7 +777,7 @@ isLinEff _ = error "Can't transpose polymorphic effects" emitCTToRef :: Maybe Atom -> Atom -> TransposeM () emitCTToRef mref ct = case mref of - Just ref -> void $ emitOp $ PrimEffect ref (MTell ct) + Just ref -> void . emitOp . PrimEffect ref . MExtend =<< updateAddAt ct Nothing -> return () substTranspose :: Subst a => (TransposeEnv -> SubstEnv) -> a -> TransposeM a @@ -789,13 +790,15 @@ substNonlin :: Subst a => a -> TransposeM a substNonlin = substTranspose nonlinSubst withLinVar :: Binder -> TransposeM a -> TransposeM (a, Atom) -withLinVar b body = case - singletonTypeVal (binderType b) of - Nothing -> flip evalStateT Nothing $ do - ans <- emitRunWriter "ref" (binderType b) \ref -> do - lift (localLinRef (b@>Just ref) body) >>= put . Just >> return UnitVal - (,) <$> (fromJust <$> get) <*> (getSnd ans) - Just x -> (,x) <$> (localLinRef (b@>Nothing) body) -- optimization to avoid accumulating into unit +withLinVar b body = case singletonTypeVal (binderType b) of + Nothing -> flip evalStateT Nothing $ do + let accTy = binderType b + let baseTy = getBaseMonoidType accTy + baseMonoid <- tangentBaseMonoidFor baseTy + ans <- emitRunWriter "ref" accTy baseMonoid \ref -> do + lift (localLinRef (b@>Just ref) body) >>= put . Just >> return UnitVal + (,) <$> (fromJust <$> get) <*> (getSnd ans) + Just x -> (,x) <$> (localLinRef (b@>Nothing) body) -- optimization to avoid accumulating into unit localLinRef :: Env (Maybe Atom) -> TransposeM a -> TransposeM a localLinRef refs = local (<> mempty { linRefs = refs }) @@ -808,3 +811,54 @@ localLinRefSubst s = local (<> mempty { linRefSubst = s }) localNonlinSubst :: SubstEnv -> TransposeM a -> TransposeM a localNonlinSubst s = local (<> mempty { nonlinSubst = s }) + +-- === The (0, +) monoid for tangent types === + +zeroAt :: Type -> Atom +zeroAt ty = case ty of + BaseTy bt -> Con $ Lit $ zeroLit bt + TabTy i a -> TabValA i $ zeroAt a + UnitTy -> UnitVal + PairTy a b -> PairVal (zeroAt a) (zeroAt b) + RecordTy (Ext tys Nothing) -> Record $ fmap zeroAt tys + _ -> unreachable + where + unreachable = error $ "Missing zero case for a tangent type: " ++ pprint ty + zeroLit bt = case bt of + Scalar Float64Type -> Float64Lit 0.0 + Scalar Float32Type -> Float32Lit 0.0 + Vector st -> VecLit $ replicate vectorWidth $ zeroLit $ Scalar st + _ -> unreachable + +updateAddAt :: MonadBuilder m => Atom -> m Atom +updateAddAt x = buildLam (Bind ("t":>getType x)) PureArrow $ addTangent x + +addTangent :: MonadBuilder m => Atom -> Atom -> m Atom +addTangent x y = case getType x of + RecordTy (NoExt tys) -> do + elems <- bindM2 (zipWithT addTangent) (getUnpacked x) (getUnpacked y) + return $ Record $ restructure elems tys + TabTy b _ -> buildFor Fwd b \i -> bindM2 addTangent (tabGet x i) (tabGet y i) + TC con -> case con of + BaseType (Scalar _) -> emitOp $ ScalarBinOp FAdd x y + BaseType (Vector _) -> emitOp $ VectorBinOp FAdd x y + UnitType -> return UnitVal + PairType _ _ -> do + (xa, xb) <- fromPair x + (ya, yb) <- fromPair y + PairVal <$> addTangent xa ya <*> addTangent xb yb + _ -> notTangent + _ -> notTangent + where notTangent = error $ "Not a tangent type: " ++ pprint (getType x) + +tangentBaseMonoidFor :: MonadBuilder m => Type -> m BaseMonoid +tangentBaseMonoidFor ty = BaseMonoid (zeroAt ty) <$> buildLam (Bind ("t":>ty)) PureArrow updateAddAt + +checkZeroPlusFloatMonoid :: BaseMonoid -> Bool +checkZeroPlusFloatMonoid (BaseMonoid zero plus) = checkZero zero && checkPlus plus + where + checkZero z = z == (Con (Lit (Float32Lit 0.0))) + checkPlus f = case f of + BinaryFunVal (Bind x) (Bind y) Pure (Block Empty (Op (ScalarBinOp FAdd (Var x') (Var y')))) -> + (x == x' && y == y') || (x == y' && y == x') + _ -> False diff --git a/src/lib/Embed.hs b/src/lib/Builder.hs similarity index 61% rename from src/lib/Embed.hs rename to src/lib/Builder.hs index 4fd4d6a27..6d4491adc 100644 --- a/src/lib/Embed.hs +++ b/src/lib/Builder.hs @@ -1,4 +1,4 @@ --- Copyright 2019 Google LLC +-- Copyright 2021 Google LLC -- -- Use of this source code is governed by a BSD-style -- license that can be found in the LICENSE file or at @@ -10,30 +10,32 @@ {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE UndecidableInstances #-} -module Embed (emit, emitTo, emitAnn, emitOp, buildDepEffLam, buildLamAux, buildPi, - getAllowedEffects, withEffects, modifyAllowedEffects, - buildLam, EmbedT, Embed, MonadEmbed, buildScoped, runEmbedT, - runSubstEmbed, runEmbed, getScope, embedLook, liftEmbed, - app, - add, mul, sub, neg, div', - iadd, imul, isub, idiv, ilt, ieq, - fpow, flog, fLitLike, recGetHead, buildImplicitNaryLam, - select, substEmbed, substEmbedR, emitUnpack, getUnpacked, - fromPair, getFst, getSnd, getFstRef, getSndRef, - naryApp, appReduce, appTryReduce, buildAbs, - buildFor, buildForAux, buildForAnn, buildForAnnAux, - emitBlock, unzipTab, isSingletonType, emitDecl, withNameHint, - singletonTypeVal, scopedDecls, embedScoped, extendScope, checkEmbed, - embedExtend, unpackConsList, emitRunWriter, applyPreludeFunction, - emitRunState, emitMaybeCase, emitWhile, buildDataDef, - emitRunReader, tabGet, SubstEmbedT, SubstEmbed, runSubstEmbedT, - ptrOffset, ptrLoad, unsafePtrLoad, - evalBlockE, substTraversalDef, - TraversalDef, traverseDecls, traverseDecl, traverseDeclsOpen, - traverseBlock, traverseExpr, traverseAtom, - clampPositive, buildNAbs, buildNAbsAux, buildNestedLam, zeroAt, - transformModuleAsBlock, dropSub, appReduceTraversalDef, - indexSetSizeE, indexToIntE, intToIndexE, freshVarE) where +module Builder (emit, emitTo, emitAnn, emitOp, buildDepEffLam, buildLamAux, buildPi, + getAllowedEffects, withEffects, modifyAllowedEffects, + buildLam, BuilderT, Builder, MonadBuilder, buildScoped, runBuilderT, + runSubstBuilder, runBuilder, getScope, builderLook, liftBuilder, + app, + add, mul, sub, neg, div', + iadd, imul, isub, idiv, ilt, ieq, + fpow, flog, fLitLike, recGetHead, buildImplicitNaryLam, + select, substBuilder, substBuilderR, emitUnpack, getUnpacked, + fromPair, getFst, getSnd, getFstRef, getSndRef, + naryApp, appReduce, appTryReduce, buildAbs, + buildFor, buildForAux, buildForAnn, buildForAnnAux, + emitBlock, unzipTab, isSingletonType, emitDecl, withNameHint, + singletonTypeVal, scopedDecls, builderScoped, extendScope, checkBuilder, + builderExtend, applyPreludeFunction, + unpackConsList, unpackLeftLeaningConsList, + emitRunWriter, emitRunWriters, mextendForRef, monoidLift, + emitRunState, emitMaybeCase, emitWhile, buildDataDef, + emitRunReader, tabGet, SubstBuilderT, SubstBuilder, runSubstBuilderT, + ptrOffset, ptrLoad, unsafePtrLoad, + evalBlockE, substTraversalDef, + TraversalDef, traverseDecls, traverseDecl, traverseDeclsOpen, + traverseBlock, traverseExpr, traverseAtom, + clampPositive, buildNAbs, buildNAbsAux, buildNestedLam, + transformModuleAsBlock, dropSub, appReduceTraversalDef, + indexSetSizeE, indexToIntE, intToIndexE, freshVarE) where import Control.Applicative import Control.Monad @@ -56,84 +58,84 @@ import Type import PPrint import Util (bindM2, scanM, restructure) -newtype EmbedT m a = EmbedT (ReaderT EmbedEnvR (CatT EmbedEnvC m) a) +newtype BuilderT m a = BuilderT (ReaderT BuilderEnvR (CatT BuilderEnvC m) a) deriving (Functor, Applicative, Monad, MonadIO, MonadFail, Alternative) -type Embed = EmbedT Identity -type EmbedEnv = (EmbedEnvR, EmbedEnvC) +type Builder = BuilderT Identity +type BuilderEnv = (BuilderEnvR, BuilderEnvC) -type SubstEmbedT m = ReaderT SubstEnv (EmbedT m) -type SubstEmbed = SubstEmbedT Identity +type SubstBuilderT m = ReaderT SubstEnv (BuilderT m) +type SubstBuilder = SubstBuilderT Identity -- Carries the vars in scope (with optional definitions) and the emitted decls -type EmbedEnvC = (Scope, Nest Decl) +type BuilderEnvC = (Scope, Nest Decl) -- Carries a name suggestion and the allowable effects -type EmbedEnvR = (Tag, EffectRow) +type BuilderEnvR = (Tag, EffectRow) -runEmbedT :: Monad m => EmbedT m a -> Scope -> m (a, EmbedEnvC) -runEmbedT (EmbedT m) scope = do +runBuilderT :: Monad m => BuilderT m a -> Scope -> m (a, BuilderEnvC) +runBuilderT (BuilderT m) scope = do (ans, env) <- runCatT (runReaderT m ("tmp", Pure)) (scope, Empty) return (ans, env) -runEmbed :: Embed a -> Scope -> (a, EmbedEnvC) -runEmbed m scope = runIdentity $ runEmbedT m scope +runBuilder :: Builder a -> Scope -> (a, BuilderEnvC) +runBuilder m scope = runIdentity $ runBuilderT m scope -runSubstEmbedT :: Monad m => SubstEmbedT m a -> Scope -> m (a, EmbedEnvC) -runSubstEmbedT m scope = runEmbedT (runReaderT m mempty) scope +runSubstBuilderT :: Monad m => SubstBuilderT m a -> Scope -> m (a, BuilderEnvC) +runSubstBuilderT m scope = runBuilderT (runReaderT m mempty) scope -runSubstEmbed :: SubstEmbed a -> Scope -> (a, EmbedEnvC) -runSubstEmbed m scope = runIdentity $ runEmbedT (runReaderT m mempty) scope +runSubstBuilder :: SubstBuilder a -> Scope -> (a, BuilderEnvC) +runSubstBuilder m scope = runIdentity $ runBuilderT (runReaderT m mempty) scope -emit :: MonadEmbed m => Expr -> m Atom +emit :: MonadBuilder m => Expr -> m Atom emit expr = emitAnn PlainLet expr -emitAnn :: MonadEmbed m => LetAnn -> Expr -> m Atom +emitAnn :: MonadBuilder m => LetAnn -> Expr -> m Atom emitAnn ann expr = do v <- getNameHint emitTo v ann expr -- Guarantees that the name will be used, possibly with a modified counter -emitTo :: MonadEmbed m => Name -> LetAnn -> Expr -> m Atom +emitTo :: MonadBuilder m => Name -> LetAnn -> Expr -> m Atom emitTo name ann expr = do scope <- getScope -- Deshadow type because types from DataDef may have binders that shadow local vars let ty = deShadow (getType expr) scope let expr' = deShadow expr scope v <- freshVarE (LetBound ann expr') $ Bind (name:>ty) - embedExtend $ asSnd $ Nest (Let ann (Bind v) expr') Empty + builderExtend $ asSnd $ Nest (Let ann (Bind v) expr') Empty return $ Var v -emitOp :: MonadEmbed m => Op -> m Atom +emitOp :: MonadBuilder m => Op -> m Atom emitOp op = emit $ Op op -emitUnpack :: MonadEmbed m => Expr -> m [Atom] +emitUnpack :: MonadBuilder m => Expr -> m [Atom] emitUnpack expr = getUnpacked =<< emit expr -emitBlock :: MonadEmbed m => Block -> m Atom +emitBlock :: MonadBuilder m => Block -> m Atom emitBlock block = emitBlockRec mempty block -emitBlockRec :: MonadEmbed m => SubstEnv -> Block -> m Atom +emitBlockRec :: MonadBuilder m => SubstEnv -> Block -> m Atom emitBlockRec env (Block (Nest (Let ann b expr) decls) result) = do - expr' <- substEmbed env expr + expr' <- substBuilder env expr x <- emitTo (binderNameHint b) ann expr' emitBlockRec (env <> b@>x) $ Block decls result -emitBlockRec env (Block Empty (Atom atom)) = substEmbed env atom -emitBlockRec env (Block Empty expr) = substEmbed env expr >>= emit +emitBlockRec env (Block Empty (Atom atom)) = substBuilder env atom +emitBlockRec env (Block Empty expr) = substBuilder env expr >>= emit -freshVarE :: MonadEmbed m => BinderInfo -> Binder -> m Var +freshVarE :: MonadBuilder m => BinderInfo -> Binder -> m Var freshVarE bInfo b = do v <- case b of Ignore _ -> getNameHint Bind (v:>_) -> return v scope <- getScope let v' = genFresh v scope - embedExtend $ asFst $ v' @> (binderType b, bInfo) + builderExtend $ asFst $ v' @> (binderType b, bInfo) return $ v' :> binderType b -freshNestedBinders :: MonadEmbed m => Nest Binder -> m (Nest Var) +freshNestedBinders :: MonadBuilder m => Nest Binder -> m (Nest Var) freshNestedBinders bs = freshNestedBindersRec mempty bs -freshNestedBindersRec :: MonadEmbed m => Env Atom -> Nest Binder -> m (Nest Var) +freshNestedBindersRec :: MonadBuilder m => Env Atom -> Nest Binder -> m (Nest Var) freshNestedBindersRec _ Empty = return Empty freshNestedBindersRec substEnv (Nest b bs) = do scope <- getScope @@ -141,7 +143,7 @@ freshNestedBindersRec substEnv (Nest b bs) = do vs <- freshNestedBindersRec (substEnv <> b@>Var v) bs return $ Nest v vs -buildPi :: (MonadError Err m, MonadEmbed m) +buildPi :: (MonadError Err m, MonadBuilder m) => Binder -> (Atom -> m (Arrow, Type)) -> m Atom buildPi b f = do scope <- getScope @@ -155,7 +157,7 @@ buildPi b f = do Nothing -> throw CompilerErr $ "Unexpected irreducible decls in pi type: " ++ pprint decls -buildAbs :: MonadEmbed m => Binder -> (Atom -> m a) -> m (Abs Binder (Nest Decl, a)) +buildAbs :: MonadBuilder m => Binder -> (Atom -> m a) -> m (Abs Binder (Nest Decl, a)) buildAbs b f = do ((b', ans), decls) <- scopedDecls $ do v <- freshVarE UnknownBinder b @@ -163,14 +165,14 @@ buildAbs b f = do return (b, ans) return (Abs b' (decls, ans)) -buildLam :: MonadEmbed m => Binder -> Arrow -> (Atom -> m Atom) -> m Atom +buildLam :: MonadBuilder m => Binder -> Arrow -> (Atom -> m Atom) -> m Atom buildLam b arr body = buildDepEffLam b (const (return arr)) body -buildDepEffLam :: MonadEmbed m +buildDepEffLam :: MonadBuilder m => Binder -> (Atom -> m Arrow) -> (Atom -> m Atom) -> m Atom buildDepEffLam b fArr fBody = liftM fst $ buildLamAux b fArr \x -> (,()) <$> fBody x -buildLamAux :: MonadEmbed m +buildLamAux :: MonadBuilder m => Binder -> (Atom -> m Arrow) -> (Atom -> m (Atom, a)) -> m (Atom, a) buildLamAux b fArr fBody = do ((b', arr, ans, aux), decls) <- scopedDecls $ do @@ -178,15 +180,15 @@ buildLamAux b fArr fBody = do let x = Var v arr <- fArr x -- overwriting the previous binder info know that we know more - embedExtend $ asFst $ v @> (varType v, LamBound (void arr)) + builderExtend $ asFst $ v @> (varType v, LamBound (void arr)) (ans, aux) <- withEffects (arrowEff arr) $ fBody x return (Bind v, arr, ans, aux) return (Lam $ makeAbs b' (arr, wrapDecls decls ans), aux) -buildNAbs :: MonadEmbed m => Nest Binder -> ([Atom] -> m Atom) -> m Alt +buildNAbs :: MonadBuilder m => Nest Binder -> ([Atom] -> m Atom) -> m Alt buildNAbs bs body = liftM fst $ buildNAbsAux bs \xs -> (,()) <$> body xs -buildNAbsAux :: MonadEmbed m => Nest Binder -> ([Atom] -> m (Atom, a)) -> m (Alt, a) +buildNAbsAux :: MonadBuilder m => Nest Binder -> ([Atom] -> m (Atom, a)) -> m (Alt, a) buildNAbsAux bs body = do ((bs', (ans, aux)), decls) <- scopedDecls $ do vs <- freshNestedBinders bs @@ -194,7 +196,7 @@ buildNAbsAux bs body = do return (fmap Bind vs, result) return (Abs bs' $ wrapDecls decls ans, aux) -buildDataDef :: MonadEmbed m +buildDataDef :: MonadBuilder m => Name -> Nest Binder -> ([Atom] -> m [DataConDef]) -> m DataDef buildDataDef tyConName paramBinders body = do ((paramBinders', dataDefs), _) <- scopedDecls $ do @@ -203,11 +205,11 @@ buildDataDef tyConName paramBinders body = do return (fmap Bind vs, result) return $ DataDef tyConName paramBinders' dataDefs -buildImplicitNaryLam :: MonadEmbed m => (Nest Binder) -> ([Atom] -> m Atom) -> m Atom +buildImplicitNaryLam :: MonadBuilder m => (Nest Binder) -> ([Atom] -> m Atom) -> m Atom buildImplicitNaryLam Empty body = body [] buildImplicitNaryLam (Nest b bs) body = buildLam b ImplicitArrow \x -> do - bs' <- substEmbed (b@>x) bs + bs' <- substBuilder (b@>x) bs buildImplicitNaryLam bs' \xs -> body $ x:xs recGetHead :: Label -> Atom -> Atom @@ -216,7 +218,7 @@ recGetHead l x = do let i = fromJust $ elemIndex l $ map fst $ toList $ reflectLabels r getProjection [i] x -buildScoped :: MonadEmbed m => m Atom -> m Block +buildScoped :: MonadBuilder m => m Atom -> m Block buildScoped m = do (ans, decls) <- scopedDecls m return $ wrapDecls decls ans @@ -231,108 +233,92 @@ inlineLastDecl block@(Block decls result) = Block (toNest (reverse rest)) expr _ -> block -zeroAt :: Type -> Atom -zeroAt ty = case ty of - BaseTy bt -> Con $ Lit $ zeroLit bt - TabTy i a -> TabValA i $ zeroAt a - UnitTy -> UnitVal - PairTy a b -> PairVal (zeroAt a) (zeroAt b) - RecordTy (Ext tys Nothing) -> Record $ fmap zeroAt tys - _ -> unreachable - where - unreachable = error $ "Missing zero case for a tangent type: " ++ pprint ty - zeroLit bt = case bt of - Scalar Float64Type -> Float64Lit 0.0 - Scalar Float32Type -> Float32Lit 0.0 - Vector st -> VecLit $ replicate vectorWidth $ zeroLit $ Scalar st - _ -> unreachable - fLitLike :: Double -> Atom -> Atom fLitLike x t = case getType t of BaseTy (Scalar Float64Type) -> Con $ Lit $ Float64Lit x BaseTy (Scalar Float32Type) -> Con $ Lit $ Float32Lit $ realToFrac x _ -> error "Expected a floating point scalar" -neg :: MonadEmbed m => Atom -> m Atom +neg :: MonadBuilder m => Atom -> m Atom neg x = emitOp $ ScalarUnOp FNeg x -add :: MonadEmbed m => Atom -> Atom -> m Atom +add :: MonadBuilder m => Atom -> Atom -> m Atom add x y = emitOp $ ScalarBinOp FAdd x y -- TODO: Implement constant folding for fixed-width integer types as well! -iadd :: MonadEmbed m => Atom -> Atom -> m Atom +iadd :: MonadBuilder m => Atom -> Atom -> m Atom iadd (Con (Lit l)) y | getIntLit l == 0 = return y iadd x (Con (Lit l)) | getIntLit l == 0 = return x iadd x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntBinOp (+) x y iadd x y = emitOp $ ScalarBinOp IAdd x y -mul :: MonadEmbed m => Atom -> Atom -> m Atom +mul :: MonadBuilder m => Atom -> Atom -> m Atom mul x y = emitOp $ ScalarBinOp FMul x y -imul :: MonadEmbed m => Atom -> Atom -> m Atom +imul :: MonadBuilder m => Atom -> Atom -> m Atom imul (Con (Lit l)) y | getIntLit l == 1 = return y imul x (Con (Lit l)) | getIntLit l == 1 = return x imul x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntBinOp (*) x y imul x y = emitOp $ ScalarBinOp IMul x y -sub :: MonadEmbed m => Atom -> Atom -> m Atom +sub :: MonadBuilder m => Atom -> Atom -> m Atom sub x y = emitOp $ ScalarBinOp FSub x y -isub :: MonadEmbed m => Atom -> Atom -> m Atom +isub :: MonadBuilder m => Atom -> Atom -> m Atom isub x (Con (Lit l)) | getIntLit l == 0 = return x isub x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntBinOp (-) x y isub x y = emitOp $ ScalarBinOp ISub x y -select :: MonadEmbed m => Atom -> Atom -> Atom -> m Atom +select :: MonadBuilder m => Atom -> Atom -> Atom -> m Atom select (Con (Lit (Word8Lit p))) x y = return $ if p /= 0 then x else y select p x y = emitOp $ Select p x y -div' :: MonadEmbed m => Atom -> Atom -> m Atom +div' :: MonadBuilder m => Atom -> Atom -> m Atom div' x y = emitOp $ ScalarBinOp FDiv x y -idiv :: MonadEmbed m => Atom -> Atom -> m Atom +idiv :: MonadBuilder m => Atom -> Atom -> m Atom idiv x (Con (Lit l)) | getIntLit l == 1 = return x idiv x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntBinOp div x y idiv x y = emitOp $ ScalarBinOp IDiv x y -irem :: MonadEmbed m => Atom -> Atom -> m Atom +irem :: MonadBuilder m => Atom -> Atom -> m Atom irem x y = emitOp $ ScalarBinOp IRem x y -fpow :: MonadEmbed m => Atom -> Atom -> m Atom +fpow :: MonadBuilder m => Atom -> Atom -> m Atom fpow x y = emitOp $ ScalarBinOp FPow x y -flog :: MonadEmbed m => Atom -> m Atom +flog :: MonadBuilder m => Atom -> m Atom flog x = emitOp $ ScalarUnOp Log x -ilt :: MonadEmbed m => Atom -> Atom -> m Atom +ilt :: MonadBuilder m => Atom -> Atom -> m Atom ilt x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntCmpOp (<) x y ilt x y = emitOp $ ScalarBinOp (ICmp Less) x y -ieq :: MonadEmbed m => Atom -> Atom -> m Atom +ieq :: MonadBuilder m => Atom -> Atom -> m Atom ieq x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntCmpOp (==) x y ieq x y = emitOp $ ScalarBinOp (ICmp Equal) x y -fromPair :: MonadEmbed m => Atom -> m (Atom, Atom) +fromPair :: MonadBuilder m => Atom -> m (Atom, Atom) fromPair pair = do ~[x, y] <- getUnpacked pair return (x, y) -getFst :: MonadEmbed m => Atom -> m Atom +getFst :: MonadBuilder m => Atom -> m Atom getFst p = fst <$> fromPair p -getSnd :: MonadEmbed m => Atom -> m Atom +getSnd :: MonadBuilder m => Atom -> m Atom getSnd p = snd <$> fromPair p -getFstRef :: MonadEmbed m => Atom -> m Atom +getFstRef :: MonadBuilder m => Atom -> m Atom getFstRef r = emitOp $ FstRef r -getSndRef :: MonadEmbed m => Atom -> m Atom +getSndRef :: MonadBuilder m => Atom -> m Atom getSndRef r = emitOp $ SndRef r -- XXX: getUnpacked must reduce its argument to enforce the invariant that -- ProjectElt atoms are always fully reduced (to avoid type errors between two -- equivalent types spelled differently). -getUnpacked :: MonadEmbed m => Atom -> m [Atom] +getUnpacked :: MonadBuilder m => Atom -> m [Atom] getUnpacked atom = do scope <- getScope let len = projectLength $ getType atom @@ -340,19 +326,19 @@ getUnpacked atom = do res = map (\i -> getProjection [i] atom') [0..(len-1)] return res -app :: MonadEmbed m => Atom -> Atom -> m Atom +app :: MonadBuilder m => Atom -> Atom -> m Atom app x i = emit $ App x i -naryApp :: MonadEmbed m => Atom -> [Atom] -> m Atom +naryApp :: MonadBuilder m => Atom -> [Atom] -> m Atom naryApp f xs = foldM app f xs -appReduce :: MonadEmbed m => Atom -> Atom -> m Atom +appReduce :: MonadBuilder m => Atom -> Atom -> m Atom appReduce (Lam (Abs v (_, b))) a = runReaderT (evalBlockE substTraversalDef b) (v @> a) appReduce _ _ = error "appReduce expected a lambda as the first argument" -- TODO: this would be more convenient if we could add type inference too -applyPreludeFunction :: MonadEmbed m => String -> [Atom] -> m Atom +applyPreludeFunction :: MonadBuilder m => String -> [Atom] -> m Atom applyPreludeFunction s xs = do scope <- getScope case envLookup scope fname of @@ -360,22 +346,22 @@ applyPreludeFunction s xs = do Just (ty, _) -> naryApp (Var (fname:>ty)) xs where fname = GlobalName (fromString s) -appTryReduce :: MonadEmbed m => Atom -> Atom -> m Atom +appTryReduce :: MonadBuilder m => Atom -> Atom -> m Atom appTryReduce f x = case f of Lam _ -> appReduce f x _ -> app f x -ptrOffset :: MonadEmbed m => Atom -> Atom -> m Atom +ptrOffset :: MonadBuilder m => Atom -> Atom -> m Atom ptrOffset x i = emitOp $ PtrOffset x i -unsafePtrLoad :: MonadEmbed m => Atom -> m Atom +unsafePtrLoad :: MonadBuilder m => Atom -> m Atom unsafePtrLoad x = emit $ Hof $ RunIO $ Lam $ Abs (Ignore UnitTy) $ (PlainArrow (oneEffect IOEffect), Block Empty (Op (PtrLoad x))) -ptrLoad :: MonadEmbed m => Atom -> m Atom +ptrLoad :: MonadBuilder m => Atom -> m Atom ptrLoad x = emitOp $ PtrLoad x -unpackConsList :: MonadEmbed m => Atom -> m [Atom] +unpackConsList :: MonadBuilder m => Atom -> m [Atom] unpackConsList xs = case getType xs of UnitTy -> return [] --PairTy _ UnitTy -> (:[]) <$> getFst xs @@ -384,13 +370,24 @@ unpackConsList xs = case getType xs of liftM (x:) $ unpackConsList rest _ -> error $ "Not a cons list: " ++ pprint (getType xs) -emitWhile :: MonadEmbed m => m Atom -> m () +-- ((...((ans, x{n}), x{n-1})..., x2), x1) -> (ans, [x1, ..., x{n}]) +-- This is useful for unpacking results of stacked effect handlers (as produced +-- by e.g. emitRunWriters). +unpackLeftLeaningConsList :: MonadBuilder m => Int -> Atom -> m (Atom, [Atom]) +unpackLeftLeaningConsList depth atom = go depth atom [] + where + go 0 curAtom xs = return (curAtom, reverse xs) + go remDepth curAtom xs = do + (consTail, x) <- fromPair curAtom + go (remDepth - 1) consTail (x : xs) + +emitWhile :: MonadBuilder m => m Atom -> m () emitWhile body = do eff <- getAllowedEffects lam <- buildLam (Ignore UnitTy) (PlainArrow eff) \_ -> body void $ emit $ Hof $ While lam -emitMaybeCase :: MonadEmbed m => Atom -> m Atom -> (Atom -> m Atom) -> m Atom +emitMaybeCase :: MonadBuilder m => Atom -> m Atom -> (Atom -> m Atom) -> m Atom emitMaybeCase scrut nothingCase justCase = do let (MaybeTy a) = getType scrut nothingAlt <- buildNAbs Empty \[] -> nothingCase @@ -399,51 +396,87 @@ emitMaybeCase scrut nothingCase justCase = do let resultTy = getType nothingBody emit $ Case scrut [nothingAlt, justAlt] resultTy -emitRunWriter :: MonadEmbed m => Name -> Type -> (Atom -> m Atom) -> m Atom -emitRunWriter v ty body = do - emit . Hof . RunWriter =<< mkBinaryEffFun Writer v ty body +monoidLift :: Type -> Type -> Nest Binder +monoidLift baseTy accTy = case baseTy == accTy of + True -> Empty + False -> case accTy of + TabTy n b -> Nest n $ monoidLift baseTy b + _ -> error $ "Base monoid type mismatch: can't lift " ++ + pprint baseTy ++ " to " ++ pprint accTy + +mextendForRef :: MonadBuilder m => Atom -> BaseMonoid -> Atom -> m Atom +mextendForRef ref (BaseMonoid _ combine) update = do + buildLam (Bind $ "refVal":>accTy) PureArrow \refVal -> + buildNestedFor (fmap (Fwd,) $ toList liftIndices) $ \indices -> do + refElem <- tabGetNd refVal indices + updateElem <- tabGetNd update indices + bindM2 appTryReduce (appTryReduce combine refElem) (return updateElem) + where + TC (RefType _ accTy) = getType ref + FunTy (BinderAnn baseTy) _ _ = getType combine + liftIndices = monoidLift baseTy accTy + +emitRunWriter :: MonadBuilder m => Name -> Type -> BaseMonoid -> (Atom -> m Atom) -> m Atom +emitRunWriter v accTy bm body = do + emit . Hof . RunWriter bm =<< mkBinaryEffFun Writer v accTy body -emitRunReader :: MonadEmbed m => Name -> Atom -> (Atom -> m Atom) -> m Atom +emitRunWriters :: MonadBuilder m => [(Name, Type, BaseMonoid)] -> ([Atom] -> m Atom) -> m Atom +emitRunWriters inits body = go inits [] + where + go [] refs = body $ reverse refs + go ((v, accTy, bm):rest) refs = emitRunWriter v accTy bm $ \ref -> go rest (ref:refs) + +emitRunReader :: MonadBuilder m => Name -> Atom -> (Atom -> m Atom) -> m Atom emitRunReader v x0 body = do emit . Hof . RunReader x0 =<< mkBinaryEffFun Reader v (getType x0) body -emitRunState :: MonadEmbed m => Name -> Atom -> (Atom -> m Atom) -> m Atom +emitRunState :: MonadBuilder m => Name -> Atom -> (Atom -> m Atom) -> m Atom emitRunState v x0 body = do emit . Hof . RunState x0 =<< mkBinaryEffFun State v (getType x0) body -mkBinaryEffFun :: MonadEmbed m => RWS -> Name -> Type -> (Atom -> m Atom) -> m Atom +mkBinaryEffFun :: MonadBuilder m => RWS -> Name -> Type -> (Atom -> m Atom) -> m Atom mkBinaryEffFun rws v ty body = do eff <- getAllowedEffects buildLam (Bind ("h":>TyKind)) PureArrow \r@(Var (rName:>_)) -> do let arr = PlainArrow $ extendEffect (RWSEffect rws rName) eff buildLam (Bind (v:> RefTy r ty)) arr body -buildForAnnAux :: MonadEmbed m => ForAnn -> Binder -> (Atom -> m (Atom, a)) -> m (Atom, a) +buildForAnnAux :: MonadBuilder m => ForAnn -> Binder -> (Atom -> m (Atom, a)) -> m (Atom, a) buildForAnnAux ann i body = do -- TODO: consider only tracking the effects that are actually needed. eff <- getAllowedEffects (lam, aux) <- buildLamAux i (const $ return $ PlainArrow eff) body (,aux) <$> (emit $ Hof $ For ann lam) -buildForAnn :: MonadEmbed m => ForAnn -> Binder -> (Atom -> m Atom) -> m Atom +buildForAnn :: MonadBuilder m => ForAnn -> Binder -> (Atom -> m Atom) -> m Atom buildForAnn ann i body = fst <$> buildForAnnAux ann i (\x -> (,()) <$> body x) -buildForAux :: MonadEmbed m => Direction -> Binder -> (Atom -> m (Atom, a)) -> m (Atom, a) +buildForAux :: MonadBuilder m => Direction -> Binder -> (Atom -> m (Atom, a)) -> m (Atom, a) buildForAux = buildForAnnAux . RegularFor -- Do we need this variant? -buildFor :: MonadEmbed m => Direction -> Binder -> (Atom -> m Atom) -> m Atom +buildFor :: MonadBuilder m => Direction -> Binder -> (Atom -> m Atom) -> m Atom buildFor = buildForAnn . RegularFor -buildNestedLam :: MonadEmbed m => Arrow -> [Binder] -> ([Atom] -> m Atom) -> m Atom +buildNestedFor :: forall m. MonadBuilder m => [(Direction, Binder)] -> ([Atom] -> m Atom) -> m Atom +buildNestedFor specs body = go specs [] + where + go :: [(Direction, Binder)] -> [Atom] -> m Atom + go [] indices = body $ reverse indices + go ((d,b):t) indices = buildFor d b $ \i -> go t (i:indices) + +buildNestedLam :: MonadBuilder m => Arrow -> [Binder] -> ([Atom] -> m Atom) -> m Atom buildNestedLam _ [] f = f [] buildNestedLam arr (b:bs) f = buildLam b arr \x -> buildNestedLam arr bs \xs -> f (x:xs) -tabGet :: MonadEmbed m => Atom -> Atom -> m Atom -tabGet x i = emit $ App x i +tabGet :: MonadBuilder m => Atom -> Atom -> m Atom +tabGet tab idx = emit $ App tab idx + +tabGetNd :: MonadBuilder m => Atom -> [Atom] -> m Atom +tabGetNd tab idxs = foldM (flip tabGet) tab idxs -unzipTab :: MonadEmbed m => Atom -> m (Atom, Atom) +unzipTab :: MonadBuilder m => Atom -> m (Atom, Atom) unzipTab tab = do fsts <- buildLam (Bind ("i":>binderType v)) TabArrow \i -> liftM fst $ app tab i >>= fromPair @@ -452,20 +485,20 @@ unzipTab tab = do return (fsts, snds) where TabTy v _ = getType tab -substEmbedR :: (MonadEmbed m, MonadReader SubstEnv m, Subst a) +substBuilderR :: (MonadBuilder m, MonadReader SubstEnv m, Subst a) => a -> m a -substEmbedR x = do +substBuilderR x = do env <- ask - substEmbed env x + substBuilder env x -substEmbed :: (MonadEmbed m, Subst a) +substBuilder :: (MonadBuilder m, Subst a) => SubstEnv -> a -> m a -substEmbed env x = do +substBuilder env x = do scope <- getScope return $ subst (env, scope) x -checkEmbed :: (HasCallStack, MonadEmbed m, HasVars a, HasType a) => a -> m a -checkEmbed x = do +checkBuilder :: (HasCallStack, MonadBuilder m, HasVars a, HasType a) => a -> m a +checkBuilder x = do scope <- getScope let globals = freeVars x `envDiff` scope unless (all (isGlobal . (:>())) $ envNames globals) $ @@ -490,114 +523,114 @@ singletonTypeVal (TC con) = case con of _ -> Nothing singletonTypeVal _ = Nothing -indexAsInt :: MonadEmbed m => Atom -> m Atom +indexAsInt :: MonadBuilder m => Atom -> m Atom indexAsInt idx = emitOp $ ToOrdinal idx -instance MonadTrans EmbedT where - lift m = EmbedT $ lift $ lift m - -class Monad m => MonadEmbed m where - embedLook :: m EmbedEnvC - embedExtend :: EmbedEnvC -> m () - embedScoped :: m a -> m (a, EmbedEnvC) - embedAsk :: m EmbedEnvR - embedLocal :: (EmbedEnvR -> EmbedEnvR) -> m a -> m a - -instance Monad m => MonadEmbed (EmbedT m) where - embedLook = EmbedT look - embedExtend env = EmbedT $ extend env - embedScoped (EmbedT m) = EmbedT $ scoped m - embedAsk = EmbedT ask - embedLocal f (EmbedT m) = EmbedT $ local f m - -instance MonadEmbed m => MonadEmbed (ReaderT r m) where - embedLook = lift embedLook - embedExtend x = lift $ embedExtend x - embedScoped m = ReaderT \r -> embedScoped $ runReaderT m r - embedAsk = lift embedAsk - embedLocal v m = ReaderT \r -> embedLocal v $ runReaderT m r - -instance MonadEmbed m => MonadEmbed (StateT s m) where - embedLook = lift embedLook - embedExtend x = lift $ embedExtend x - embedScoped m = do +instance MonadTrans BuilderT where + lift m = BuilderT $ lift $ lift m + +class Monad m => MonadBuilder m where + builderLook :: m BuilderEnvC + builderExtend :: BuilderEnvC -> m () + builderScoped :: m a -> m (a, BuilderEnvC) + builderAsk :: m BuilderEnvR + builderLocal :: (BuilderEnvR -> BuilderEnvR) -> m a -> m a + +instance Monad m => MonadBuilder (BuilderT m) where + builderLook = BuilderT look + builderExtend env = BuilderT $ extend env + builderScoped (BuilderT m) = BuilderT $ scoped m + builderAsk = BuilderT ask + builderLocal f (BuilderT m) = BuilderT $ local f m + +instance MonadBuilder m => MonadBuilder (ReaderT r m) where + builderLook = lift builderLook + builderExtend x = lift $ builderExtend x + builderScoped m = ReaderT \r -> builderScoped $ runReaderT m r + builderAsk = lift builderAsk + builderLocal v m = ReaderT \r -> builderLocal v $ runReaderT m r + +instance MonadBuilder m => MonadBuilder (StateT s m) where + builderLook = lift builderLook + builderExtend x = lift $ builderExtend x + builderScoped m = do s <- get - ((x, s'), env) <- lift $ embedScoped $ runStateT m s + ((x, s'), env) <- lift $ builderScoped $ runStateT m s put s' return (x, env) - embedAsk = lift embedAsk - embedLocal v m = do + builderAsk = lift builderAsk + builderLocal v m = do s <- get - (x, s') <- lift $ embedLocal v $ runStateT m s + (x, s') <- lift $ builderLocal v $ runStateT m s put s' return x -instance (Monoid env, MonadEmbed m) => MonadEmbed (CatT env m) where - embedLook = lift embedLook - embedExtend x = lift $ embedExtend x - embedScoped m = do +instance (Monoid env, MonadBuilder m) => MonadBuilder (CatT env m) where + builderLook = lift builderLook + builderExtend x = lift $ builderExtend x + builderScoped m = do env <- look - ((ans, env'), scopeEnv) <- lift $ embedScoped $ runCatT m env + ((ans, env'), scopeEnv) <- lift $ builderScoped $ runCatT m env extend env' return (ans, scopeEnv) - embedAsk = lift embedAsk - embedLocal v m = do + builderAsk = lift builderAsk + builderLocal v m = do env <- look - (ans, env') <- lift $ embedLocal v $ runCatT m env + (ans, env') <- lift $ builderLocal v $ runCatT m env extend env' return ans -instance (Monoid w, MonadEmbed m) => MonadEmbed (WriterT w m) where - embedLook = lift embedLook - embedExtend x = lift $ embedExtend x - embedScoped m = do - ((x, w), env) <- lift $ embedScoped $ runWriterT m +instance (Monoid w, MonadBuilder m) => MonadBuilder (WriterT w m) where + builderLook = lift builderLook + builderExtend x = lift $ builderExtend x + builderScoped m = do + ((x, w), env) <- lift $ builderScoped $ runWriterT m tell w return (x, env) - embedAsk = lift embedAsk - embedLocal v m = WriterT $ embedLocal v $ runWriterT m + builderAsk = lift builderAsk + builderLocal v m = WriterT $ builderLocal v $ runWriterT m -instance (Monoid env, MonadCat env m) => MonadCat env (EmbedT m) where +instance (Monoid env, MonadCat env m) => MonadCat env (BuilderT m) where look = lift look extend x = lift $ extend x - scoped (EmbedT m) = EmbedT $ do + scoped (BuilderT m) = BuilderT $ do name <- ask env <- look ((ans, env'), scopeEnv) <- lift $ lift $ scoped $ runCatT (runReaderT m name) env extend env' return (ans, scopeEnv) -instance MonadError e m => MonadError e (EmbedT m) where +instance MonadError e m => MonadError e (BuilderT m) where throwError = lift . throwError catchError m catch = do - envC <- embedLook - envR <- embedAsk - (ans, envC') <- lift $ runEmbedT' m (envR, envC) - `catchError` (\e -> runEmbedT' (catch e) (envR, envC)) - embedExtend envC' + envC <- builderLook + envR <- builderAsk + (ans, envC') <- lift $ runBuilderT' m (envR, envC) + `catchError` (\e -> runBuilderT' (catch e) (envR, envC)) + builderExtend envC' return ans -instance MonadReader r m => MonadReader r (EmbedT m) where +instance MonadReader r m => MonadReader r (BuilderT m) where ask = lift ask local r m = do - envC <- embedLook - envR <- embedAsk - (ans, envC') <- lift $ local r $ runEmbedT' m (envR, envC) - embedExtend envC' + envC <- builderLook + envR <- builderAsk + (ans, envC') <- lift $ local r $ runBuilderT' m (envR, envC) + builderExtend envC' return ans -instance MonadState s m => MonadState s (EmbedT m) where +instance MonadState s m => MonadState s (BuilderT m) where get = lift get put = lift . put -getNameHint :: MonadEmbed m => m Name +getNameHint :: MonadBuilder m => m Name getNameHint = do - tag <- fst <$> embedAsk + tag <- fst <$> builderAsk return $ Name GenName tag 0 -- This is purely for human readability. `const id` would be a valid implementation. -withNameHint :: (MonadEmbed m, HasName a) => a -> m b -> m b -withNameHint name m = embedLocal (\(_, eff) -> (tag, eff)) m +withNameHint :: (MonadBuilder m, HasName a) => a -> m b -> m b +withNameHint name m = builderLocal (\(_, eff) -> (tag, eff)) m where tag = case getName name of Just (Name _ t _) -> t @@ -605,53 +638,53 @@ withNameHint name m = embedLocal (\(_, eff) -> (tag, eff)) m Just (GlobalArrayName _) -> "arr" Nothing -> "tmp" -runEmbedT' :: Monad m => EmbedT m a -> EmbedEnv -> m (a, EmbedEnvC) -runEmbedT' (EmbedT m) (envR, envC) = runCatT (runReaderT m envR) envC +runBuilderT' :: Monad m => BuilderT m a -> BuilderEnv -> m (a, BuilderEnvC) +runBuilderT' (BuilderT m) (envR, envC) = runCatT (runReaderT m envR) envC -getScope :: MonadEmbed m => m Scope -getScope = fst <$> embedLook +getScope :: MonadBuilder m => m Scope +getScope = fst <$> builderLook -extendScope :: MonadEmbed m => Scope -> m () -extendScope scope = embedExtend $ asFst scope +extendScope :: MonadBuilder m => Scope -> m () +extendScope scope = builderExtend $ asFst scope -getAllowedEffects :: MonadEmbed m => m EffectRow -getAllowedEffects = snd <$> embedAsk +getAllowedEffects :: MonadBuilder m => m EffectRow +getAllowedEffects = snd <$> builderAsk -withEffects :: MonadEmbed m => EffectRow -> m a -> m a +withEffects :: MonadBuilder m => EffectRow -> m a -> m a withEffects effs m = modifyAllowedEffects (const effs) m -modifyAllowedEffects :: MonadEmbed m => (EffectRow -> EffectRow) -> m a -> m a -modifyAllowedEffects f m = embedLocal (\(name, eff) -> (name, f eff)) m +modifyAllowedEffects :: MonadBuilder m => (EffectRow -> EffectRow) -> m a -> m a +modifyAllowedEffects f m = builderLocal (\(name, eff) -> (name, f eff)) m -emitDecl :: MonadEmbed m => Decl -> m () -emitDecl decl = embedExtend (bindings, Nest decl Empty) +emitDecl :: MonadBuilder m => Decl -> m () +emitDecl decl = builderExtend (bindings, Nest decl Empty) where bindings = case decl of Let ann b expr -> b @> (binderType b, LetBound ann expr) -scopedDecls :: MonadEmbed m => m a -> m (a, Nest Decl) +scopedDecls :: MonadBuilder m => m a -> m (a, Nest Decl) scopedDecls m = do - (ans, (_, decls)) <- embedScoped m + (ans, (_, decls)) <- builderScoped m return (ans, decls) -liftEmbed :: MonadEmbed m => Embed a -> m a -liftEmbed action = do - envR <- embedAsk - envC <- embedLook - let (ans, envC') = runIdentity $ runEmbedT' action (envR, envC) - embedExtend envC' +liftBuilder :: MonadBuilder m => Builder a -> m a +liftBuilder action = do + envR <- builderAsk + envC <- builderLook + let (ans, envC') = runIdentity $ runBuilderT' action (envR, envC) + builderExtend envC' return ans -- === generic traversal === type TraversalDef m = (Decl -> m SubstEnv, Expr -> m Expr, Atom -> m Atom) -substTraversalDef :: (MonadEmbed m, MonadReader SubstEnv m) => TraversalDef m +substTraversalDef :: (MonadBuilder m, MonadReader SubstEnv m) => TraversalDef m substTraversalDef = ( traverseDecl substTraversalDef , traverseExpr substTraversalDef , traverseAtom substTraversalDef ) -appReduceTraversalDef :: (MonadEmbed m, MonadReader SubstEnv m) => TraversalDef m +appReduceTraversalDef :: (MonadBuilder m, MonadReader SubstEnv m) => TraversalDef m appReduceTraversalDef = ( traverseDecl appReduceTraversalDef , reduceAppExpr , traverseAtom appReduceTraversalDef @@ -668,11 +701,11 @@ appReduceTraversalDef = ( traverseDecl appReduceTraversalDef _ -> traverseExpr appReduceTraversalDef expr -- With `def = (traverseExpr def, traverseAtom def)` this should be a no-op -traverseDecls :: (MonadEmbed m, MonadReader SubstEnv m) +traverseDecls :: (MonadBuilder m, MonadReader SubstEnv m) => TraversalDef m -> Nest Decl -> m ((Nest Decl), SubstEnv) traverseDecls def decls = liftM swap $ scopedDecls $ traverseDeclsOpen def decls -traverseDeclsOpen :: (MonadEmbed m, MonadReader SubstEnv m) +traverseDeclsOpen :: (MonadBuilder m, MonadReader SubstEnv m) => TraversalDef m -> Nest Decl -> m SubstEnv traverseDeclsOpen _ Empty = return mempty traverseDeclsOpen def@(fDecl, _, _) (Nest decl decls) = do @@ -680,21 +713,20 @@ traverseDeclsOpen def@(fDecl, _, _) (Nest decl decls) = do env' <- extendR env $ traverseDeclsOpen def decls return (env <> env') -traverseDecl :: (MonadEmbed m, MonadReader SubstEnv m) +traverseDecl :: (MonadBuilder m, MonadReader SubstEnv m) => TraversalDef m -> Decl -> m SubstEnv traverseDecl (_, fExpr, _) decl = case decl of Let letAnn b expr -> do expr' <- fExpr expr case expr' of - Atom a | not (isGlobalBinder b) -> return $ b @> a - -- TODO: Do we need to use the name hint here? + Atom a | not (isGlobalBinder b) && letAnn == PlainLet -> return $ b @> a _ -> (b@>) <$> emitTo (binderNameHint b) letAnn expr' -traverseBlock :: (MonadEmbed m, MonadReader SubstEnv m) +traverseBlock :: (MonadBuilder m, MonadReader SubstEnv m) => TraversalDef m -> Block -> m Block traverseBlock def block = buildScoped $ evalBlockE def block -evalBlockE :: (MonadEmbed m, MonadReader SubstEnv m) +evalBlockE :: (MonadBuilder m, MonadReader SubstEnv m) => TraversalDef m -> Block -> m Atom evalBlockE def@(_, fExpr, _) (Block decls result) = do env <- traverseDeclsOpen def decls @@ -703,7 +735,7 @@ evalBlockE def@(_, fExpr, _) (Block decls result) = do Atom a -> return a _ -> emit resultExpr -traverseExpr :: (MonadEmbed m, MonadReader SubstEnv m) +traverseExpr :: (MonadBuilder m, MonadReader SubstEnv m) => TraversalDef m -> Expr -> m Expr traverseExpr def@(_, _, fAtom) expr = case expr of App g x -> App <$> fAtom g <*> fAtom x @@ -716,19 +748,19 @@ traverseExpr def@(_, _, fAtom) expr = case expr of bs' <- mapM (mapM fAtom) bs buildNAbs bs' \xs -> extendR (newEnv bs' xs) $ evalBlockE def body -traverseAtom :: forall m . (MonadEmbed m, MonadReader SubstEnv m) +traverseAtom :: forall m . (MonadBuilder m, MonadReader SubstEnv m) => TraversalDef m -> Atom -> m Atom traverseAtom def@(_, _, fAtom) atom = case atom of - Var _ -> substEmbedR atom + Var _ -> substBuilderR atom Lam (Abs b (arr, body)) -> do b' <- mapM fAtom b buildDepEffLam b' - (\x -> extendR (b'@>x) (substEmbedR arr)) + (\x -> extendR (b'@>x) (substBuilderR arr)) (\x -> extendR (b'@>x) (evalBlockE def body)) - Pi _ -> substEmbedR atom + Pi _ -> substBuilderR atom Con con -> Con <$> traverse fAtom con TC tc -> TC <$> traverse fAtom tc - Eff _ -> substEmbedR atom + Eff _ -> substBuilderR atom DataCon dataDef params con args -> DataCon dataDef <$> traverse fAtom params <*> pure con <*> traverse fAtom args TypeCon dataDef params -> TypeCon dataDef <$> traverse fAtom params @@ -756,13 +788,13 @@ traverseAtom def@(_, _, fAtom) atom = case atom of case decls of Empty -> return $ BoxedRef b' ptr' size' body' _ -> error "Traversing the body atom shouldn't produce decls" - ProjectElt _ _ -> substEmbedR atom + ProjectElt _ _ -> substBuilderR atom where traverseNestedArgs :: Nest DataConRefBinding -> m (Nest DataConRefBinding) traverseNestedArgs Empty = return Empty traverseNestedArgs (Nest (DataConRefBinding b ref) rest) = do ref' <- fAtom ref - b' <- substEmbedR b + b' <- substBuilderR b v <- freshVarE UnknownBinder b' rest' <- extendR (b @> Var v) $ traverseNestedArgs rest return $ Nest (DataConRefBinding (Bind v) ref') rest' @@ -785,7 +817,7 @@ transformModuleAsBlock transform (Module ir decls bindings) = do dropSub :: MonadReader SubstEnv m => m a -> m a dropSub m = local mempty m -indexSetSizeE :: MonadEmbed m => Type -> m Atom +indexSetSizeE :: MonadBuilder m => Type -> m Atom indexSetSizeE (TC con) = case con of UnitType -> return $ IdxRepVal 1 IntRange low high -> clampPositive =<< high `isub` low @@ -810,7 +842,7 @@ indexSetSizeE (VariantTy (NoExt types)) = do foldM iadd (IdxRepVal 0) sizes indexSetSizeE ty = error $ "Not implemented " ++ pprint ty -clampPositive :: MonadEmbed m => Atom -> m Atom +clampPositive :: MonadBuilder m => Atom -> m Atom clampPositive x = do isNegative <- x `ilt` (IdxRepVal 0) select isNegative (IdxRepVal 0) x @@ -819,7 +851,7 @@ clampPositive x = do -- IndexAsInt instruction, as for Int and IndexRanges it will -- generate the same instruction again, potentially leading to an -- infinite loop. -indexToIntE :: MonadEmbed m => Atom -> m Atom +indexToIntE :: MonadBuilder m => Atom -> m Atom indexToIntE (Con (IntRangeVal _ _ i)) = return i indexToIntE (Con (IndexRangeVal _ _ _ i)) = return i indexToIntE idx = case getType idx of @@ -852,7 +884,7 @@ indexToIntE idx = case getType idx of emit $ Case idx alts IdxRepTy ty -> error $ "Unexpected type " ++ pprint ty -intToIndexE :: MonadEmbed m => Type -> Atom -> m Atom +intToIndexE :: MonadBuilder m => Type -> Atom -> m Atom intToIndexE (TC con) i = case con of IntRange low high -> return $ Con $ IntRangeVal low high i IndexRange from low high -> return $ Con $ IndexRangeVal from low high i diff --git a/src/lib/Export.hs b/src/lib/Export.hs index 0db91a501..a01cc8613 100644 --- a/src/lib/Export.hs +++ b/src/lib/Export.hs @@ -21,7 +21,7 @@ import Data.List (nub, intercalate) import Algebra import Syntax -import Embed +import Builder import Cat import Env import Type @@ -33,7 +33,7 @@ import LLVMExec import PPrint import Optimize -exportFunctions :: FilePath -> [(String, Atom)] -> TopEnv -> IO () +exportFunctions :: FilePath -> [(String, Atom)] -> Bindings -> IO () exportFunctions objPath funcs env = do let names = fmap fst funcs unless (length (nub names) == length names) $ liftEitherIO $ @@ -47,9 +47,9 @@ exportFunctions objPath funcs env = do type CArgList = [IBinder] -- ^ List of arguments to the C call data CArgEnv = CArgEnv { -- | Maps scalar atom binders to their CArgs. All atoms are Vars. cargScalarScope :: Env Atom - -- | Tracks the CArg names used so far (globally scoped, unlike Embed) + -- | Tracks the CArg names used so far (globally scoped, unlike Builder) , cargScope :: Env () } -type CArgM = WriterT CArgList (CatT CArgEnv Embed) +type CArgM = WriterT CArgList (CatT CArgEnv Builder) instance Semigroup CArgEnv where (CArgEnv a1 a2) <> (CArgEnv b1 b2) = CArgEnv (a1 <> b1) (a2 <> b2) @@ -57,15 +57,15 @@ instance Semigroup CArgEnv where instance Monoid CArgEnv where mempty = CArgEnv mempty mempty -runCArg :: CArgEnv -> CArgM a -> Embed (a, [IBinder], CArgEnv) +runCArg :: CArgEnv -> CArgM a -> Builder (a, [IBinder], CArgEnv) runCArg initEnv m = repack <$> runCatT (runWriterT m) initEnv where repack ((ans, cargs), env) = (ans, cargs, env) -prepareFunctionForExport :: TopEnv -> String -> Atom -> (ImpModule, ExportedSignature) +prepareFunctionForExport :: Bindings -> String -> Atom -> (ImpModule, ExportedSignature) prepareFunctionForExport env nameStr func = do -- Create a module that simulates an application of arguments to the function -- TODO: Assert that the type of func is closed? - let ((dest, cargs, apiDesc), (_, decls)) = flip runEmbed (freeVars func) $ do + let ((dest, cargs, apiDesc), (_, decls)) = flip runBuilder (freeVars func) $ do (args, cargArgs, cargEnv) <- runCArg mempty $ createArgs $ getType func let (atomArgs, exportedArgSig) = unzip args resultAtom <- naryApp func atomArgs @@ -121,7 +121,7 @@ prepareFunctionForExport env nameStr func = do return (destAtom, exportArg) TabTy b elemTy -> do buildLamAux b (const $ return TabArrow) $ \(Var i) -> do - elemTy' <- substEmbed (b@>Var i) elemTy + elemTy' <- substBuilder (b@>Var i) elemTy createTabArg vis (idx <> Nest (Bind i) Empty) elemTy' _ -> unsupported where unsupported = error $ "Unsupported table type suffix: " ++ pprint ty @@ -140,7 +140,7 @@ prepareFunctionForExport env nameStr func = do return (dest, exportResult) TabTy b elemTy -> do (destTab, exportResult) <- buildLamAux b (const $ return TabArrow) $ \(Var i) -> do - elemTy' <- substEmbed (b@>Var i) elemTy + elemTy' <- substBuilder (b@>Var i) elemTy createDest (idx <> Nest (Bind i) Empty) elemTy' return (Con $ TabRef destTab, exportResult) _ -> unsupported diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index c264c5b2b..881b2d1f4 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -36,7 +36,7 @@ import GHC.Stack import qualified Data.List.NonEmpty as NE import qualified Data.Map.Strict as M -import Embed +import Builder import Syntax import Env import Type @@ -75,7 +75,7 @@ data ImpCatEnv = ImpCatEnv type ImpM = ExceptT () (ReaderT ImpCtx (Cat ImpCatEnv)) type AtomRecon = Abs (Nest Binder) Atom -toImpModule :: TopEnv -> Backend -> CallingConvention -> Name +toImpModule :: Bindings -> Backend -> CallingConvention -> Name -> [IBinder] -> Maybe Dest -> Block -> (ImpFunction, ImpModule, AtomRecon) toImpModule env backend cc entryName argBinders maybeDest block = do @@ -114,7 +114,7 @@ requiredFunctions scope expr = -- We don't emit any results when a destination is provided, since they are already -- going to be available through the dest. -translateTopLevel :: TopEnv -> WithDest Block -> ImpM (AtomRecon, [IExpr]) +translateTopLevel :: Bindings -> WithDest Block -> ImpM (AtomRecon, [IExpr]) translateTopLevel topEnv (maybeDest, block) = do outDest <- case maybeDest of Nothing -> makeAllocDest Unmanaged $ getType block @@ -140,7 +140,7 @@ toImpStandalone fname ~(LamVal b body) = do let outTy = getType body backend <- asks impBackend curDev <- asks curDevice - (ptrSizes, ~(Con (ConRef (PairCon outDest argDest)))) <- fromEmbed $ + (ptrSizes, ~(Con (ConRef (PairCon outDest argDest)))) <- fromBuilder $ makeDest (backend, curDev, Unmanaged) (PairTy outTy argTy) impBlock <- scopedErrBlock $ do arg <- destToAtom argDest @@ -248,9 +248,14 @@ toImpOp (maybeDest, op) = case op of destToAtom dest PrimEffect refDest m -> do case m of - MAsk -> returnVal =<< destToAtom refDest - MTell x -> addToAtom refDest x >> returnVal UnitVal - MPut x -> copyAtom refDest x >> returnVal UnitVal + MAsk -> returnVal =<< destToAtom refDest + MExtend ~(Lam f) -> do + -- TODO: Update in-place? + refValue <- destToAtom refDest + result <- translateBlock mempty (Nothing, snd $ applyAbs f refValue) + copyAtom refDest result + returnVal UnitVal + MPut x -> copyAtom refDest x >> returnVal UnitVal MGet -> do dest <- allocDest maybeDest resultTy -- It might be more efficient to implement a specialized copy for dests @@ -395,10 +400,10 @@ toImpHof env (maybeDest, hof) = do dest <- allocDest maybeDest resultTy buildKernel idxTy \LaunchInfo{..} buildBody -> do liftM (,()) $ buildBody \ThreadInfo{..} -> do - let threadBody = fst $ flip runSubstEmbed (freeVars fbody) $ + let threadBody = fst $ flip runSubstBuilder (freeVars fbody) $ buildLam (Bind $ "hwidx" :> threadRange) PureArrow \hwidx -> appReduce fbody =<< (emitOp $ Inject hwidx) - let threadDest = Con $ TabRef $ fst $ flip runSubstEmbed (freeVars dest) $ + let threadDest = Con $ TabRef $ fst $ flip runSubstBuilder (freeVars dest) $ buildLam (Bind $ "hwidx" :> threadRange) TabArrow \hwidx -> indexDest dest =<< (emitOp $ Inject hwidx) void $ toImpHof env (Just threadDest, For (RegularFor Fwd) threadBody) @@ -414,36 +419,37 @@ toImpHof env (maybeDest, hof) = do emitLoop (binderNameHint tb) Fwd nTiles \iTile -> do tileOffset <- toScalarAtom <$> iTile `imulI` tileLen let tileAtom = Con $ IndexSliceVal idxTy tileIdxTy tileOffset - tileDest <- fromEmbed $ sliceDestDim d dest tileOffset tileIdxTy + tileDest <- fromBuilder $ sliceDestDim d dest tileOffset tileIdxTy void $ translateBlock (env <> tb @> tileAtom) (Just tileDest, tBody) emitLoop (binderNameHint sb) Fwd nEpilogue \iEpi -> do i <- iEpi `iaddI` epilogueOff idx <- intToIndex idxTy i - sDest <- fromEmbed $ indexDestDim d dest idx + sDest <- fromBuilder $ indexDestDim d dest idx void $ translateBlock (env <> sb @> idx) (Just sDest, sBody) destToAtom dest - PTileReduce idxTy' ~(BinaryFunVal gtidB nthrB _ body) -> do + PTileReduce baseMonoids idxTy' ~(BinaryFunVal gtidB nthrB _ body) -> do idxTy <- impSubst env idxTy' (mappingDest, finalAccDest) <- destPairUnpack <$> allocDest maybeDest resultTy - let PairTy _ accType = resultTy - (numTileWorkgroups, wgResArr, widIdxTy) <- buildKernel idxTy \LaunchInfo{..} buildBody -> do + let PairTy _ accTypes = resultTy + (numTileWorkgroups, wgAccsArr, widIdxTy) <- buildKernel idxTy \LaunchInfo{..} buildBody -> do let widIdxTy = Fin $ toScalarAtom numWorkgroups let tidIdxTy = Fin $ toScalarAtom workgroupSize - wgResArr <- alloc $ TabTy (Ignore widIdxTy) accType - thrAccArr <- alloc $ TabTy (Ignore widIdxTy) $ TabTy (Ignore tidIdxTy) accType + wgAccsArr <- alloc $ TabTy (Ignore widIdxTy) accTypes + thrAccsArr <- alloc $ TabTy (Ignore widIdxTy) $ TabTy (Ignore tidIdxTy) accTypes mappingKernelBody <- buildBody \ThreadInfo{..} -> do let TC (ParIndexRange _ gtid nthr) = threadRange - let scope = freeVars mappingDest - let tileDest = Con $ TabRef $ fst $ flip runSubstEmbed scope $ do + let tileDest = Con $ TabRef $ fst $ flip runSubstBuilder (freeVars mappingDest) $ do buildLam (Bind $ "hwidx":>threadRange) TabArrow \hwidx -> do indexDest mappingDest =<< (emitOp $ Inject hwidx) - wgAccs <- destGet thrAccArr =<< intToIndex widIdxTy wid - thrAcc <- destGet wgAccs =<< intToIndex tidIdxTy tid - let threadDest = Con $ ConRef $ PairCon tileDest thrAcc + wgThrAccs <- destGet thrAccsArr =<< intToIndex widIdxTy wid + thrAccs <- destGet wgThrAccs =<< intToIndex tidIdxTy tid + let thrAccsList = fromDestConsList thrAccs + let threadDest = foldr ((Con . ConRef) ... flip PairCon) tileDest thrAccsList + -- TODO: Make sure that threadDest has the right type void $ translateBlock (env <> gtidB @> gtid <> nthrB @> nthr) (Just threadDest, body) - wgRes <- destGet wgResArr =<< intToIndex widIdxTy wid - workgroupReduce tid wgRes wgAccs workgroupSize - return (mappingKernelBody, (numWorkgroups, wgResArr, widIdxTy)) + wgAccs <- destGet wgAccsArr =<< intToIndex widIdxTy wid + workgroupReduce tid wgAccs wgThrAccs workgroupSize + return (mappingKernelBody, (numWorkgroups, wgAccsArr, widIdxTy)) -- TODO: Skip the reduction kernel if unnecessary? -- TODO: Reduce sequentially in the CPU backend? -- TODO: Actually we only need the previous-power-of-2 many threads @@ -453,13 +459,14 @@ toImpHof env (maybeDest, hof) = do moreThanOneGroup <- (IIdxRepVal 1) `iltI` numWorkgroups guardBlock moreThanOneGroup $ emitStatement IThrowError redKernelBody <- buildBody \ThreadInfo{..} -> - workgroupReduce tid finalAccDest wgResArr numTileWorkgroups + workgroupReduce tid finalAccDest wgAccsArr numTileWorkgroups return (redKernelBody, ()) PairVal <$> destToAtom mappingDest <*> destToAtom finalAccDest where guardBlock cond m = do block <- scopedErrBlock m emitStatement $ ICond cond block (ImpBlock mempty mempty) + -- XXX: Overwrites the contents of arrDest, writes result in resDest workgroupReduce tid resDest arrDest elemCount = do elemCountDown2 <- prevPowerOf2 elemCount let RawRefTy (TabTy arrIdxB _) = getType arrDest @@ -472,7 +479,7 @@ toImpHof env (maybeDest, hof) = do shouldAdd <- bindM2 bandI (tid `iltI` off) (loadIdx `iltI` elemCount) guardBlock shouldAdd $ do threadDest <- destGet arrDest =<< intToIndex arrIdxTy tid - addToAtom threadDest =<< destToAtom =<< destGet arrDest =<< intToIndex arrIdxTy loadIdx + combineWithDest threadDest =<< destToAtom =<< destGet arrDest =<< intToIndex arrIdxTy loadIdx emitStatement ISyncWorkgroup copyAtom offPtr . toScalarAtom =<< off `idivI` (IIdxRepVal 2) cond <- liftM snd $ scopedBlock $ do @@ -484,6 +491,13 @@ toImpHof env (maybeDest, hof) = do firstThread <- tid `iltI` (IIdxRepVal 1) guardBlock firstThread $ copyAtom resDest =<< destToAtom =<< destGet arrDest =<< intToIndex arrIdxTy tid + combineWithDest :: Dest -> Atom -> ImpM () + combineWithDest accsDest accsUpdates = do + let accsDestList = fromDestConsList accsDest + let Right accsUpdatesList = fromConsList accsUpdates + forM_ (zip3 accsDestList baseMonoids accsUpdatesList) $ \(dest, bm, update) -> do + extender <- fromBuilder $ mextendForRef dest bm update + void $ toImpOp (Nothing, PrimEffect dest $ MExtend extender) -- TODO: Do some popcount tricks? prevPowerOf2 :: IExpr -> ImpM IExpr prevPowerOf2 x = do @@ -506,12 +520,13 @@ toImpHof env (maybeDest, hof) = do rDest <- alloc $ getType r copyAtom rDest =<< impSubst env r translateBlock (env <> ref @> rDest) (maybeDest, body) - RunWriter ~(BinaryFunVal _ ref _ body) -> do + RunWriter (BaseMonoid e' _) ~(BinaryFunVal _ ref _ body) -> do + let PairTy _ accTy = resultTy (aDest, wDest) <- destPairUnpack <$> allocDest maybeDest resultTy - let RefTy _ wTy = getType ref - copyAtom wDest (zeroAt wTy) + copyAtom wDest =<< (liftNeutral accTy <$> impSubst env e') void $ translateBlock (env <> ref @> wDest) (Just aDest, body) PairVal <$> destToAtom aDest <*> destToAtom wDest + where liftNeutral accTy e = foldr TabValA e $ monoidLift (getType e) accTy RunState s ~(BinaryFunVal _ ref _ body) -> do (aDest, sDest) <- destPairUnpack <$> allocDest maybeDest resultTy copyAtom sDest =<< impSubst env s @@ -579,10 +594,10 @@ data DestEnv = DestEnv -- The Cat env carries names for the pointers needed, along with their types and -- blocks that compute allocation sizes needed -type DestM = ReaderT DestEnv (CatT (Env (Type, Block)) Embed) +type DestM = ReaderT DestEnv (CatT (Env (Type, Block)) Builder) -- builds a dest and a list of pointer binders along with their required allocation sizes -makeDest :: AllocInfo -> Type -> Embed ([(Binder, Atom)], Dest) +makeDest :: AllocInfo -> Type -> Builder ([(Binder, Atom)], Dest) makeDest allocInfo ty = do (dest, ptrs) <- flip runCatT mempty $ flip runReaderT env $ makeDestRec ty ptrs' <- forM (envPairs ptrs) \(v, (ptrTy, numel)) -> do @@ -603,7 +618,7 @@ makeDestRec ty = case ty of makeBoxes (envPairs ptrs) dest else do lam <- buildLam (Bind ("i":> binderAnn b)) TabArrow \(Var i) -> do - bodyTy' <- substEmbed (b@>Var i) bodyTy + bodyTy' <- substBuilder (b@>Var i) bodyTy withEnclosingIdxs (Bind i) $ makeDestRec bodyTy' return $ Con $ TabRef lam TypeCon def params -> do @@ -673,7 +688,7 @@ makeDataConDest (Nest b rest) = do let ty = binderAnn b dest <- makeDestRec ty v <- freshVarE UnknownBinder b -- TODO: scope names more carefully - rest' <- substEmbed (b @> Var v) rest + rest' <- substBuilder (b @> Var v) rest rest'' <- withDepVar (Bind v) $ makeDataConDest rest' return $ Nest (DataConRefBinding (Bind v) dest) rest'' @@ -719,10 +734,10 @@ copyDataConArgs (Nest (DataConRefBinding b ref) rest) (x:xs) = do copyDataConArgs bindings args = error $ "Mismatched bindings/args: " ++ pprint (bindings, args) -loadDest :: MonadEmbed m => Dest -> m Atom +loadDest :: MonadBuilder m => Dest -> m Atom loadDest (BoxedRef b ptrPtr _ body) = do ptr <- unsafePtrLoad ptrPtr - body' <- substEmbed (b@>ptr) body + body' <- substBuilder (b@>ptr) body loadDest body' loadDest (DataConRef def params bs) = do DataCon def params 0 <$> loadDataConArgs bs @@ -730,7 +745,7 @@ loadDest (Con dest) = do case dest of BaseTypeRef ptr -> unsafePtrLoad ptr TabRef (TabVal b body) -> buildLam b TabArrow \i -> do - body' <- substEmbed (b@>i) body + body' <- substBuilder (b@>i) body result <- emitBlock body' loadDest result RecordRef xs -> Record <$> traverse loadDest xs @@ -744,14 +759,14 @@ loadDest (Con dest) = do _ -> error $ "Not a valid dest: " ++ pprint dest loadDest dest = error $ "Not a valid dest: " ++ pprint dest -loadDataConArgs :: MonadEmbed m => Nest DataConRefBinding -> m [Atom] +loadDataConArgs :: MonadBuilder m => Nest DataConRefBinding -> m [Atom] loadDataConArgs Empty = return [] loadDataConArgs (Nest (DataConRefBinding b ref) rest) = do val <- loadDest ref - rest' <- substEmbed (b@>val) rest + rest' <- substBuilder (b@>val) rest (val:) <$> loadDataConArgs rest' -indexDestDim :: MonadEmbed m => Int->Dest -> Atom -> m Dest +indexDestDim :: MonadBuilder m => Int->Dest -> Atom -> m Dest indexDestDim 0 dest i = indexDest dest i indexDestDim d dest i = buildFor Fwd (Bind ("i":>idxTy)) \j -> do dest' <- indexDest dest j @@ -760,11 +775,11 @@ indexDestDim d dest i = buildFor Fwd (Bind ("i":>idxTy)) \j -> do RawRefTy (TabTy idxBinder _) = dest idxTy = binderType idxBinder -indexDest :: MonadEmbed m => Dest -> Atom -> m Dest +indexDest :: MonadBuilder m => Dest -> Atom -> m Dest indexDest (Con (TabRef tabVal)) i = appReduce tabVal i indexDest dest _ = error $ pprint dest -sliceDestDim :: MonadEmbed m => Int -> Dest -> Atom -> Type -> m Dest +sliceDestDim :: MonadBuilder m => Int -> Dest -> Atom -> Type -> m Dest sliceDestDim 0 dest i sliceIdxTy = sliceDest dest i sliceIdxTy sliceDestDim d dest i sliceIdxTy = buildFor Fwd (Bind ("i":>idxTy)) \j -> do dest' <- indexDest dest j @@ -773,7 +788,7 @@ sliceDestDim d dest i sliceIdxTy = buildFor Fwd (Bind ("i":>idxTy)) \j -> do RawRefTy (TabTy idxBinder _) = dest idxTy = binderType idxBinder -sliceDest :: MonadEmbed m => Dest -> Atom -> Type -> m Dest +sliceDest :: MonadBuilder m => Dest -> Atom -> Type -> m Dest sliceDest ~(Con (TabRef tab@(TabVal b _))) i sliceIdxTy = (Con . TabRef) <$> do buildFor Fwd (Bind ("j" :> sliceIdxTy)) \j -> do j' <- indexToIntE j @@ -782,15 +797,21 @@ sliceDest ~(Con (TabRef tab@(TabVal b _))) i sliceIdxTy = (Con . TabRef) <$> do appReduce tab vidx destToAtom :: Dest -> ImpM Atom -destToAtom dest = fromEmbed $ loadDest dest +destToAtom dest = fromBuilder $ loadDest dest destGet :: Dest -> Atom -> ImpM Dest -destGet dest i = fromEmbed $ indexDest dest i +destGet dest i = fromBuilder $ indexDest dest i destPairUnpack :: Dest -> (Dest, Dest) destPairUnpack (Con (ConRef (PairCon l r))) = (l, r) destPairUnpack d = error $ "Not a pair destination: " ++ show d +fromDestConsList :: Dest -> [Dest] +fromDestConsList dest = case dest of + Con (ConRef (PairCon h t)) -> h : fromDestConsList t + Con (ConRef UnitCon) -> [] + _ -> error $ "Not a dest cons list: " ++ pprint dest + makeAllocDest :: AllocType -> Type -> ImpM Dest makeAllocDest allocTy ty = fst <$> makeAllocDestWithPtrs allocTy ty @@ -798,7 +819,7 @@ makeAllocDestWithPtrs :: AllocType -> Type -> ImpM (Dest, [IExpr]) makeAllocDestWithPtrs allocTy ty = do backend <- asks impBackend curDev <- asks curDevice - (ptrsSizes, dest) <- fromEmbed $ makeDest (backend, curDev, allocTy) ty + (ptrsSizes, dest) <- fromBuilder $ makeDest (backend, curDev, allocTy) ty (env, ptrs) <- flip foldMapM ptrsSizes \(Bind (ptr:>PtrTy ptrTy), size) -> do ptr' <- emitAlloc ptrTy $ fromScalarAtom size case ptrTy of @@ -927,21 +948,21 @@ toScalarType b = BaseTy b -- === Type classes === -fromEmbed :: Subst a => Embed a -> ImpM a -fromEmbed m = do +fromBuilder :: Subst a => Builder a -> ImpM a +fromBuilder m = do scope <- variableScope - let (ans, (_, decls)) = runEmbed m scope + let (ans, (_, decls)) = runBuilder m scope env <- catFoldM translateDecl mempty $ fmap (Nothing,) decls impSubst env ans intToIndex :: Type -> IExpr -> ImpM Atom -intToIndex ty i = fromEmbed (intToIndexE ty (toScalarAtom i)) +intToIndex ty i = fromBuilder (intToIndexE ty (toScalarAtom i)) indexToInt :: Atom -> ImpM IExpr -indexToInt idx = fromScalarAtom <$> fromEmbed (indexToIntE idx) +indexToInt idx = fromScalarAtom <$> fromBuilder (indexToIntE idx) indexSetSize :: Type -> ImpM IExpr -indexSetSize ty = fromScalarAtom <$> fromEmbed (indexSetSizeE ty) +indexSetSize ty = fromScalarAtom <$> fromBuilder (indexSetSizeE ty) zipTabDestAtom :: HasCallStack => (Dest -> Atom -> ImpM ()) -> Dest -> Atom -> ImpM () zipTabDestAtom f ~dest@(Con (TabRef (TabVal b _))) ~src@(TabVal b' _) = do @@ -963,29 +984,6 @@ zipWithRefConM f destCon srcCon = case (destCon, srcCon) of (IndexRangeVal _ _ _ iRef, IndexRangeVal _ _ _ i) -> f iRef i _ -> error $ "Unexpected ref/val " ++ pprint (destCon, srcCon) --- TODO: put this in userspace using type classes -addToAtom :: Dest -> Atom -> ImpM () -addToAtom dest src = case (dest, src) of - (Con (BaseTypeRef ptr), x) -> do - let ptr' = fromScalarAtom ptr - let x' = fromScalarAtom x - cur <- loadAnywhere ptr' - let op = case getIType cur of - Scalar _ -> ScalarBinOp - Vector _ -> VectorBinOp - _ -> error $ "The result of load cannot be a reference" - updated <- emitInstr $ IPrimOp $ op FAdd cur x' - storeAnywhere ptr' updated - (Con (TabRef _), TabVal _ _) -> zipTabDestAtom addToAtom dest src - (Con (ConRef (SumAsProd _ _ payloadDest)), Con (SumAsProd _ tag payload)) -> do - unless (all null payload) $ -- optimization - emitSwitch (fromScalarAtom tag) $ - zipWith (zipWithM_ addToAtom) payloadDest payload - (Con (ConRef destCon), Con srcCon) -> zipWithRefConM addToAtom destCon srcCon - (Con (RecordRef dests), Record srcs) -> - zipWithM_ addToAtom (toList dests) (toList srcs) - _ -> error $ "Not implemented " ++ pprint (dest, src) - loadAnywhere :: IExpr -> ImpM IExpr loadAnywhere ptr = do curDev <- asks curDevice @@ -1011,29 +1009,29 @@ storeAnywhere ptr val = do allocateStackSingleton :: IType -> ImpM IExpr allocateStackSingleton ty = allocateBuffer Stack False ty (IIdxRepVal 1) --- === Imp embedding === +-- === Imp IR builders === -embedBinOp :: (Atom -> Atom -> Embed Atom) -> (IExpr -> IExpr -> ImpM IExpr) -embedBinOp f x y = - fromScalarAtom <$> fromEmbed (f (toScalarAtom x) (toScalarAtom y)) +buildBinOp :: (Atom -> Atom -> Builder Atom) -> (IExpr -> IExpr -> ImpM IExpr) +buildBinOp f x y = + fromScalarAtom <$> fromBuilder (f (toScalarAtom x) (toScalarAtom y)) iaddI :: IExpr -> IExpr -> ImpM IExpr -iaddI = embedBinOp iadd +iaddI = buildBinOp iadd isubI :: IExpr -> IExpr -> ImpM IExpr -isubI = embedBinOp isub +isubI = buildBinOp isub imulI :: IExpr -> IExpr -> ImpM IExpr -imulI = embedBinOp imul +imulI = buildBinOp imul idivI :: IExpr -> IExpr -> ImpM IExpr -idivI = embedBinOp idiv +idivI = buildBinOp idiv iltI :: IExpr -> IExpr -> ImpM IExpr -iltI = embedBinOp ilt +iltI = buildBinOp ilt ieqI :: IExpr -> IExpr -> ImpM IExpr -ieqI = embedBinOp ieq +ieqI = buildBinOp ieq bandI :: IExpr -> IExpr -> ImpM IExpr bandI x y = emitInstr $ IPrimOp $ ScalarBinOp BAnd x y @@ -1180,7 +1178,7 @@ instance Checkable ImpFunction where checkValid f@(ImpFunction (_:> IFunType cc _ _) bs block) = addContext ctx $ do let scope = foldMap (binderAsEnv . fmap (const ())) bs let env = foldMap (binderAsEnv ) bs - <> fmap (fromScalarType . fst) initTopEnv + <> fmap (fromScalarType . fst) initBindings void $ flip runReaderT (env, deviceFromCallingConvention cc) $ flip runStateT scope $ checkBlock block where ctx = "Checking:\n" ++ pprint f diff --git a/src/lib/Imp/Embed.hs b/src/lib/Imp/Builder.hs similarity index 67% rename from src/lib/Imp/Embed.hs rename to src/lib/Imp/Builder.hs index 315a00226..987bf4e06 100644 --- a/src/lib/Imp/Embed.hs +++ b/src/lib/Imp/Builder.hs @@ -4,20 +4,20 @@ -- license that can be found in the LICENSE file or at -- https://developers.google.com/open-source/licenses/bsd -module Imp.Embed ( ISubstEmbedT, ISubstEnv (..) - , runISubstEmbedT, liftSE - , emit, freshIVar, extendValSubst - , buildScoped - -- embedding - , iadd, imul - , alloc, ptrOffset - -- traversal - , traverseImpModule, traverseImpFunction - , traverseImpBlock, evalImpBlock - , traverseImpDecl, traverseImpInstr - , traverseIExpr, traverseIFunVar - , ITraversalDef, substTraversalDef - ) where +module Imp.Builder ( ISubstBuilderT, ISubstEnv (..) + , runISubstBuilderT, liftSE + , emit, freshIVar, extendValSubst + , buildScoped + -- IR builders + , iadd, imul + , alloc, ptrOffset + -- traversal + , traverseImpModule, traverseImpFunction + , traverseImpBlock, evalImpBlock + , traverseImpDecl, traverseImpInstr + , traverseIExpr, traverseIFunVar + , ITraversalDef, substTraversalDef + ) where import Control.Monad.Reader @@ -28,7 +28,7 @@ import Imp import Util (bindM2) -- XXX: Scope is actually global within each function -data IEmbedEnv = IEmbedEnv +data IBuilderEnv = IBuilderEnv { scope :: Env () , blockDecls :: Nest ImpDecl } @@ -37,40 +37,40 @@ data ISubstEnv = ISubstEnv , funcSubst :: Env IFunVar } -type IEmbedT m = CatT IEmbedEnv m +type IBuilderT m = CatT IBuilderEnv m type ISubstT m = ReaderT ISubstEnv m -type ISubstEmbedT m = IEmbedT (ISubstT m) +type ISubstBuilderT m = IBuilderT (ISubstT m) -runIEmbedT :: Monad m => IEmbedT m a -> m a -runIEmbedT m = fst <$> runCatT m mempty +runIBuilderT :: Monad m => IBuilderT m a -> m a +runIBuilderT m = fst <$> runCatT m mempty runISubstT :: Monad m => ISubstEnv -> ISubstT m a -> m a runISubstT env m = runReaderT m env -runISubstEmbedT :: Monad m => ISubstEnv -> ISubstEmbedT m a -> m a -runISubstEmbedT env = (runISubstT env) . runIEmbedT +runISubstBuilderT :: Monad m => ISubstEnv -> ISubstBuilderT m a -> m a +runISubstBuilderT env = (runISubstT env) . runIBuilderT -liftSE :: Monad m => m a -> ISubstEmbedT m a +liftSE :: Monad m => m a -> ISubstBuilderT m a liftSE = lift . lift -extendScope :: Monad m => Env a -> IEmbedT m () -extendScope s = extend $ IEmbedEnv (fmap (const ()) s) mempty +extendScope :: Monad m => Env a -> IBuilderT m () +extendScope s = extend $ IBuilderEnv (fmap (const ()) s) mempty -emit :: Monad m => ImpInstr -> IEmbedT m [IExpr] +emit :: Monad m => ImpInstr -> IBuilderT m [IExpr] emit instr = do vs <- traverse (freshIVar . Ignore) $ impInstrTypes instr emitTo vs instr -emitTo :: Monad m => [IVar] -> ImpInstr -> IEmbedT m [IExpr] +emitTo :: Monad m => [IVar] -> ImpInstr -> IBuilderT m [IExpr] emitTo bs instr = do extend $ mempty { blockDecls = (Nest (ImpLet (fmap Bind bs) instr) Empty) } return $ fmap IVar bs -instance Semigroup IEmbedEnv where - (IEmbedEnv s d) <> (IEmbedEnv s' d') = IEmbedEnv (s <> s') (d <> d') +instance Semigroup IBuilderEnv where + (IBuilderEnv s d) <> (IBuilderEnv s' d') = IBuilderEnv (s <> s') (d <> d') -instance Monoid IEmbedEnv where - mempty = IEmbedEnv mempty mempty +instance Monoid IBuilderEnv where + mempty = IBuilderEnv mempty mempty instance Semigroup ISubstEnv where (ISubstEnv v f) <> (ISubstEnv v' f') = ISubstEnv (v <> v') (f <> f') @@ -78,24 +78,24 @@ instance Semigroup ISubstEnv where instance Monoid ISubstEnv where mempty = ISubstEnv mempty mempty --- === Imp embedding === +-- === Imp IR builders === -ptrOffset :: Monad m => IExpr -> IExpr -> IEmbedT m IExpr +ptrOffset :: Monad m => IExpr -> IExpr -> IBuilderT m IExpr ptrOffset ptr off = liftM head $ emit $ IPrimOp $ PtrOffset ptr off -imul :: Monad m => IExpr -> IExpr -> IEmbedT m IExpr +imul :: Monad m => IExpr -> IExpr -> IBuilderT m IExpr imul x y = liftM head $ emit $ IPrimOp $ ScalarBinOp IMul x y -iadd :: Monad m => IExpr -> IExpr -> IEmbedT m IExpr +iadd :: Monad m => IExpr -> IExpr -> IBuilderT m IExpr iadd x y = liftM head $ emit $ IPrimOp $ ScalarBinOp IAdd x y -alloc :: Monad m => AddressSpace -> IType -> IExpr -> IEmbedT m IExpr +alloc :: Monad m => AddressSpace -> IType -> IExpr -> IBuilderT m IExpr alloc addrSpc ty size = liftM head $ emit $ Alloc addrSpc ty size -- === Imp IR traversal === -type ITraversalDef m = ( ImpDecl -> ISubstEmbedT m (Env IExpr) - , ImpInstr -> ISubstEmbedT m ImpInstr +type ITraversalDef m = ( ImpDecl -> ISubstBuilderT m (Env IExpr) + , ImpInstr -> ISubstBuilderT m ImpInstr ) substTraversalDef :: Monad m => ITraversalDef m @@ -115,7 +115,7 @@ traverseImpModule fTrav (ImpModule funcs) = ImpModule . fst <$> runCatT (travers traverseImpFunction :: Monad m => ITraversalDef m -> Env IFunVar -> ImpFunction -> m ImpFunction traverseImpFunction _ _ (FFIFunction f ) = return $ FFIFunction f -traverseImpFunction def fenv (ImpFunction name args body) = runISubstEmbedT env $ do +traverseImpFunction def fenv (ImpFunction name args body) = runISubstBuilderT env $ do extendScope $ foldMap binderAsEnv args body' <- extendValSubst (foldMap argSub args) $ traverseImpBlock def body return $ ImpFunction name args body' @@ -125,10 +125,10 @@ traverseImpFunction def fenv (ImpFunction name args body) = runISubstEmbedT env Bind v -> v @> IVar v env = ISubstEnv mempty fenv -traverseImpBlock :: Monad m => ITraversalDef m -> ImpBlock -> ISubstEmbedT m ImpBlock +traverseImpBlock :: Monad m => ITraversalDef m -> ImpBlock -> ISubstBuilderT m ImpBlock traverseImpBlock def block = buildScoped $ evalImpBlock def block -evalImpBlock :: Monad m => ITraversalDef m -> ImpBlock -> ISubstEmbedT m [IExpr] +evalImpBlock :: Monad m => ITraversalDef m -> ImpBlock -> ISubstBuilderT m [IExpr] evalImpBlock def@(fDecl, _) (ImpBlock decls results) = do case decls of Nest decl rest -> do @@ -136,12 +136,12 @@ evalImpBlock def@(fDecl, _) (ImpBlock decls results) = do extendValSubst env' $ evalImpBlock def $ ImpBlock rest results Empty -> traverse traverseIExpr results -traverseImpDecl :: Monad m => ITraversalDef m -> ImpDecl -> ISubstEmbedT m (Env IExpr) +traverseImpDecl :: Monad m => ITraversalDef m -> ImpDecl -> ISubstBuilderT m (Env IExpr) traverseImpDecl (_, fInstr) (ImpLet bs instr) = do vs <- bindM2 emitTo (traverse freshIVar bs) (fInstr instr) return $ newEnv bs vs -traverseImpInstr :: Monad m => ITraversalDef m -> ImpInstr -> ISubstEmbedT m ImpInstr +traverseImpInstr :: Monad m => ITraversalDef m -> ImpInstr -> ISubstBuilderT m ImpInstr traverseImpInstr def instr = case instr of IFor dir b size body -> do b' <- freshIVar b @@ -168,14 +168,14 @@ traverseImpInstr def instr = case instr of ICastOp ty val -> ICastOp ty <$> traverseIExpr val IPrimOp op -> IPrimOp <$> traverse traverseIExpr op -traverseIExpr :: Monad m => IExpr -> ISubstEmbedT m IExpr +traverseIExpr :: Monad m => IExpr -> ISubstBuilderT m IExpr traverseIExpr (ILit l) = return $ ILit l traverseIExpr (IVar v) = (!v) <$> asks valSubst -traverseIFunVar :: Monad m => IFunVar -> ISubstEmbedT m IFunVar +traverseIFunVar :: Monad m => IFunVar -> ISubstBuilderT m IFunVar traverseIFunVar fv = (!fv) <$> asks funcSubst -freshIVar :: Monad m => IBinder -> IEmbedT m IVar +freshIVar :: Monad m => IBinder -> IBuilderT m IVar freshIVar b = do let nameHint = case b of Bind (name:>_) -> name @@ -184,11 +184,11 @@ freshIVar b = do extendScope $ name @> () return $ name :> binderAnn b -buildScoped :: Monad m => IEmbedT m [IExpr] -> IEmbedT m ImpBlock +buildScoped :: Monad m => IBuilderT m [IExpr] -> IBuilderT m ImpBlock buildScoped m = do - (results, IEmbedEnv scopeExt decls) <- scoped m - extend $ IEmbedEnv scopeExt mempty -- Names are global in Imp IR + (results, IBuilderEnv scopeExt decls) <- scoped m + extend $ IBuilderEnv scopeExt mempty -- Names are global in Imp IR return $ ImpBlock decls results -extendValSubst :: Monad m => Env IExpr -> ISubstEmbedT m a -> ISubstEmbedT m a +extendValSubst :: Monad m => Env IExpr -> ISubstBuilderT m a -> ISubstBuilderT m a extendValSubst s = local (\env -> env { valSubst = valSubst env <> s }) diff --git a/src/lib/Imp/Optimize.hs b/src/lib/Imp/Optimize.hs index 385856f3c..a98c86560 100644 --- a/src/lib/Imp/Optimize.hs +++ b/src/lib/Imp/Optimize.hs @@ -12,7 +12,7 @@ import PPrint import Env import Cat import Syntax -import Imp.Embed +import Imp.Builder -- TODO: DCE! @@ -30,7 +30,7 @@ liftCUDAAllocations m = ImpFunction (fname:>IFunType cc argTys retTys) argBs' body' -> case cc of CUDAKernelLaunch -> do let ((argBs, body), fAllocEnv) = - flip runCat mempty $ runISubstEmbedT (ISubstEnv mempty fenv) $ do + flip runCat mempty $ runISubstBuilderT (ISubstEnv mempty fenv) $ do ~args@(tid:wid:wsz:_) <- traverse freshIVar argBs' newBody <- extendValSubst (newEnv argBs' $ fmap IVar args) $ buildScoped $ do gtid <- iadd (IVar tid) =<< imul (IVar wid) (IVar wsz) @@ -64,7 +64,7 @@ liftCUDAAllocations m = amendLaunch :: ITraversalDef (Cat ModAllocEnv) amendLaunch = (traverseImpDecl amendLaunch, amendLaunchInstr) where - amendLaunchInstr :: ImpInstr -> ISubstEmbedT (Cat ModAllocEnv) ImpInstr + amendLaunchInstr :: ImpInstr -> ISubstBuilderT (Cat ModAllocEnv) ImpInstr amendLaunchInstr instr = case instr of ILaunch f' s' args' -> do s <- traverseIExpr s' diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index ed8f3c1d2..3c801c350 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -14,7 +14,7 @@ import Control.Applicative import Control.Monad import Control.Monad.Reader import Control.Monad.Except hiding (Except) -import Data.Foldable (fold, toList, asum) +import Data.Foldable (fold, toList) import Data.Functor import qualified Data.List.NonEmpty as NE import qualified Data.Map.Strict as M @@ -24,14 +24,14 @@ import Data.Text.Prettyprint.Doc import Syntax import Interpreter (indicesNoIO) -import Embed hiding (sub) +import Builder hiding (sub) import Env import Type import PPrint import Cat import Util -type UInferM = ReaderT SubstEnv (ReaderT SrcCtx ((EmbedT (SolverT (Either Err))))) +type UInferM = ReaderT SubstEnv (ReaderT SrcCtx ((BuilderT (SolverT (Either Err))))) type SigmaType = Type -- may start with an implicit lambda type RhoType = Type -- doesn't start with an implicit lambda @@ -46,7 +46,7 @@ pattern Check t <- {-# COMPLETE Infer, Check #-} -inferModule :: TopEnv -> UModule -> Except Module +inferModule :: Bindings -> UModule -> Except Module inferModule scope (UModule decls) = do ((), (bindings, decls')) <- runUInferM mempty scope $ mapM_ (inferUDecl True) decls @@ -59,7 +59,7 @@ inferModule scope (UModule decls) = do runUInferM :: (Subst a, Pretty a) => SubstEnv -> Scope -> UInferM a -> Except (a, (Scope, Nest Decl)) runUInferM env scope m = runSolverT $ - runEmbedT (runReaderT (runReaderT m env) Nothing) scope + runBuilderT (runReaderT (runReaderT m env) Nothing) scope checkSigma :: UExpr -> (Type -> RequiredTy Type) -> SigmaType -> UInferM Atom checkSigma expr reqCon sTy = case sTy of @@ -137,10 +137,10 @@ checkOrInferRho (WithSrc pos expr) reqTy = do -- is safe and doesn't make the type checking depend on the program order. infTy <- getType <$> zonk fVal piTy <- addSrcContext' (srcPos f) $ fromPiType True arr infTy - (xVal, embedEnv@(_, xDecls)) <- embedScoped $ checkSigma x Suggest (absArgType piTy) + (xVal, builderEnv@(_, xDecls)) <- builderScoped $ checkSigma x Suggest (absArgType piTy) (xVal', arr') <- case piTy of Abs b rhs@(arr', _) -> case b `isin` freeVars rhs of - False -> embedExtend embedEnv $> (xVal, arr') + False -> builderExtend builderEnv $> (xVal, arr') True -> do xValMaybeRed <- flip typeReduceBlock (Block xDecls (Atom xVal)) <$> getScope case xValMaybeRed of @@ -339,7 +339,14 @@ inferUDecl topLevel (ULet letAnn (p, ann) rhs) = do expr <- zonk $ Atom val if topLevel then unpackTopPat letAnn p expr $> mempty - else bindPat p val + else do + env <- bindPat p val + -- XXX: We have to preserve the non-standard let annotations + -- Ideally we would just put the annotations on the equations + -- elaborated during inference, but this approach is simpler. + case letAnn of + PlainLet -> return env + _ -> forM env $ emitAnn letAnn . Atom inferUDecl True (UData tc dcs) = do (tc', paramBs) <- inferUConDef tc dataDef <- buildDataDef tc' paramBs \params -> do @@ -361,20 +368,23 @@ inferUDecl True (UInterface superclasses tc methods) = do emitSuperclassGetters dataDef emitMethodGetters dataDef return mempty -inferUDecl True (UInstance argBinders instanceTy methods) = do - instanceDict <- checkInstance argBinders instanceTy methods - let instanceName = Name TypeClassGenName "instance" 0 - void $ emitTo instanceName InstanceLet $ Atom instanceDict - return mempty +inferUDecl topLevel (UInstance maybeName argBinders instanceTy methods) = do + instanceDict <- checkInstance argBinders instanceTy methods + case (topLevel, maybeName) of + (False, Nothing) -> error "anonymous instance definitions should be top-level" + (False, Just n ) -> return $ n @> instanceDict + (True , Nothing) -> mempty <$ emitTo nameHint InstanceLet (Atom instanceDict) + where nameHint = Name TypeClassGenName "instance" 0 + (True , Just n ) -> mempty <$ (checkNotInScope gn >> emitTo gn PlainLet (Atom instanceDict)) + where gn = asGlobal $ varName n inferUDecl False (UData _ _ ) = error "data definitions should be top-level" inferUDecl False (UInterface _ _ _) = error "interface definitions should be top-level" -inferUDecl False (UInstance _ _ _) = error "instance definitions should be top-level" -freshClassGenName :: MonadEmbed m => m Name +freshClassGenName :: MonadBuilder m => m Name freshClassGenName = do scope <- getScope let v' = genFresh (Name TypeClassGenName "classgen" 0) scope - embedExtend $ asFst $ v' @> (UnitTy, UnknownBinder) + builderExtend $ asFst $ v' @> (UnitTy, UnknownBinder) return v' mkMethod :: UAnnBinder -> UInferM (Label, Type) @@ -418,7 +428,7 @@ emitMethodGetters def@(DataDef _ paramBs (ClassDictDef _ _ methodTys)) = do emitTo methodName PlainLet $ Atom f emitMethodGetters (DataDef _ _ _) = error "Not a class dictionary" -emitSuperclassGetters :: MonadEmbed m => DataDef -> m () +emitSuperclassGetters :: MonadBuilder m => DataDef -> m () emitSuperclassGetters def@(DataDef _ paramBs (ClassDictDef _ superclassTys _)) = do forM_ (getLabels superclassTys) \l -> do f <- buildImplicitNaryLam paramBs \params -> do @@ -447,7 +457,7 @@ checkShadows vs = do inferUConDef :: UConDef -> UInferM (Name, Nest Binder) inferUConDef (UConDef v bs) = do - (bs', _) <- embedScoped $ checkNestedBinders bs + (bs', _) <- builderScoped $ checkNestedBinders bs let v' = asGlobal v checkNotInScope v' return (v', bs') @@ -801,12 +811,12 @@ getSrcCtx = lift ask -- We have two variants here because at the top level we want error messages and -- internally we want to consider all alternatives. -type SynthPassM = SubstEmbedT (Either Err) -type SynthDictM = SubstEmbedT [] +type SynthPassM = SubstBuilderT (Either Err) +type SynthDictM = SubstBuilderT [] synthModule :: Scope -> Module -> Except Module synthModule scope (Module Typed decls bindings) = do - decls' <- fst . fst <$> runSubstEmbedT + decls' <- fst . fst <$> runSubstBuilderT (traverseDecls (traverseHoles synthDictTop) decls) scope return $ Module Core decls' bindings synthModule _ _ = error $ "Unexpected IR variant" @@ -814,45 +824,56 @@ synthModule _ _ = error $ "Unexpected IR variant" synthDictTop :: SrcCtx -> Type -> SynthPassM Atom synthDictTop ctx ty = do scope <- getScope - let solutions = runSubstEmbedT (synthDict ty) scope + let solutions = runSubstBuilderT (synthDict ty) scope addSrcContext ctx $ case solutions of [] -> throw TypeErr $ "Couldn't synthesize a class dictionary for: " ++ pprint ty - [(ans, env)] -> embedExtend env $> ans + [(ans, env)] -> builderExtend env $> ans _ -> throw TypeErr $ "Multiple candidate class dictionaries for: " ++ pprint ty ++ "\n" ++ pprint solutions -traverseHoles :: (MonadReader SubstEnv m, MonadEmbed m) +traverseHoles :: (MonadReader SubstEnv m, MonadBuilder m) => (SrcCtx -> Type -> m Atom) -> TraversalDef m traverseHoles fillHole = (traverseDecl recur, traverseExpr recur, synthPassAtom) where synthPassAtom atom = case atom of - Con (ClassDictHole ctx ty) -> fillHole ctx =<< substEmbedR ty + Con (ClassDictHole ctx ty) -> fillHole ctx =<< substBuilderR ty _ -> traverseAtom recur atom recur = traverseHoles fillHole synthDict :: Type -> SynthDictM Atom -synthDict ty = do - (d, bindingInfo) <- getBinding - case bindingInfo of - LetBound InstanceLet _ -> do - block <- buildScoped $ inferToSynth $ instantiateAndCheck ty d +synthDict ty = case ty of + PiTy b arr body -> synthesizeNow <|> introFirst + where + introFirst = buildDepEffLam b + (\x -> extendR (b @> x) $ substBuilderR arr) + (\x -> extendR (b @> x) $ substBuilderR body >>= synthDict) + _ -> synthesizeNow + where + synthesizeNow = do + (d, bindingInfo) <- getBinding + case bindingInfo of + LetBound InstanceLet _ -> trySynth d + LamBound ClassArrow -> withSuperclasses d >>= trySynth + _ -> empty + trySynth step = do + block <- buildScoped $ inferToSynth $ instantiateAndCheck ty step + -- NOTE: It's ok to emit unconditionally here. It will only ever emit + -- blocks that fully resolved without any holes, and if we ever + -- end up with two results, we don't use the duplicate code because + -- it's an error! traverseBlock (traverseHoles (const synthDict)) block >>= emitBlock - LamBound ClassArrow -> do - d' <- superclass d - inferToSynth $ instantiateAndCheck ty d' - _ -> empty -- TODO: this doesn't de-dup, so we'll get multiple results if we have a -- diamond-shaped hierarchy. -superclass :: Atom -> SynthDictM Atom -superclass dict = return dict <|> do +withSuperclasses :: Atom -> SynthDictM Atom +withSuperclasses dict = return dict <|> do (f, LetBound SuperclassLet _) <- getBinding inferToSynth $ tryApply f dict getBinding :: SynthDictM (Atom, BinderInfo) getBinding = do scope <- getScope - (v, (ty, bindingInfo)) <- asum $ map return $ envPairs scope + (v, (ty, bindingInfo)) <- lift $ lift $ envPairs scope return (Var (v:>ty), bindingInfo) inferToSynth :: (Pretty a, Subst a) => UInferM a -> SynthDictM a @@ -906,8 +927,8 @@ solveLocal :: Subst a => UInferM a -> UInferM a solveLocal m = do (ans, env@(SolverEnv freshVars sub)) <- scoped $ do -- This might get expensive. TODO: revisit once we can measure performance. - (ans, embedEnv) <- zonk =<< embedScoped m - embedExtend embedEnv + (ans, builderEnv) <- zonk =<< builderScoped m + builderExtend builderEnv return ans extend $ SolverEnv (unsolved env) (sub `envDiff` freshVars) return ans @@ -1067,7 +1088,7 @@ instance Monoid SolverEnv where mempty = SolverEnv mempty mempty mappend = (<>) -typeReduceScoped :: MonadEmbed m => m Atom -> m (Maybe Atom) +typeReduceScoped :: MonadBuilder m => m Atom -> m (Maybe Atom) typeReduceScoped m = do block <- buildScoped m scope <- getScope diff --git a/src/lib/Interpreter.hs b/src/lib/Interpreter.hs index 3f9911119..cf52be47e 100644 --- a/src/lib/Interpreter.hs +++ b/src/lib/Interpreter.hs @@ -23,7 +23,7 @@ import Cat import Syntax import Env import PPrint -import Embed +import Builder import Util (enumerate, restructure) import LLVMExec @@ -108,7 +108,7 @@ evalOp expr = case expr of ToOrdinal idxArg -> case idxArg of Con (IntRangeVal _ _ i) -> return i Con (IndexRangeVal _ _ _ i) -> return i - _ -> evalEmbed (indexToIntE idxArg) + _ -> evalBuilder (indexToIntE idxArg) _ -> error $ "Not implemented: " ++ pprint expr -- We can use this when we know we won't be dereferencing pointers. A better @@ -147,12 +147,12 @@ indices ty = do indexSetSize :: Type -> InterpM Int indexSetSize ty = do - IdxRepVal l <- evalEmbed (indexSetSizeE ty) + IdxRepVal l <- evalBuilder (indexSetSizeE ty) return $ fromIntegral l -evalEmbed :: EmbedT InterpM Atom -> InterpM Atom -evalEmbed embed = do - (atom, (_, decls)) <- runEmbedT embed mempty +evalBuilder :: BuilderT InterpM Atom -> InterpM Atom +evalBuilder builder = do + (atom, (_, decls)) <- runBuilderT builder mempty evalBlock mempty $ Block decls (Atom atom) pattern Int64Val :: Int64 -> Atom diff --git a/src/lib/JIT.hs b/src/lib/JIT.hs index 860829bc3..ecdf6adf1 100644 --- a/src/lib/JIT.hs +++ b/src/lib/JIT.hs @@ -262,7 +262,12 @@ compileInstr instr = case instr of Heap dev -> do numBytes <- mul (sizeof elemTy) =<< (`asIntWidth` i64) =<< compileExpr s case dev of - CPU -> malloc elemTy numBytes + CPU -> case t of + -- XXX: it's important to initialize pointers to zero so that we don't + -- try to dereference them when we serialize. + PtrType _ -> malloc True elemTy numBytes + _ -> malloc False elemTy numBytes + -- TODO: initialize GPU pointers too, once we handle serialization GPU -> cuMemAlloc elemTy numBytes where elemTy = scalarTy t Free ptr -> [] <$ do @@ -734,10 +739,13 @@ alloca elems ty = do return $ L.LocalReference (hostPtrTy ty) v where instr = L.Alloca ty (Just $ i32Lit elems) 0 [] -malloc :: L.Type -> Operand -> Compile Operand -malloc ty bytes = do +malloc :: Bool -> L.Type -> Operand -> Compile Operand +malloc initialize ty bytes = do bytes64 <- asIntWidth bytes i64 - castLPtr ty =<< emitExternCall mallocFun [bytes64] + ptr <- if initialize + then emitExternCall mallocInitializedFun [bytes64] + else emitExternCall mallocFun [bytes64] + castLPtr ty ptr free :: Operand -> Compile () free ptr = do @@ -984,6 +992,10 @@ mathFlags = L.noFastMathFlags { L.allowContract = allowContractions } mallocFun :: ExternFunSpec mallocFun = ExternFunSpec "malloc_dex" (hostPtrTy i8) [L.NoAlias] [] [i64] +mallocInitializedFun :: ExternFunSpec +mallocInitializedFun = + ExternFunSpec "dex_malloc_initialized" (hostPtrTy i8) [L.NoAlias] [] [i64] + freeFun :: ExternFunSpec freeFun = ExternFunSpec "free_dex" L.VoidType [] [] [hostPtrTy i8] diff --git a/src/lib/LiveOutput.hs b/src/lib/LiveOutput.hs index 83a302b43..c391b5516 100644 --- a/src/lib/LiveOutput.hs +++ b/src/lib/LiveOutput.hs @@ -94,7 +94,7 @@ sourceBlockToDag block = do -- that contain interface instance definitions. extend (foldMap ((@>n) . Bind) $ envAsVars $ boundUVars block, [n]) case sbContents block of - IncludeSourceFile _ -> extend $ asSnd [n] + ImportModule _ -> extend $ asSnd [n] _ -> return () return n diff --git a/src/lib/Optimize.hs b/src/lib/Optimize.hs index 44022ba70..84a92a19a 100644 --- a/src/lib/Optimize.hs +++ b/src/lib/Optimize.hs @@ -13,7 +13,7 @@ import Data.Foldable import Data.Maybe import Syntax -import Embed +import Builder import Cat import Env import Type @@ -83,7 +83,7 @@ dceAtom atom = case atom of -- === For inlining === -type InlineM = SubstEmbed +type InlineM = SubstBuilder inlineTraversalDef :: TraversalDef InlineM inlineTraversalDef = (inlineTraverseDecl, inlineTraverseExpr, traverseAtom inlineTraversalDef) @@ -91,7 +91,7 @@ inlineTraversalDef = (inlineTraverseDecl, inlineTraverseExpr, traverseAtom inlin inlineModule :: Module -> Module inlineModule m = transformModuleAsBlock inlineBlock (computeInlineHints m) where - inlineBlock block = fst $ runSubstEmbed (traverseBlock inlineTraversalDef block) mempty + inlineBlock block = fst $ runSubstBuilder (traverseBlock inlineTraversalDef block) mempty inlineTraverseDecl :: Decl -> InlineM SubstEnv inlineTraverseDecl decl = case decl of diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index c09da8a2d..60039d3d1 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -252,8 +252,8 @@ prettyPrecPrimCon con = case con of instance PrettyPrec e => Pretty (PrimOp e) where pretty = prettyFromPrettyPrec instance PrettyPrec e => PrettyPrec (PrimOp e) where prettyPrec op = case op of - PrimEffect ref (MPut val ) -> atPrec LowestPrec $ pApp ref <+> ":=" <+> pApp val - PrimEffect ref (MTell val) -> atPrec LowestPrec $ pApp ref <+> "+=" <+> pApp val + PrimEffect ref (MPut val ) -> atPrec LowestPrec $ pApp ref <+> ":=" <+> pApp val + PrimEffect ref (MExtend update) -> atPrec LowestPrec $ "extend" <+> pApp ref <+> "using" <+> pLowest update PtrOffset ptr idx -> atPrec LowestPrec $ pApp ptr <+> "+>" <+> pApp idx PtrLoad ptr -> atPrec AppPrec $ pAppArg "load" [ptr] RecordCons items rest -> @@ -287,10 +287,17 @@ instance Pretty ClassName where instance Pretty Decl where pretty decl = case decl of - Let _ (Ignore _) bound -> pLowest bound + Let ann (Ignore _) bound -> p ann <+> pLowest bound -- This is just to reduce clutter a bit. We can comment it out when needed. -- Let (v:>Pi _) bound -> p v <+> "=" <+> p bound - Let _ b rhs -> align $ p b <+> "=" <> (nest 2 $ group $ line <> pLowest rhs) + Let ann b rhs -> align $ p ann <+> p b <+> "=" <> (nest 2 $ group $ line <> pLowest rhs) + +instance Pretty LetAnn where + pretty ann = case ann of + PlainLet -> "" + InstanceLet -> "%instance" + SuperclassLet -> "%superclass" + NoInlineLet -> "%noinline" prettyPiTypeHelper :: PiType -> Doc ann prettyPiTypeHelper (Abs binder (arr, body)) = let @@ -625,14 +632,16 @@ instance Pretty a => Pretty (Limit a) where pretty (InclusiveLim x) = "incLim" <+> p x instance Pretty UDecl where - pretty (ULet _ b rhs) = - align $ prettyUBinder b <+> "=" <> (nest 2 $ group $ line <> pLowest rhs) + pretty (ULet ann b rhs) = + align $ p ann <+> prettyUBinder b <+> "=" <> (nest 2 $ group $ line <> pLowest rhs) pretty (UData tyCon dataCons) = "data" <+> p tyCon <+> "where" <> nest 2 (hardline <> prettyLines dataCons) pretty (UInterface cs def methods) = "interface" <+> p cs <+> p def <> hardline <> prettyLines methods - pretty (UInstance bs ty methods) = + pretty (UInstance Nothing bs ty methods) = "instance" <+> p bs <+> p ty <> hardline <> prettyLines methods + pretty (UInstance (Just v) bs ty methods) = + "named-instance" <+> p v <+> ":" <+> p bs <+> p ty <> hardline <> prettyLines methods instance Pretty UMethodDef where pretty (UMethodDef b rhs) = p b <+> "=" <+> p rhs diff --git a/src/lib/Parallelize.hs b/src/lib/Parallelize.hs index e11842020..7b162bafc 100644 --- a/src/lib/Parallelize.hs +++ b/src/lib/Parallelize.hs @@ -15,11 +15,12 @@ import Data.Foldable import Optimize import Syntax -import Embed +import Builder import Cat import Env import Type import PPrint +import Util (for) -- TODO: extractParallelism can benefit a lot from horizontal fusion (can happen be after) -- TODO: Parallelism extraction can emit some really cheap (but not trivial) @@ -37,7 +38,7 @@ asABlock :: Block -> ABlock asABlock block = ABlock decls result where scope = freeVars block - ((result, decls), _) = flip runEmbed scope $ scopedDecls $ emitBlock block + ((result, decls), _) = flip runBuilder scope $ scopedDecls $ emitBlock block data LoopEnv = LoopEnv @@ -45,9 +46,9 @@ data LoopEnv = LoopEnv , delayedApps :: Env (Atom, [Atom]) -- (n @> (arr, bs)), n and bs in scope of the original program -- arr in scope of the newly constructed program! } -data AccEnv = AccEnv { activeAccs :: Env Var } +data AccEnv = AccEnv { activeAccs :: Env (Var, BaseMonoid) } -- (reference, its base monoid) -type TLParallelM = SubstEmbedT (State AccEnv) -- Top-level non-parallel statements +type TLParallelM = SubstBuilderT (State AccEnv) -- Top-level non-parallel statements type LoopM = ReaderT LoopEnv TLParallelM -- Generation of (parallel) loop nests runLoopM :: LoopM a -> TLParallelM a @@ -55,7 +56,7 @@ runLoopM m = runReaderT m $ LoopEnv mempty mempty extractParallelism :: Module -> Module extractParallelism = transformModuleAsBlock go - where go block = fst $ evalState (runSubstEmbedT (traverseBlock parallelTrav block) mempty) $ AccEnv mempty + where go block = fst $ evalState (runSubstBuilderT (traverseBlock parallelTrav block) mempty) $ AccEnv mempty parallelTrav :: TraversalDef TLParallelM parallelTrav = ( traverseDecl parallelTrav @@ -69,25 +70,25 @@ parallelTraverseExpr expr = case expr of Hof (For (RegularFor _) fbody@(LamVal b body)) -> do -- TODO: functionEffs is an overapproximation of the effects that really appear inside refs <- gets activeAccs - let allowedRegions = foldMap (\(varType -> RefTy (Var reg) _) -> reg @> ()) refs - (EffectRow bodyEffs t) <- substEmbedR $ functionEffs fbody + let allowedRegions = foldMap (\(varType . fst -> RefTy (Var reg) _) -> reg @> ()) refs + (EffectRow bodyEffs t) <- substBuilderR $ functionEffs fbody let onlyAllowedEffects = all (parallelizableEffect allowedRegions) $ toList bodyEffs case t == Nothing && onlyAllowedEffects of True -> do - b' <- substEmbedR b + b' <- substBuilderR b liftM Atom $ runLoopM $ withLoopBinder b' $ buildParallelBlock $ asABlock body False -> nothingSpecial - Hof (RunWriter (BinaryFunVal h b _ body)) -> do + Hof (RunWriter bm (BinaryFunVal h b _ body)) -> do ~(RefTy _ accTy) <- traverseAtom substTraversalDef $ binderType b - liftM Atom $ emitRunWriter (binderNameHint b) accTy \ref@(Var refVar) -> do + liftM Atom $ emitRunWriter (binderNameHint b) accTy bm \ref@(Var refVar) -> do let RefTy h' _ = varType refVar - modify \accEnv -> accEnv { activeAccs = activeAccs accEnv <> b @> refVar } + modify \accEnv -> accEnv { activeAccs = activeAccs accEnv <> b @> (refVar, bm) } extendR (h @> h' <> b @> ref) $ evalBlockE parallelTrav body -- TODO: Do some alias analysis. This is not fundamentally hard, but it is a little annoying. -- We would have to track not only the base references, but also all the aliases, along -- with their relationships. Then, when we emit local effects in emitLoops, we would have -- to recreate all the aliases. We could do that by carrying around a block and using - -- SubstEmbed to take care of renaming for us. + -- SubstBuilder to take care of renaming for us. Op (IndexRef ref _) -> disallowRef ref >> nothingSpecial Op (FstRef ref ) -> disallowRef ref >> nothingSpecial Op (SndRef ref ) -> disallowRef ref >> nothingSpecial @@ -129,12 +130,12 @@ buildParallelBlock ablock@(ABlock decls result) = do prologueCtxArrs <- mapM (unflattenConsTab lbs) =<< unzipConsListTab prologueCtxAtom return $ foldMap (\(v, arr) -> v @> (arr, loopVars)) $ zip prologueCtxVars prologueCtxArrs delayApps prologueApps $ do - i' <- lift $ substEmbedR i + i' <- lift $ substBuilderR i loopAtom <- withLoopBinder i' $ buildParallelBlock $ asABlock lbody delayApps (arrb @> (loopAtom, loopVars)) $ buildParallelBlock $ ABlock epilogue result -unzipConsListTab :: MonadEmbed m => Atom -> m [Atom] +unzipConsListTab :: MonadBuilder m => Atom -> m [Atom] unzipConsListTab tab = case getType tab of TabTy _ UnitTy -> return [] TabTy _ (PairTy _ _) -> do @@ -142,7 +143,7 @@ unzipConsListTab tab = case getType tab of (x:) <$> unzipConsListTab t _ -> error $ "Expected a table cons list, got: " ++ pprint (getType tab) -unflattenConsTab :: MonadEmbed m => [Var] -> Atom -> m Atom +unflattenConsTab :: MonadBuilder m => [Var] -> Atom -> m Atom unflattenConsTab ivs arr = buildNestedLam TabArrow (fmap Bind ivs) $ app arr . mkConsList type Loop = Abs Binder Block @@ -205,7 +206,7 @@ emitLoops buildPureLoop (ABlock decls result) = do extendR (newEnv lbs is) $ do ctxEnv <- flip traverseNames dapps \_ (arr, idx) -> -- XXX: arr is namespaced in the new program - foldM appTryReduce arr =<< substEmbedR idx + foldM appTryReduce arr =<< substBuilderR idx extendR ctxEnv $ evalBlockE appReduceTraversalDef $ Block decls $ Atom result lift $ case null refs of True -> buildPureLoop (Bind $ "pari" :> iterTy) buildBody @@ -214,23 +215,16 @@ emitLoops buildPureLoop (ABlock decls result) = do buildLam (Bind $ "gtid" :> IdxRepTy) PureArrow \gtid -> do buildLam (Bind $ "nthr" :> IdxRepTy) PureArrow \nthr -> do let threadRange = TC $ ParIndexRange iterTy gtid nthr - let accTys = mkConsListTy $ fmap (derefType . varType) newRefs - emitRunWriter "refsList" accTys \localRefsList -> do - localRefs <- unpackRefConsList localRefsList + let writerSpecs = for newRefs \(ref, bm) -> (varName ref, derefType (varType ref), bm) + emitRunWriters writerSpecs $ \localRefs -> do buildFor Fwd (Bind $ "tidx" :> threadRange) \tidx -> do pari <- emitOp $ Inject tidx extendR (newEnv oldRefNames localRefs) $ buildBody pari - (ans, updateList) <- fromPair =<< (emit $ Hof $ PTileReduce iterTy body) + (ans, updateList) <- fromPair =<< (emit $ Hof $ PTileReduce (fmap snd newRefs) iterTy body) updates <- unpackConsList updateList - forM_ (zip newRefs updates) \(ref, update) -> - emitOp $ PrimEffect (Var ref) $ MTell update + forM_ (zip newRefs updates) $ \((ref, bm), update) -> do + updater <- mextendForRef (Var ref) bm update + emitOp $ PrimEffect (Var ref) $ MExtend updater return ans - where - derefType ~(RefTy _ accTy) = accTy - unpackRefConsList xs = case derefType $ getType xs of - UnitTy -> return [] - PairTy _ _ -> do - x <- getFstRef xs - rest <- getSndRef xs - (x:) <$> unpackRefConsList rest - _ -> error $ "Not a ref cons list: " ++ pprint (getType xs) + where + derefType ~(RefTy _ accTy) = accTy diff --git a/src/lib/Parser.hs b/src/lib/Parser.hs index cb6574b8b..284b6c987 100644 --- a/src/lib/Parser.hs +++ b/src/lib/Parser.hs @@ -59,8 +59,12 @@ mustParseit s p = case parseit s p of Right ans -> ans Left e -> error $ "This shouldn't happen:\n" ++ pprint e -includeSourceFile :: Parser String -includeSourceFile = symbol "include" >> stringLiteral <* eol +importModule :: Parser SourceBlock' +importModule = ImportModule <$> do + keyWord ImportKW + s <- (:) <$> letterChar <*> many alphaNumChar + eol + return s sourceBlock :: Parser SourceBlock sourceBlock = do @@ -117,7 +121,8 @@ sourceBlock' = proseBlock <|> topLevelCommand <|> liftM declToModule (topDecl <* eolf) - <|> liftM declToModule (instanceDef <* eolf) + <|> liftM declToModule (instanceDef True <* eolf) + <|> liftM declToModule (instanceDef False <* eolf) <|> liftM declToModule (interfaceDef <* eolf) <|> liftM (Command (EvalExpr Printed) . exprAsModule) (expr <* eol) <|> hidden (some eol >> return EmptyLines) @@ -131,7 +136,7 @@ proseBlock = label "prose block" $ char '\'' >> fmap (ProseBlock . fst) (withSou topLevelCommand :: Parser SourceBlock' topLevelCommand = - liftM IncludeSourceFile includeSourceFile + importModule <|> explicitCommand "top-level command" @@ -333,9 +338,11 @@ decl = do rhs <- sym "=" >> blockOrExpr return $ lhs rhs -instanceDef :: Parser UDecl -instanceDef = do - keyWord InstanceKW +instanceDef :: Bool -> Parser UDecl +instanceDef isNamed = do + name <- case isNamed of + False -> keyWord InstanceKW $> Nothing + True -> keyWord NamedInstanceKW *> (Just . (:>()) <$> anyName) <* sym ":" explicitArgs <- many defArg constraints <- classConstraints classTy <- uType @@ -347,7 +354,7 @@ instanceDef = do explicitArgs ++ [((UnderscoreUPat, Just c) , ClassArrow ) | c <- constraints] methods <- onePerLine instanceMethod - return $ UInstance (toNest argBinders) classTy methods + return $ UInstance name (toNest argBinders) classTy methods where addClassConstraint :: UType -> UType -> UType addClassConstraint c ty = ns $ UPi (UnderscoreUPat, Just c) ClassArrow ty @@ -361,9 +368,10 @@ instanceMethod = do simpleLet :: Parser (UExpr -> UDecl) simpleLet = label "let binding" $ do + letAnn <- (InstanceLet <$ string "%instance" <* sc) <|> (pure PlainLet) p <- try $ (letPat <|> leafPat) <* lookAhead (sym "=" <|> sym ":") - ann <- optional $ annot uType - return $ ULet PlainLet (p, ann) + typeAnn <- optional $ annot uType + return $ ULet letAnn (p, typeAnn) letPat :: Parser UPat letPat = withSrc $ nameToPat <$> anyName @@ -511,7 +519,7 @@ wrapUStatements statements = case statements of [] -> error "Shouldn't be reachable" uStatement :: Parser UStatement -uStatement = withPos $ liftM Left decl +uStatement = withPos $ liftM Left (instanceDef True <|> decl) <|> liftM Right expr -- TODO: put the `try` only around the `x:` not the annotation itself @@ -965,7 +973,7 @@ type Lexer = Parser data KeyWord = DefKW | ForKW | For_KW | RofKW | Rof_KW | CaseKW | OfKW | ReadKW | WriteKW | StateKW | DataKW | InterfaceKW | InstanceKW | WhereKW | IfKW | ThenKW | ElseKW | DoKW - | ExceptKW | IOKW | ViewKW + | ExceptKW | IOKW | ViewKW | ImportKW | NamedInstanceKW upperName :: Lexer Name upperName = liftM mkName $ label "upper-case name" $ lexeme $ @@ -1009,14 +1017,16 @@ keyWord kw = lexeme $ try $ string s >> notFollowedBy nameTailChar DataKW -> "data" InterfaceKW -> "interface" InstanceKW -> "instance" + NamedInstanceKW -> "named-instance" WhereKW -> "where" DoKW -> "do" ViewKW -> "view" + ImportKW -> "import" keyWordStrs :: [String] keyWordStrs = ["def", "for", "for_", "rof", "rof_", "case", "of", "llam", "Read", "Write", "Accum", "Except", "IO", "data", "interface", - "instance", "where", "if", "then", "else", "do", "view"] + "instance", "named-instance", "where", "if", "then", "else", "do", "view", "import"] fieldLabel :: Lexer Label fieldLabel = label "field label" $ lexeme $ diff --git a/src/lib/Serialize.hs b/src/lib/Serialize.hs index 602fdeed2..7d3abf962 100644 --- a/src/lib/Serialize.hs +++ b/src/lib/Serialize.hs @@ -4,24 +4,39 @@ -- license that can be found in the LICENSE file or at -- https://developers.google.com/open-source/licenses/bsd +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE CPP #-} -module Serialize (pprintVal, cached, getDexString) where +module Serialize (pprintVal, cached, getDexString, cachedWithSnapshot, + HasPtrs (..)) where import Prelude hiding (pi, abs) import Control.Monad import qualified Data.ByteString as BS +import Data.ByteString.Internal (memcpy) +import Data.ByteString.Unsafe (unsafeUseAsCString) import System.Directory import System.FilePath -import Data.Foldable (toList) +import Control.Monad.Writer +import Control.Monad.State.Strict +import Data.Foldable import qualified Data.Map.Strict as M +import Data.Int import Data.Store hiding (size) import Data.Text.Prettyprint.Doc hiding (brackets) +import Foreign.Ptr +import Foreign.Marshal.Array +import GHC.Generics (Generic) import Interpreter import Syntax import Type import PPrint +import Env + +foreign import ccall "malloc_dex" dexMalloc :: Int64 -> IO (Ptr ()) +foreign import ccall "dex_allocation_size" dexAllocSize :: Ptr () -> IO Int64 pprintVal :: Val -> IO String pprintVal val = asStr <$> prettyVal val @@ -104,10 +119,87 @@ prettyVal val = case val of return $ align $ group innerDoc atom -> return $ prettyPrec atom LowestPrec +-- === taking memory snapshots for serializing heap-backed dex values === + +data WithSnapshot a = WithSnapshot a [PtrSnapshot] deriving Generic +type RawPtr = Ptr () +-- TODO: we could consider using some mmap-able instead of ByteString +data PtrSnapshot = ByteArray BS.ByteString + | PtrArray [PtrSnapshot] + | NullPtr deriving Generic + +class HasPtrs a where + traversePtrs :: Applicative f => (PtrType -> RawPtr -> f RawPtr) -> a -> f a + +takeSnapshot :: HasPtrs a => a -> IO (WithSnapshot a) +takeSnapshot x = + -- TODO: we're using `Writer []` (as we do elsewhere) which has bad + -- asymptotics. We should switch all of these uses to use snoc lists instead. + liftM (WithSnapshot x) $ execWriterT $ flip traversePtrs x \ptrTy ptrVal -> do + snapshot <- lift $ takePtrSnapshot ptrTy ptrVal + tell [snapshot] + return ptrVal + +takePtrSnapshot :: PtrType -> RawPtr -> IO PtrSnapshot +takePtrSnapshot _ ptrVal | ptrVal == nullPtr = return NullPtr +takePtrSnapshot (Heap CPU, ptrTy) ptrVal = case ptrTy of + PtrType eltTy -> do + childPtrs <- loadPtrPtrs ptrVal + PtrArray <$> mapM (takePtrSnapshot eltTy) childPtrs + _ -> ByteArray <$> loadPtrBytes ptrVal +takePtrSnapshot (Heap GPU, _) _ = error "Snapshots of GPU memory not implemented" +takePtrSnapshot (Stack , _) _ = error "Can't take snapshots of stack memory" + +loadPtrBytes :: RawPtr -> IO BS.ByteString +loadPtrBytes ptr = do + numBytes <- fromIntegral <$> dexAllocSize ptr + liftM BS.pack $ peekArray numBytes $ castPtr ptr + +loadPtrPtrs :: RawPtr -> IO [RawPtr] +loadPtrPtrs ptr = do + numBytes <- fromIntegral <$> dexAllocSize ptr + peekArray (numBytes `div` ptrSize) $ castPtr ptr + +restoreSnapshot :: HasPtrs a => WithSnapshot a -> IO a +restoreSnapshot (WithSnapshot x snapshots) = + flip evalStateT snapshots $ flip traversePtrs x \_ _ -> do + (s:ss) <- get + put ss + lift $ restorePtrSnapshot s + +restorePtrSnapshot :: PtrSnapshot -> IO RawPtr +restorePtrSnapshot snapshot = case snapshot of + PtrArray children -> storePtrPtrs =<< mapM restorePtrSnapshot children + ByteArray bytes -> storePtrBytes bytes + NullPtr -> return nullPtr + +storePtrBytes :: BS.ByteString -> IO RawPtr +storePtrBytes xs = do + let numBytes = BS.length xs + destPtr <- dexMalloc $ fromIntegral numBytes + -- this is safe because we don't modify srcPtr's memory or let it escape + unsafeUseAsCString xs \srcPtr -> + memcpy (castPtr destPtr) (castPtr srcPtr) numBytes + return destPtr + +storePtrPtrs :: [RawPtr] -> IO RawPtr +storePtrPtrs ptrs = do + ptr <- dexMalloc $ fromIntegral $ length ptrs * ptrSize + pokeArray (castPtr ptr) ptrs + return ptr + +-- === persistent cache === + -- TODO: this isn't enough, since this module's compilation might be cached curCompilerVersion :: String curCompilerVersion = __TIME__ +cachedWithSnapshot :: (Eq k, Store k, Store a, HasPtrs a) + => String -> k -> IO a -> IO a +cachedWithSnapshot cacheName key create = do + result <- cached cacheName key $ create >>= takeSnapshot + restoreSnapshot result + cached :: (Eq k, Store k, Store a) => String -> k -> IO a -> IO a cached cacheName key create = do cacheDir <- getXdgDirectory XdgCache "dex" @@ -133,3 +225,81 @@ cached cacheName key create = do BS.writeFile cacheKeyPath $ encode (curCompilerVersion, key) BS.writeFile cachePath $ encode result return result + +-- === instances === + +tp :: (HasPtrs a, Applicative f) => (PtrType -> RawPtr -> f RawPtr) -> a -> f a +tp = traversePtrs + +instance HasPtrs Expr where + traversePtrs f expr = case expr of + App e1 e2 -> App <$> tp f e1 <*> tp f e2 + Atom x -> Atom <$> tp f x + Op e -> Op <$> traverse (tp f) e + Hof e -> Hof <$> traverse (tp f) e + Case e alts resultTy -> + Case <$> tp f e <*> traverse (tp f) alts <*> tp f resultTy + +instance (HasPtrs a, HasPtrs b) => HasPtrs (Abs a b) where + traversePtrs f (Abs b body) = Abs <$> tp f b <*> tp f body + +instance HasPtrs Block where + traversePtrs f (Block decls result) = + Block <$> traverse (tp f) decls <*> tp f result + +instance HasPtrs Decl where + traversePtrs f (Let ann b body) = Let ann <$> tp f b <*> tp f body + +instance (HasPtrs a, HasPtrs b) => HasPtrs (a, b) where + traversePtrs f (x, y) = (,) <$> tp f x <*> tp f y + +instance HasPtrs eff => HasPtrs (ArrowP eff) where + traversePtrs f arrow = case arrow of + PlainArrow eff -> PlainArrow <$> tp f eff + _ -> pure arrow + +instance (HasPtrs a, HasPtrs b) => HasPtrs (ExtLabeledItems a b) where + traversePtrs f (Ext items t) = + Ext <$> traverse (tp f) items <*> traverse (tp f) t + +instance HasPtrs DataConRefBinding where + traversePtrs f (DataConRefBinding b ref) = + DataConRefBinding <$> tp f b <*> tp f ref + +instance HasPtrs Atom where + traversePtrs f atom = case atom of + Var v -> Var <$> traverse (tp f) v + Lam lam -> Lam <$> tp f lam + Pi ty -> Pi <$> tp f ty + TC tc -> TC <$> traverse (tp f) tc + Con (Lit (PtrLit ptrTy ptr)) -> (Con . Lit . PtrLit ptrTy) <$> f ptrTy ptr + Con con -> Con <$> traverse (tp f) con + Eff eff -> Eff <$> tp f eff + DataCon def ps con args -> DataCon def <$> tp f ps <*> pure con <*> tp f args + TypeCon def ps -> TypeCon def <$> tp f ps + LabeledRow row -> LabeledRow <$> tp f row + Record la -> Record <$> traverse (tp f) la + Variant row label i val -> + Variant <$> tp f row <*> pure label <*> pure i <*> tp f val + RecordTy row -> RecordTy <$> tp f row + VariantTy row -> VariantTy <$> tp f row + ACase v alts rty -> ACase <$> tp f v <*> tp f alts <*> tp f rty + DataConRef def params args -> DataConRef def <$> tp f params <*> tp f args + BoxedRef b ptr size body -> + BoxedRef <$> tp f b <*> tp f ptr <*> tp f size <*> tp f body + ProjectElt idxs v -> pure $ ProjectElt idxs v + +instance HasPtrs Name where traversePtrs _ x = pure x +instance HasPtrs EffectRow where traversePtrs _ x = pure x + +instance HasPtrs a => HasPtrs [a] where traversePtrs f xs = traverse (tp f) xs +instance HasPtrs a => HasPtrs (Nest a) where traversePtrs f xs = traverse (tp f) xs +instance HasPtrs a => HasPtrs (BinderP a) where traversePtrs f xs = traverse (tp f) xs + +instance HasPtrs BinderInfo where + traversePtrs f binfo = case binfo of + LetBound ann expr -> LetBound ann <$> tp f expr + _ -> pure binfo + +instance Store a => Store (WithSnapshot a) +instance Store PtrSnapshot diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index d1c03e4fd..b8fd75c5b 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -22,16 +22,16 @@ import Autodiff import Env import Syntax import Cat -import Embed +import Builder import Type import PPrint import Util -type SimplifyM = SubstEmbed +type SimplifyM = SubstBuilder -simplifyModule :: TopEnv -> Module -> Module +simplifyModule :: Bindings -> Module -> Module simplifyModule scope (Module Core decls bindings) = do - let simpDecls = snd $ snd $ runSubstEmbed (simplifyDecls decls) scope + let simpDecls = snd $ snd $ runSubstBuilder (simplifyDecls decls) scope -- We don't have to check that the binders are global here, since all local -- Atom binders have been inlined as part of the simplification. let isAtomDecl decl = case decl of Let _ _ (Atom _) -> True; _ -> False @@ -40,26 +40,26 @@ simplifyModule scope (Module Core decls bindings) = do Module Simp (toNest declsNotDone) (bindings <> bindings') simplifyModule _ (Module ir _ _) = error $ "Expected Core, got: " ++ show ir -splitSimpModule :: TopEnv -> Module -> (Block, Abs Binder Module) +splitSimpModule :: Bindings -> Module -> (Block, Abs Binder Module) splitSimpModule scope m = do let (Module Simp decls bindings) = hoistDepDataCons scope m let localVars = filter (not . isGlobal) $ bindingsAsVars $ freeVars bindings let block = Block decls $ Atom $ mkConsList $ map Var localVars let (Abs b (decls', bindings')) = - fst $ flip runEmbed scope $ buildAbs (Bind ("result":>getType block)) $ + fst $ flip runBuilder scope $ buildAbs (Bind ("result":>getType block)) $ \result -> do results <- unpackConsList result - substEmbed (newEnv localVars results) bindings + substBuilder (newEnv localVars results) bindings (block, Abs b (Module Evaluated decls' bindings')) -- Bundling up the free vars in a result with a dependent constructor like -- `AsList n xs` doesn't give us a well typed term. This is a short-term -- workaround. -hoistDepDataCons :: TopEnv -> Module -> Module +hoistDepDataCons :: Bindings -> Module -> Module hoistDepDataCons scope (Module Simp decls bindings) = Module Simp decls' bindings' where - (bindings', (_, decls')) = flip runEmbed scope $ do + (bindings', (_, decls')) = flip runBuilder scope $ do mapM_ emitDecl decls forM bindings \(ty, info) -> case info of LetBound ann x | isData ty -> do x' <- emit x @@ -88,7 +88,7 @@ simplifyDecl (Let ann b expr) = do simplifyStandalone :: Expr -> SimplifyM Atom simplifyStandalone (Atom (LamVal b body)) = do - b' <- mapM substEmbedR b + b' <- mapM substBuilderR b buildLam b' PureArrow \x -> extendR (b@>x) $ simplifyBlock body simplifyStandalone block = @@ -109,8 +109,8 @@ simplifyAtom atom = case atom of Nothing -> case envLookup scope v of Just (_, info) -> case info of LetBound ann (Atom x) | ann /= NoInlineLet -> dropSub $ simplifyAtom x - _ -> substEmbedR atom - _ -> substEmbedR atom + _ -> substBuilderR atom + _ -> substBuilderR atom -- Tables that only contain data aren't necessarily getting inlined, -- so this might be the last chance to simplify them. TabVal _ _ -> do @@ -118,34 +118,34 @@ simplifyAtom atom = case atom of True -> do ~(tab', Nothing) <- simplifyLam atom return tab' - False -> substEmbedR atom + False -> substBuilderR atom -- We don't simplify body of lam because we'll beta-reduce it soon. - Lam _ -> substEmbedR atom - Pi _ -> substEmbedR atom + Lam _ -> substBuilderR atom + Pi _ -> substBuilderR atom Con con -> Con <$> mapM simplifyAtom con TC tc -> TC <$> mapM simplifyAtom tc - Eff eff -> Eff <$> substEmbedR eff + Eff eff -> Eff <$> substBuilderR eff TypeCon def params -> TypeCon def <$> mapM simplifyAtom params DataCon def params con args -> DataCon def <$> mapM simplifyAtom params <*> pure con <*> mapM simplifyAtom args Record items -> Record <$> mapM simplifyAtom items RecordTy items -> RecordTy <$> simplifyExtLabeledItems items Variant types label i value -> Variant <$> - substEmbedR types <*> pure label <*> pure i <*> simplifyAtom value + substBuilderR types <*> pure label <*> pure i <*> simplifyAtom value VariantTy items -> VariantTy <$> simplifyExtLabeledItems items LabeledRow items -> LabeledRow <$> simplifyExtLabeledItems items ACase e alts rty -> do - e' <- substEmbedR e + e' <- substBuilderR e case simplifyCase e' alts of Just (env, result) -> extendR env $ simplifyAtom result Nothing -> do alts' <- forM alts \(Abs bs a) -> do - bs' <- mapM (mapM substEmbedR) bs + bs' <- mapM (mapM substBuilderR) bs (Abs bs'' b) <- buildNAbs bs' \xs -> extendR (newEnv bs' xs) $ simplifyAtom a case b of Block Empty (Atom r) -> return $ Abs bs'' r _ -> error $ "Nontrivial block in ACase simplification" - ACase e' alts' <$> (substEmbedR rty) + ACase e' alts' <$> (substBuilderR rty) DataConRef _ _ _ -> error "Should only occur in Imp lowering" BoxedRef _ _ _ _ -> error "Should only occur in Imp lowering" ProjectElt idxs v -> getProjection (toList idxs) <$> simplifyAtom (Var v) @@ -153,7 +153,7 @@ simplifyAtom atom = case atom of simplifyExtLabeledItems :: ExtLabeledItems Atom Name -> SimplifyM (ExtLabeledItems Atom Name) simplifyExtLabeledItems (Ext items ext) = do items' <- mapM simplifyAtom items - ext' <- substEmbedR (Ext NoLabeledItems ext) + ext' <- substBuilderR (Ext NoLabeledItems ext) return $ prefixExtLabeledItems items' ext' simplifyCase :: Atom -> [AltP a] -> Maybe (SubstEnv, a) @@ -180,10 +180,10 @@ simplifyLam = simplifyLams 1 simplifyBinaryLam :: Atom -> SimplifyM (Atom, Reconstruct SimplifyM Atom) simplifyBinaryLam = simplifyLams 2 --- Unlike `substEmbedR`, this simplifies under the binder too. +-- Unlike `substBuilderR`, this simplifies under the binder too. simplifyLams :: Int -> Atom -> SimplifyM (Atom, Reconstruct SimplifyM Atom) simplifyLams numArgs lam = do - lam' <- substEmbedR lam + lam' <- substBuilderR lam dropSub $ go numArgs mempty $ Block Empty $ Atom lam' where go 0 scope block = do @@ -199,8 +199,8 @@ simplifyLams numArgs lam = do atomf dat' <$> recon dat' ctx' ) go n scope ~(Block Empty (Atom (Lam (Abs b (arr, body))))) = do - b' <- mapM substEmbedR b - buildLamAux b' (\x -> extendR (b@>x) $ substEmbedR arr) \x@(Var v) -> do + b' <- mapM substBuilderR b + buildLamAux b' (\x -> extendR (b@>x) $ substBuilderR arr) \x@(Var v) -> do let scope' = scope <> v @> (varType v, LamBound (void arr)) extendR (b@>x) $ go (n-1) scope' body @@ -209,7 +209,7 @@ defunBlock localScope block = do if isData (getType block) then Left <$> simplifyBlock block else do - (result, (localScope', decls)) <- embedScoped $ simplifyBlock block + (result, (localScope', decls)) <- builderScoped $ simplifyBlock block mapM_ emitDecl decls Right <$> separateDataComponent (localScope <> localScope') result @@ -233,7 +233,7 @@ type AtomFac m = -- TODO: Records -- Guarantees that data elements are entirely type driven (e.g. won't be deduplicated based on -- the supplied atom). The same guarantee doesn't apply to the non-data closures. -separateDataComponent :: forall m. MonadEmbed m => Scope -> Atom -> m (AtomFac m) +separateDataComponent :: forall m. MonadBuilder m => Scope -> Atom -> m (AtomFac m) separateDataComponent localVars v = do (dat, (ctx, recon), atomf) <- rec v let (ctx', ctxRec) = dedup dat ctx @@ -315,21 +315,21 @@ simplifyExpr expr = case expr of Atom x -> simplifyAtom x Case e alts resultTy -> do e' <- simplifyAtom e - resultTy' <- substEmbedR resultTy + resultTy' <- substBuilderR resultTy case simplifyCase e' alts of Just (env, body) -> extendR env $ simplifyBlock body Nothing -> do if isData resultTy' then do alts' <- forM alts \(Abs bs body) -> do - bs' <- mapM (mapM substEmbedR) bs + bs' <- mapM (mapM substBuilderR) bs buildNAbs bs' \xs -> extendR (newEnv bs' xs) $ simplifyBlock body emit $ Case e' alts' resultTy' else do -- Construct the blocks of new cases. The results will only get replaced -- later, once we learn the closures of the non-data component of each case. (alts', facs) <- liftM unzip $ forM alts \(Abs bs body) -> do - bs' <- mapM (mapM substEmbedR) bs + bs' <- mapM (mapM substBuilderR) bs buildNAbsAux bs' \xs -> do ~(Right fac@(dat, (ctx, _), _)) <- extendR (newEnv bs' xs) $ defunBlock (boundVars bs') body -- NB: The return value here doesn't really matter as we're going to replace it afterwards. @@ -432,6 +432,9 @@ simplifyOp op = case op of -- Simplify the case away if we can. dropSub $ simplifyExpr $ Case full alts $ VariantTy resultRow _ -> emitOp op + PrimEffect ref (MExtend f) -> dropSub $ do + ~(f', Nothing) <- simplifyLam f + emitOp $ PrimEffect ref $ MExtend f' _ -> emitOp op simplifyHof :: Hof -> SimplifyM Atom @@ -446,7 +449,7 @@ simplifyHof hof = case hof of ~(fT', Nothing) <- simplifyLam fT ~(fS', Nothing) <- simplifyLam fS emit $ Hof $ Tile d fT' fS' - PTileReduce _ _ -> error "Unexpected PTileReduce" + PTileReduce _ _ _ -> error "Unexpected PTileReduce" While body -> do ~(body', Nothing) <- simplifyLam body emit $ Hof $ While body' @@ -463,9 +466,11 @@ simplifyHof hof = case hof of r' <- simplifyAtom r ~(lam', recon) <- simplifyBinaryLam lam applyRecon recon =<< (emit $ Hof $ RunReader r' lam') - RunWriter lam -> do + RunWriter (BaseMonoid e combine) lam -> do + e' <- simplifyAtom e + ~(combine', Nothing) <- simplifyBinaryLam combine ~(lam', recon) <- simplifyBinaryLam lam - (ans, w) <- fromPair =<< (emit $ Hof $ RunWriter lam') + (ans, w) <- fromPair =<< (emit $ Hof $ RunWriter (BaseMonoid e' combine') lam') ans' <- applyRecon recon ans return $ PairVal ans' w RunState s lam -> do @@ -485,10 +490,10 @@ simplifyHof hof = case hof of applyRecon Nothing x = return x applyRecon (Just f) x = f x -exceptToMaybeBlock :: Block -> SubstEmbed Atom +exceptToMaybeBlock :: Block -> SubstBuilder Atom exceptToMaybeBlock (Block Empty result) = exceptToMaybeExpr result exceptToMaybeBlock (Block (Nest (Let _ b expr) decls) result) = do - a <- substEmbedR $ getType result + a <- substBuilderR $ getType result maybeResult <- exceptToMaybeExpr expr case maybeResult of -- These two cases are just an optimization @@ -498,25 +503,25 @@ exceptToMaybeBlock (Block (Nest (Let _ b expr) decls) result) = do emitMaybeCase maybeResult (return $ NothingAtom a) \x -> do extendR (b@>x) $ exceptToMaybeBlock $ Block decls result -exceptToMaybeExpr :: Expr -> SubstEmbed Atom +exceptToMaybeExpr :: Expr -> SubstBuilder Atom exceptToMaybeExpr expr = do - a <- substEmbedR $ getType expr + a <- substBuilderR $ getType expr case expr of Case e alts resultTy -> do - e' <- substEmbedR e - resultTy' <- substEmbedR $ MaybeTy resultTy + e' <- substBuilderR e + resultTy' <- substBuilderR $ MaybeTy resultTy alts' <- forM alts \(Abs bs body) -> do - bs' <- substEmbedR bs + bs' <- substBuilderR bs buildNAbs bs' \xs -> extendR (newEnv bs' xs) $ exceptToMaybeBlock body emit $ Case e' alts' resultTy' - Atom x -> substEmbedR $ JustAtom (getType x) x + Atom x -> substBuilderR $ JustAtom (getType x) x Op (ThrowException _) -> return $ NothingAtom a Hof (For ann ~(Lam (Abs b (_, body)))) -> do - b' <- substEmbedR b + b' <- substBuilderR b maybes <- buildForAnn ann b' \i -> extendR (b@>i) $ exceptToMaybeBlock body catMaybesE maybes Hof (RunState s lam) -> do - s' <- substEmbedR s + s' <- substBuilderR s let BinaryFunVal _ b _ body = lam result <- emitRunState "ref" s' \ref -> extendR (b@>ref) $ exceptToMaybeBlock body @@ -529,7 +534,7 @@ exceptToMaybeExpr expr = do exceptToMaybeBlock body runMaybeWhile lam _ | not (hasExceptions expr) -> do - x <- substEmbedR expr >>= emit + x <- substBuilderR expr >>= emit return $ JustAtom (getType x) x | otherwise -> error $ "Unexpected exception-throwing expression: " ++ pprint expr @@ -540,17 +545,17 @@ hasExceptions expr = case t of Just _ -> error "Shouldn't have tail left" where (EffectRow effs t) = exprEffs expr -catMaybesE :: MonadEmbed m => Atom -> m Atom -catMaybesE maybes = simplifyEmbed $ do +catMaybesE :: MonadBuilder m => Atom -> m Atom +catMaybesE maybes = simplifyBuilder $ do let (TabTy b (MaybeTy a)) = getType maybes applyPreludeFunction "seqMaybes" [binderAnn b, a, maybes] -runMaybeWhile :: MonadEmbed m => Atom -> m Atom -runMaybeWhile lam = simplifyEmbed $ do +runMaybeWhile :: MonadBuilder m => Atom -> m Atom +runMaybeWhile lam = simplifyBuilder $ do let (Pi (Abs _ (PlainArrow eff, _))) = getType lam applyPreludeFunction "whileMaybe" [Eff eff, lam] -simplifyEmbed :: MonadEmbed m => m Atom -> m Atom -simplifyEmbed m = do +simplifyBuilder :: MonadBuilder m => m Atom -> m Atom +simplifyBuilder m = do block <- buildScoped m - liftEmbed $ runReaderT (simplifyBlock block) mempty + liftBuilder $ runReaderT (simplifyBlock block) mempty diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index 4aa7d8d3c..3beb24cef 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -23,14 +23,14 @@ module Syntax ( BinOp (..), UnOp (..), CmpOp (..), SourceBlock (..), ReachedEOF, SourceBlock' (..), SubstEnv, ScopedSubstEnv, Scope, CmdName (..), HasIVars (..), ForAnn (..), - Val, TopEnv, Op, Con, Hof, TC, Module (..), DataConRefBinding (..), + Val, Op, Con, Hof, TC, Module (..), DataConRefBinding (..), ImpModule (..), ImpBlock (..), ImpFunction (..), ImpDecl (..), IExpr (..), IVal, ImpInstr (..), Backend (..), Device (..), IPrimOp, IVar, IBinder, IType, SetVal (..), MonMap (..), LitProg, IFunType (..), IFunVar, CallingConvention (..), IsCUDARequired (..), UAlt (..), AltP, Alt, Label, LabeledItems (..), labeledSingleton, lookupLabelHead, reflectLabels, withLabels, ExtLabeledItems (..), - prefixExtLabeledItems, getLabels, + prefixExtLabeledItems, getLabels, ModuleName, IScope, BinderInfo (..), Bindings, CUDAKernel (..), BenchStats, SrcCtx, Result (..), Output (..), OutFormat (..), Err (..), ErrType (..), Except, throw, throwIf, modifyErr, addContext, @@ -45,11 +45,13 @@ module Syntax ( DataDef (..), DataConDef (..), UConDef (..), Nest (..), toNest, subst, deShadow, scopelessSubst, absArgType, applyAbs, makeAbs, applyNaryAbs, applyDataDefParams, freshSkolemVar, IndexStructure, - mkConsList, mkConsListTy, fromConsList, fromConsListTy, extendEffRow, - getProjection, outputStreamPtrName, initTopEnv, + mkConsList, mkConsListTy, fromConsList, fromConsListTy, fromLeftLeaningConsListTy, + extendEffRow, + getProjection, outputStreamPtrName, initBindings, varType, binderType, isTabTy, LogLevel (..), IRVariant (..), + BaseMonoidP (..), BaseMonoid, getBaseMonoidType, applyIntBinOp, applyIntCmpOp, applyFloatBinOp, applyFloatUnOp, - getIntLit, getFloatLit, sizeOf, vectorWidth, + getIntLit, getFloatLit, sizeOf, ptrSize, vectorWidth, pattern MaybeTy, pattern JustAtom, pattern NothingAtom, pattern IdxRepTy, pattern IdxRepVal, pattern IIdxRepVal, pattern IIdxRepTy, pattern TagRepTy, pattern TagRepVal, pattern Word8Ty, @@ -110,7 +112,7 @@ data Atom = Var Var -- XXX: Variable name must not be an alias for another name or for -- a statically-known atom. This is because the variable name used -- here may also appear in the type of the atom. (We maintain this - -- invariant during substitution and in Embed.hs.) + -- invariant during substitution and in Builder.hs.) | ProjectElt (NE.NonEmpty Int) Var deriving (Show, Generic) @@ -166,7 +168,6 @@ type Op = PrimOp Atom type Hof = PrimHof Atom data Module = Module IRVariant (Nest Decl) Bindings deriving Show -type TopEnv = Scope data IRVariant = Surface | Typed | Core | Simp | Evaluated deriving (Show, Eq, Ord, Generic) @@ -250,7 +251,7 @@ data UDecl = ULet LetAnn UPatAnn UExpr | UData UConDef [UConDef] | UInterface [UType] UConDef [UAnnBinder] -- superclasses, constructor, methods - | UInstance (Nest UPatAnnArrow) UType [UMethodDef] -- args, type, methods + | UInstance (Maybe UVar) (Nest UPatAnnArrow) UType [UMethodDef] -- name, args, type, methods deriving (Show, Generic) type UType = UExpr @@ -385,16 +386,20 @@ data PrimHof e = | Tile Int e e -- dimension number, tiled body, scalar body | While e | RunReader e e - | RunWriter e + | RunWriter (BaseMonoidP e) e | RunState e e | RunIO e | CatchException e | Linearize e | Transpose e - | PTileReduce e e -- index set, thread body + | PTileReduce [BaseMonoidP e] e e -- accumulator monoids, index set, thread body deriving (Show, Eq, Generic, Functor, Foldable, Traversable) -data PrimEffect e = MAsk | MTell e | MGet | MPut e +data BaseMonoidP e = BaseMonoid { baseEmpty :: e, baseCombine :: e } + deriving (Show, Eq, Generic, Functor, Foldable, Traversable) +type BaseMonoid = BaseMonoidP Atom + +data PrimEffect e = MAsk | MExtend e | MGet | MPut e deriving (Show, Eq, Generic, Functor, Foldable, Traversable) data BinOp = IAdd | ISub | IMul | IDiv | ICmp CmpOp @@ -441,6 +446,11 @@ primNameToStr prim = case lookup prim $ map swap $ M.toList builtinNames of showPrimName :: PrimExpr e -> String showPrimName prim = primNameToStr $ fmap (const ()) prim +getBaseMonoidType :: Type -> Type +getBaseMonoidType ty = case ty of + TabTy _ b -> getBaseMonoidType b + _ -> ty + -- === effects === data EffectRow = EffectRow (S.Set Effect) (Maybe Name) @@ -457,8 +467,8 @@ pattern Pure <- ((\(EffectRow effs t) -> (S.null effs, t)) -> (True, Nothing)) outputStreamPtrName :: Name outputStreamPtrName = GlobalName "OUT_STREAM_PTR" -initTopEnv :: TopEnv -initTopEnv = fold [v @> (ty, LamBound ImplicitArrow) | (v, ty) <- +initBindings :: Bindings +initBindings = fold [v @> (ty, LamBound ImplicitArrow) | (v, ty) <- [(outputStreamPtrName , BaseTy $ hostPtrTy $ hostPtrTy $ Scalar Word8Type)]] hostPtrTy :: BaseType -> BaseType @@ -489,10 +499,11 @@ data SourceBlock = SourceBlock type BlockId = Int type ReachedEOF = Bool +type ModuleName = String data SourceBlock' = RunModule UModule | Command CmdName (Name, UModule) | GetNameType Name - | IncludeSourceFile String + | ImportModule ModuleName | ProseBlock String | CommentLine | EmptyLines @@ -592,9 +603,12 @@ sizeOf t = case t of Scalar Word8Type -> 1 Scalar Float64Type -> 8 Scalar Float32Type -> 4 - PtrType _ -> 8 + PtrType _ -> ptrSize Vector st -> vectorWidth * sizeOf (Scalar st) +ptrSize :: Int +ptrSize = 8 + vectorWidth :: Int vectorWidth = 4 @@ -801,7 +815,7 @@ instance HasUVars UDecl where freeUVars (UData (UConDef _ bs) dataCons) = freeUVars $ Abs bs dataCons freeUVars (UInterface superclasses tc methods) = freeUVars $ Abs tc (superclasses, methods) - freeUVars (UInstance bsArrows ty methods) = freeUVars $ Abs bs (ty, methods) + freeUVars (UInstance _ bsArrows ty methods) = freeUVars $ Abs bs (ty, methods) where bs = fmap fst bsArrows instance HasUVars UMethodDef where @@ -812,10 +826,11 @@ instance BindsUVars UPatAnn where instance BindsUVars UDecl where boundUVars decl = case decl of - ULet _ (p,_) _ -> boundUVars p - UData tyCon dataCons -> boundUVars tyCon <> foldMap boundUVars dataCons - UInterface _ _ _ -> mempty - UInstance _ _ _ -> mempty + ULet _ (p,_) _ -> boundUVars p + UData tyCon dataCons -> boundUVars tyCon <> foldMap boundUVars dataCons + UInterface _ _ _ -> mempty + UInstance Nothing _ _ _ -> mempty + UInstance (Just v) _ _ _ -> v @> () instance HasUVars UModule where freeUVars (UModule decls) = freeUVars decls @@ -1477,6 +1492,15 @@ fromConsListTy ty = case ty of PairTy t rest -> (t:) <$> fromConsListTy rest _ -> throw CompilerErr $ "Not a pair or unit: " ++ show ty +-- ((...((ans & x{n}) & x{n-1})... & x2) & x1) -> (ans, [x1, ..., x{n}]) +fromLeftLeaningConsListTy :: MonadError Err m => Int -> Type -> m (Type, [Type]) +fromLeftLeaningConsListTy depth initTy = go depth initTy [] + where + go 0 ty xs = return (ty, reverse xs) + go remDepth ty xs = case ty of + PairTy lt rt -> go (remDepth - 1) lt (rt : xs) + _ -> throw CompilerErr $ "Not a pair: " ++ show xs + fromConsList :: MonadError Err m => Atom -> m [Atom] fromConsList xs = case xs of UnitVal -> return [] @@ -1596,7 +1620,7 @@ builtinNames = M.fromList , ("throwError" , OpExpr $ ThrowError ()) , ("throwException" , OpExpr $ ThrowException ()) , ("ask" , OpExpr $ PrimEffect () $ MAsk) - , ("tell" , OpExpr $ PrimEffect () $ MTell ()) + , ("mextend" , OpExpr $ PrimEffect () $ MExtend ()) , ("get" , OpExpr $ PrimEffect () $ MGet) , ("put" , OpExpr $ PrimEffect () $ MPut ()) , ("indexRef" , OpExpr $ IndexRef () ()) @@ -1606,7 +1630,7 @@ builtinNames = M.fromList , ("linearize" , HofExpr $ Linearize ()) , ("linearTranspose" , HofExpr $ Transpose ()) , ("runReader" , HofExpr $ RunReader () ()) - , ("runWriter" , HofExpr $ RunWriter ()) + , ("runWriter" , HofExpr $ RunWriter (BaseMonoid () ()) ()) , ("runState" , HofExpr $ RunState () ()) , ("runIO" , HofExpr $ RunIO ()) , ("catchException" , HofExpr $ CatchException ()) @@ -1661,6 +1685,7 @@ instance Store a => Store (Nest a) instance Store a => Store (ArrowP a) instance Store a => Store (Limit a) instance Store a => Store (PrimEffect a) +instance Store a => Store (BaseMonoidP a) instance Store a => Store (LabeledItems a) instance (Store a, Store b) => Store (ExtLabeledItems a b) instance Store ForAnn diff --git a/src/lib/TopLevel.hs b/src/lib/TopLevel.hs index 63b203099..bdf2da43b 100644 --- a/src/lib/TopLevel.hs +++ b/src/lib/TopLevel.hs @@ -6,8 +6,10 @@ {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE DeriveGeneric #-} -module TopLevel (evalSourceBlock, evalDecl, evalSource, evalFile, EvalConfig (..)) where +module TopLevel (evalSourceBlock, evalDecl, evalSource, evalFile, + initTopEnv, EvalConfig (..), TopEnv (..)) where import Control.Monad.State.Strict import Control.Monad.Reader @@ -16,9 +18,14 @@ import Data.Text.Prettyprint.Doc import Data.String import Data.List (partition) import qualified Data.Map.Strict as M +import Data.Store (Store) +import GHC.Generics (Generic) +import System.FilePath + +import Paths_dex (getDataFileName) import Syntax -import Embed +import Builder import Cat import Env import Type @@ -39,6 +46,7 @@ import Parallelize data EvalConfig = EvalConfig { backendName :: Backend + , libPath :: Maybe FilePath , logFile :: Maybe FilePath } @@ -48,6 +56,16 @@ data TopPassEnv = TopPassEnv , evalConfig :: EvalConfig } type TopPassM a = ReaderT TopPassEnv IO a +data TopEnv = TopEnv + { topBindings :: Bindings + , modulesImported :: M.Map ModuleName ModuleImportStatus} + deriving Generic + +data ModuleImportStatus = CurrentlyImporting | FullyImported deriving Generic + +initTopEnv :: TopEnv +initTopEnv = TopEnv initBindings mempty + evalDecl :: EvalConfig -> SourceBlock -> StateT TopEnv IO Result evalDecl opts block = do env <- get @@ -81,11 +99,13 @@ runTopPassM bench opts m = runLogger (logFile opts) \logger -> runExceptT $ catchIOExcept $ runReaderT m $ TopPassEnv logger bench opts evalSourceBlockM :: TopEnv -> SourceBlock -> TopPassM TopEnv -evalSourceBlockM env block = case sbContents block of - RunModule m -> evalUModule env m +evalSourceBlockM env@(TopEnv bindings _) block = case sbContents block of + RunModule m -> do + newBindings <- evalUModule bindings m + return $ mempty { topBindings = newBindings } Command cmd (v, m) -> mempty <$ case cmd of EvalExpr fmt -> do - val <- evalUModuleVal env v m + val <- evalUModuleVal bindings v m case fmt of Printed -> do s <- liftIO $ pprintVal val @@ -95,22 +115,30 @@ evalSourceBlockM env block = case sbContents block of s <- liftIO $ getDexString val logTop $ HtmlOut s ExportFun name -> do - f <- evalUModuleVal env v m + f <- evalUModuleVal bindings v m void $ traverseLiterals f \val -> case val of PtrLit _ _ -> liftEitherIO $ throw CompilerErr $ "Can't export functions with captured pointers (not implemented)." _ -> return $ Con $ Lit val logTop $ ExportedFun name f GetType -> do -- TODO: don't actually evaluate it - val <- evalUModuleVal env v m + val <- evalUModuleVal bindings v m logTop $ TextOut $ pprint $ getType val - GetNameType v -> case envLookup env (v:>()) of + GetNameType v -> case envLookup bindings (v:>()) of Just (ty, _) -> logTop (TextOut $ pprint ty) >> return mempty _ -> liftEitherIO $ throw UnboundVarErr $ pprint v - IncludeSourceFile fname -> do - fullPath <- liftIO $ findSourceFile fname - source <- liftIO $ readFile fullPath - evalSourceBlocks env $ parseProg source + ImportModule moduleName -> + case M.lookup moduleName $ modulesImported env of + Just CurrentlyImporting -> liftEitherIO $ throw MiscErr $ + "Circular import detected: " ++ pprint moduleName + Just FullyImported -> return mempty + Nothing -> do + fullPath <- findModulePath moduleName + source <- liftIO $ readFile fullPath + newTopEnv <- evalSourceBlocks + (env <> moduleStatus moduleName CurrentlyImporting) $ + parseProg source + return $ newTopEnv <> moduleStatus moduleName FullyImported UnParseable _ s -> liftEitherIO $ throw ParseErr s _ -> return mempty @@ -152,19 +180,19 @@ isLogInfo out = case out of evalSourceBlocks :: TopEnv -> [SourceBlock] -> TopPassM TopEnv evalSourceBlocks env blocks = catFoldM evalSourceBlockM env blocks -evalUModuleVal :: TopEnv -> Name -> UModule -> TopPassM Val +evalUModuleVal :: Bindings -> Name -> UModule -> TopPassM Val evalUModuleVal env v m = do env' <- evalUModule env m return $ lookupBindings (env <> env') (v:>()) -lookupBindings :: Scope -> VarP ann -> Atom +lookupBindings :: Bindings -> VarP ann -> Atom lookupBindings scope v = x where (_, LetBound PlainLet (Atom x)) = scope ! v -- TODO: extract only the relevant part of the env we can check for module-level -- unbound vars and upstream errors here. This should catch all unbound variable -- errors, but there could still be internal shadowing errors. -evalUModule :: TopEnv -> UModule -> TopPassM TopEnv +evalUModule :: Bindings -> UModule -> TopPassM Bindings evalUModule env untyped = do logPass Parse untyped typed <- liftEitherIO $ inferModule env untyped @@ -191,7 +219,7 @@ evalUModule env untyped = do checkPass ResultPass $ Module Evaluated Empty newBindings return newBindings -evalBackend :: TopEnv -> Block -> TopPassM Atom +evalBackend :: Bindings -> Block -> TopPassM Atom evalBackend env block = do backend <- asks (backendName . evalConfig) bench <- asks benchmark @@ -266,7 +294,7 @@ abstractPtrLiterals block = flip evalState mempty $ do return (impBinders, vals, block') class HasTraversal a where - traverseCore :: (MonadEmbed m, MonadReader SubstEnv m) => TraversalDef m -> a -> m a + traverseCore :: (MonadBuilder m, MonadReader SubstEnv m) => TraversalDef m -> a -> m a instance HasTraversal Block where traverseCore = traverseBlock @@ -276,13 +304,35 @@ instance HasTraversal Atom where traverseLiterals :: (HasTraversal e, Monad m) => e -> (LitVal -> m Atom) -> m e traverseLiterals block f = - liftM fst $ flip runSubstEmbedT mempty $ traverseCore def block + liftM fst $ flip runSubstBuilderT mempty $ traverseCore def block where def = (traverseDecl def, traverseExpr def, traverseAtomLiterals) traverseAtomLiterals atom = case atom of Con (Lit x) -> lift $ lift $ f x _ -> traverseAtom def atom --- TODO: use something like a `DEXPATH` env var for finding source files -findSourceFile :: FilePath -> IO FilePath -findSourceFile fpath = return $ "lib/" ++ fpath +findModulePath :: ModuleName -> TopPassM FilePath +findModulePath moduleName = do + let fname = moduleName ++ ".dx" + specifiedPath <- asks (libPath . evalConfig) + case specifiedPath of + Nothing -> liftIO $ getDataFileName $ "lib/" ++ fname + Just path -> return $ path fname + +instance Semigroup TopEnv where + (TopEnv env ms) <> (TopEnv env' ms') = + -- Data.Map is left-biased so we flip the order + TopEnv (env <> env') (ms' <> ms) + +instance Monoid TopEnv where + mempty = TopEnv mempty mempty + +moduleStatus :: ModuleName -> ModuleImportStatus -> TopEnv +moduleStatus name status = mempty { modulesImported = M.singleton name status} + +instance HasPtrs TopEnv where + traversePtrs f (TopEnv bindings status) = + TopEnv <$> traverse (traversePtrs f) bindings <*> pure status + +instance Store TopEnv +instance Store ModuleImportStatus diff --git a/src/lib/Type.hs b/src/lib/Type.hs index 29248533a..4fd549e21 100644 --- a/src/lib/Type.hs +++ b/src/lib/Type.hs @@ -272,10 +272,10 @@ exprEffs expr = case expr of App f _ -> functionEffs f Op op -> case op of PrimEffect ref m -> case m of - MGet -> oneEffect (RWSEffect State h) - MPut _ -> oneEffect (RWSEffect State h) - MAsk -> oneEffect (RWSEffect Reader h) - MTell _ -> oneEffect (RWSEffect Writer h) + MGet -> oneEffect (RWSEffect State h) + MPut _ -> oneEffect (RWSEffect State h) + MAsk -> oneEffect (RWSEffect Reader h) + MExtend _ -> oneEffect (RWSEffect Writer h) where RefTy (Var (h:>_)) _ = getType ref ThrowException _ -> oneEffect ExceptionEffect IOAlloc _ _ -> oneEffect IOEffect @@ -291,9 +291,9 @@ exprEffs expr = case expr of Linearize _ -> mempty -- Body has to be a pure function Transpose _ -> mempty -- Body has to be a pure function RunReader _ f -> handleRWSRunner Reader f - RunWriter f -> handleRWSRunner Writer f + RunWriter _ f -> handleRWSRunner Writer f RunState _ f -> handleRWSRunner State f - PTileReduce _ _ -> mempty + PTileReduce _ _ _ -> mempty RunIO ~(Lam (Abs _ (PlainArrow (EffectRow effs t), _))) -> EffectRow (S.delete IOEffect effs) t CatchException ~(Lam (Abs _ (PlainArrow (EffectRow effs t), _))) -> @@ -445,14 +445,14 @@ instance CoreVariant (PrimHof a) where For _ _ -> alwaysAllowed While _ -> alwaysAllowed RunReader _ _ -> alwaysAllowed - RunWriter _ -> alwaysAllowed + RunWriter _ _ -> alwaysAllowed RunState _ _ -> alwaysAllowed RunIO _ -> alwaysAllowed Linearize _ -> goneBy Simp Transpose _ -> goneBy Simp Tile _ _ _ -> alwaysAllowed - PTileReduce _ _ -> absentUntil Simp -- really absent until parallelization - CatchException _ -> goneBy Simp + PTileReduce _ _ _ -> absentUntil Simp -- really absent until parallelization + CatchException _ -> goneBy Simp -- TODO: namespace restrictions? alwaysAllowed :: VariantM () @@ -704,10 +704,10 @@ typeCheckOp op = case op of PrimEffect ref m -> do TC (RefType ~(Just (Var (h':>TyKind))) s) <- typeCheck ref case m of - MGet -> declareEff (RWSEffect State h') $> s - MPut x -> x|:s >> declareEff (RWSEffect State h') $> UnitTy - MAsk -> declareEff (RWSEffect Reader h') $> s - MTell x -> x|:s >> declareEff (RWSEffect Writer h') $> UnitTy + MGet -> declareEff (RWSEffect State h') $> s + MPut x -> x|:s >> declareEff (RWSEffect State h') $> UnitTy + MAsk -> declareEff (RWSEffect Reader h') $> s + MExtend x -> x|:(s --> s) >> declareEff (RWSEffect Writer h') $> UnitTy IndexRef ref i -> do RefTy h (TabTyAbs a) <- typeCheck ref i |: absArgType a @@ -855,15 +855,16 @@ typeCheckHof hof = case hof of replaceDim 0 (TabTy _ b) n = TabTy (Ignore n) b replaceDim d (TabTy dv b) n = TabTy dv $ replaceDim (d-1) b n replaceDim _ _ _ = error "This should be checked before" - PTileReduce n mapping -> do - -- mapping : gtid:IdxRepTy -> nthr:IdxRepTy -> ((ParIndexRange n gtid nthr)=>a, r) + PTileReduce baseMonoids n mapping -> do + -- mapping : gtid:IdxRepTy -> nthr:IdxRepTy -> (...((ParIndexRange n gtid nthr)=>a, acc{n})..., acc1) BinaryFunTy (Bind gtid) (Bind nthr) Pure mapResultTy <- typeCheck mapping - PairTy tiledArrTy accTy <- return mapResultTy + (tiledArrTy, accTys) <- fromLeftLeaningConsListTy (length baseMonoids) mapResultTy let threadRange = TC $ ParIndexRange n (Var gtid) (Var nthr) TabTy threadRange' tileElemTy <- return tiledArrTy checkEq threadRange (binderType threadRange') - -- PTileReduce n mapping : (n=>a, ro) - return $ PairTy (TabTy (Ignore n) tileElemTy) accTy + -- TODO: Check compatibility of baseMonoids and accTys (need to be careful about lifting!) + -- PTileReduce n mapping : (n=>a, (acc1, ..., acc{n})) + return $ PairTy (TabTy (Ignore n) tileElemTy) $ mkConsListTy accTys While body -> do Pi (Abs (Ignore UnitTy) (arr , condTy)) <- typeCheck body declareEffs $ arrowEff arr @@ -879,7 +880,14 @@ typeCheckHof hof = case hof of (resultTy, readTy) <- checkRWSAction Reader f r |: readTy return resultTy - RunWriter f -> uncurry PairTy <$> checkRWSAction Writer f + RunWriter _ f -> do + -- XXX: We can't verify compatibility between the base monoid and f, because + -- the only way in which they are related in the runAccum definition is via + -- the AccumMonoid typeclass. The frontend constraints should be sufficient + -- to ensure that only well typed programs are accepted, but it is a bit + -- disappointing that we cannot verify that internally. We might want to consider + -- e.g. only disabling this check for prelude. + uncurry PairTy <$> checkRWSAction Writer f RunState s f -> do (resultTy, stateTy) <- checkRWSAction State f s |: stateTy diff --git a/src/lib/dexrt.cpp b/src/lib/dexrt.cpp index be89d4028..505e25387 100644 --- a/src/lib/dexrt.cpp +++ b/src/lib/dexrt.cpp @@ -23,19 +23,33 @@ extern "C" { +// XXX: Changes to this value might require additional changes to parameter attributes in LLVM +const int64_t alignment = 64; + char* malloc_dex(int64_t nbytes) { - // XXX: Changes to this value might require additional changes to parameter attributes in LLVM - static const int64_t alignment = 64; + // reserves `alignment` bytes before the data region to store the size of the allocation + int64_t nbytes_total = nbytes + alignment; char *ptr; - if (posix_memalign(reinterpret_cast(&ptr), alignment, nbytes)) { + if (posix_memalign(reinterpret_cast(&ptr), alignment, nbytes_total)) { fprintf(stderr, "Failed to allocate %ld bytes", (long)nbytes); std::abort(); } + *(reinterpret_cast(ptr)) = nbytes; + return ptr + alignment; +} + +char* dex_malloc_initialized(int64_t nbytes) { + char *ptr = malloc_dex(nbytes); + memset(ptr, 0, nbytes); return ptr; } void free_dex(char* ptr) { - free(ptr); + free(ptr - alignment); +} + +int64_t dex_allocation_size (char* ptr) { + return *(reinterpret_cast(ptr - alignment)); } void* fdopen_w(int fd) { diff --git a/src/resources/Resources.hs b/src/resources/Resources.hs index d834767e1..6fe1db9d2 100644 --- a/src/resources/Resources.hs +++ b/src/resources/Resources.hs @@ -17,4 +17,4 @@ preludeSource = B.unpack $(embedFile "lib/prelude.dx") -- The source code of the CSS used for rendering Dex programs as HTML. cssSource :: String -cssSource = B.unpack $(embedFile "static/style.css") \ No newline at end of file +cssSource = B.unpack $(embedFile "static/style.css") diff --git a/tests/ad-tests.dx b/tests/ad-tests.dx index 6affc69f6..a843a49ad 100644 --- a/tests/ad-tests.dx +++ b/tests/ad-tests.dx @@ -1,6 +1,6 @@ -- TODO: use prelude sum instead once we can differentiate state effect -def sum' (xs:n=>Float) : Float = yieldAccum \ref. for i. ref += xs.i +def sum' (xs:n=>Float) : Float = yieldAccum (AddMonoid Float) \ref. for i. ref += xs.i :p f : Float -> Float = \x. x @@ -69,7 +69,7 @@ def sum' (xs:n=>Float) : Float = yieldAccum \ref. for i. ref += xs.i :p jvp sum' [1., 2.] [10.0, 20.0] > 30. -f : Float -> Float = \x. yieldAccum \ref. ref += x +f : Float -> Float = \x. yieldAccum (AddMonoid Float) \ref. ref += x :p jvp f 1.0 1.0 > 1. @@ -167,7 +167,7 @@ tripleit : Float --o Float = \x. x + x + x > [2., 4.] myOtherSquare : Float -> Float = - \x. yieldAccum \w. w += x * x + \x. yieldAccum (AddMonoid Float) \w. w += x * x :p checkDeriv myOtherSquare 3.0 > True @@ -225,7 +225,7 @@ vec = [1.] :p f : Float -> Float = \x. y = x * 2.0 - yieldAccum \a. + yieldAccum (AddMonoid Float) \a. a += x * 2.0 a += y grad f 1.0 diff --git a/tests/adt-tests.dx b/tests/adt-tests.dx index 08a28c6d0..42f84c043 100644 --- a/tests/adt-tests.dx +++ b/tests/adt-tests.dx @@ -102,7 +102,7 @@ myTab = [MyLeft 1, MyRight 3.5, MyLeft 123, MyLeft 456] > Runtime error :p - yieldAccum \ref. + yieldAccum (AddMonoid Float) \ref. for i. case myTab.i of MyLeft tmp -> () MyRight val -> ref += 1.0 + val @@ -110,7 +110,7 @@ myTab = [MyLeft 1, MyRight 3.5, MyLeft 123, MyLeft 456] :p -- check that the order of the case alternatives doesn't matter - yieldAccum \ref. + yieldAccum (AddMonoid Float) \ref. for i. case myTab.i of MyRight val -> ref += 1.0 + val MyLeft tmp -> () @@ -128,7 +128,7 @@ threeCaseTab : (Fin 4)=>ThreeCases = > [(TheIntCase 3), TheEmptyCase, (ThePairCase 2 0.1), TheEmptyCase] :p - yieldAccum \ref. + yieldAccum (AddMonoid Float) \ref. for i. case threeCaseTab.i of TheEmptyCase -> ref += 1000.0 ThePairCase x y -> ref += 100.0 + y + IToF x diff --git a/tests/eval-tests.dx b/tests/eval-tests.dx index f853dceab..2426304f7 100644 --- a/tests/eval-tests.dx +++ b/tests/eval-tests.dx @@ -698,7 +698,7 @@ def newtonSolve (tol:Float) (f : Float -> Float) (x0:Float) : Float = > 415 :p - (f, w) = runAccum \ref. + (f, w) = runAccum (AddMonoid Float) \ref. ref += 2.0 w = 2 \z. z + w @@ -716,17 +716,6 @@ arr2d = for i:(Fin 2). for j:(Fin 2). (iota _).(i,j) arr2d.(1@_) > [2, 3] -:p - runState (1,2) \ref. - r1 = fstRef ref - r2 = sndRef ref - x = get r1 - y = get r2 - r2 := x - r1 := y -> ((), (2, 1)) - - :p any [True, False] > True :p any [False, False] diff --git a/tests/monad-tests.dx b/tests/monad-tests.dx index 8f9994252..58e116d7f 100644 --- a/tests/monad-tests.dx +++ b/tests/monad-tests.dx @@ -27,6 +27,7 @@ :p def rwsAction (rh:Type) ?-> (wh:Type) ?-> (sh:Type) ?-> + (_:AccumMonoid wh Float) ?=> (r:Ref rh Int) (w:Ref wh Float) (s:Ref sh Bool) : {Read rh, Accum wh, State sh} Int = x = get s @@ -38,7 +39,7 @@ withReader 2 \r. runState True \s. - runAccum \w. + runAccum (AddMonoid Float) \w. rwsAction r w s > ((4, 6.), False) @@ -56,29 +57,31 @@ :p def m (wh:Type) ?-> (sh:Type) ?-> + (_:AccumMonoid wh Float) ?=> (w:Ref wh Float) (s:Ref sh Float) : {Accum wh, State sh} Unit = x = get s w += x - runState 1.0 \s. runAccum \w . m w s + runState 1.0 \s. runAccum (AddMonoid Float) \w . m w s > (((), 1.), 1.) -def myAction (w:Ref hw Float) (r:Ref hr Float) : {Read hr, Accum hw} Unit = +def myAction [AccumMonoid hw Float] (w:Ref hw Float) (r:Ref hr Float) : {Read hr, Accum hw} Unit = x = ask r w += x w += 2.0 -:p withReader 1.5 \r. runAccum \w. myAction w r +:p withReader 1.5 \r. runAccum (AddMonoid Float) \w. myAction w r > ((), 3.5) :p def m (h1:Type) ?-> (h2:Type) ?-> + (_:AccumMonoid h1 Float) ?=> (_:AccumMonoid h2 Float) ?=> (w1:Ref h1 Float) (w2:Ref h2 Float) : {Accum h1, Accum h2} Unit = w1 += 1.0 w2 += 3.0 w1 += 1.0 - runAccum \w1. runAccum \w2. m w1 w2 + runAccum (AddMonoid Float) \w1. runAccum (AddMonoid Float) \w2. m w1 w2 > (((), 3.), 2.) def foom (h:Type) ?-> (s:Ref h ((Fin 3)=>Int)) : {State h} Unit = @@ -125,8 +128,8 @@ def foom (h:Type) ?-> (s:Ref h ((Fin 3)=>Int)) : {State h} Unit = -- (maybe just explicit implicit args) :p withReader 2.0 \r. - runAccum \w. - runAccum \w'. + runAccum (AddMonoid Float) \w. + runAccum (AddMonoid Float) \w'. runState 3 \s. x = ask r y = get s @@ -151,19 +154,19 @@ symmetrizeInPlace [[1.,2.],[3.,4.]] :p withReader 5 \r. () > () -:p yieldAccum \w. +:p yieldAccum (AddMonoid Float) \w. for i:(Fin 2). w += 1.0 w += 1.0 > 4. -:p yieldAccum \w. +:p yieldAccum (AddMonoid Float) \w. for i:(Fin 2). w += 1.0 w += 1.0 > 3. -:p yieldAccum \ref. +:p yieldAccum (AddMonoid Float) \ref. ref += [1.,2.,3.] ref += [2.,4.,5.] > [3., 6., 8.] diff --git a/tests/parser-combinator-tests.dx b/tests/parser-combinator-tests.dx index 983a4da83..bcb7ff88a 100644 --- a/tests/parser-combinator-tests.dx +++ b/tests/parser-combinator-tests.dx @@ -1,5 +1,5 @@ -include "parser.dx" +import parser parseABC : Parser Unit = MkParser \h. parse h $ pChar 'A' diff --git a/tests/parser-tests.dx b/tests/parser-tests.dx index 93d216354..3729dae35 100644 --- a/tests/parser-tests.dx +++ b/tests/parser-tests.dx @@ -113,7 +113,7 @@ def myInt : {State h} Int = 1 > Nullary def can't have effects :p - yieldAccum \ref. + yieldAccum (AddMonoid Float) \ref. x = if True then 1. else 3. if True then ref += x diff --git a/tests/type-tests.dx b/tests/type-tests.dx index 2261e6878..3da155c71 100644 --- a/tests/type-tests.dx +++ b/tests/type-tests.dx @@ -376,3 +376,16 @@ def weakerInferenceReduction (l : i:n=>(..i)=>Float) (j:n): Unit = l.i'.k () () + +-- Tests for table + +a = [0, 1] +b = [0, 1] + +:p a == b +> True + +c = [1, 2] + +:p a < c +> True