Skip to content

Commit

Permalink
Parametrize runWriter by a monoid used for the reduction
Browse files Browse the repository at this point in the history
This makes it possible to express (parallel) reductions over arbitrary
monoids. Thanks to this, we can start removing some nasty hacks (like
the one used for `Eq (n=>a)`) and make the (work-in-progress) FFT example
parallel!

Anyway, this whole change turned out to be surprisingly difficult, but
thanks to many chats with @dougalm, I think that we've arrived at a
particularly nice solution.

The crux of the matter is the fact that Dex, unlike most other
languages with some form of a built-in reduction operator, allows
slicing the accumulator. This poses an interesting problem: if the user
was to specify the `Monoid` instance for the full accumulator (e.g. a
matrix), then what monoid are we supposed to use for its slice?! As it
turns out, this might not even be well defined! For example, the type of square
matrices with identity matrix and matrix multiplication forms a monoid,
but there is no natural "sub-monoid" we could use in an expression
of the form `ref!i += ...`.

So, unless we're ok with giving up reference slicing (which we know we
want for sure, since this is a way to express e.g. parallel scatters and
histograms), we have to come up with a way of constructing those
sub-monoids. And here, and answer is to turn the problem around: instead
of asking the users to provide us the monoids for the full references,
we expect the monoid to refer to some _base type_ (and we call it a
_base monoid_). That is, when the `Accum` reference is of type
`n=>m=>...=>k=>a`, then any of `m=>...=>k=>a`, ..., `k=>a` and even
`a` are considered base types. While this is a bit surprising at first,
it turns out to actually be quite convenient, since it does seem more
straightforward to say "I want this to be a reduction over `(Float, 0.0, +)`"
instead of mentioning the full table type, a broadcast version of `0.0`
and a pointwise-lifted version of `+`.

Finally, because many data types have multiple valid monoids (`Float`
has at least four: `+`, `*`, `min`, `max`), the monoid argument is
explicit and those instances can be obtained via the `named-instance`
syntax added in the previous commits. Note that I've also included some
helper functions which make it possible to synthesize `Monoid` instances
automatically from `Add` and `Mul` instance for any given type (see
`AddMonoid` and `MulMonoid`).

I haven't been fully able to verify the correctness of the
parallelization change, because the CUDA backend seems to be broken
anyway (sigh...), but the code it generates looks ok.
  • Loading branch information
apaszke authored and srush committed Jan 19, 2021
1 parent b93ae77 commit 72934fe
Show file tree
Hide file tree
Showing 18 changed files with 359 additions and 229 deletions.
8 changes: 4 additions & 4 deletions examples/isomorphisms.dx
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ that produce isos. We will start with the first two:
:t #b : Iso {a:Int & b:Float & c:Unit} _
> (Iso {a: Int32 & b: Float32 & c: Unit} (Float32 & {a: Int32 & c: Unit}))
> === parse ===
> _ans_ =
> _ans_ =
> MkIso {bwd = \(x, r). {b = x, ...r}, fwd = \{b = x, ...r}. (,) x r}
> : Iso {a: Int & b: Float & c: Unit} _

