diff --git a/examples/isomorphisms.dx b/examples/isomorphisms.dx index 9668eac42..102f644d0 100644 --- a/examples/isomorphisms.dx +++ b/examples/isomorphisms.dx @@ -46,7 +46,7 @@ that produce isos. We will start with the first two: :t #b : Iso {a:Int & b:Float & c:Unit} _ > (Iso {a: Int32 & b: Float32 & c: Unit} (Float32 & {a: Int32 & c: Unit})) > === parse === -> _ans_ = +> _ans_ = > MkIso {bwd = \(x, r). {b = x, ...r}, fwd = \{b = x, ...r}. (,) x r} > : Iso {a: Int & b: Float & c: Unit} _ @@ -54,7 +54,7 @@ that produce isos. We will start with the first two: :t #?b : Iso {a:Int | b:Float | c:Unit} _ > (Iso {a: Int32 | b: Float32 | c: Unit} (Float32 | {a: Int32 | c: Unit})) > === parse === -> _ans_ = +> _ans_ = > MkIso > { bwd = \v. case v > ((Left x)) -> {| b = x |} @@ -142,7 +142,7 @@ another. For instance: > ({ &} & {a: Int32 & b: Float32 & c: Unit}) > ({a: Int32} & {b: Float32 & c: Unit})) > === parse === -> _ans_ = +> _ans_ = > MkIso > { bwd = \({a = x, ...l}, {, ...r}). (,) {, ...l} {a = x, ...r} > , fwd = \({, ...l}, {a = x, ...r}). (,) {a = x, ...l} {, ...r}} @@ -212,7 +212,7 @@ zipper isomorphisms: > ({ |} | {a: Int32 | b: Float32 | c: Unit}) > ({a: Int32} | {b: Float32 | c: Unit})) > === parse === -> _ans_ = +> _ans_ = > MkIso > { bwd = \v. case v > ((Left w)) -> (case w diff --git a/examples/raytrace.dx b/examples/raytrace.dx index c6678101f..b0a7e97f6 100644 --- a/examples/raytrace.dx +++ b/examples/raytrace.dx @@ -212,7 +212,7 @@ def sampleLightRadiance (surfNor, surf) = osurf (rayPos, _) = inRay (MkScene objs) = scene - yieldAccum \radiance. + yieldAccum (AddMonoid Float) \radiance. for i. case objs.i of PassiveObject _ _ -> () Light lightPos hw _ -> @@ -227,7 +227,7 @@ def sampleLightRadiance def trace (params:Params) (scene:Scene n) (initRay:Ray) (k:Key) : Color = noFilter = [1.0, 1.0, 1.0] - yieldAccum \radiance. + yieldAccum (AddMonoid Float) \radiance. runState noFilter \filter. runState initRay \ray. boundedIter (getAt #maxBounces params) () \i. diff --git a/examples/tiled-matmul.dx b/examples/tiled-matmul.dx index 7238d671e..677a80746 100644 --- a/examples/tiled-matmul.dx +++ b/examples/tiled-matmul.dx @@ -16,7 +16,7 @@ def matmul (k : Type) ?-> (n : Type) ?-> (m : Type) ?-> vectorTile = Fin VectorWidth colTile = (colVectors & vectorTile) (tile2d (\nt:(Tile n rowTile). \mt:(Tile m colTile). - ct = yieldAccum \acc. + ct = yieldAccum (AddMonoid Float) \acc. for l:k. for i:rowTile. ail = broadcastVector a.(nt +> i).l diff --git a/lib/prelude.dx b/lib/prelude.dx index ccd80ee16..2302d4fcf 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -309,8 +309,23 @@ def MulMonoid (a:Type) -> (_:Mul a) ?=> : Monoid a = def Ref (r:Type) (a:Type) : Type = %Ref r a def get (ref:Ref h s) : {State h} s = %get ref def (:=) (ref:Ref h s) (x:s) : {State h} Unit = %put ref x + def ask (ref:Ref h r) : {Read h} r = %ask ref -def (+=) (ref:Ref h w) (x:w) : {Accum h} Unit = %tell ref x + +data AccumMonoid h w = UnsafeMkAccumMonoid (Monoid w) + +@instance +def tableAccumMonoid ((UnsafeMkAccumMonoid m):AccumMonoid h w) ?=> : AccumMonoid h (n=>w) = + %instance mHint = m + def liftTableMonoid (tm:Monoid (n=>w)) ?=> : Monoid (n=>w) = tm + UnsafeMkAccumMonoid liftTableMonoid + +def (+=) (am:AccumMonoid h w) ?=> (ref:Ref h w) (x:w) : {Accum h} Unit = + (UnsafeMkAccumMonoid m) = am + %instance mHint = m + updater = \v. mcombine v x + %mextend ref updater + def (!) (ref:Ref h (n=>a)) (i:n) : Ref h a = %indexRef ref i def fstRef (ref: Ref h (a & b)) : Ref h a = %fstRef ref def sndRef (ref: Ref h (a & b)) : Ref h b = %sndRef ref @@ -328,16 +343,29 @@ def withReader : {|eff} a = runReader init action +def MonoidLifter (b:Type) (w:Type) : Type = h:Type -> AccumMonoid h b ?=> AccumMonoid h w + def runAccum - (action: (h:Type ?-> Ref h w -> {Accum h|eff} a)) + (mlift:MonoidLifter b w) ?=> + (bm:Monoid b) + (action: (h:Type ?-> AccumMonoid h b ?=> Ref h w -> {Accum h|eff} a)) : {|eff} (a & w) = - def explicitAction (h':Type) (ref:Ref h' w) : {Accum h'|eff} a = action ref - %runWriter explicitAction + -- Normally, only the ?=> lambda binders participate in dictionary synthesis, + -- so we need to explicitly declare `m` as a hint. + %instance bmHint = bm + empty : b = mempty + combine : b -> b -> b = mcombine + def explicitAction (h':Type) (ref:Ref h' w) : {Accum h'|eff} a = + %instance accumBaseMonoidHint : AccumMonoid h' b = UnsafeMkAccumMonoid bm + action ref + %runWriter empty combine explicitAction def yieldAccum - (action: (h:Type ?-> Ref h w -> {Accum h|eff} a)) + (mlift:MonoidLifter b w) ?=> + (m:Monoid b) + (action: (h:Type ?-> AccumMonoid h b ?=> Ref h w -> {Accum h|eff} a)) : {|eff} w = - snd $ runAccum action + snd $ runAccum m action def runState (init:s) @@ -471,13 +499,10 @@ instance Monoid Ordering GT -> GT EQ -> y --- TODO: accumulate using the True/&& monoid instance [Eq a] Eq (n=>a) (==) = \xs ys. - numDifferent : Float = - yieldAccum \ref. for i. - ref += (IToF (BToI (xs.i /= ys.i))) - numDifferent == 0.0 + yieldAccum AndMonoid \ref. + for i. ref += xs.i == ys.i instance [Ord a] Ord (n=>a) (>) = \xs ys. @@ -716,7 +741,7 @@ def reduce (identity:a) (combine:(a->a->a)) (xs:n=>a) : a = -- TODO: call this `scan` and call the current `scan` something else def scan' (init:a) (body:n->a->a) : n=>a = snd $ scan init \i x. dup (body i x) -- TODO: allow tables-via-lambda and get rid of this -def fsum (xs:n=>Float) : Float = yieldAccum \ref. for i. ref += xs i +def fsum (xs:n=>Float) : Float = yieldAccum (AddMonoid Float) \ref. for i. ref += xs i def sum [Add v] (xs:n=>v) : v = reduce zero (+) xs def prod [Mul v] (xs:n=>v) : v = reduce one (*) xs def mean [VSpace v] (xs:n=>v) : v = sum xs / IToF (size n) diff --git a/src/lib/Autodiff.hs b/src/lib/Autodiff.hs index 1da5eac39..c6b0b7b41 100644 --- a/src/lib/Autodiff.hs +++ b/src/lib/Autodiff.hs @@ -124,13 +124,26 @@ linearizeOp :: Op -> LinA Atom linearizeOp op = case op of ScalarUnOp uop x -> linearizeUnOp uop x ScalarBinOp bop x y -> linearizeBinOp bop x y + PrimEffect refArg (MExtend ~(LamVal b body)) -> LinA $ do + (primalRef, mkTangentRef) <- runLinA $ la refArg + (primalUpdate, mkTangentUpdate) <- + buildLamAux b (const $ return PureArrow) \x@(Var v) -> + extendWrt [v] [] $ runLinA $ linearizeBlock (b @> x) body + let LamVal (Bind primalStateVar) _ = primalUpdate + ans <- emitOp $ PrimEffect primalRef $ MExtend primalUpdate + return (ans, do + tangentRef <- mkTangentRef + -- TODO: Assert that tangent update doesn't close over anything? + tangentUpdate <- buildLam (Bind $ "t":>tangentType (varType primalStateVar)) PureArrow \tx -> + extendTangentEnv (primalStateVar @> tx) [] $ mkTangentUpdate + emitOp $ PrimEffect tangentRef $ MExtend tangentUpdate) PrimEffect refArg m -> liftA2 PrimEffect (la refArg) (case m of - MAsk -> pure MAsk - MTell x -> liftA MTell $ la x - MGet -> pure MGet - MPut x -> liftA MPut $ la x) `bindLin` emitOp + MAsk -> pure MAsk + MExtend _ -> error "Unhandled MExtend" + MGet -> pure MGet + MPut x -> liftA MPut $ la x) `bindLin` emitOp IndexRef ref i -> (IndexRef <$> la ref <*> pure i) `bindLin` emitOp FstRef ref -> (FstRef <$> la ref ) `bindLin` emitOp SndRef ref -> (SndRef <$> la ref ) `bindLin` emitOp @@ -261,9 +274,17 @@ linearizeHof env hof = case hof of (ans, linTab) <- unzipTab ansWithLinTab return (ans, buildFor d i' \i'' -> provideRemat vi'' i'' $ app linTab i'' >>= applyLinToTangents) Tile _ _ _ -> notImplemented - RunWriter lam -> linearizeEff Nothing lam True (const RunWriter) (emitRunWriter "r") Writer - RunReader val lam -> linearizeEff (Just val) lam False RunReader (emitRunReader "r") Reader - RunState val lam -> linearizeEff (Just val) lam True RunState (emitRunState "r") State + RunWriter bm ~lam@(BinaryFunVal _ refBinder _ _) -> LinA $ do + unless (checkZeroPlusFloatMonoid bm) $ + error "AD of Accum effect only supported when the base monoid is (0, +) on Float" + let RefTy _ accTy = binderType refBinder + linearizeEff lam Writer (RunWriter bm) (emitRunWriter "r" accTy bm) + RunReader val lam -> LinA $ do + (val', mkLinInit) <- runLinA <$> linearizeAtom =<< substEmbed env val + linearizeEff lam Reader (RunReader val') \f -> mkLinInit >>= emitRunReader "r" `flip` f + RunState val lam -> LinA $ do + (val', mkLinInit) <- runLinA <$> linearizeAtom =<< substEmbed env val + linearizeEff lam State (RunState val') \f -> mkLinInit >>= emitRunState "r" `flip` f RunIO ~(Lam (Abs _ (arrow, body))) -> LinA $ do arrow' <- substEmbed env arrow -- TODO: consider the possibility of other effects here besides IO @@ -279,30 +300,19 @@ linearizeHof env hof = case hof of CatchException _ -> notImplemented Linearize _ -> error "Unexpected linearization" Transpose _ -> error "Unexpected transposition" - PTileReduce _ _ -> error "Unexpected PTileReduce" + PTileReduce _ _ _ -> error "Unexpected PTileReduce" where - linearizeEff maybeInit lam hasResult hofMaker emitter eff = LinA $ do - (valHofMaker, maybeLinInit) <- case maybeInit of - Just val -> do - (val', linVal) <- runLinA <$> linearizeAtom =<< substEmbed env val - return (hofMaker val', Just linVal) - Nothing -> return (hofMaker undefined, Nothing) + linearizeEff lam eff primalHofCon tangentEmitter = do (lam', ref) <- linearizeEffectFun eff lam - (ans, linBody) <- case hasResult of - True -> do - (ansLin, w) <- fromPair =<< emit (Hof $ valHofMaker lam') + -- The reader effect doesn't return any additional values + (ans, linBody) <- case eff of + Reader -> fromPair =<< emit (Hof $ primalHofCon lam') + _ -> do + (ansLin, w) <- fromPair =<< emit (Hof $ primalHofCon lam') (ans, linBody) <- fromPair ansLin return (PairVal ans w, linBody) - False -> fromPair =<< emit (Hof $ valHofMaker lam') - let lin = do - valEmitter <- case maybeLinInit of - Just linVal -> emitter <$> linVal - Nothing -> do - let (BinaryFunTy _ b _ _) = getType lam' - let RefTy _ wTy = binderType b - return $ emitter $ tangentType wTy - valEmitter \ref'@(Var (_:> RefTy (Var (h:>_)) _)) -> do - extendTangentEnv (ref @> ref') [h] $ applyLinToTangents linBody + let lin = tangentEmitter \ref'@(Var (_:> RefTy (Var (h:>_)) _)) -> do + extendTangentEnv (ref @> ref') [h] $ applyLinToTangents linBody return (ans, lin) linearizeEffectFun :: RWS -> Atom -> PrimalM (Atom, Var) @@ -391,24 +401,6 @@ tangentType ty = case ty of _ -> unsupported where unsupported = error $ "Can't differentiate wrt type " ++ pprint ty -addTangent :: MonadEmbed m => Atom -> Atom -> m Atom -addTangent x y = case getType x of - RecordTy (NoExt tys) -> do - elems <- bindM2 (zipWithT addTangent) (getUnpacked x) (getUnpacked y) - return $ Record $ restructure elems tys - TabTy b _ -> buildFor Fwd b \i -> bindM2 addTangent (tabGet x i) (tabGet y i) - TC con -> case con of - BaseType (Scalar _) -> emitOp $ ScalarBinOp FAdd x y - BaseType (Vector _) -> emitOp $ VectorBinOp FAdd x y - UnitType -> return UnitVal - PairType _ _ -> do - (xa, xb) <- fromPair x - (ya, yb) <- fromPair y - PairVal <$> addTangent xa ya <*> addTangent xb yb - _ -> notTangent - _ -> notTangent - where notTangent = error $ "Not a tangent type: " ++ pprint (getType x) - isTrivialForAD :: Expr -> Bool isTrivialForAD expr = isSingletonType tangentTy && exprEffs expr == mempty where tangentTy = tangentType $ getType expr @@ -614,12 +606,16 @@ transposeOp op ct = case op of refArg' <- substTranspose linRefSubst refArg let emitEff = emitOp . PrimEffect refArg' case m of - MAsk -> void $ emitEff $ MTell ct - MTell x -> transposeAtom x =<< emitEff MAsk + MAsk -> void $ emitEff . MExtend =<< (updateAddAt ct) + -- XXX: This assumes that the update function uses a tangent (0, +) monoid, + -- which is why we can ignore the binder (we even can't; we only have a + -- reader reference!). This should have been checked in the transposeHof + -- rule for RunWriter. + MExtend ~(LamVal _ body) -> transposeBlock body =<< emitEff MAsk -- TODO: Do something more efficient for state. We should be able -- to do in-place addition, just like we do for the Writer effect. - MGet -> void $ emitEff . MPut =<< addTangent ct =<< emitEff MGet - MPut x -> do + MGet -> void $ emitEff . MPut =<< addTangent ct =<< emitEff MGet + MPut x -> do transposeAtom x =<< emitEff MGet void $ emitEff $ MPut $ zeroAt $ getType x TabCon ~(TabTy b _) es -> forM_ (enumerate es) \(i, e) -> do @@ -685,11 +681,16 @@ transposeHof hof ct = case hof of return UnitVal where flipDir dir = case dir of Fwd -> Rev; Rev -> Fwd RunReader r ~(BinaryFunVal (Bind (h:>_)) b _ body) -> do - (_, ctr) <- (fromPair =<<) $ emitRunWriter "w" (getType r) \ref -> do + let RefTy _ valTy = binderType b + let baseTy = getBaseMonoidType valTy + baseMonoid <- tangentBaseMonoidFor baseTy + (_, ctr) <- (fromPair =<<) $ emitRunWriter "w" valTy baseMonoid \ref -> do localLinRegion h $ localLinRefSubst (b@>ref) $ transposeBlock body ct return UnitVal transposeAtom r ctr - RunWriter ~(BinaryFunVal (Bind (h:>_)) b _ body) -> do + RunWriter bm ~(BinaryFunVal (Bind (h:>_)) b _ body) -> do + unless (checkZeroPlusFloatMonoid bm) $ + error "AD of Accum effect only supported when the base monoid is (0, +) on Float" (ctBody, ctEff) <- fromPair ct void $ emitRunReader "r" ctEff \ref -> do localLinRegion h $ localLinRefSubst (b@>ref) $ transposeBlock body ctBody @@ -706,7 +707,7 @@ transposeHof hof ct = case hof of CatchException _ -> notImplemented Linearize _ -> error "Unexpected linearization" Transpose _ -> error "Unexpected transposition" - PTileReduce _ _ -> error "Unexpected PTileReduce" + PTileReduce _ _ _ -> error "Unexpected PTileReduce" transposeAtom :: Atom -> Atom -> TransposeM () transposeAtom atom ct = case atom of @@ -776,7 +777,7 @@ isLinEff _ = error "Can't transpose polymorphic effects" emitCTToRef :: Maybe Atom -> Atom -> TransposeM () emitCTToRef mref ct = case mref of - Just ref -> void $ emitOp $ PrimEffect ref (MTell ct) + Just ref -> void . emitOp . PrimEffect ref . MExtend =<< updateAddAt ct Nothing -> return () substTranspose :: Subst a => (TransposeEnv -> SubstEnv) -> a -> TransposeM a @@ -789,13 +790,15 @@ substNonlin :: Subst a => a -> TransposeM a substNonlin = substTranspose nonlinSubst withLinVar :: Binder -> TransposeM a -> TransposeM (a, Atom) -withLinVar b body = case - singletonTypeVal (binderType b) of - Nothing -> flip evalStateT Nothing $ do - ans <- emitRunWriter "ref" (binderType b) \ref -> do - lift (localLinRef (b@>Just ref) body) >>= put . Just >> return UnitVal - (,) <$> (fromJust <$> get) <*> (getSnd ans) - Just x -> (,x) <$> (localLinRef (b@>Nothing) body) -- optimization to avoid accumulating into unit +withLinVar b body = case singletonTypeVal (binderType b) of + Nothing -> flip evalStateT Nothing $ do + let accTy = binderType b + let baseTy = getBaseMonoidType accTy + baseMonoid <- tangentBaseMonoidFor baseTy + ans <- emitRunWriter "ref" accTy baseMonoid \ref -> do + lift (localLinRef (b@>Just ref) body) >>= put . Just >> return UnitVal + (,) <$> (fromJust <$> get) <*> (getSnd ans) + Just x -> (,x) <$> (localLinRef (b@>Nothing) body) -- optimization to avoid accumulating into unit localLinRef :: Env (Maybe Atom) -> TransposeM a -> TransposeM a localLinRef refs = local (<> mempty { linRefs = refs }) @@ -808,3 +811,54 @@ localLinRefSubst s = local (<> mempty { linRefSubst = s }) localNonlinSubst :: SubstEnv -> TransposeM a -> TransposeM a localNonlinSubst s = local (<> mempty { nonlinSubst = s }) + +-- === The (0, +) monoid for tangent types === + +zeroAt :: Type -> Atom +zeroAt ty = case ty of + BaseTy bt -> Con $ Lit $ zeroLit bt + TabTy i a -> TabValA i $ zeroAt a + UnitTy -> UnitVal + PairTy a b -> PairVal (zeroAt a) (zeroAt b) + RecordTy (Ext tys Nothing) -> Record $ fmap zeroAt tys + _ -> unreachable + where + unreachable = error $ "Missing zero case for a tangent type: " ++ pprint ty + zeroLit bt = case bt of + Scalar Float64Type -> Float64Lit 0.0 + Scalar Float32Type -> Float32Lit 0.0 + Vector st -> VecLit $ replicate vectorWidth $ zeroLit $ Scalar st + _ -> unreachable + +updateAddAt :: MonadEmbed m => Atom -> m Atom +updateAddAt x = buildLam (Bind ("t":>getType x)) PureArrow $ addTangent x + +addTangent :: MonadEmbed m => Atom -> Atom -> m Atom +addTangent x y = case getType x of + RecordTy (NoExt tys) -> do + elems <- bindM2 (zipWithT addTangent) (getUnpacked x) (getUnpacked y) + return $ Record $ restructure elems tys + TabTy b _ -> buildFor Fwd b \i -> bindM2 addTangent (tabGet x i) (tabGet y i) + TC con -> case con of + BaseType (Scalar _) -> emitOp $ ScalarBinOp FAdd x y + BaseType (Vector _) -> emitOp $ VectorBinOp FAdd x y + UnitType -> return UnitVal + PairType _ _ -> do + (xa, xb) <- fromPair x + (ya, yb) <- fromPair y + PairVal <$> addTangent xa ya <*> addTangent xb yb + _ -> notTangent + _ -> notTangent + where notTangent = error $ "Not a tangent type: " ++ pprint (getType x) + +tangentBaseMonoidFor :: MonadEmbed m => Type -> m BaseMonoid +tangentBaseMonoidFor ty = BaseMonoid (zeroAt ty) <$> buildLam (Bind ("t":>ty)) PureArrow updateAddAt + +checkZeroPlusFloatMonoid :: BaseMonoid -> Bool +checkZeroPlusFloatMonoid (BaseMonoid zero plus) = checkZero zero && checkPlus plus + where + checkZero z = z == (Con (Lit (Float32Lit 0.0))) + checkPlus f = case f of + BinaryFunVal (Bind x) (Bind y) Pure (Block Empty (Op (ScalarBinOp FAdd (Var x') (Var y')))) -> + (x == x' && y == y') || (x == y' && y == x') + _ -> False diff --git a/src/lib/Embed.hs b/src/lib/Embed.hs index f1d735829..cc2dec229 100644 --- a/src/lib/Embed.hs +++ b/src/lib/Embed.hs @@ -1,4 +1,4 @@ --- Copyright 2019 Google LLC +-- Copyright 2021 Google LLC -- -- Use of this source code is governed by a BSD-style -- license that can be found in the LICENSE file or at @@ -24,14 +24,16 @@ module Embed (emit, emitTo, emitAnn, emitOp, buildDepEffLam, buildLamAux, buildP buildFor, buildForAux, buildForAnn, buildForAnnAux, emitBlock, unzipTab, isSingletonType, emitDecl, withNameHint, singletonTypeVal, scopedDecls, embedScoped, extendScope, checkEmbed, - embedExtend, unpackConsList, emitRunWriter, applyPreludeFunction, - emitRunState, emitMaybeCase, emitWhile, buildDataDef, + embedExtend, applyPreludeFunction, + unpackConsList, unpackLeftLeaningConsList, + emitRunWriter, emitRunWriters, mextendForRef, monoidLift, + emitRunState, emitMaybeCase, emitWhile, buildDataDef, emitRunReader, tabGet, SubstEmbedT, SubstEmbed, runSubstEmbedT, ptrOffset, ptrLoad, unsafePtrLoad, evalBlockE, substTraversalDef, TraversalDef, traverseDecls, traverseDecl, traverseDeclsOpen, traverseBlock, traverseExpr, traverseAtom, - clampPositive, buildNAbs, buildNAbsAux, buildNestedLam, zeroAt, + clampPositive, buildNAbs, buildNAbsAux, buildNestedLam, transformModuleAsBlock, dropSub, appReduceTraversalDef, indexSetSizeE, indexToIntE, intToIndexE, freshVarE) where @@ -231,22 +233,6 @@ inlineLastDecl block@(Block decls result) = Block (toNest (reverse rest)) expr _ -> block -zeroAt :: Type -> Atom -zeroAt ty = case ty of - BaseTy bt -> Con $ Lit $ zeroLit bt - TabTy i a -> TabValA i $ zeroAt a - UnitTy -> UnitVal - PairTy a b -> PairVal (zeroAt a) (zeroAt b) - RecordTy (Ext tys Nothing) -> Record $ fmap zeroAt tys - _ -> unreachable - where - unreachable = error $ "Missing zero case for a tangent type: " ++ pprint ty - zeroLit bt = case bt of - Scalar Float64Type -> Float64Lit 0.0 - Scalar Float32Type -> Float32Lit 0.0 - Vector st -> VecLit $ replicate vectorWidth $ zeroLit $ Scalar st - _ -> unreachable - fLitLike :: Double -> Atom -> Atom fLitLike x t = case getType t of BaseTy (Scalar Float64Type) -> Con $ Lit $ Float64Lit x @@ -384,6 +370,17 @@ unpackConsList xs = case getType xs of liftM (x:) $ unpackConsList rest _ -> error $ "Not a cons list: " ++ pprint (getType xs) +-- ((...((ans, x{n}), x{n-1})..., x2), x1) -> (ans, [x1, ..., x{n}]) +-- This is useful for unpacking results of stacked effect handlers (as produced +-- by e.g. emitRunWriters). +unpackLeftLeaningConsList :: MonadEmbed m => Int -> Atom -> m (Atom, [Atom]) +unpackLeftLeaningConsList depth atom = go depth atom [] + where + go 0 curAtom xs = return (curAtom, reverse xs) + go remDepth curAtom xs = do + (consTail, x) <- fromPair curAtom + go (remDepth - 1) consTail (x : xs) + emitWhile :: MonadEmbed m => m Atom -> m () emitWhile body = do eff <- getAllowedEffects @@ -399,9 +396,35 @@ emitMaybeCase scrut nothingCase justCase = do let resultTy = getType nothingBody emit $ Case scrut [nothingAlt, justAlt] resultTy -emitRunWriter :: MonadEmbed m => Name -> Type -> (Atom -> m Atom) -> m Atom -emitRunWriter v ty body = do - emit . Hof . RunWriter =<< mkBinaryEffFun Writer v ty body +monoidLift :: Type -> Type -> Nest Binder +monoidLift baseTy accTy = case baseTy == accTy of + True -> Empty + False -> case accTy of + TabTy n b -> Nest n $ monoidLift baseTy b + _ -> error $ "Base monoid type mismatch: can't lift " ++ + pprint baseTy ++ " to " ++ pprint accTy + +mextendForRef :: MonadEmbed m => Atom -> BaseMonoid -> Atom -> m Atom +mextendForRef ref (BaseMonoid _ combine) update = do + buildLam (Bind $ "refVal":>accTy) PureArrow \refVal -> + buildNestedFor (fmap (Fwd,) $ toList liftIndices) $ \indices -> do + refElem <- tabGetNd refVal indices + updateElem <- tabGetNd update indices + bindM2 appTryReduce (appTryReduce combine refElem) (return updateElem) + where + TC (RefType _ accTy) = getType ref + FunTy (BinderAnn baseTy) _ _ = getType combine + liftIndices = monoidLift baseTy accTy + +emitRunWriter :: MonadEmbed m => Name -> Type -> BaseMonoid -> (Atom -> m Atom) -> m Atom +emitRunWriter v accTy bm body = do + emit . Hof . RunWriter bm =<< mkBinaryEffFun Writer v accTy body + +emitRunWriters :: MonadEmbed m => [(Name, Type, BaseMonoid)] -> ([Atom] -> m Atom) -> m Atom +emitRunWriters inits body = go inits [] + where + go [] refs = body $ reverse refs + go ((v, accTy, bm):rest) refs = emitRunWriter v accTy bm $ \ref -> go rest (ref:refs) emitRunReader :: MonadEmbed m => Name -> Atom -> (Atom -> m Atom) -> m Atom emitRunReader v x0 body = do @@ -435,13 +458,23 @@ buildForAux = buildForAnnAux . RegularFor buildFor :: MonadEmbed m => Direction -> Binder -> (Atom -> m Atom) -> m Atom buildFor = buildForAnn . RegularFor +buildNestedFor :: forall m. MonadEmbed m => [(Direction, Binder)] -> ([Atom] -> m Atom) -> m Atom +buildNestedFor specs body = go specs [] + where + go :: [(Direction, Binder)] -> [Atom] -> m Atom + go [] indices = body $ reverse indices + go ((d,b):t) indices = buildFor d b $ \i -> go t (i:indices) + buildNestedLam :: MonadEmbed m => Arrow -> [Binder] -> ([Atom] -> m Atom) -> m Atom buildNestedLam _ [] f = f [] buildNestedLam arr (b:bs) f = buildLam b arr \x -> buildNestedLam arr bs \xs -> f (x:xs) tabGet :: MonadEmbed m => Atom -> Atom -> m Atom -tabGet x i = emit $ App x i +tabGet tab idx = emit $ App tab idx + +tabGetNd :: MonadEmbed m => Atom -> [Atom] -> m Atom +tabGetNd tab idxs = foldM (flip tabGet) tab idxs unzipTab :: MonadEmbed m => Atom -> m (Atom, Atom) unzipTab tab = do diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index d95abe9a6..adea44bfe 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -248,9 +248,14 @@ toImpOp (maybeDest, op) = case op of destToAtom dest PrimEffect refDest m -> do case m of - MAsk -> returnVal =<< destToAtom refDest - MTell x -> addToAtom refDest x >> returnVal UnitVal - MPut x -> copyAtom refDest x >> returnVal UnitVal + MAsk -> returnVal =<< destToAtom refDest + MExtend ~(Lam f) -> do + -- TODO: Update in-place? + refValue <- destToAtom refDest + result <- translateBlock mempty (Nothing, snd $ applyAbs f refValue) + copyAtom refDest result + returnVal UnitVal + MPut x -> copyAtom refDest x >> returnVal UnitVal MGet -> do dest <- allocDest maybeDest resultTy -- It might be more efficient to implement a specialized copy for dests @@ -422,28 +427,29 @@ toImpHof env (maybeDest, hof) = do sDest <- fromEmbed $ indexDestDim d dest idx void $ translateBlock (env <> sb @> idx) (Just sDest, sBody) destToAtom dest - PTileReduce idxTy' ~(BinaryFunVal gtidB nthrB _ body) -> do + PTileReduce baseMonoids idxTy' ~(BinaryFunVal gtidB nthrB _ body) -> do idxTy <- impSubst env idxTy' (mappingDest, finalAccDest) <- destPairUnpack <$> allocDest maybeDest resultTy - let PairTy _ accType = resultTy - (numTileWorkgroups, wgResArr, widIdxTy) <- buildKernel idxTy \LaunchInfo{..} buildBody -> do + let PairTy _ accTypes = resultTy + (numTileWorkgroups, wgAccsArr, widIdxTy) <- buildKernel idxTy \LaunchInfo{..} buildBody -> do let widIdxTy = Fin $ toScalarAtom numWorkgroups let tidIdxTy = Fin $ toScalarAtom workgroupSize - wgResArr <- alloc $ TabTy (Ignore widIdxTy) accType - thrAccArr <- alloc $ TabTy (Ignore widIdxTy) $ TabTy (Ignore tidIdxTy) accType + wgAccsArr <- alloc $ TabTy (Ignore widIdxTy) accTypes + thrAccsArr <- alloc $ TabTy (Ignore widIdxTy) $ TabTy (Ignore tidIdxTy) accTypes mappingKernelBody <- buildBody \ThreadInfo{..} -> do let TC (ParIndexRange _ gtid nthr) = threadRange - let scope = freeVars mappingDest - let tileDest = Con $ TabRef $ fst $ flip runSubstEmbed scope $ do + let tileDest = Con $ TabRef $ fst $ flip runSubstEmbed (freeVars mappingDest) $ do buildLam (Bind $ "hwidx":>threadRange) TabArrow \hwidx -> do indexDest mappingDest =<< (emitOp $ Inject hwidx) - wgAccs <- destGet thrAccArr =<< intToIndex widIdxTy wid - thrAcc <- destGet wgAccs =<< intToIndex tidIdxTy tid - let threadDest = Con $ ConRef $ PairCon tileDest thrAcc + wgThrAccs <- destGet thrAccsArr =<< intToIndex widIdxTy wid + thrAccs <- destGet wgThrAccs =<< intToIndex tidIdxTy tid + let thrAccsList = fromDestConsList thrAccs + let threadDest = foldr ((Con . ConRef) ... flip PairCon) tileDest thrAccsList + -- TODO: Make sure that threadDest has the right type void $ translateBlock (env <> gtidB @> gtid <> nthrB @> nthr) (Just threadDest, body) - wgRes <- destGet wgResArr =<< intToIndex widIdxTy wid - workgroupReduce tid wgRes wgAccs workgroupSize - return (mappingKernelBody, (numWorkgroups, wgResArr, widIdxTy)) + wgAccs <- destGet wgAccsArr =<< intToIndex widIdxTy wid + workgroupReduce tid wgAccs wgThrAccs workgroupSize + return (mappingKernelBody, (numWorkgroups, wgAccsArr, widIdxTy)) -- TODO: Skip the reduction kernel if unnecessary? -- TODO: Reduce sequentially in the CPU backend? -- TODO: Actually we only need the previous-power-of-2 many threads @@ -453,13 +459,14 @@ toImpHof env (maybeDest, hof) = do moreThanOneGroup <- (IIdxRepVal 1) `iltI` numWorkgroups guardBlock moreThanOneGroup $ emitStatement IThrowError redKernelBody <- buildBody \ThreadInfo{..} -> - workgroupReduce tid finalAccDest wgResArr numTileWorkgroups + workgroupReduce tid finalAccDest wgAccsArr numTileWorkgroups return (redKernelBody, ()) PairVal <$> destToAtom mappingDest <*> destToAtom finalAccDest where guardBlock cond m = do block <- scopedErrBlock m emitStatement $ ICond cond block (ImpBlock mempty mempty) + -- XXX: Overwrites the contents of arrDest, writes result in resDest workgroupReduce tid resDest arrDest elemCount = do elemCountDown2 <- prevPowerOf2 elemCount let RawRefTy (TabTy arrIdxB _) = getType arrDest @@ -472,7 +479,7 @@ toImpHof env (maybeDest, hof) = do shouldAdd <- bindM2 bandI (tid `iltI` off) (loadIdx `iltI` elemCount) guardBlock shouldAdd $ do threadDest <- destGet arrDest =<< intToIndex arrIdxTy tid - addToAtom threadDest =<< destToAtom =<< destGet arrDest =<< intToIndex arrIdxTy loadIdx + combineWithDest threadDest =<< destToAtom =<< destGet arrDest =<< intToIndex arrIdxTy loadIdx emitStatement ISyncWorkgroup copyAtom offPtr . toScalarAtom =<< off `idivI` (IIdxRepVal 2) cond <- liftM snd $ scopedBlock $ do @@ -484,6 +491,13 @@ toImpHof env (maybeDest, hof) = do firstThread <- tid `iltI` (IIdxRepVal 1) guardBlock firstThread $ copyAtom resDest =<< destToAtom =<< destGet arrDest =<< intToIndex arrIdxTy tid + combineWithDest :: Dest -> Atom -> ImpM () + combineWithDest accsDest accsUpdates = do + let accsDestList = fromDestConsList accsDest + let Right accsUpdatesList = fromConsList accsUpdates + forM_ (zip3 accsDestList baseMonoids accsUpdatesList) $ \(dest, bm, update) -> do + extender <- fromEmbed $ mextendForRef dest bm update + void $ toImpOp (Nothing, PrimEffect dest $ MExtend extender) -- TODO: Do some popcount tricks? prevPowerOf2 :: IExpr -> ImpM IExpr prevPowerOf2 x = do @@ -506,12 +520,13 @@ toImpHof env (maybeDest, hof) = do rDest <- alloc $ getType r copyAtom rDest =<< impSubst env r translateBlock (env <> ref @> rDest) (maybeDest, body) - RunWriter ~(BinaryFunVal _ ref _ body) -> do + RunWriter (BaseMonoid e' _) ~(BinaryFunVal _ ref _ body) -> do + let PairTy _ accTy = resultTy (aDest, wDest) <- destPairUnpack <$> allocDest maybeDest resultTy - let RefTy _ wTy = getType ref - copyAtom wDest (zeroAt wTy) + copyAtom wDest =<< (liftNeutral accTy <$> impSubst env e') void $ translateBlock (env <> ref @> wDest) (Just aDest, body) PairVal <$> destToAtom aDest <*> destToAtom wDest + where liftNeutral accTy e = foldr TabValA e $ monoidLift (getType e) accTy RunState s ~(BinaryFunVal _ ref _ body) -> do (aDest, sDest) <- destPairUnpack <$> allocDest maybeDest resultTy copyAtom sDest =<< impSubst env s @@ -791,6 +806,12 @@ destPairUnpack :: Dest -> (Dest, Dest) destPairUnpack (Con (ConRef (PairCon l r))) = (l, r) destPairUnpack d = error $ "Not a pair destination: " ++ show d +fromDestConsList :: Dest -> [Dest] +fromDestConsList dest = case dest of + Con (ConRef (PairCon h t)) -> h : fromDestConsList t + Con (ConRef UnitCon) -> [] + _ -> error $ "Not a dest cons list: " ++ pprint dest + makeAllocDest :: AllocType -> Type -> ImpM Dest makeAllocDest allocTy ty = fst <$> makeAllocDestWithPtrs allocTy ty @@ -963,29 +984,6 @@ zipWithRefConM f destCon srcCon = case (destCon, srcCon) of (IndexRangeVal _ _ _ iRef, IndexRangeVal _ _ _ i) -> f iRef i _ -> error $ "Unexpected ref/val " ++ pprint (destCon, srcCon) --- TODO: put this in userspace using type classes -addToAtom :: Dest -> Atom -> ImpM () -addToAtom dest src = case (dest, src) of - (Con (BaseTypeRef ptr), x) -> do - let ptr' = fromScalarAtom ptr - let x' = fromScalarAtom x - cur <- loadAnywhere ptr' - let op = case getIType cur of - Scalar _ -> ScalarBinOp - Vector _ -> VectorBinOp - _ -> error $ "The result of load cannot be a reference" - updated <- emitInstr $ IPrimOp $ op FAdd cur x' - storeAnywhere ptr' updated - (Con (TabRef _), TabVal _ _) -> zipTabDestAtom addToAtom dest src - (Con (ConRef (SumAsProd _ _ payloadDest)), Con (SumAsProd _ tag payload)) -> do - unless (all null payload) $ -- optimization - emitSwitch (fromScalarAtom tag) $ - zipWith (zipWithM_ addToAtom) payloadDest payload - (Con (ConRef destCon), Con srcCon) -> zipWithRefConM addToAtom destCon srcCon - (Con (RecordRef dests), Record srcs) -> - zipWithM_ addToAtom (toList dests) (toList srcs) - _ -> error $ "Not implemented " ++ pprint (dest, src) - loadAnywhere :: IExpr -> ImpM IExpr loadAnywhere ptr = do curDev <- asks curDevice diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index 76e726593..cdf9cacfc 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -865,8 +865,8 @@ synthDict ty = case ty of -- TODO: this doesn't de-dup, so we'll get multiple results if we have a -- diamond-shaped hierarchy. -superclass :: Atom -> SynthDictM Atom -superclass dict = return dict <|> do +withSuperclasses :: Atom -> SynthDictM Atom +withSuperclasses dict = return dict <|> do (f, LetBound SuperclassLet _) <- getBinding inferToSynth $ tryApply f dict diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index 655bc9195..60039d3d1 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -252,8 +252,8 @@ prettyPrecPrimCon con = case con of instance PrettyPrec e => Pretty (PrimOp e) where pretty = prettyFromPrettyPrec instance PrettyPrec e => PrettyPrec (PrimOp e) where prettyPrec op = case op of - PrimEffect ref (MPut val ) -> atPrec LowestPrec $ pApp ref <+> ":=" <+> pApp val - PrimEffect ref (MTell val) -> atPrec LowestPrec $ pApp ref <+> "+=" <+> pApp val + PrimEffect ref (MPut val ) -> atPrec LowestPrec $ pApp ref <+> ":=" <+> pApp val + PrimEffect ref (MExtend update) -> atPrec LowestPrec $ "extend" <+> pApp ref <+> "using" <+> pLowest update PtrOffset ptr idx -> atPrec LowestPrec $ pApp ptr <+> "+>" <+> pApp idx PtrLoad ptr -> atPrec AppPrec $ pAppArg "load" [ptr] RecordCons items rest -> diff --git a/src/lib/Parallelize.hs b/src/lib/Parallelize.hs index e11842020..b273faca7 100644 --- a/src/lib/Parallelize.hs +++ b/src/lib/Parallelize.hs @@ -20,6 +20,7 @@ import Cat import Env import Type import PPrint +import Util (for) -- TODO: extractParallelism can benefit a lot from horizontal fusion (can happen be after) -- TODO: Parallelism extraction can emit some really cheap (but not trivial) @@ -45,7 +46,7 @@ data LoopEnv = LoopEnv , delayedApps :: Env (Atom, [Atom]) -- (n @> (arr, bs)), n and bs in scope of the original program -- arr in scope of the newly constructed program! } -data AccEnv = AccEnv { activeAccs :: Env Var } +data AccEnv = AccEnv { activeAccs :: Env (Var, BaseMonoid) } -- (reference, its base monoid) type TLParallelM = SubstEmbedT (State AccEnv) -- Top-level non-parallel statements type LoopM = ReaderT LoopEnv TLParallelM -- Generation of (parallel) loop nests @@ -69,7 +70,7 @@ parallelTraverseExpr expr = case expr of Hof (For (RegularFor _) fbody@(LamVal b body)) -> do -- TODO: functionEffs is an overapproximation of the effects that really appear inside refs <- gets activeAccs - let allowedRegions = foldMap (\(varType -> RefTy (Var reg) _) -> reg @> ()) refs + let allowedRegions = foldMap (\(varType . fst -> RefTy (Var reg) _) -> reg @> ()) refs (EffectRow bodyEffs t) <- substEmbedR $ functionEffs fbody let onlyAllowedEffects = all (parallelizableEffect allowedRegions) $ toList bodyEffs case t == Nothing && onlyAllowedEffects of @@ -77,11 +78,11 @@ parallelTraverseExpr expr = case expr of b' <- substEmbedR b liftM Atom $ runLoopM $ withLoopBinder b' $ buildParallelBlock $ asABlock body False -> nothingSpecial - Hof (RunWriter (BinaryFunVal h b _ body)) -> do + Hof (RunWriter bm (BinaryFunVal h b _ body)) -> do ~(RefTy _ accTy) <- traverseAtom substTraversalDef $ binderType b - liftM Atom $ emitRunWriter (binderNameHint b) accTy \ref@(Var refVar) -> do + liftM Atom $ emitRunWriter (binderNameHint b) accTy bm \ref@(Var refVar) -> do let RefTy h' _ = varType refVar - modify \accEnv -> accEnv { activeAccs = activeAccs accEnv <> b @> refVar } + modify \accEnv -> accEnv { activeAccs = activeAccs accEnv <> b @> (refVar, bm) } extendR (h @> h' <> b @> ref) $ evalBlockE parallelTrav body -- TODO: Do some alias analysis. This is not fundamentally hard, but it is a little annoying. -- We would have to track not only the base references, but also all the aliases, along @@ -214,23 +215,16 @@ emitLoops buildPureLoop (ABlock decls result) = do buildLam (Bind $ "gtid" :> IdxRepTy) PureArrow \gtid -> do buildLam (Bind $ "nthr" :> IdxRepTy) PureArrow \nthr -> do let threadRange = TC $ ParIndexRange iterTy gtid nthr - let accTys = mkConsListTy $ fmap (derefType . varType) newRefs - emitRunWriter "refsList" accTys \localRefsList -> do - localRefs <- unpackRefConsList localRefsList + let writerSpecs = for newRefs \(ref, bm) -> (varName ref, derefType (varType ref), bm) + emitRunWriters writerSpecs $ \localRefs -> do buildFor Fwd (Bind $ "tidx" :> threadRange) \tidx -> do pari <- emitOp $ Inject tidx extendR (newEnv oldRefNames localRefs) $ buildBody pari - (ans, updateList) <- fromPair =<< (emit $ Hof $ PTileReduce iterTy body) + (ans, updateList) <- fromPair =<< (emit $ Hof $ PTileReduce (fmap snd newRefs) iterTy body) updates <- unpackConsList updateList - forM_ (zip newRefs updates) \(ref, update) -> - emitOp $ PrimEffect (Var ref) $ MTell update + forM_ (zip newRefs updates) $ \((ref, bm), update) -> do + updater <- mextendForRef (Var ref) bm update + emitOp $ PrimEffect (Var ref) $ MExtend updater return ans - where - derefType ~(RefTy _ accTy) = accTy - unpackRefConsList xs = case derefType $ getType xs of - UnitTy -> return [] - PairTy _ _ -> do - x <- getFstRef xs - rest <- getSndRef xs - (x:) <$> unpackRefConsList rest - _ -> error $ "Not a ref cons list: " ++ pprint (getType xs) + where + derefType ~(RefTy _ accTy) = accTy diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index 51aab25cb..5ae910659 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -432,6 +432,9 @@ simplifyOp op = case op of -- Simplify the case away if we can. dropSub $ simplifyExpr $ Case full alts $ VariantTy resultRow _ -> emitOp op + PrimEffect ref (MExtend f) -> dropSub $ do + ~(f', Nothing) <- simplifyLam f + emitOp $ PrimEffect ref $ MExtend f' _ -> emitOp op simplifyHof :: Hof -> SimplifyM Atom @@ -446,7 +449,7 @@ simplifyHof hof = case hof of ~(fT', Nothing) <- simplifyLam fT ~(fS', Nothing) <- simplifyLam fS emit $ Hof $ Tile d fT' fS' - PTileReduce _ _ -> error "Unexpected PTileReduce" + PTileReduce _ _ _ -> error "Unexpected PTileReduce" While body -> do ~(body', Nothing) <- simplifyLam body emit $ Hof $ While body' @@ -463,9 +466,11 @@ simplifyHof hof = case hof of r' <- simplifyAtom r ~(lam', recon) <- simplifyBinaryLam lam applyRecon recon =<< (emit $ Hof $ RunReader r' lam') - RunWriter lam -> do + RunWriter (BaseMonoid e combine) lam -> do + e' <- simplifyAtom e + ~(combine', Nothing) <- simplifyBinaryLam combine ~(lam', recon) <- simplifyBinaryLam lam - (ans, w) <- fromPair =<< (emit $ Hof $ RunWriter lam') + (ans, w) <- fromPair =<< (emit $ Hof $ RunWriter (BaseMonoid e' combine') lam') ans' <- applyRecon recon ans return $ PairVal ans' w RunState s lam -> do diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index 707be5460..54a57925e 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -45,9 +45,11 @@ module Syntax ( DataDef (..), DataConDef (..), UConDef (..), Nest (..), toNest, subst, deShadow, scopelessSubst, absArgType, applyAbs, makeAbs, applyNaryAbs, applyDataDefParams, freshSkolemVar, IndexStructure, - mkConsList, mkConsListTy, fromConsList, fromConsListTy, extendEffRow, + mkConsList, mkConsListTy, fromConsList, fromConsListTy, fromLeftLeaningConsListTy, + extendEffRow, getProjection, outputStreamPtrName, initBindings, varType, binderType, isTabTy, LogLevel (..), IRVariant (..), + BaseMonoidP (..), BaseMonoid, getBaseMonoidType, applyIntBinOp, applyIntCmpOp, applyFloatBinOp, applyFloatUnOp, getIntLit, getFloatLit, sizeOf, ptrSize, vectorWidth, pattern MaybeTy, pattern JustAtom, pattern NothingAtom, @@ -384,16 +386,20 @@ data PrimHof e = | Tile Int e e -- dimension number, tiled body, scalar body | While e | RunReader e e - | RunWriter e + | RunWriter (BaseMonoidP e) e | RunState e e | RunIO e | CatchException e | Linearize e | Transpose e - | PTileReduce e e -- index set, thread body + | PTileReduce [BaseMonoidP e] e e -- accumulator monoids, index set, thread body deriving (Show, Eq, Generic, Functor, Foldable, Traversable) -data PrimEffect e = MAsk | MTell e | MGet | MPut e +data BaseMonoidP e = BaseMonoid { baseEmpty :: e, baseCombine :: e } + deriving (Show, Eq, Generic, Functor, Foldable, Traversable) +type BaseMonoid = BaseMonoidP Atom + +data PrimEffect e = MAsk | MExtend e | MGet | MPut e deriving (Show, Eq, Generic, Functor, Foldable, Traversable) data BinOp = IAdd | ISub | IMul | IDiv | ICmp CmpOp @@ -440,6 +446,11 @@ primNameToStr prim = case lookup prim $ map swap $ M.toList builtinNames of showPrimName :: PrimExpr e -> String showPrimName prim = primNameToStr $ fmap (const ()) prim +getBaseMonoidType :: Type -> Type +getBaseMonoidType ty = case ty of + TabTy _ b -> getBaseMonoidType b + _ -> ty + -- === effects === data EffectRow = EffectRow (S.Set Effect) (Maybe Name) @@ -1481,6 +1492,15 @@ fromConsListTy ty = case ty of PairTy t rest -> (t:) <$> fromConsListTy rest _ -> throw CompilerErr $ "Not a pair or unit: " ++ show ty +-- ((...((ans & x{n}) & x{n-1})... & x2) & x1) -> (ans, [x1, ..., x{n}]) +fromLeftLeaningConsListTy :: MonadError Err m => Int -> Type -> m (Type, [Type]) +fromLeftLeaningConsListTy depth initTy = go depth initTy [] + where + go 0 ty xs = return (ty, reverse xs) + go remDepth ty xs = case ty of + PairTy lt rt -> go (remDepth - 1) lt (rt : xs) + _ -> throw CompilerErr $ "Not a pair: " ++ show xs + fromConsList :: MonadError Err m => Atom -> m [Atom] fromConsList xs = case xs of UnitVal -> return [] @@ -1600,7 +1620,7 @@ builtinNames = M.fromList , ("throwError" , OpExpr $ ThrowError ()) , ("throwException" , OpExpr $ ThrowException ()) , ("ask" , OpExpr $ PrimEffect () $ MAsk) - , ("tell" , OpExpr $ PrimEffect () $ MTell ()) + , ("mextend" , OpExpr $ PrimEffect () $ MExtend ()) , ("get" , OpExpr $ PrimEffect () $ MGet) , ("put" , OpExpr $ PrimEffect () $ MPut ()) , ("indexRef" , OpExpr $ IndexRef () ()) @@ -1610,7 +1630,7 @@ builtinNames = M.fromList , ("linearize" , HofExpr $ Linearize ()) , ("linearTranspose" , HofExpr $ Transpose ()) , ("runReader" , HofExpr $ RunReader () ()) - , ("runWriter" , HofExpr $ RunWriter ()) + , ("runWriter" , HofExpr $ RunWriter (BaseMonoid () ()) ()) , ("runState" , HofExpr $ RunState () ()) , ("runIO" , HofExpr $ RunIO ()) , ("catchException" , HofExpr $ CatchException ()) @@ -1665,6 +1685,7 @@ instance Store a => Store (Nest a) instance Store a => Store (ArrowP a) instance Store a => Store (Limit a) instance Store a => Store (PrimEffect a) +instance Store a => Store (BaseMonoidP a) instance Store a => Store (LabeledItems a) instance (Store a, Store b) => Store (ExtLabeledItems a b) instance Store ForAnn diff --git a/src/lib/Type.hs b/src/lib/Type.hs index 29248533a..4fd549e21 100644 --- a/src/lib/Type.hs +++ b/src/lib/Type.hs @@ -272,10 +272,10 @@ exprEffs expr = case expr of App f _ -> functionEffs f Op op -> case op of PrimEffect ref m -> case m of - MGet -> oneEffect (RWSEffect State h) - MPut _ -> oneEffect (RWSEffect State h) - MAsk -> oneEffect (RWSEffect Reader h) - MTell _ -> oneEffect (RWSEffect Writer h) + MGet -> oneEffect (RWSEffect State h) + MPut _ -> oneEffect (RWSEffect State h) + MAsk -> oneEffect (RWSEffect Reader h) + MExtend _ -> oneEffect (RWSEffect Writer h) where RefTy (Var (h:>_)) _ = getType ref ThrowException _ -> oneEffect ExceptionEffect IOAlloc _ _ -> oneEffect IOEffect @@ -291,9 +291,9 @@ exprEffs expr = case expr of Linearize _ -> mempty -- Body has to be a pure function Transpose _ -> mempty -- Body has to be a pure function RunReader _ f -> handleRWSRunner Reader f - RunWriter f -> handleRWSRunner Writer f + RunWriter _ f -> handleRWSRunner Writer f RunState _ f -> handleRWSRunner State f - PTileReduce _ _ -> mempty + PTileReduce _ _ _ -> mempty RunIO ~(Lam (Abs _ (PlainArrow (EffectRow effs t), _))) -> EffectRow (S.delete IOEffect effs) t CatchException ~(Lam (Abs _ (PlainArrow (EffectRow effs t), _))) -> @@ -445,14 +445,14 @@ instance CoreVariant (PrimHof a) where For _ _ -> alwaysAllowed While _ -> alwaysAllowed RunReader _ _ -> alwaysAllowed - RunWriter _ -> alwaysAllowed + RunWriter _ _ -> alwaysAllowed RunState _ _ -> alwaysAllowed RunIO _ -> alwaysAllowed Linearize _ -> goneBy Simp Transpose _ -> goneBy Simp Tile _ _ _ -> alwaysAllowed - PTileReduce _ _ -> absentUntil Simp -- really absent until parallelization - CatchException _ -> goneBy Simp + PTileReduce _ _ _ -> absentUntil Simp -- really absent until parallelization + CatchException _ -> goneBy Simp -- TODO: namespace restrictions? alwaysAllowed :: VariantM () @@ -704,10 +704,10 @@ typeCheckOp op = case op of PrimEffect ref m -> do TC (RefType ~(Just (Var (h':>TyKind))) s) <- typeCheck ref case m of - MGet -> declareEff (RWSEffect State h') $> s - MPut x -> x|:s >> declareEff (RWSEffect State h') $> UnitTy - MAsk -> declareEff (RWSEffect Reader h') $> s - MTell x -> x|:s >> declareEff (RWSEffect Writer h') $> UnitTy + MGet -> declareEff (RWSEffect State h') $> s + MPut x -> x|:s >> declareEff (RWSEffect State h') $> UnitTy + MAsk -> declareEff (RWSEffect Reader h') $> s + MExtend x -> x|:(s --> s) >> declareEff (RWSEffect Writer h') $> UnitTy IndexRef ref i -> do RefTy h (TabTyAbs a) <- typeCheck ref i |: absArgType a @@ -855,15 +855,16 @@ typeCheckHof hof = case hof of replaceDim 0 (TabTy _ b) n = TabTy (Ignore n) b replaceDim d (TabTy dv b) n = TabTy dv $ replaceDim (d-1) b n replaceDim _ _ _ = error "This should be checked before" - PTileReduce n mapping -> do - -- mapping : gtid:IdxRepTy -> nthr:IdxRepTy -> ((ParIndexRange n gtid nthr)=>a, r) + PTileReduce baseMonoids n mapping -> do + -- mapping : gtid:IdxRepTy -> nthr:IdxRepTy -> (...((ParIndexRange n gtid nthr)=>a, acc{n})..., acc1) BinaryFunTy (Bind gtid) (Bind nthr) Pure mapResultTy <- typeCheck mapping - PairTy tiledArrTy accTy <- return mapResultTy + (tiledArrTy, accTys) <- fromLeftLeaningConsListTy (length baseMonoids) mapResultTy let threadRange = TC $ ParIndexRange n (Var gtid) (Var nthr) TabTy threadRange' tileElemTy <- return tiledArrTy checkEq threadRange (binderType threadRange') - -- PTileReduce n mapping : (n=>a, ro) - return $ PairTy (TabTy (Ignore n) tileElemTy) accTy + -- TODO: Check compatibility of baseMonoids and accTys (need to be careful about lifting!) + -- PTileReduce n mapping : (n=>a, (acc1, ..., acc{n})) + return $ PairTy (TabTy (Ignore n) tileElemTy) $ mkConsListTy accTys While body -> do Pi (Abs (Ignore UnitTy) (arr , condTy)) <- typeCheck body declareEffs $ arrowEff arr @@ -879,7 +880,14 @@ typeCheckHof hof = case hof of (resultTy, readTy) <- checkRWSAction Reader f r |: readTy return resultTy - RunWriter f -> uncurry PairTy <$> checkRWSAction Writer f + RunWriter _ f -> do + -- XXX: We can't verify compatibility between the base monoid and f, because + -- the only way in which they are related in the runAccum definition is via + -- the AccumMonoid typeclass. The frontend constraints should be sufficient + -- to ensure that only well typed programs are accepted, but it is a bit + -- disappointing that we cannot verify that internally. We might want to consider + -- e.g. only disabling this check for prelude. + uncurry PairTy <$> checkRWSAction Writer f RunState s f -> do (resultTy, stateTy) <- checkRWSAction State f s |: stateTy diff --git a/tests/ad-tests.dx b/tests/ad-tests.dx index 6affc69f6..a843a49ad 100644 --- a/tests/ad-tests.dx +++ b/tests/ad-tests.dx @@ -1,6 +1,6 @@ -- TODO: use prelude sum instead once we can differentiate state effect -def sum' (xs:n=>Float) : Float = yieldAccum \ref. for i. ref += xs.i +def sum' (xs:n=>Float) : Float = yieldAccum (AddMonoid Float) \ref. for i. ref += xs.i :p f : Float -> Float = \x. x @@ -69,7 +69,7 @@ def sum' (xs:n=>Float) : Float = yieldAccum \ref. for i. ref += xs.i :p jvp sum' [1., 2.] [10.0, 20.0] > 30. -f : Float -> Float = \x. yieldAccum \ref. ref += x +f : Float -> Float = \x. yieldAccum (AddMonoid Float) \ref. ref += x :p jvp f 1.0 1.0 > 1. @@ -167,7 +167,7 @@ tripleit : Float --o Float = \x. x + x + x > [2., 4.] myOtherSquare : Float -> Float = - \x. yieldAccum \w. w += x * x + \x. yieldAccum (AddMonoid Float) \w. w += x * x :p checkDeriv myOtherSquare 3.0 > True @@ -225,7 +225,7 @@ vec = [1.] :p f : Float -> Float = \x. y = x * 2.0 - yieldAccum \a. + yieldAccum (AddMonoid Float) \a. a += x * 2.0 a += y grad f 1.0 diff --git a/tests/adt-tests.dx b/tests/adt-tests.dx index 08a28c6d0..42f84c043 100644 --- a/tests/adt-tests.dx +++ b/tests/adt-tests.dx @@ -102,7 +102,7 @@ myTab = [MyLeft 1, MyRight 3.5, MyLeft 123, MyLeft 456] > Runtime error :p - yieldAccum \ref. + yieldAccum (AddMonoid Float) \ref. for i. case myTab.i of MyLeft tmp -> () MyRight val -> ref += 1.0 + val @@ -110,7 +110,7 @@ myTab = [MyLeft 1, MyRight 3.5, MyLeft 123, MyLeft 456] :p -- check that the order of the case alternatives doesn't matter - yieldAccum \ref. + yieldAccum (AddMonoid Float) \ref. for i. case myTab.i of MyRight val -> ref += 1.0 + val MyLeft tmp -> () @@ -128,7 +128,7 @@ threeCaseTab : (Fin 4)=>ThreeCases = > [(TheIntCase 3), TheEmptyCase, (ThePairCase 2 0.1), TheEmptyCase] :p - yieldAccum \ref. + yieldAccum (AddMonoid Float) \ref. for i. case threeCaseTab.i of TheEmptyCase -> ref += 1000.0 ThePairCase x y -> ref += 100.0 + y + IToF x diff --git a/tests/eval-tests.dx b/tests/eval-tests.dx index f853dceab..2426304f7 100644 --- a/tests/eval-tests.dx +++ b/tests/eval-tests.dx @@ -698,7 +698,7 @@ def newtonSolve (tol:Float) (f : Float -> Float) (x0:Float) : Float = > 415 :p - (f, w) = runAccum \ref. + (f, w) = runAccum (AddMonoid Float) \ref. ref += 2.0 w = 2 \z. z + w @@ -716,17 +716,6 @@ arr2d = for i:(Fin 2). for j:(Fin 2). (iota _).(i,j) arr2d.(1@_) > [2, 3] -:p - runState (1,2) \ref. - r1 = fstRef ref - r2 = sndRef ref - x = get r1 - y = get r2 - r2 := x - r1 := y -> ((), (2, 1)) - - :p any [True, False] > True :p any [False, False] diff --git a/tests/monad-tests.dx b/tests/monad-tests.dx index 8f9994252..58e116d7f 100644 --- a/tests/monad-tests.dx +++ b/tests/monad-tests.dx @@ -27,6 +27,7 @@ :p def rwsAction (rh:Type) ?-> (wh:Type) ?-> (sh:Type) ?-> + (_:AccumMonoid wh Float) ?=> (r:Ref rh Int) (w:Ref wh Float) (s:Ref sh Bool) : {Read rh, Accum wh, State sh} Int = x = get s @@ -38,7 +39,7 @@ withReader 2 \r. runState True \s. - runAccum \w. + runAccum (AddMonoid Float) \w. rwsAction r w s > ((4, 6.), False) @@ -56,29 +57,31 @@ :p def m (wh:Type) ?-> (sh:Type) ?-> + (_:AccumMonoid wh Float) ?=> (w:Ref wh Float) (s:Ref sh Float) : {Accum wh, State sh} Unit = x = get s w += x - runState 1.0 \s. runAccum \w . m w s + runState 1.0 \s. runAccum (AddMonoid Float) \w . m w s > (((), 1.), 1.) -def myAction (w:Ref hw Float) (r:Ref hr Float) : {Read hr, Accum hw} Unit = +def myAction [AccumMonoid hw Float] (w:Ref hw Float) (r:Ref hr Float) : {Read hr, Accum hw} Unit = x = ask r w += x w += 2.0 -:p withReader 1.5 \r. runAccum \w. myAction w r +:p withReader 1.5 \r. runAccum (AddMonoid Float) \w. myAction w r > ((), 3.5) :p def m (h1:Type) ?-> (h2:Type) ?-> + (_:AccumMonoid h1 Float) ?=> (_:AccumMonoid h2 Float) ?=> (w1:Ref h1 Float) (w2:Ref h2 Float) : {Accum h1, Accum h2} Unit = w1 += 1.0 w2 += 3.0 w1 += 1.0 - runAccum \w1. runAccum \w2. m w1 w2 + runAccum (AddMonoid Float) \w1. runAccum (AddMonoid Float) \w2. m w1 w2 > (((), 3.), 2.) def foom (h:Type) ?-> (s:Ref h ((Fin 3)=>Int)) : {State h} Unit = @@ -125,8 +128,8 @@ def foom (h:Type) ?-> (s:Ref h ((Fin 3)=>Int)) : {State h} Unit = -- (maybe just explicit implicit args) :p withReader 2.0 \r. - runAccum \w. - runAccum \w'. + runAccum (AddMonoid Float) \w. + runAccum (AddMonoid Float) \w'. runState 3 \s. x = ask r y = get s @@ -151,19 +154,19 @@ symmetrizeInPlace [[1.,2.],[3.,4.]] :p withReader 5 \r. () > () -:p yieldAccum \w. +:p yieldAccum (AddMonoid Float) \w. for i:(Fin 2). w += 1.0 w += 1.0 > 4. -:p yieldAccum \w. +:p yieldAccum (AddMonoid Float) \w. for i:(Fin 2). w += 1.0 w += 1.0 > 3. -:p yieldAccum \ref. +:p yieldAccum (AddMonoid Float) \ref. ref += [1.,2.,3.] ref += [2.,4.,5.] > [3., 6., 8.] diff --git a/tests/parser-tests.dx b/tests/parser-tests.dx index 93d216354..3729dae35 100644 --- a/tests/parser-tests.dx +++ b/tests/parser-tests.dx @@ -113,7 +113,7 @@ def myInt : {State h} Int = 1 > Nullary def can't have effects :p - yieldAccum \ref. + yieldAccum (AddMonoid Float) \ref. x = if True then 1. else 3. if True then ref += x