From 72934fec1b87d120c04512a3b4ba905c26dc68e5 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 8 Jan 2021 17:47:12 +0000 Subject: [PATCH] Parametrize runWriter by a monoid used for the reduction This makes it possible to express (parallel) reductions over arbitrary monoids. Thanks to this, we can start removing some nasty hacks (like the one used for `Eq (n=>a)`) and make the (work-in-progress) FFT example parallel! Anyway, this whole change turned out to be surprisingly difficult, but thanks to many chats with @dougalm, I think that we've arrived at a particularly nice solution. The crux of the matter is the fact that Dex, unlike most other languages with some form of a built-in reduction operator, allows slicing the accumulator. This poses an interesting problem: if the user was to specify the `Monoid` instance for the full accumulator (e.g. a matrix), then what monoid are we supposed to use for its slice?! As it turns out, this might not even be well defined! For example, the type of square matrices with identity matrix and matrix multiplication forms a monoid, but there is no natural "sub-monoid" we could use in an expression of the form `ref!i += ...`. So, unless we're ok with giving up reference slicing (which we know we want for sure, since this is a way to express e.g. parallel scatters and histograms), we have to come up with a way of constructing those sub-monoids. And here, and answer is to turn the problem around: instead of asking the users to provide us the monoids for the full references, we expect the monoid to refer to some _base type_ (and we call it a _base monoid_). That is, when the `Accum` reference is of type `n=>m=>...=>k=>a`, then any of `m=>...=>k=>a`, ..., `k=>a` and even `a` are considered base types. While this is a bit surprising at first, it turns out to actually be quite convenient, since it does seem more straightforward to say "I want this to be a reduction over `(Float, 0.0, +)`" instead of mentioning the full table type, a broadcast version of `0.0` and a pointwise-lifted version of `+`. Finally, because many data types have multiple valid monoids (`Float` has at least four: `+`, `*`, `min`, `max`), the monoid argument is explicit and those instances can be obtained via the `named-instance` syntax added in the previous commits. Note that I've also included some helper functions which make it possible to synthesize `Monoid` instances automatically from `Add` and `Mul` instance for any given type (see `AddMonoid` and `MulMonoid`). I haven't been fully able to verify the correctness of the parallelization change, because the CUDA backend seems to be broken anyway (sigh...), but the code it generates looks ok. --- examples/isomorphisms.dx | 8 +- examples/raytrace.dx | 4 +- examples/tiled-matmul.dx | 2 +- lib/prelude.dx | 49 ++++++++--- src/lib/Autodiff.hs | 174 +++++++++++++++++++++++++-------------- src/lib/Embed.hs | 81 ++++++++++++------ src/lib/Imp.hs | 86 ++++++++++--------- src/lib/Inference.hs | 4 +- src/lib/PPrint.hs | 4 +- src/lib/Parallelize.hs | 34 ++++---- src/lib/Simplify.hs | 11 ++- src/lib/Syntax.hs | 33 ++++++-- src/lib/Type.hs | 46 ++++++----- tests/ad-tests.dx | 8 +- tests/adt-tests.dx | 6 +- tests/eval-tests.dx | 13 +-- tests/monad-tests.dx | 23 +++--- tests/parser-tests.dx | 2 +- 18 files changed, 359 insertions(+), 229 deletions(-) diff --git a/examples/isomorphisms.dx b/examples/isomorphisms.dx index 9668eac42..102f644d0 100644 --- a/examples/isomorphisms.dx +++ b/examples/isomorphisms.dx @@ -46,7 +46,7 @@ that produce isos. We will start with the first two: :t #b : Iso {a:Int & b:Float & c:Unit} _ > (Iso {a: Int32 & b: Float32 & c: Unit} (Float32 & {a: Int32 & c: Unit})) > === parse === -> _ans_ = +> _ans_ = > MkIso {bwd = \(x, r). {b = x, ...r}, fwd = \{b = x, ...r}. (,) x r} > : Iso {a: Int & b: Float & c: Unit} _ @@ -54,7 +54,7 @@ that produce isos. We will start with the first two: :t #?b : Iso {a:Int | b:Float | c:Unit} _ > (Iso {a: Int32 | b: Float32 | c: Unit} (Float32 | {a: Int32 | c: Unit})) > === parse === -> _ans_ = +> _ans_ = > MkIso > { bwd = \v. case v > ((Left x)) -> {| b = x |} @@ -142,7 +142,7 @@ another. For instance: > ({ &} & {a: Int32 & b: Float32 & c: Unit}) > ({a: Int32} & {b: Float32 & c: Unit})) > === parse === -> _ans_ = +> _ans_ = > MkIso > { bwd = \({a = x, ...l}, {, ...r}). (,) {, ...l} {a = x, ...r} > , fwd = \({, ...l}, {a = x, ...r}). (,) {a = x, ...l} {, ...r}} @@ -212,7 +212,7 @@ zipper isomorphisms: > ({ |} | {a: Int32 | b: Float32 | c: Unit}) > ({a: Int32} | {b: Float32 | c: Unit})) > === parse === -> _ans_ = +> _ans_ = > MkIso > { bwd = \v. case v > ((Left w)) -> (case w diff --git a/examples/raytrace.dx b/examples/raytrace.dx index c6678101f..b0a7e97f6 100644 --- a/examples/raytrace.dx +++ b/examples/raytrace.dx @@ -212,7 +212,7 @@ def sampleLightRadiance (surfNor, surf) = osurf (rayPos, _) = inRay (MkScene objs) = scene - yieldAccum \radiance. + yieldAccum (AddMonoid Float) \radiance. for i. case objs.i of PassiveObject _ _ -> () Light lightPos hw _ -> @@ -227,7 +227,7 @@ def sampleLightRadiance def trace (params:Params) (scene:Scene n) (initRay:Ray) (k:Key) : Color = noFilter = [1.0, 1.0, 1.0] - yieldAccum \radiance. + yieldAccum (AddMonoid Float) \radiance. runState noFilter \filter. runState initRay \ray. boundedIter (getAt #maxBounces params) () \i. diff --git a/examples/tiled-matmul.dx b/examples/tiled-matmul.dx index 7238d671e..677a80746 100644 --- a/examples/tiled-matmul.dx +++ b/examples/tiled-matmul.dx @@ -16,7 +16,7 @@ def matmul (k : Type) ?-> (n : Type) ?-> (m : Type) ?-> vectorTile = Fin VectorWidth colTile = (colVectors & vectorTile) (tile2d (\nt:(Tile n rowTile). \mt:(Tile m colTile). - ct = yieldAccum \acc. + ct = yieldAccum (AddMonoid Float) \acc. for l:k. for i:rowTile. ail = broadcastVector a.(nt +> i).l diff --git a/lib/prelude.dx b/lib/prelude.dx index ccd80ee16..2302d4fcf 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -309,8 +309,23 @@ def MulMonoid (a:Type) -> (_:Mul a) ?=> : Monoid a = def Ref (r:Type) (a:Type) : Type = %Ref r a def get (ref:Ref h s) : {State h} s = %get ref def (:=) (ref:Ref h s) (x:s) : {State h} Unit = %put ref x + def ask (ref:Ref h r) : {Read h} r = %ask ref -def (+=) (ref:Ref h w) (x:w) : {Accum h} Unit = %tell ref x + +data AccumMonoid h w = UnsafeMkAccumMonoid (Monoid w) + +@instance +def tableAccumMonoid ((UnsafeMkAccumMonoid m):AccumMonoid h w) ?=> : AccumMonoid h (n=>w) = + %instance mHint = m + def liftTableMonoid (tm:Monoid (n=>w)) ?=> : Monoid (n=>w) = tm + UnsafeMkAccumMonoid liftTableMonoid + +def (+=) (am:AccumMonoid h w) ?=> (ref:Ref h w) (x:w) : {Accum h} Unit = + (UnsafeMkAccumMonoid m) = am + %instance mHint = m + updater = \v. mcombine v x + %mextend ref updater + def (!) (ref:Ref h (n=>a)) (i:n) : Ref h a = %indexRef ref i def fstRef (ref: Ref h (a & b)) : Ref h a = %fstRef ref def sndRef (ref: Ref h (a & b)) : Ref h b = %sndRef ref @@ -328,16 +343,29 @@ def withReader : {|eff} a = runReader init action +def MonoidLifter (b:Type) (w:Type) : Type = h:Type -> AccumMonoid h b ?=> AccumMonoid h w + def runAccum - (action: (h:Type ?-> Ref h w -> {Accum h|eff} a)) + (mlift:MonoidLifter b w) ?=> + (bm:Monoid b) + (action: (h:Type ?-> AccumMonoid h b ?=> Ref h w -> {Accum h|eff} a)) : {|eff} (a & w) = - def explicitAction (h':Type) (ref:Ref h' w) : {Accum h'|eff} a = action ref - %runWriter explicitAction + -- Normally, only the ?=> lambda binders participate in dictionary synthesis, + -- so we need to explicitly declare `m` as a hint. + %instance bmHint = bm + empty : b = mempty + combine : b -> b -> b = mcombine + def explicitAction (h':Type) (ref:Ref h' w) : {Accum h'|eff} a = + %instance accumBaseMonoidHint : AccumMonoid h' b = UnsafeMkAccumMonoid bm + action ref + %runWriter empty combine explicitAction def yieldAccum - (action: (h:Type ?-> Ref h w -> {Accum h|eff} a)) + (mlift:MonoidLifter b w) ?=> + (m:Monoid b) + (action: (h:Type ?-> AccumMonoid h b ?=> Ref h w -> {Accum h|eff} a)) : {|eff} w = - snd $ runAccum action + snd $ runAccum m action def runState (init:s) @@ -471,13 +499,10 @@ instance Monoid Ordering GT -> GT EQ -> y --- TODO: accumulate using the True/&& monoid instance [Eq a] Eq (n=>a) (==) = \xs ys. - numDifferent : Float = - yieldAccum \ref. for i. - ref += (IToF (BToI (xs.i /= ys.i))) - numDifferent == 0.0 + yieldAccum AndMonoid \ref. + for i. ref += xs.i == ys.i instance [Ord a] Ord (n=>a) (>) = \xs ys. @@ -716,7 +741,7 @@ def reduce (identity:a) (combine:(a->a->a)) (xs:n=>a) : a = -- TODO: call this `scan` and call the current `scan` something else def scan' (init:a) (body:n->a->a) : n=>a = snd $ scan init \i x. dup (body i x) -- TODO: allow tables-via-lambda and get rid of this -def fsum (xs:n=>Float) : Float = yieldAccum \ref. for i. ref += xs i +def fsum (xs:n=>Float) : Float = yieldAccum (AddMonoid Float) \ref. for i. ref += xs i def sum [Add v] (xs:n=>v) : v = reduce zero (+) xs def prod [Mul v] (xs:n=>v) : v = reduce one (*) xs def mean [VSpace v] (xs:n=>v) : v = sum xs / IToF (size n) diff --git a/src/lib/Autodiff.hs b/src/lib/Autodiff.hs index 1da5eac39..c6b0b7b41 100644 --- a/src/lib/Autodiff.hs +++ b/src/lib/Autodiff.hs @@ -124,13 +124,26 @@ linearizeOp :: Op -> LinA Atom linearizeOp op = case op of ScalarUnOp uop x -> linearizeUnOp uop x ScalarBinOp bop x y -> linearizeBinOp bop x y + PrimEffect refArg (MExtend ~(LamVal b body)) -> LinA $ do + (primalRef, mkTangentRef) <- runLinA $ la refArg + (primalUpdate, mkTangentUpdate) <- + buildLamAux b (const $ return PureArrow) \x@(Var v) -> + extendWrt [v] [] $ runLinA $ linearizeBlock (b @> x) body + let LamVal (Bind primalStateVar) _ = primalUpdate + ans <- emitOp $ PrimEffect primalRef $ MExtend primalUpdate + return (ans, do + tangentRef <- mkTangentRef + -- TODO: Assert that tangent update doesn't close over anything? + tangentUpdate <- buildLam (Bind $ "t":>tangentType (varType primalStateVar)) PureArrow \tx -> + extendTangentEnv (primalStateVar @> tx) [] $ mkTangentUpdate + emitOp $ PrimEffect tangentRef $ MExtend tangentUpdate) PrimEffect refArg m -> liftA2 PrimEffect (la refArg) (case m of - MAsk -> pure MAsk - MTell x -> liftA MTell $ la x - MGet -> pure MGet - MPut x -> liftA MPut $ la x) `bindLin` emitOp + MAsk -> pure MAsk + MExtend _ -> error "Unhandled MExtend" + MGet -> pure MGet + MPut x -> liftA MPut $ la x) `bindLin` emitOp IndexRef ref i -> (IndexRef <$> la ref <*> pure i) `bindLin` emitOp FstRef ref -> (FstRef <$> la ref ) `bindLin` emitOp SndRef ref -> (SndRef <$> la ref ) `bindLin` emitOp @@ -261,9 +274,17 @@ linearizeHof env hof = case hof of (ans, linTab) <- unzipTab ansWithLinTab return (ans, buildFor d i' \i'' -> provideRemat vi'' i'' $ app linTab i'' >>= applyLinToTangents) Tile _ _ _ -> notImplemented - RunWriter lam -> linearizeEff Nothing lam True (const RunWriter) (emitRunWriter "r") Writer - RunReader val lam -> linearizeEff (Just val) lam False RunReader (emitRunReader "r") Reader - RunState val lam -> linearizeEff (Just val) lam True RunState (emitRunState "r") State + RunWriter bm ~lam@(BinaryFunVal _ refBinder _ _) -> LinA $ do + unless (checkZeroPlusFloatMonoid bm) $ + error "AD of Accum effect only supported when the base monoid is (0, +) on Float" + let RefTy _ accTy = binderType refBinder + linearizeEff lam Writer (RunWriter bm) (emitRunWriter "r" accTy bm) + RunReader val lam -> LinA $ do + (val', mkLinInit) <- runLinA <$> linearizeAtom =<< substEmbed env val + linearizeEff lam Reader (RunReader val') \f -> mkLinInit >>= emitRunReader "r" `flip` f + RunState val lam -> LinA $ do + (val', mkLinInit) <- runLinA <$> linearizeAtom =<< substEmbed env val + linearizeEff lam State (RunState val') \f -> mkLinInit >>= emitRunState "r" `flip` f RunIO ~(Lam (Abs _ (arrow, body))) -> LinA $ do arrow' <- substEmbed env arrow -- TODO: consider the possibility of other effects here besides IO @@ -279,30 +300,19 @@ linearizeHof env hof = case hof of CatchException _ -> notImplemented Linearize _ -> error "Unexpected linearization" Transpose _ -> error "Unexpected transposition" - PTileReduce _ _ -> error "Unexpected PTileReduce" + PTileReduce _ _ _ -> error "Unexpected PTileReduce" where - linearizeEff maybeInit lam hasResult hofMaker emitter eff = LinA $ do - (valHofMaker, maybeLinInit) <- case maybeInit of - Just val -> do - (val', linVal) <- runLinA <$> linearizeAtom =<< substEmbed env val - return (hofMaker val', Just linVal) - Nothing -> return (hofMaker undefined, Nothing) + linearizeEff lam eff primalHofCon tangentEmitter = do (lam', ref) <- linearizeEffectFun eff lam - (ans, linBody) <- case hasResult of - True -> do - (ansLin, w) <- fromPair =<< emit (Hof $ valHofMaker lam') + -- The reader effect doesn't return any additional values + (ans, linBody) <- case eff of + Reader -> fromPair =<< emit (Hof $ primalHofCon lam') + _ -> do + (ansLin, w) <- fromPair =<< emit (Hof $ primalHofCon lam') (ans, linBody) <- fromPair ansLin return (PairVal ans w, linBody) - False -> fromPair =<< emit (Hof $ valHofMaker lam') - let lin = do - valEmitter <- case maybeLinInit of - Just linVal -> emitter <$> linVal - Nothing -> do - let (BinaryFunTy _ b _ _) = getType lam' - let RefTy _ wTy = binderType b - return $ emitter $ tangentType wTy - valEmitter \ref'@(Var (_:> RefTy (Var (h:>_)) _)) -> do - extendTangentEnv (ref @> ref') [h] $ applyLinToTangents linBody + let lin = tangentEmitter \ref'@(Var (_:> RefTy (Var (h:>_)) _)) -> do + extendTangentEnv (ref @> ref') [h] $ applyLinToTangents linBody return (ans, lin) linearizeEffectFun :: RWS -> Atom -> PrimalM (Atom, Var) @@ -391,24 +401,6 @@ tangentType ty = case ty of _ -> unsupported where unsupported = error $ "Can't differentiate wrt type " ++ pprint ty -addTangent :: MonadEmbed m => Atom -> Atom -> m Atom -addTangent x y = case getType x of - RecordTy (NoExt tys) -> do - elems <- bindM2 (zipWithT addTangent) (getUnpacked x) (getUnpacked y) - return $ Record $ restructure elems tys - TabTy b _ -> buildFor Fwd b \i -> bindM2 addTangent (tabGet x i) (tabGet y i) - TC con -> case con of - BaseType (Scalar _) -> emitOp $ ScalarBinOp FAdd x y - BaseType (Vector _) -> emitOp $ VectorBinOp FAdd x y - UnitType -> return UnitVal - PairType _ _ -> do - (xa, xb) <- fromPair x - (ya, yb) <- fromPair y - PairVal <$> addTangent xa ya <*> addTangent xb yb - _ -> notTangent - _ -> notTangent - where notTangent = error $ "Not a tangent type: " ++ pprint (getType x) - isTrivialForAD :: Expr -> Bool isTrivialForAD expr = isSingletonType tangentTy && exprEffs expr == mempty where tangentTy = tangentType $ getType expr @@ -614,12 +606,16 @@ transposeOp op ct = case op of refArg' <- substTranspose linRefSubst refArg let emitEff = emitOp . PrimEffect refArg' case m of - MAsk -> void $ emitEff $ MTell ct - MTell x -> transposeAtom x =<< emitEff MAsk + MAsk -> void $ emitEff . MExtend =<< (updateAddAt ct) + -- XXX: This assumes that the update function uses a tangent (0, +) monoid, + -- which is why we can ignore the binder (we even can't; we only have a + -- reader reference!). This should have been checked in the transposeHof + -- rule for RunWriter. + MExtend ~(LamVal _ body) -> transposeBlock body =<< emitEff MAsk -- TODO: Do something more efficient for state. We should be able -- to do in-place addition, just like we do for the Writer effect. - MGet -> void $ emitEff . MPut =<< addTangent ct =<< emitEff MGet - MPut x -> do + MGet -> void $ emitEff . MPut =<< addTangent ct =<< emitEff MGet + MPut x -> do transposeAtom x =<< emitEff MGet void $ emitEff $ MPut $ zeroAt $ getType x TabCon ~(TabTy b _) es -> forM_ (enumerate es) \(i, e) -> do @@ -685,11 +681,16 @@ transposeHof hof ct = case hof of return UnitVal where flipDir dir = case dir of Fwd -> Rev; Rev -> Fwd RunReader r ~(BinaryFunVal (Bind (h:>_)) b _ body) -> do - (_, ctr) <- (fromPair =<<) $ emitRunWriter "w" (getType r) \ref -> do + let RefTy _ valTy = binderType b + let baseTy = getBaseMonoidType valTy + baseMonoid <- tangentBaseMonoidFor baseTy + (_, ctr) <- (fromPair =<<) $ emitRunWriter "w" valTy baseMonoid \ref -> do localLinRegion h $ localLinRefSubst (b@>ref) $ transposeBlock body ct return UnitVal transposeAtom r ctr - RunWriter ~(BinaryFunVal (Bind (h:>_)) b _ body) -> do + RunWriter bm ~(BinaryFunVal (Bind (h:>_)) b _ body) -> do + unless (checkZeroPlusFloatMonoid bm) $ + error "AD of Accum effect only supported when the base monoid is (0, +) on Float" (ctBody, ctEff) <- fromPair ct void $ emitRunReader "r" ctEff \ref -> do localLinRegion h $ localLinRefSubst (b@>ref) $ transposeBlock body ctBody @@ -706,7 +707,7 @@ transposeHof hof ct = case hof of CatchException _ -> notImplemented Linearize _ -> error "Unexpected linearization" Transpose _ -> error "Unexpected transposition" - PTileReduce _ _ -> error "Unexpected PTileReduce" + PTileReduce _ _ _ -> error "Unexpected PTileReduce" transposeAtom :: Atom -> Atom -> TransposeM () transposeAtom atom ct = case atom of @@ -776,7 +777,7 @@ isLinEff _ = error "Can't transpose polymorphic effects" emitCTToRef :: Maybe Atom -> Atom -> TransposeM () emitCTToRef mref ct = case mref of - Just ref -> void $ emitOp $ PrimEffect ref (MTell ct) + Just ref -> void . emitOp . PrimEffect ref . MExtend =<< updateAddAt ct Nothing -> return () substTranspose :: Subst a => (TransposeEnv -> SubstEnv) -> a -> TransposeM a @@ -789,13 +790,15 @@ substNonlin :: Subst a => a -> TransposeM a substNonlin = substTranspose nonlinSubst withLinVar :: Binder -> TransposeM a -> TransposeM (a, Atom) -withLinVar b body = case - singletonTypeVal (binderType b) of - Nothing -> flip evalStateT Nothing $ do - ans <- emitRunWriter "ref" (binderType b) \ref -> do - lift (localLinRef (b@>Just ref) body) >>= put . Just >> return UnitVal - (,) <$> (fromJust <$> get) <*> (getSnd ans) - Just x -> (,x) <$> (localLinRef (b@>Nothing) body) -- optimization to avoid accumulating into unit +withLinVar b body = case singletonTypeVal (binderType b) of + Nothing -> flip evalStateT Nothing $ do + let accTy = binderType b + let baseTy = getBaseMonoidType accTy + baseMonoid <- tangentBaseMonoidFor baseTy + ans <- emitRunWriter "ref" accTy baseMonoid \ref -> do + lift (localLinRef (b@>Just ref) body) >>= put . Just >> return UnitVal + (,) <$> (fromJust <$> get) <*> (getSnd ans) + Just x -> (,x) <$> (localLinRef (b@>Nothing) body) -- optimization to avoid accumulating into unit localLinRef :: Env (Maybe Atom) -> TransposeM a -> TransposeM a localLinRef refs = local (<> mempty { linRefs = refs }) @@ -808,3 +811,54 @@ localLinRefSubst s = local (<> mempty { linRefSubst = s }) localNonlinSubst :: SubstEnv -> TransposeM a -> TransposeM a localNonlinSubst s = local (<> mempty { nonlinSubst = s }) + +-- === The (0, +) monoid for tangent types === + +zeroAt :: Type -> Atom +zeroAt ty = case ty of + BaseTy bt -> Con $ Lit $ zeroLit bt + TabTy i a -> TabValA i $ zeroAt a + UnitTy -> UnitVal + PairTy a b -> PairVal (zeroAt a) (zeroAt b) + RecordTy (Ext tys Nothing) -> Record $ fmap zeroAt tys + _ -> unreachable + where + unreachable = error $ "Missing zero case for a tangent type: " ++ pprint ty + zeroLit bt = case bt of + Scalar Float64Type -> Float64Lit 0.0 + Scalar Float32Type -> Float32Lit 0.0 + Vector st -> VecLit $ replicate vectorWidth $ zeroLit $ Scalar st + _ -> unreachable + +updateAddAt :: MonadEmbed m => Atom -> m Atom +updateAddAt x = buildLam (Bind ("t":>getType x)) PureArrow $ addTangent x + +addTangent :: MonadEmbed m => Atom -> Atom -> m Atom +addTangent x y = case getType x of + RecordTy (NoExt tys) -> do + elems <- bindM2 (zipWithT addTangent) (getUnpacked x) (getUnpacked y) + return $ Record $ restructure elems tys + TabTy b _ -> buildFor Fwd b \i -> bindM2 addTangent (tabGet x i) (tabGet y i) + TC con -> case con of + BaseType (Scalar _) -> emitOp $ ScalarBinOp FAdd x y + BaseType (Vector _) -> emitOp $ VectorBinOp FAdd x y + UnitType -> return UnitVal + PairType _ _ -> do + (xa, xb) <- fromPair x + (ya, yb) <- fromPair y + PairVal <$> addTangent xa ya <*> addTangent xb yb + _ -> notTangent + _ -> notTangent + where notTangent = error $ "Not a tangent type: " ++ pprint (getType x) + +tangentBaseMonoidFor :: MonadEmbed m => Type -> m BaseMonoid +tangentBaseMonoidFor ty = BaseMonoid (zeroAt ty) <$> buildLam (Bind ("t":>ty)) PureArrow updateAddAt + +checkZeroPlusFloatMonoid :: BaseMonoid -> Bool +checkZeroPlusFloatMonoid (BaseMonoid zero plus) = checkZero zero && checkPlus plus + where + checkZero z = z == (Con (Lit (Float32Lit 0.0))) + checkPlus f = case f of + BinaryFunVal (Bind x) (Bind y) Pure (Block Empty (Op (ScalarBinOp FAdd (Var x') (Var y')))) -> + (x == x' && y == y') || (x == y' && y == x') + _ -> False diff --git a/src/lib/Embed.hs b/src/lib/Embed.hs index f1d735829..cc2dec229 100644 --- a/src/lib/Embed.hs +++ b/src/lib/Embed.hs @@ -1,4 +1,4 @@ --- Copyright 2019 Google LLC +-- Copyright 2021 Google LLC -- -- Use of this source code is governed by a BSD-style -- license that can be found in the LICENSE file or at @@ -24,14 +24,16 @@ module Embed (emit, emitTo, emitAnn, emitOp, buildDepEffLam, buildLamAux, buildP buildFor, buildForAux, buildForAnn, buildForAnnAux, emitBlock, unzipTab, isSingletonType, emitDecl, withNameHint, singletonTypeVal, scopedDecls, embedScoped, extendScope, checkEmbed, - embedExtend, unpackConsList, emitRunWriter, applyPreludeFunction, - emitRunState, emitMaybeCase, emitWhile, buildDataDef, + embedExtend, applyPreludeFunction, + unpackConsList, unpackLeftLeaningConsList, + emitRunWriter, emitRunWriters, mextendForRef, monoidLift, + emitRunState, emitMaybeCase, emitWhile, buildDataDef, emitRunReader, tabGet, SubstEmbedT, SubstEmbed, runSubstEmbedT, ptrOffset, ptrLoad, unsafePtrLoad, evalBlockE, substTraversalDef, TraversalDef, traverseDecls, traverseDecl, traverseDeclsOpen, traverseBlock, traverseExpr, traverseAtom, - clampPositive, buildNAbs, buildNAbsAux, buildNestedLam, zeroAt, + clampPositive, buildNAbs, buildNAbsAux, buildNestedLam, transformModuleAsBlock, dropSub, appReduceTraversalDef, indexSetSizeE, indexToIntE, intToIndexE, freshVarE) where @@ -231,22 +233,6 @@ inlineLastDecl block@(Block decls result) = Block (toNest (reverse rest)) expr _ -> block -zeroAt :: Type -> Atom -zeroAt ty = case ty of - BaseTy bt -> Con $ Lit $ zeroLit bt - TabTy i a -> TabValA i $ zeroAt a - UnitTy -> UnitVal - PairTy a b -> PairVal (zeroAt a) (zeroAt b) - RecordTy (Ext tys Nothing) -> Record $ fmap zeroAt tys - _ -> unreachable - where - unreachable = error $ "Missing zero case for a tangent type: " ++ pprint ty - zeroLit bt = case bt of - Scalar Float64Type -> Float64Lit 0.0 - Scalar Float32Type -> Float32Lit 0.0 - Vector st -> VecLit $ replicate vectorWidth $ zeroLit $ Scalar st - _ -> unreachable - fLitLike :: Double -> Atom -> Atom fLitLike x t = case getType t of BaseTy (Scalar Float64Type) -> Con $ Lit $ Float64Lit x @@ -384,6 +370,17 @@ unpackConsList xs = case getType xs of liftM (x:) $ unpackConsList rest _ -> error $ "Not a cons list: " ++ pprint (getType xs) +-- ((...((ans, x{n}), x{n-1})..., x2), x1) -> (ans, [x1, ..., x{n}]) +-- This is useful for unpacking results of stacked effect handlers (as produced +-- by e.g. emitRunWriters). +unpackLeftLeaningConsList :: MonadEmbed m => Int -> Atom -> m (Atom, [Atom]) +unpackLeftLeaningConsList depth atom = go depth atom [] + where + go 0 curAtom xs = return (curAtom, reverse xs) + go remDepth curAtom xs = do + (consTail, x) <- fromPair curAtom + go (remDepth - 1) consTail (x : xs) + emitWhile :: MonadEmbed m => m Atom -> m () emitWhile body = do eff <- getAllowedEffects @@ -399,9 +396,35 @@ emitMaybeCase scrut nothingCase justCase = do let resultTy = getType nothingBody emit $ Case scrut [nothingAlt, justAlt] resultTy -emitRunWriter :: MonadEmbed m => Name -> Type -> (Atom -> m Atom) -> m Atom -emitRunWriter v ty body = do - emit . Hof . RunWriter =<< mkBinaryEffFun Writer v ty body +monoidLift :: Type -> Type -> Nest Binder +monoidLift baseTy accTy = case baseTy == accTy of + True -> Empty + False -> case accTy of + TabTy n b -> Nest n $ monoidLift baseTy b + _ -> error $ "Base monoid type mismatch: can't lift " ++ + pprint baseTy ++ " to " ++ pprint accTy + +mextendForRef :: MonadEmbed m => Atom -> BaseMonoid -> Atom -> m Atom +mextendForRef ref (BaseMonoid _ combine) update = do + buildLam (Bind $ "refVal":>accTy) PureArrow \refVal -> + buildNestedFor (fmap (Fwd,) $ toList liftIndices) $ \indices -> do + refElem <- tabGetNd refVal indices + updateElem <- tabGetNd update indices + bindM2 appTryReduce (appTryReduce combine refElem) (return updateElem) + where + TC (RefType _ accTy) = getType ref + FunTy (BinderAnn baseTy) _ _ = getType combine + liftIndices = monoidLift baseTy accTy + +emitRunWriter :: MonadEmbed m => Name -> Type -> BaseMonoid -> (Atom -> m Atom) -> m Atom +emitRunWriter v accTy bm body = do + emit . Hof . RunWriter bm =<< mkBinaryEffFun Writer v accTy body + +emitRunWriters :: MonadEmbed m => [(Name, Type, BaseMonoid)] -> ([Atom] -> m Atom) -> m Atom +emitRunWriters inits body = go inits [] + where + go [] refs = body $ reverse refs + go ((v, accTy, bm):rest) refs = emitRunWriter v accTy bm $ \ref -> go rest (ref:refs) emitRunReader :: MonadEmbed m => Name -> Atom -> (Atom -> m Atom) -> m Atom emitRunReader v x0 body = do @@ -435,13 +458,23 @@ buildForAux = buildForAnnAux . RegularFor buildFor :: MonadEmbed m => Direction -> Binder -> (Atom -> m Atom) -> m Atom buildFor = buildForAnn . RegularFor +buildNestedFor :: forall m. MonadEmbed m => [(Direction, Binder)] -> ([Atom] -> m Atom) -> m Atom +buildNestedFor specs body = go specs [] + where + go :: [(Direction, Binder)] -> [Atom] -> m Atom + go [] indices = body $ reverse indices + go ((d,b):t) indices = buildFor d b $ \i -> go t (i:indices) + buildNestedLam :: MonadEmbed m => Arrow -> [Binder] -> ([Atom] -> m Atom) -> m Atom buildNestedLam _ [] f = f [] buildNestedLam arr (b:bs) f = buildLam b arr \x -> buildNestedLam arr bs \xs -> f (x:xs) tabGet :: MonadEmbed m => Atom -> Atom -> m Atom -tabGet x i = emit $ App x i +tabGet tab idx = emit $ App tab idx + +tabGetNd :: MonadEmbed m => Atom -> [Atom] -> m Atom +tabGetNd tab idxs = foldM (flip tabGet) tab idxs unzipTab :: MonadEmbed m => Atom -> m (Atom, Atom) unzipTab tab = do diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index d95abe9a6..adea44bfe 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -248,9 +248,14 @@ toImpOp (maybeDest, op) = case op of destToAtom dest PrimEffect refDest m -> do case m of - MAsk -> returnVal =<< destToAtom refDest - MTell x -> addToAtom refDest x >> returnVal UnitVal - MPut x -> copyAtom refDest x >> returnVal UnitVal + MAsk -> returnVal =<< destToAtom refDest + MExtend ~(Lam f) -> do + -- TODO: Update in-place? + refValue <- destToAtom refDest + result <- translateBlock mempty (Nothing, snd $ applyAbs f refValue) + copyAtom refDest result + returnVal UnitVal + MPut x -> copyAtom refDest x >> returnVal UnitVal MGet -> do dest <- allocDest maybeDest resultTy -- It might be more efficient to implement a specialized copy for dests @@ -422,28 +427,29 @@ toImpHof env (maybeDest, hof) = do sDest <- fromEmbed $ indexDestDim d dest idx void $ translateBlock (env <> sb @> idx) (Just sDest, sBody) destToAtom dest - PTileReduce idxTy' ~(BinaryFunVal gtidB nthrB _ body) -> do + PTileReduce baseMonoids idxTy' ~(BinaryFunVal gtidB nthrB _ body) -> do idxTy <- impSubst env idxTy' (mappingDest, finalAccDest) <- destPairUnpack <$> allocDest maybeDest resultTy - let PairTy _ accType = resultTy - (numTileWorkgroups, wgResArr, widIdxTy) <- buildKernel idxTy \LaunchInfo{..} buildBody -> do + let PairTy _ accTypes = resultTy + (numTileWorkgroups, wgAccsArr, widIdxTy) <- buildKernel idxTy \LaunchInfo{..} buildBody -> do let widIdxTy = Fin $ toScalarAtom numWorkgroups let tidIdxTy = Fin $ toScalarAtom workgroupSize - wgResArr <- alloc $ TabTy (Ignore widIdxTy) accType - thrAccArr <- alloc $ TabTy (Ignore widIdxTy) $ TabTy (Ignore tidIdxTy) accType + wgAccsArr <- alloc $ TabTy (Ignore widIdxTy) accTypes + thrAccsArr <- alloc $ TabTy (Ignore widIdxTy) $ TabTy (Ignore tidIdxTy) accTypes mappingKernelBody <- buildBody \ThreadInfo{..} -> do let TC (ParIndexRange _ gtid nthr) = threadRange - let scope = freeVars mappingDest - let tileDest = Con $ TabRef $ fst $ flip runSubstEmbed scope $ do + let tileDest = Con $ TabRef $ fst $ flip runSubstEmbed (freeVars mappingDest) $ do buildLam (Bind $ "hwidx":>threadRange) TabArrow \hwidx -> do indexDest mappingDest =<< (emitOp $ Inject hwidx) - wgAccs <- destGet thrAccArr =<< intToIndex widIdxTy wid - thrAcc <- destGet wgAccs =<< intToIndex tidIdxTy tid - let threadDest = Con $ ConRef $ PairCon tileDest thrAcc + wgThrAccs <- destGet thrAccsArr =<< intToIndex widIdxTy wid + thrAccs <- destGet wgThrAccs =<< intToIndex tidIdxTy tid + let thrAccsList = fromDestConsList thrAccs + let threadDest = foldr ((Con . ConRef) ... flip PairCon) tileDest thrAccsList + -- TODO: Make sure that threadDest has the right type void $ translateBlock (env <> gtidB @> gtid <> nthrB @> nthr) (Just threadDest, body) - wgRes <- destGet wgResArr =<< intToIndex widIdxTy wid - workgroupReduce tid wgRes wgAccs workgroupSize - return (mappingKernelBody, (numWorkgroups, wgResArr, widIdxTy)) + wgAccs <- destGet wgAccsArr =<< intToIndex widIdxTy wid + workgroupReduce tid wgAccs wgThrAccs workgroupSize + return (mappingKernelBody, (numWorkgroups, wgAccsArr, widIdxTy)) -- TODO: Skip the reduction kernel if unnecessary? -- TODO: Reduce sequentially in the CPU backend? -- TODO: Actually we only need the previous-power-of-2 many threads @@ -453,13 +459,14 @@ toImpHof env (maybeDest, hof) = do moreThanOneGroup <- (IIdxRepVal 1) `iltI` numWorkgroups guardBlock moreThanOneGroup $ emitStatement IThrowError redKernelBody <- buildBody \ThreadInfo{..} -> - workgroupReduce tid finalAccDest wgResArr numTileWorkgroups + workgroupReduce tid finalAccDest wgAccsArr numTileWorkgroups return (redKernelBody, ()) PairVal <$> destToAtom mappingDest <*> destToAtom finalAccDest where guardBlock cond m = do block <- scopedErrBlock m emitStatement $ ICond cond block (ImpBlock mempty mempty) + -- XXX: Overwrites the contents of arrDest, writes result in resDest workgroupReduce tid resDest arrDest elemCount = do elemCountDown2 <- prevPowerOf2 elemCount let RawRefTy (TabTy arrIdxB _) = getType arrDest @@ -472,7 +479,7 @@ toImpHof env (maybeDest, hof) = do shouldAdd <- bindM2 bandI (tid `iltI` off) (loadIdx `iltI` elemCount) guardBlock shouldAdd $ do threadDest <- destGet arrDest =<< intToIndex arrIdxTy tid - addToAtom threadDest =<< destToAtom =<< destGet arrDest =<< intToIndex arrIdxTy loadIdx + combineWithDest threadDest =<< destToAtom =<< destGet arrDest =<< intToIndex arrIdxTy loadIdx emitStatement ISyncWorkgroup copyAtom offPtr . toScalarAtom =<< off `idivI` (IIdxRepVal 2) cond <- liftM snd $ scopedBlock $ do @@ -484,6 +491,13 @@ toImpHof env (maybeDest, hof) = do firstThread <- tid `iltI` (IIdxRepVal 1) guardBlock firstThread $ copyAtom resDest =<< destToAtom =<< destGet arrDest =<< intToIndex arrIdxTy tid + combineWithDest :: Dest -> Atom -> ImpM () + combineWithDest accsDest accsUpdates = do + let accsDestList = fromDestConsList accsDest + let Right accsUpdatesList = fromConsList accsUpdates + forM_ (zip3 accsDestList baseMonoids accsUpdatesList) $ \(dest, bm, update) -> do + extender <- fromEmbed $ mextendForRef dest bm update + void $ toImpOp (Nothing, PrimEffect dest $ MExtend extender) -- TODO: Do some popcount tricks? prevPowerOf2 :: IExpr -> ImpM IExpr prevPowerOf2 x = do @@ -506,12 +520,13 @@ toImpHof env (maybeDest, hof) = do rDest <- alloc $ getType r copyAtom rDest =<< impSubst env r translateBlock (env <> ref @> rDest) (maybeDest, body) - RunWriter ~(BinaryFunVal _ ref _ body) -> do + RunWriter (BaseMonoid e' _) ~(BinaryFunVal _ ref _ body) -> do + let PairTy _ accTy = resultTy (aDest, wDest) <- destPairUnpack <$> allocDest maybeDest resultTy - let RefTy _ wTy = getType ref - copyAtom wDest (zeroAt wTy) + copyAtom wDest =<< (liftNeutral accTy <$> impSubst env e') void $ translateBlock (env <> ref @> wDest) (Just aDest, body) PairVal <$> destToAtom aDest <*> destToAtom wDest + where liftNeutral accTy e = foldr TabValA e $ monoidLift (getType e) accTy RunState s ~(BinaryFunVal _ ref _ body) -> do (aDest, sDest) <- destPairUnpack <$> allocDest maybeDest resultTy copyAtom sDest =<< impSubst env s @@ -791,6 +806,12 @@ destPairUnpack :: Dest -> (Dest, Dest) destPairUnpack (Con (ConRef (PairCon l r))) = (l, r) destPairUnpack d = error $ "Not a pair destination: " ++ show d +fromDestConsList :: Dest -> [Dest] +fromDestConsList dest = case dest of + Con (ConRef (PairCon h t)) -> h : fromDestConsList t + Con (ConRef UnitCon) -> [] + _ -> error $ "Not a dest cons list: " ++ pprint dest + makeAllocDest :: AllocType -> Type -> ImpM Dest makeAllocDest allocTy ty = fst <$> makeAllocDestWithPtrs allocTy ty @@ -963,29 +984,6 @@ zipWithRefConM f destCon srcCon = case (destCon, srcCon) of (IndexRangeVal _ _ _ iRef, IndexRangeVal _ _ _ i) -> f iRef i _ -> error $ "Unexpected ref/val " ++ pprint (destCon, srcCon) --- TODO: put this in userspace using type classes -addToAtom :: Dest -> Atom -> ImpM () -addToAtom dest src = case (dest, src) of - (Con (BaseTypeRef ptr), x) -> do - let ptr' = fromScalarAtom ptr - let x' = fromScalarAtom x - cur <- loadAnywhere ptr' - let op = case getIType cur of - Scalar _ -> ScalarBinOp - Vector _ -> VectorBinOp - _ -> error $ "The result of load cannot be a reference" - updated <- emitInstr $ IPrimOp $ op FAdd cur x' - storeAnywhere ptr' updated - (Con (TabRef _), TabVal _ _) -> zipTabDestAtom addToAtom dest src - (Con (ConRef (SumAsProd _ _ payloadDest)), Con (SumAsProd _ tag payload)) -> do - unless (all null payload) $ -- optimization - emitSwitch (fromScalarAtom tag) $ - zipWith (zipWithM_ addToAtom) payloadDest payload - (Con (ConRef destCon), Con srcCon) -> zipWithRefConM addToAtom destCon srcCon - (Con (RecordRef dests), Record srcs) -> - zipWithM_ addToAtom (toList dests) (toList srcs) - _ -> error $ "Not implemented " ++ pprint (dest, src) - loadAnywhere :: IExpr -> ImpM IExpr loadAnywhere ptr = do curDev <- asks curDevice diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index 76e726593..cdf9cacfc 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -865,8 +865,8 @@ synthDict ty = case ty of -- TODO: this doesn't de-dup, so we'll get multiple results if we have a -- diamond-shaped hierarchy. -superclass :: Atom -> SynthDictM Atom -superclass dict = return dict <|> do +withSuperclasses :: Atom -> SynthDictM Atom +withSuperclasses dict = return dict <|> do (f, LetBound SuperclassLet _) <- getBinding inferToSynth $ tryApply f dict diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index 655bc9195..60039d3d1 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -252,8 +252,8 @@ prettyPrecPrimCon con = case con of instance PrettyPrec e => Pretty (PrimOp e) where pretty = prettyFromPrettyPrec instance PrettyPrec e => PrettyPrec (PrimOp e) where prettyPrec op = case op of - PrimEffect ref (MPut val ) -> atPrec LowestPrec $ pApp ref <+> ":=" <+> pApp val - PrimEffect ref (MTell val) -> atPrec LowestPrec $ pApp ref <+> "+=" <+> pApp val + PrimEffect ref (MPut val ) -> atPrec LowestPrec $ pApp ref <+> ":=" <+> pApp val + PrimEffect ref (MExtend update) -> atPrec LowestPrec $ "extend" <+> pApp ref <+> "using" <+> pLowest update PtrOffset ptr idx -> atPrec LowestPrec $ pApp ptr <+> "+>" <+> pApp idx PtrLoad ptr -> atPrec AppPrec $ pAppArg "load" [ptr] RecordCons items rest -> diff --git a/src/lib/Parallelize.hs b/src/lib/Parallelize.hs index e11842020..b273faca7 100644 --- a/src/lib/Parallelize.hs +++ b/src/lib/Parallelize.hs @@ -20,6 +20,7 @@ import Cat import Env import Type import PPrint +import Util (for) -- TODO: extractParallelism can benefit a lot from horizontal fusion (can happen be after) -- TODO: Parallelism extraction can emit some really cheap (but not trivial) @@ -45,7 +46,7 @@ data LoopEnv = LoopEnv , delayedApps :: Env (Atom, [Atom]) -- (n @> (arr, bs)), n and bs in scope of the original program -- arr in scope of the newly constructed program! } -data AccEnv = AccEnv { activeAccs :: Env Var } +data AccEnv = AccEnv { activeAccs :: Env (Var, BaseMonoid) } -- (reference, its base monoid) type TLParallelM = SubstEmbedT (State AccEnv) -- Top-level non-parallel statements type LoopM = ReaderT LoopEnv TLParallelM -- Generation of (parallel) loop nests @@ -69,7 +70,7 @@ parallelTraverseExpr expr = case expr of Hof (For (RegularFor _) fbody@(LamVal b body)) -> do -- TODO: functionEffs is an overapproximation of the effects that really appear inside refs <- gets activeAccs - let allowedRegions = foldMap (\(varType -> RefTy (Var reg) _) -> reg @> ()) refs + let allowedRegions = foldMap (\(varType . fst -> RefTy (Var reg) _) -> reg @> ()) refs (EffectRow bodyEffs t) <- substEmbedR $ functionEffs fbody let onlyAllowedEffects = all (parallelizableEffect allowedRegions) $ toList bodyEffs case t == Nothing && onlyAllowedEffects of @@ -77,11 +78,11 @@ parallelTraverseExpr expr = case expr of b' <- substEmbedR b liftM Atom $ runLoopM $ withLoopBinder b' $ buildParallelBlock $ asABlock body False -> nothingSpecial - Hof (RunWriter (BinaryFunVal h b _ body)) -> do + Hof (RunWriter bm (BinaryFunVal h b _ body)) -> do ~(RefTy _ accTy) <- traverseAtom substTraversalDef $ binderType b - liftM Atom $ emitRunWriter (binderNameHint b) accTy \ref@(Var refVar) -> do + liftM Atom $ emitRunWriter (binderNameHint b) accTy bm \ref@(Var refVar) -> do let RefTy h' _ = varType refVar - modify \accEnv -> accEnv { activeAccs = activeAccs accEnv <> b @> refVar } + modify \accEnv -> accEnv { activeAccs = activeAccs accEnv <> b @> (refVar, bm) } extendR (h @> h' <> b @> ref) $ evalBlockE parallelTrav body -- TODO: Do some alias analysis. This is not fundamentally hard, but it is a little annoying. -- We would have to track not only the base references, but also all the aliases, along @@ -214,23 +215,16 @@ emitLoops buildPureLoop (ABlock decls result) = do buildLam (Bind $ "gtid" :> IdxRepTy) PureArrow \gtid -> do buildLam (Bind $ "nthr" :> IdxRepTy) PureArrow \nthr -> do let threadRange = TC $ ParIndexRange iterTy gtid nthr - let accTys = mkConsListTy $ fmap (derefType . varType) newRefs - emitRunWriter "refsList" accTys \localRefsList -> do - localRefs <- unpackRefConsList localRefsList + let writerSpecs = for newRefs \(ref, bm) -> (varName ref, derefType (varType ref), bm) + emitRunWriters writerSpecs $ \localRefs -> do buildFor Fwd (Bind $ "tidx" :> threadRange) \tidx -> do pari <- emitOp $ Inject tidx extendR (newEnv oldRefNames localRefs) $ buildBody pari - (ans, updateList) <- fromPair =<< (emit $ Hof $ PTileReduce iterTy body) + (ans, updateList) <- fromPair =<< (emit $ Hof $ PTileReduce (fmap snd newRefs) iterTy body) updates <- unpackConsList updateList - forM_ (zip newRefs updates) \(ref, update) -> - emitOp $ PrimEffect (Var ref) $ MTell update + forM_ (zip newRefs updates) $ \((ref, bm), update) -> do + updater <- mextendForRef (Var ref) bm update + emitOp $ PrimEffect (Var ref) $ MExtend updater return ans - where - derefType ~(RefTy _ accTy) = accTy - unpackRefConsList xs = case derefType $ getType xs of - UnitTy -> return [] - PairTy _ _ -> do - x <- getFstRef xs - rest <- getSndRef xs - (x:) <$> unpackRefConsList rest - _ -> error $ "Not a ref cons list: " ++ pprint (getType xs) + where + derefType ~(RefTy _ accTy) = accTy diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index 51aab25cb..5ae910659 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -432,6 +432,9 @@ simplifyOp op = case op of -- Simplify the case away if we can. dropSub $ simplifyExpr $ Case full alts $ VariantTy resultRow _ -> emitOp op + PrimEffect ref (MExtend f) -> dropSub $ do + ~(f', Nothing) <- simplifyLam f + emitOp $ PrimEffect ref $ MExtend f' _ -> emitOp op simplifyHof :: Hof -> SimplifyM Atom @@ -446,7 +449,7 @@ simplifyHof hof = case hof of ~(fT', Nothing) <- simplifyLam fT ~(fS', Nothing) <- simplifyLam fS emit $ Hof $ Tile d fT' fS' - PTileReduce _ _ -> error "Unexpected PTileReduce" + PTileReduce _ _ _ -> error "Unexpected PTileReduce" While body -> do ~(body', Nothing) <- simplifyLam body emit $ Hof $ While body' @@ -463,9 +466,11 @@ simplifyHof hof = case hof of r' <- simplifyAtom r ~(lam', recon) <- simplifyBinaryLam lam applyRecon recon =<< (emit $ Hof $ RunReader r' lam') - RunWriter lam -> do + RunWriter (BaseMonoid e combine) lam -> do + e' <- simplifyAtom e + ~(combine', Nothing) <- simplifyBinaryLam combine ~(lam', recon) <- simplifyBinaryLam lam - (ans, w) <- fromPair =<< (emit $ Hof $ RunWriter lam') + (ans, w) <- fromPair =<< (emit $ Hof $ RunWriter (BaseMonoid e' combine') lam') ans' <- applyRecon recon ans return $ PairVal ans' w RunState s lam -> do diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index 707be5460..54a57925e 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -45,9 +45,11 @@ module Syntax ( DataDef (..), DataConDef (..), UConDef (..), Nest (..), toNest, subst, deShadow, scopelessSubst, absArgType, applyAbs, makeAbs, applyNaryAbs, applyDataDefParams, freshSkolemVar, IndexStructure, - mkConsList, mkConsListTy, fromConsList, fromConsListTy, extendEffRow, + mkConsList, mkConsListTy, fromConsList, fromConsListTy, fromLeftLeaningConsListTy, + extendEffRow, getProjection, outputStreamPtrName, initBindings, varType, binderType, isTabTy, LogLevel (..), IRVariant (..), + BaseMonoidP (..), BaseMonoid, getBaseMonoidType, applyIntBinOp, applyIntCmpOp, applyFloatBinOp, applyFloatUnOp, getIntLit, getFloatLit, sizeOf, ptrSize, vectorWidth, pattern MaybeTy, pattern JustAtom, pattern NothingAtom, @@ -384,16 +386,20 @@ data PrimHof e = | Tile Int e e -- dimension number, tiled body, scalar body | While e | RunReader e e - | RunWriter e + | RunWriter (BaseMonoidP e) e | RunState e e | RunIO e | CatchException e | Linearize e | Transpose e - | PTileReduce e e -- index set, thread body + | PTileReduce [BaseMonoidP e] e e -- accumulator monoids, index set, thread body deriving (Show, Eq, Generic, Functor, Foldable, Traversable) -data PrimEffect e = MAsk | MTell e | MGet | MPut e +data BaseMonoidP e = BaseMonoid { baseEmpty :: e, baseCombine :: e } + deriving (Show, Eq, Generic, Functor, Foldable, Traversable) +type BaseMonoid = BaseMonoidP Atom + +data PrimEffect e = MAsk | MExtend e | MGet | MPut e deriving (Show, Eq, Generic, Functor, Foldable, Traversable) data BinOp = IAdd | ISub | IMul | IDiv | ICmp CmpOp @@ -440,6 +446,11 @@ primNameToStr prim = case lookup prim $ map swap $ M.toList builtinNames of showPrimName :: PrimExpr e -> String showPrimName prim = primNameToStr $ fmap (const ()) prim +getBaseMonoidType :: Type -> Type +getBaseMonoidType ty = case ty of + TabTy _ b -> getBaseMonoidType b + _ -> ty + -- === effects === data EffectRow = EffectRow (S.Set Effect) (Maybe Name) @@ -1481,6 +1492,15 @@ fromConsListTy ty = case ty of PairTy t rest -> (t:) <$> fromConsListTy rest _ -> throw CompilerErr $ "Not a pair or unit: " ++ show ty +-- ((...((ans & x{n}) & x{n-1})... & x2) & x1) -> (ans, [x1, ..., x{n}]) +fromLeftLeaningConsListTy :: MonadError Err m => Int -> Type -> m (Type, [Type]) +fromLeftLeaningConsListTy depth initTy = go depth initTy [] + where + go 0 ty xs = return (ty, reverse xs) + go remDepth ty xs = case ty of + PairTy lt rt -> go (remDepth - 1) lt (rt : xs) + _ -> throw CompilerErr $ "Not a pair: " ++ show xs + fromConsList :: MonadError Err m => Atom -> m [Atom] fromConsList xs = case xs of UnitVal -> return [] @@ -1600,7 +1620,7 @@ builtinNames = M.fromList , ("throwError" , OpExpr $ ThrowError ()) , ("throwException" , OpExpr $ ThrowException ()) , ("ask" , OpExpr $ PrimEffect () $ MAsk) - , ("tell" , OpExpr $ PrimEffect () $ MTell ()) + , ("mextend" , OpExpr $ PrimEffect () $ MExtend ()) , ("get" , OpExpr $ PrimEffect () $ MGet) , ("put" , OpExpr $ PrimEffect () $ MPut ()) , ("indexRef" , OpExpr $ IndexRef () ()) @@ -1610,7 +1630,7 @@ builtinNames = M.fromList , ("linearize" , HofExpr $ Linearize ()) , ("linearTranspose" , HofExpr $ Transpose ()) , ("runReader" , HofExpr $ RunReader () ()) - , ("runWriter" , HofExpr $ RunWriter ()) + , ("runWriter" , HofExpr $ RunWriter (BaseMonoid () ()) ()) , ("runState" , HofExpr $ RunState () ()) , ("runIO" , HofExpr $ RunIO ()) , ("catchException" , HofExpr $ CatchException ()) @@ -1665,6 +1685,7 @@ instance Store a => Store (Nest a) instance Store a => Store (ArrowP a) instance Store a => Store (Limit a) instance Store a => Store (PrimEffect a) +instance Store a => Store (BaseMonoidP a) instance Store a => Store (LabeledItems a) instance (Store a, Store b) => Store (ExtLabeledItems a b) instance Store ForAnn diff --git a/src/lib/Type.hs b/src/lib/Type.hs index 29248533a..4fd549e21 100644 --- a/src/lib/Type.hs +++ b/src/lib/Type.hs @@ -272,10 +272,10 @@ exprEffs expr = case expr of App f _ -> functionEffs f Op op -> case op of PrimEffect ref m -> case m of - MGet -> oneEffect (RWSEffect State h) - MPut _ -> oneEffect (RWSEffect State h) - MAsk -> oneEffect (RWSEffect Reader h) - MTell _ -> oneEffect (RWSEffect Writer h) + MGet -> oneEffect (RWSEffect State h) + MPut _ -> oneEffect (RWSEffect State h) + MAsk -> oneEffect (RWSEffect Reader h) + MExtend _ -> oneEffect (RWSEffect Writer h) where RefTy (Var (h:>_)) _ = getType ref ThrowException _ -> oneEffect ExceptionEffect IOAlloc _ _ -> oneEffect IOEffect @@ -291,9 +291,9 @@ exprEffs expr = case expr of Linearize _ -> mempty -- Body has to be a pure function Transpose _ -> mempty -- Body has to be a pure function RunReader _ f -> handleRWSRunner Reader f - RunWriter f -> handleRWSRunner Writer f + RunWriter _ f -> handleRWSRunner Writer f RunState _ f -> handleRWSRunner State f - PTileReduce _ _ -> mempty + PTileReduce _ _ _ -> mempty RunIO ~(Lam (Abs _ (PlainArrow (EffectRow effs t), _))) -> EffectRow (S.delete IOEffect effs) t CatchException ~(Lam (Abs _ (PlainArrow (EffectRow effs t), _))) -> @@ -445,14 +445,14 @@ instance CoreVariant (PrimHof a) where For _ _ -> alwaysAllowed While _ -> alwaysAllowed RunReader _ _ -> alwaysAllowed - RunWriter _ -> alwaysAllowed + RunWriter _ _ -> alwaysAllowed RunState _ _ -> alwaysAllowed RunIO _ -> alwaysAllowed Linearize _ -> goneBy Simp Transpose _ -> goneBy Simp Tile _ _ _ -> alwaysAllowed - PTileReduce _ _ -> absentUntil Simp -- really absent until parallelization - CatchException _ -> goneBy Simp + PTileReduce _ _ _ -> absentUntil Simp -- really absent until parallelization + CatchException _ -> goneBy Simp -- TODO: namespace restrictions? alwaysAllowed :: VariantM () @@ -704,10 +704,10 @@ typeCheckOp op = case op of PrimEffect ref m -> do TC (RefType ~(Just (Var (h':>TyKind))) s) <- typeCheck ref case m of - MGet -> declareEff (RWSEffect State h') $> s - MPut x -> x|:s >> declareEff (RWSEffect State h') $> UnitTy - MAsk -> declareEff (RWSEffect Reader h') $> s - MTell x -> x|:s >> declareEff (RWSEffect Writer h') $> UnitTy + MGet -> declareEff (RWSEffect State h') $> s + MPut x -> x|:s >> declareEff (RWSEffect State h') $> UnitTy + MAsk -> declareEff (RWSEffect Reader h') $> s + MExtend x -> x|:(s --> s) >> declareEff (RWSEffect Writer h') $> UnitTy IndexRef ref i -> do RefTy h (TabTyAbs a) <- typeCheck ref i |: absArgType a @@ -855,15 +855,16 @@ typeCheckHof hof = case hof of replaceDim 0 (TabTy _ b) n = TabTy (Ignore n) b replaceDim d (TabTy dv b) n = TabTy dv $ replaceDim (d-1) b n replaceDim _ _ _ = error "This should be checked before" - PTileReduce n mapping -> do - -- mapping : gtid:IdxRepTy -> nthr:IdxRepTy -> ((ParIndexRange n gtid nthr)=>a, r) + PTileReduce baseMonoids n mapping -> do + -- mapping : gtid:IdxRepTy -> nthr:IdxRepTy -> (...((ParIndexRange n gtid nthr)=>a, acc{n})..., acc1) BinaryFunTy (Bind gtid) (Bind nthr) Pure mapResultTy <- typeCheck mapping - PairTy tiledArrTy accTy <- return mapResultTy + (tiledArrTy, accTys) <- fromLeftLeaningConsListTy (length baseMonoids) mapResultTy let threadRange = TC $ ParIndexRange n (Var gtid) (Var nthr) TabTy threadRange' tileElemTy <- return tiledArrTy checkEq threadRange (binderType threadRange') - -- PTileReduce n mapping : (n=>a, ro) - return $ PairTy (TabTy (Ignore n) tileElemTy) accTy + -- TODO: Check compatibility of baseMonoids and accTys (need to be careful about lifting!) + -- PTileReduce n mapping : (n=>a, (acc1, ..., acc{n})) + return $ PairTy (TabTy (Ignore n) tileElemTy) $ mkConsListTy accTys While body -> do Pi (Abs (Ignore UnitTy) (arr , condTy)) <- typeCheck body declareEffs $ arrowEff arr @@ -879,7 +880,14 @@ typeCheckHof hof = case hof of (resultTy, readTy) <- checkRWSAction Reader f r |: readTy return resultTy - RunWriter f -> uncurry PairTy <$> checkRWSAction Writer f + RunWriter _ f -> do + -- XXX: We can't verify compatibility between the base monoid and f, because + -- the only way in which they are related in the runAccum definition is via + -- the AccumMonoid typeclass. The frontend constraints should be sufficient + -- to ensure that only well typed programs are accepted, but it is a bit + -- disappointing that we cannot verify that internally. We might want to consider + -- e.g. only disabling this check for prelude. + uncurry PairTy <$> checkRWSAction Writer f RunState s f -> do (resultTy, stateTy) <- checkRWSAction State f s |: stateTy diff --git a/tests/ad-tests.dx b/tests/ad-tests.dx index 6affc69f6..a843a49ad 100644 --- a/tests/ad-tests.dx +++ b/tests/ad-tests.dx @@ -1,6 +1,6 @@ -- TODO: use prelude sum instead once we can differentiate state effect -def sum' (xs:n=>Float) : Float = yieldAccum \ref. for i. ref += xs.i +def sum' (xs:n=>Float) : Float = yieldAccum (AddMonoid Float) \ref. for i. ref += xs.i :p f : Float -> Float = \x. x @@ -69,7 +69,7 @@ def sum' (xs:n=>Float) : Float = yieldAccum \ref. for i. ref += xs.i :p jvp sum' [1., 2.] [10.0, 20.0] > 30. -f : Float -> Float = \x. yieldAccum \ref. ref += x +f : Float -> Float = \x. yieldAccum (AddMonoid Float) \ref. ref += x :p jvp f 1.0 1.0 > 1. @@ -167,7 +167,7 @@ tripleit : Float --o Float = \x. x + x + x > [2., 4.] myOtherSquare : Float -> Float = - \x. yieldAccum \w. w += x * x + \x. yieldAccum (AddMonoid Float) \w. w += x * x :p checkDeriv myOtherSquare 3.0 > True @@ -225,7 +225,7 @@ vec = [1.] :p f : Float -> Float = \x. y = x * 2.0 - yieldAccum \a. + yieldAccum (AddMonoid Float) \a. a += x * 2.0 a += y grad f 1.0 diff --git a/tests/adt-tests.dx b/tests/adt-tests.dx index 08a28c6d0..42f84c043 100644 --- a/tests/adt-tests.dx +++ b/tests/adt-tests.dx @@ -102,7 +102,7 @@ myTab = [MyLeft 1, MyRight 3.5, MyLeft 123, MyLeft 456] > Runtime error :p - yieldAccum \ref. + yieldAccum (AddMonoid Float) \ref. for i. case myTab.i of MyLeft tmp -> () MyRight val -> ref += 1.0 + val @@ -110,7 +110,7 @@ myTab = [MyLeft 1, MyRight 3.5, MyLeft 123, MyLeft 456] :p -- check that the order of the case alternatives doesn't matter - yieldAccum \ref. + yieldAccum (AddMonoid Float) \ref. for i. case myTab.i of MyRight val -> ref += 1.0 + val MyLeft tmp -> () @@ -128,7 +128,7 @@ threeCaseTab : (Fin 4)=>ThreeCases = > [(TheIntCase 3), TheEmptyCase, (ThePairCase 2 0.1), TheEmptyCase] :p - yieldAccum \ref. + yieldAccum (AddMonoid Float) \ref. for i. case threeCaseTab.i of TheEmptyCase -> ref += 1000.0 ThePairCase x y -> ref += 100.0 + y + IToF x diff --git a/tests/eval-tests.dx b/tests/eval-tests.dx index f853dceab..2426304f7 100644 --- a/tests/eval-tests.dx +++ b/tests/eval-tests.dx @@ -698,7 +698,7 @@ def newtonSolve (tol:Float) (f : Float -> Float) (x0:Float) : Float = > 415 :p - (f, w) = runAccum \ref. + (f, w) = runAccum (AddMonoid Float) \ref. ref += 2.0 w = 2 \z. z + w @@ -716,17 +716,6 @@ arr2d = for i:(Fin 2). for j:(Fin 2). (iota _).(i,j) arr2d.(1@_) > [2, 3] -:p - runState (1,2) \ref. - r1 = fstRef ref - r2 = sndRef ref - x = get r1 - y = get r2 - r2 := x - r1 := y -> ((), (2, 1)) - - :p any [True, False] > True :p any [False, False] diff --git a/tests/monad-tests.dx b/tests/monad-tests.dx index 8f9994252..58e116d7f 100644 --- a/tests/monad-tests.dx +++ b/tests/monad-tests.dx @@ -27,6 +27,7 @@ :p def rwsAction (rh:Type) ?-> (wh:Type) ?-> (sh:Type) ?-> + (_:AccumMonoid wh Float) ?=> (r:Ref rh Int) (w:Ref wh Float) (s:Ref sh Bool) : {Read rh, Accum wh, State sh} Int = x = get s @@ -38,7 +39,7 @@ withReader 2 \r. runState True \s. - runAccum \w. + runAccum (AddMonoid Float) \w. rwsAction r w s > ((4, 6.), False) @@ -56,29 +57,31 @@ :p def m (wh:Type) ?-> (sh:Type) ?-> + (_:AccumMonoid wh Float) ?=> (w:Ref wh Float) (s:Ref sh Float) : {Accum wh, State sh} Unit = x = get s w += x - runState 1.0 \s. runAccum \w . m w s + runState 1.0 \s. runAccum (AddMonoid Float) \w . m w s > (((), 1.), 1.) -def myAction (w:Ref hw Float) (r:Ref hr Float) : {Read hr, Accum hw} Unit = +def myAction [AccumMonoid hw Float] (w:Ref hw Float) (r:Ref hr Float) : {Read hr, Accum hw} Unit = x = ask r w += x w += 2.0 -:p withReader 1.5 \r. runAccum \w. myAction w r +:p withReader 1.5 \r. runAccum (AddMonoid Float) \w. myAction w r > ((), 3.5) :p def m (h1:Type) ?-> (h2:Type) ?-> + (_:AccumMonoid h1 Float) ?=> (_:AccumMonoid h2 Float) ?=> (w1:Ref h1 Float) (w2:Ref h2 Float) : {Accum h1, Accum h2} Unit = w1 += 1.0 w2 += 3.0 w1 += 1.0 - runAccum \w1. runAccum \w2. m w1 w2 + runAccum (AddMonoid Float) \w1. runAccum (AddMonoid Float) \w2. m w1 w2 > (((), 3.), 2.) def foom (h:Type) ?-> (s:Ref h ((Fin 3)=>Int)) : {State h} Unit = @@ -125,8 +128,8 @@ def foom (h:Type) ?-> (s:Ref h ((Fin 3)=>Int)) : {State h} Unit = -- (maybe just explicit implicit args) :p withReader 2.0 \r. - runAccum \w. - runAccum \w'. + runAccum (AddMonoid Float) \w. + runAccum (AddMonoid Float) \w'. runState 3 \s. x = ask r y = get s @@ -151,19 +154,19 @@ symmetrizeInPlace [[1.,2.],[3.,4.]] :p withReader 5 \r. () > () -:p yieldAccum \w. +:p yieldAccum (AddMonoid Float) \w. for i:(Fin 2). w += 1.0 w += 1.0 > 4. -:p yieldAccum \w. +:p yieldAccum (AddMonoid Float) \w. for i:(Fin 2). w += 1.0 w += 1.0 > 3. -:p yieldAccum \ref. +:p yieldAccum (AddMonoid Float) \ref. ref += [1.,2.,3.] ref += [2.,4.,5.] > [3., 6., 8.] diff --git a/tests/parser-tests.dx b/tests/parser-tests.dx index 93d216354..3729dae35 100644 --- a/tests/parser-tests.dx +++ b/tests/parser-tests.dx @@ -113,7 +113,7 @@ def myInt : {State h} Int = 1 > Nullary def can't have effects :p - yieldAccum \ref. + yieldAccum (AddMonoid Float) \ref. x = if True then 1. else 3. if True then ref += x