%passes parse
:t #?b : Iso {a:Int | b:Float | c:Unit} _
> (Iso {a: Int32 | b: Float32 | c: Unit} (Float32 | {a: Int32 | c: Unit}))
> === parse ===
> _ans_ =
> _ans_ =
> MkIso
> { bwd = \v. case v
> ((Left x)) -> {| b = x |}
Expand Down Expand Up @@ -142,7 +142,7 @@ another. For instance:
> ({ &} & {a: Int32 & b: Float32 & c: Unit})
> ({a: Int32} & {b: Float32 & c: Unit}))
> === parse ===
> _ans_ =
> _ans_ =
> MkIso
> { bwd = \({a = x, ...l}, {, ...r}). (,) {, ...l} {a = x, ...r}
> , fwd = \({, ...l}, {a = x, ...r}). (,) {a = x, ...l} {, ...r}}
Expand Down Expand Up @@ -212,7 +212,7 @@ zipper isomorphisms:
> ({ |} | {a: Int32 | b: Float32 | c: Unit})
> ({a: Int32} | {b: Float32 | c: Unit}))
> === parse ===
> _ans_ =
> _ans_ =
> MkIso
> { bwd = \v. case v
> ((Left w)) -> (case w
Expand Down
4 changes: 2 additions & 2 deletions examples/raytrace.dx
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def sampleLightRadiance
(surfNor, surf) = osurf
(rayPos, _) = inRay
(MkScene objs) = scene
yieldAccum \radiance.
yieldAccum (AddMonoid Float) \radiance.
for i. case objs.i of
PassiveObject _ _ -> ()
Light lightPos hw _ ->
Expand All @@ -227,7 +227,7 @@ def sampleLightRadiance

def trace (params:Params) (scene:Scene n) (initRay:Ray) (k:Key) : Color =
noFilter = [1.0, 1.0, 1.0]
yieldAccum \radiance.
yieldAccum (AddMonoid Float) \radiance.
runState noFilter \filter.
runState initRay \ray.
boundedIter (getAt #maxBounces params) () \i.
Expand Down
2 changes: 1 addition & 1 deletion examples/tiled-matmul.dx
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def matmul (k : Type) ?-> (n : Type) ?-> (m : Type) ?->
vectorTile = Fin VectorWidth
colTile = (colVectors & vectorTile)
(tile2d (\nt:(Tile n rowTile). \mt:(Tile m colTile).
ct = yieldAccum \acc.
ct = yieldAccum (AddMonoid Float) \acc.
for l:k.
for i:rowTile.
ail = broadcastVector a.(nt +> i).l
Expand Down
49 changes: 37 additions & 12 deletions lib/prelude.dx
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,23 @@ def MulMonoid (a:Type) -> (_:Mul a) ?=> : Monoid a =
def Ref (r:Type) (a:Type) : Type = %Ref r a
def get (ref:Ref h s) : {State h} s = %get ref
def (:=) (ref:Ref h s) (x:s) : {State h} Unit = %put ref x

def ask (ref:Ref h r) : {Read h} r = %ask ref
def (+=) (ref:Ref h w) (x:w) : {Accum h} Unit = %tell ref x

data AccumMonoid h w = UnsafeMkAccumMonoid (Monoid w)

@instance
def tableAccumMonoid ((UnsafeMkAccumMonoid m):AccumMonoid h w) ?=> : AccumMonoid h (n=>w) =
%instance mHint = m
def liftTableMonoid (tm:Monoid (n=>w)) ?=> : Monoid (n=>w) = tm
UnsafeMkAccumMonoid liftTableMonoid

def (+=) (am:AccumMonoid h w) ?=> (ref:Ref h w) (x:w) : {Accum h} Unit =
(UnsafeMkAccumMonoid m) = am
%instance mHint = m
updater = \v. mcombine v x
%mextend ref updater

def (!) (ref:Ref h (n=>a)) (i:n) : Ref h a = %indexRef ref i
def fstRef (ref: Ref h (a & b)) : Ref h a = %fstRef ref
def sndRef (ref: Ref h (a & b)) : Ref h b = %sndRef ref
Expand All @@ -328,16 +343,29 @@ def withReader
: {|eff} a =
runReader init action

def MonoidLifter (b:Type) (w:Type) : Type = h:Type -> AccumMonoid h b ?=> AccumMonoid h w

def runAccum
(action: (h:Type ?-> Ref h w -> {Accum h|eff} a))
(mlift:MonoidLifter b w) ?=>
(bm:Monoid b)
(action: (h:Type ?-> AccumMonoid h b ?=> Ref h w -> {Accum h|eff} a))
: {|eff} (a & w) =
def explicitAction (h':Type) (ref:Ref h' w) : {Accum h'|eff} a = action ref
%runWriter explicitAction
-- Normally, only the ?=> lambda binders participate in dictionary synthesis,
-- so we need to explicitly declare `m` as a hint.
%instance bmHint = bm
empty : b = mempty
combine : b -> b -> b = mcombine
def explicitAction (h':Type) (ref:Ref h' w) : {Accum h'|eff} a =
%instance accumBaseMonoidHint : AccumMonoid h' b = UnsafeMkAccumMonoid bm
action ref
%runWriter empty combine explicitAction

def yieldAccum
(action: (h:Type ?-> Ref h w -> {Accum h|eff} a))
(mlift:MonoidLifter b w) ?=>
(m:Monoid b)
(action: (h:Type ?-> AccumMonoid h b ?=> Ref h w -> {Accum h|eff} a))
: {|eff} w =
snd $ runAccum action
snd $ runAccum m action

def runState
(init:s)
Expand Down Expand Up @@ -471,13 +499,10 @@ instance Monoid Ordering
GT -> GT
EQ -> y

-- TODO: accumulate using the True/&& monoid
instance [Eq a] Eq (n=>a)
(==) = \xs ys.
numDifferent : Float =
yieldAccum \ref. for i.
ref += (IToF (BToI (xs.i /= ys.i)))
numDifferent == 0.0
yieldAccum AndMonoid \ref.
for i. ref += xs.i == ys.i

instance [Ord a] Ord (n=>a)
(>) = \xs ys.
Expand Down Expand Up @@ -716,7 +741,7 @@ def reduce (identity:a) (combine:(a->a->a)) (xs:n=>a) : a =
-- TODO: call this `scan` and call the current `scan` something else
def scan' (init:a) (body:n->a->a) : n=>a = snd $ scan init \i x. dup (body i x)
-- TODO: allow tables-via-lambda and get rid of this
def fsum (xs:n=>Float) : Float = yieldAccum \ref. for i. ref += xs i
def fsum (xs:n=>Float) : Float = yieldAccum (AddMonoid Float) \ref. for i. ref += xs i
def sum [Add v] (xs:n=>v) : v = reduce zero (+) xs
def prod [Mul v] (xs:n=>v) : v = reduce one (*) xs
def mean [VSpace v] (xs:n=>v) : v = sum xs / IToF (size n)
Expand Down
Loading

0 comments on commit 72934fe

Please sign in to comment.