From 95c4c8fd4b4360cd7425327a7925924fe8c55915 Mon Sep 17 00:00:00 2001 From: Daniel Johnson Date: Wed, 18 Nov 2020 15:29:49 -0500 Subject: [PATCH 01/13] Add ProjectElt atom, which extracts specific elements from constructors. To try to reduce the surface for bugs, ProjectElt syntactically requires its argument to be a Var, indexed by a nonempty list of indices. This means that substituting into a ProjectElt requires immediately reducing it. Note that getting the type of a ProjectElt atom is a bit subtle, because if we are extracting a value from an existential ADT DataCon, the type of the projected result may itelf include earlier bindings in the DataCon, which must also be converted to ProjectElt atoms. --- src/lib/Autodiff.hs | 2 ++ src/lib/Embed.hs | 1 + src/lib/PPrint.hs | 2 ++ src/lib/Simplify.hs | 1 + src/lib/Syntax.hs | 21 +++++++++++++++++++++ src/lib/Type.hs | 24 ++++++++++++++++++++++++ 6 files changed, 51 insertions(+) diff --git a/src/lib/Autodiff.hs b/src/lib/Autodiff.hs index f3bc926ac..c56670b12 100644 --- a/src/lib/Autodiff.hs +++ b/src/lib/Autodiff.hs @@ -341,6 +341,7 @@ linearizeAtom atom = case atom of Pi _ -> emitWithZero TC _ -> emitWithZero Eff _ -> emitWithZero + ProjectElt _ _ -> error "TODO: linearize projections" -- Those should be gone after simplification Lam _ -> error "Unexpected non-table lambda" ACase _ _ _ -> error "Unexpected ACase" @@ -700,6 +701,7 @@ transposeAtom atom ct = case atom of ACase _ _ _ -> error "Unexpected ACase" DataConRef _ _ _ -> error "Unexpected ref" BoxedRef _ _ _ _ -> error "Unexpected ref" + ProjectElt _ _ -> error "TODO: projection transpose types" where notTangent = error $ "Not a tangent atom: " ++ pprint atom transposeCon :: Con -> Atom -> TransposeM () diff --git a/src/lib/Embed.hs b/src/lib/Embed.hs index 62bb5a169..0cced290f 100644 --- a/src/lib/Embed.hs +++ b/src/lib/Embed.hs @@ -697,6 +697,7 @@ traverseAtom def@(_, _, fAtom) atom = case atom of case decls of Empty -> return $ BoxedRef b' ptr' size' body' _ -> error "Traversing the body atom shouldn't produce decls" + ProjectElt _ _ -> substEmbedR atom where traverseNestedArgs :: Nest DataConRefBinding -> m (Nest DataConRefBinding) traverseNestedArgs Empty = return Empty diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index 12eee01e9..a76df6bd1 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -370,6 +370,8 @@ instance PrettyPrec Atom where "DataConRef" <+> p params <+> p args BoxedRef b ptr size body -> atPrec AppPrec $ "Box" <+> p b <+> "<-" <+> p ptr <+> "[" <> p size <> "]" <+> hardline <> "in" <+> p body + ProjectElt idxs (x:>_) -> atPrec AppPrec $ + "ProjectElt" <+> p idxs <+> p x instance Pretty DataConRefBinding where pretty = prettyFromPrettyPrec instance PrettyPrec DataConRefBinding where diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index e096c8289..62c3fa3f7 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -144,6 +144,7 @@ simplifyAtom atom = case atom of ACase e' alts' <$> (substEmbedR rty) DataConRef _ _ _ -> error "Should only occur in Imp lowering" BoxedRef _ _ _ _ -> error "Should only occur in Imp lowering" + ProjectElt idxs v -> reduceProjection (toList idxs) <$> simplifyAtom (Var v) simplifyCase :: Atom -> [AltP a] -> Maybe (SubstEnv, a) simplifyCase e alts = case e of diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index 92d3362c8..f250b91e1 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -46,6 +46,7 @@ module Syntax ( subst, deShadow, scopelessSubst, absArgType, applyAbs, makeAbs, applyNaryAbs, applyDataDefParams, freshSkolemVar, IndexStructure, mkConsList, mkConsListTy, fromConsList, fromConsListTy, extendEffRow, + reduceProjection, varType, binderType, isTabTy, LogLevel (..), IRVariant (..), applyIntBinOp, applyIntCmpOp, applyFloatBinOp, applyFloatUnOp, getIntLit, getFloatLit, sizeOf, vectorWidth, pattern CharLit, @@ -103,6 +104,7 @@ data Atom = Var Var -- single-constructor only for now | DataConRef DataDef [Atom] (Nest DataConRefBinding) | BoxedRef Binder Atom Block Atom -- binder, ptr, size, body + | ProjectElt (NE.NonEmpty Int) Var -- access a nested member of a binder deriving (Show, Generic) data Expr = App Atom Atom @@ -945,6 +947,7 @@ instance Eq Atom where Con con == Con con' = con == con' TC con == TC con' = con == con' Eff eff == Eff eff' = eff == eff' + ProjectElt idxs v == ProjectElt idxs' v' = (idxs, v) == (idxs', v') _ == _ = False instance Eq DataDef where @@ -1077,6 +1080,7 @@ instance HasVars Atom where DataConRef _ params args -> freeVars params <> freeVars args BoxedRef b ptr size body -> freeVars ptr <> freeVars size <> freeVars (Abs b body) + ProjectElt _ v -> freeVars (Var v) instance Subst Atom where subst env atom = case atom of @@ -1098,6 +1102,7 @@ instance Subst Atom where where Abs args' () = subst env $ Abs args () BoxedRef b ptr size body -> BoxedRef b' (subst env ptr) (subst env size) body' where Abs b' body' = subst env $ Abs b body + ProjectElt idxs v -> substProjectElt (fst env) idxs v instance HasVars Module where freeVars (Module _ decls bindings) = freeVars $ Abs decls bindings @@ -1170,6 +1175,22 @@ substExtLabeledItemsTail env (Just v) = case envLookup env (v:>()) of Just (LabeledRow row) -> row _ -> error "Not a valid labeled row substitution" +substProjectElt :: SubstEnv -> NE.NonEmpty Int -> Var -> Atom +substProjectElt env idxs v = case envLookup env v of + Nothing -> ProjectElt idxs v + Just (Var v') -> ProjectElt idxs v' + Just atom -> reduceProjection (toList idxs) atom + +reduceProjection :: [Int] -> Atom -> Atom +reduceProjection [] a = a +reduceProjection (i:is) a = case reduceProjection is a of + ProjectElt idxs' a' -> ProjectElt (NE.cons i idxs') a' + DataCon _ _ _ xs -> xs !! i + Record items -> (toList items) !! i + PairVal x _ | i == 0 -> x + PairVal _ y | i == 1 -> y + _ -> error "Not a valid projection" + instance HasVars () where freeVars () = mempty instance Subst () where subst _ () = () diff --git a/src/lib/Type.hs b/src/lib/Type.hs index 939eb93f4..2e51acf77 100644 --- a/src/lib/Type.hs +++ b/src/lib/Type.hs @@ -162,6 +162,29 @@ instance HasType Atom where numel |: IdxRepTy void $ typeCheck b withBinder b $ typeCheck body + ProjectElt (i NE.:| is) v -> do + ty <- typeCheck $ case NE.nonEmpty is of + Nothing -> Var v + Just is' -> ProjectElt is' v + case ty of + TypeCon def params -> do + [DataConDef _ bs'] <- return $ applyDataDefParams def params + -- Users might be accessing a value whose type depends on earlier + -- projected values from this constructor. Rewrite them to also + -- use projections. + let go :: Int -> Nest Binder -> Type + go j (Nest b _) | i == j = binderAnn b + -- TODO: is scopelessSubst correct here? + go j (Nest b rest) = go (j+1) (scopelessSubst (b @> proj) rest) + where proj = ProjectElt (j NE.:| is) v + go _ _ = error "Bad projection index" + return $ go 0 bs' + RecordTy (NoExt types) -> return $ toList types !! i + RecordTy _ -> throw CompilerErr "Can't project partially-known records" + PairTy x _ | i == 0 -> return x + PairTy _ y | i == 1 -> return y + _ -> throw TypeErr "Only single-member ADTs and record types can be projected" + checkDataConRefBindings :: Nest Binder -> Nest DataConRefBinding -> TypeM () checkDataConRefBindings Empty Empty = return () @@ -374,6 +397,7 @@ instance CoreVariant Atom where ACase _ _ _ -> goneBy Simp DataConRef _ _ _ -> neverAllowed -- only used internally in Imp lowering BoxedRef _ _ _ _ -> neverAllowed -- only used internally in Imp lowering + ProjectElt _ (_:>ty) -> checkVariant ty instance CoreVariant BinderInfo where checkVariant info = case info of From ec5017ffda85a9bd29db8acdc65acaa2dbc047f6 Mon Sep 17 00:00:00 2001 From: Daniel Johnson Date: Wed, 18 Nov 2020 17:59:48 -0500 Subject: [PATCH 02/13] Generate projections instead of Unpack decls. After this change, we should use projections instead of unpacks in most places. There's a few remaining bugs due to differences in simplifications between decls and atoms. --- examples/adt-tests.dx | 22 ++++++++++++---------- src/lib/Embed.hs | 28 +++++++++++----------------- src/lib/Simplify.hs | 2 +- src/lib/Syntax.hs | 13 +++++++------ 4 files changed, 31 insertions(+), 34 deletions(-) diff --git a/examples/adt-tests.dx b/examples/adt-tests.dx index f59e18d7b..aa22ade2c 100644 --- a/examples/adt-tests.dx +++ b/examples/adt-tests.dx @@ -162,10 +162,11 @@ data MyIntish = MkIntish Int xsList = AsList _ [1,2,3] -:p - (AsList _ xsTab) = xsList - sum xsTab -> 6 +-- TODO: reenable! +-- :p +-- (AsList _ xsTab) = xsList +-- sum xsTab +-- > 6 (AsList _ xsTab) = xsList @@ -179,12 +180,13 @@ xsList = AsList _ [1,2,3] sum ans > 15 -:p - (MkMyPair x y) = case 3 < 2 of - True -> MkMyPair 1 2 - False -> MkMyPair 3 4 - (x, y) -> (3, 4) +-- TODO: reenable! +-- :p +-- (MkMyPair x y) = case 3 < 2 of +-- True -> MkMyPair 1 2 +-- False -> MkMyPair 3 4 +-- (x, y) +-- > (3, 4) def catLists (xs:List a) (ys:List a) : List a = (AsList nx xs') = xs diff --git a/src/lib/Embed.hs b/src/lib/Embed.hs index 0cced290f..e012f525a 100644 --- a/src/lib/Embed.hs +++ b/src/lib/Embed.hs @@ -103,20 +103,7 @@ emitOp op = emit $ Op op emitUnpack :: MonadEmbed m => Expr -> m [Atom] emitUnpack expr = do - bs <- case getType expr of - TypeCon def params -> do - let [DataConDef _ bs] = applyDataDefParams def params - return bs - RecordTy (NoExt types) -> do - -- TODO: is using Ignore here appropriate? We don't have any existing - -- binders to bind, but we still plan to use the results. - let bs = toNest $ map Ignore $ toList types - return bs - _ -> error $ "Unpacking a type that doesn't support unpacking: " ++ pprint (getType expr) - expr' <- deShadow expr <$> getScope - vs <- freshNestedBinders bs - embedExtend $ asSnd $ Nest (Unpack (fmap Bind vs) expr') Empty - return $ map Var $ toList vs + getUnpacked =<< emit expr -- Assumes the decl binders are already fresh wrt current scope emitBlock :: MonadEmbed m => Block -> m Atom @@ -292,6 +279,7 @@ ieq :: MonadEmbed m => Atom -> Atom -> m Atom ieq x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntCmpOp (==) x y ieq x y = emitOp $ ScalarBinOp (ICmp Equal) x y +-- TODO: make pairs also use projection atoms? getFst :: MonadEmbed m => Atom -> m Atom getFst (PairVal x _) = return x getFst p = emitOp $ Fst p @@ -306,10 +294,16 @@ getFstRef r = emitOp $ FstRef r getSndRef :: MonadEmbed m => Atom -> m Atom getSndRef r = emitOp $ SndRef r +-- TODO: refactor? getUnpacked :: MonadEmbed m => Atom -> m [Atom] -getUnpacked (DataCon _ _ _ xs) = return xs -getUnpacked (Record items) = return $ toList items -getUnpacked a = emitUnpack (Atom a) +getUnpacked atom = return res where + len = case getType atom of + TypeCon def params -> + let [DataConDef _ bs] = applyDataDefParams def params + in length bs + RecordTy (NoExt types) -> length types + ty -> error $ "Unpacking a type that doesn't support unpacking: " ++ pprint ty + res = map (\i -> getProjection [i] atom) [0..(len-1)] app :: MonadEmbed m => Atom -> Atom -> m Atom app x i = emit $ App x i diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index 62c3fa3f7..720fdf3a6 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -144,7 +144,7 @@ simplifyAtom atom = case atom of ACase e' alts' <$> (substEmbedR rty) DataConRef _ _ _ -> error "Should only occur in Imp lowering" BoxedRef _ _ _ _ -> error "Should only occur in Imp lowering" - ProjectElt idxs v -> reduceProjection (toList idxs) <$> simplifyAtom (Var v) + ProjectElt idxs v -> getProjection (toList idxs) <$> simplifyAtom (Var v) simplifyCase :: Atom -> [AltP a] -> Maybe (SubstEnv, a) simplifyCase e alts = case e of diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index f250b91e1..4f50876ae 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -46,7 +46,7 @@ module Syntax ( subst, deShadow, scopelessSubst, absArgType, applyAbs, makeAbs, applyNaryAbs, applyDataDefParams, freshSkolemVar, IndexStructure, mkConsList, mkConsListTy, fromConsList, fromConsListTy, extendEffRow, - reduceProjection, + getProjection, varType, binderType, isTabTy, LogLevel (..), IRVariant (..), applyIntBinOp, applyIntCmpOp, applyFloatBinOp, applyFloatUnOp, getIntLit, getFloatLit, sizeOf, vectorWidth, pattern CharLit, @@ -1179,17 +1179,18 @@ substProjectElt :: SubstEnv -> NE.NonEmpty Int -> Var -> Atom substProjectElt env idxs v = case envLookup env v of Nothing -> ProjectElt idxs v Just (Var v') -> ProjectElt idxs v' - Just atom -> reduceProjection (toList idxs) atom + Just atom -> getProjection (toList idxs) atom -reduceProjection :: [Int] -> Atom -> Atom -reduceProjection [] a = a -reduceProjection (i:is) a = case reduceProjection is a of +getProjection :: [Int] -> Atom -> Atom +getProjection [] a = a +getProjection (i:is) a = case getProjection is a of + Var v -> ProjectElt (NE.fromList [i]) v ProjectElt idxs' a' -> ProjectElt (NE.cons i idxs') a' DataCon _ _ _ xs -> xs !! i Record items -> (toList items) !! i PairVal x _ | i == 0 -> x PairVal _ y | i == 1 -> y - _ -> error "Not a valid projection" + _ -> error $ "Not a valid projection: " ++ show i ++ " of " ++ show a instance HasVars () where freeVars () = mempty instance Subst () where subst _ () = () From 5f18c411cde3fb526949fa24539b77f6a976f0ec Mon Sep 17 00:00:00 2001 From: Daniel Johnson Date: Wed, 18 Nov 2020 18:07:10 -0500 Subject: [PATCH 03/13] Simplify through types in type constructors. Now that types can involve projections, we need to fully simplify the type arguments to type constructors instead of simply using substEmbedR. This fixes one of the broken tests. --- examples/adt-tests.dx | 9 ++++----- src/lib/Simplify.hs | 18 ++++++++++++------ 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/examples/adt-tests.dx b/examples/adt-tests.dx index aa22ade2c..5ea4bd58d 100644 --- a/examples/adt-tests.dx +++ b/examples/adt-tests.dx @@ -162,11 +162,10 @@ data MyIntish = MkIntish Int xsList = AsList _ [1,2,3] --- TODO: reenable! --- :p --- (AsList _ xsTab) = xsList --- sum xsTab --- > 6 +:p + (AsList _ xsTab) = xsList + sum xsTab +> 6 (AsList _ xsTab) = xsList diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index 720fdf3a6..502cbf817 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -119,17 +119,17 @@ simplifyAtom atom = case atom of Lam _ -> substEmbedR atom Pi _ -> substEmbedR atom Con con -> Con <$> mapM simplifyAtom con - TC tc -> TC <$> mapM substEmbedR tc + TC tc -> TC <$> mapM simplifyAtom tc Eff eff -> Eff <$> substEmbedR eff - TypeCon def params -> TypeCon def <$> substEmbedR params - DataCon def params con args -> DataCon def <$> substEmbedR params + TypeCon def params -> TypeCon def <$> mapM simplifyAtom params + DataCon def params con args -> DataCon def <$> mapM simplifyAtom params <*> pure con <*> mapM simplifyAtom args Record items -> Record <$> mapM simplifyAtom items - RecordTy items -> RecordTy <$> substEmbedR items + RecordTy items -> RecordTy <$> simplifyExtLabeledItems items Variant types label i value -> Variant <$> substEmbedR types <*> pure label <*> pure i <*> simplifyAtom value - VariantTy items -> VariantTy <$> substEmbedR items - LabeledRow items -> LabeledRow <$> substEmbedR items + VariantTy items -> VariantTy <$> simplifyExtLabeledItems items + LabeledRow items -> LabeledRow <$> simplifyExtLabeledItems items ACase e alts rty -> do e' <- substEmbedR e case simplifyCase e' alts of @@ -146,6 +146,12 @@ simplifyAtom atom = case atom of BoxedRef _ _ _ _ -> error "Should only occur in Imp lowering" ProjectElt idxs v -> getProjection (toList idxs) <$> simplifyAtom (Var v) +simplifyExtLabeledItems :: ExtLabeledItems Atom Name -> SimplifyM (ExtLabeledItems Atom Name) +simplifyExtLabeledItems (Ext items ext) = do + items' <- mapM simplifyAtom items + ext' <- substEmbedR (Ext NoLabeledItems ext) + return $ prefixExtLabeledItems items' ext' + simplifyCase :: Atom -> [AltP a] -> Maybe (SubstEnv, a) simplifyCase e alts = case e of DataCon _ _ con args -> do From c39ae751d5a7a43b53be5526e89d9ce83019bea7 Mon Sep 17 00:00:00 2001 From: Daniel Johnson Date: Wed, 18 Nov 2020 18:15:08 -0500 Subject: [PATCH 04/13] Add a case in gatherVarDests for ProjectElt. There is likely a better implementation that re-uses an existing destination if we have the right structure. But this seems to work for now, and fixes the broken tests. --- examples/adt-tests.dx | 13 ++++++------- src/lib/Imp.hs | 1 + 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/adt-tests.dx b/examples/adt-tests.dx index 5ea4bd58d..f59e18d7b 100644 --- a/examples/adt-tests.dx +++ b/examples/adt-tests.dx @@ -179,13 +179,12 @@ xsList = AsList _ [1,2,3] sum ans > 15 --- TODO: reenable! --- :p --- (MkMyPair x y) = case 3 < 2 of --- True -> MkMyPair 1 2 --- False -> MkMyPair 3 4 --- (x, y) --- > (3, 4) +:p + (MkMyPair x y) = case 3 < 2 of + True -> MkMyPair 1 2 + False -> MkMyPair 3 4 + (x, y) +> (3, 4) def catLists (xs:List a) (ys:List a) : List a = (AsList nx xs') = xs diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index 7717127b7..6caa0e93d 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -773,6 +773,7 @@ splitDest (maybeDest, (Block decls ans)) = do | fmap (const ()) items == fmap (const ()) items' -> do zipWithM_ gatherVarDests (toList items) (toList items') (Con (ConRef (SumAsProd _ _ _)), _) -> tell [(dest, result)] -- TODO + (_, ProjectElt _ _) -> tell [(dest, result)] -- TODO: is this reasonable? _ -> unreachable where unreachable = error $ "Invalid dest-result pair:\n" From a01570c7e1d97fd6a0f80cef8b0f4e052caaf059 Mon Sep 17 00:00:00 2001 From: Daniel Johnson Date: Thu, 19 Nov 2020 20:49:58 -0500 Subject: [PATCH 05/13] Remove Unpack decls from the core IR. All unpacks can be represented as let decls followed by projections. --- src/lib/Autodiff.hs | 3 +-- src/lib/Embed.hs | 5 ----- src/lib/Imp.hs | 9 --------- src/lib/Interpreter.hs | 8 -------- src/lib/Optimize.hs | 2 -- src/lib/PPrint.hs | 1 - src/lib/Simplify.hs | 7 ------- src/lib/Syntax.hs | 9 +-------- src/lib/Type.hs | 15 --------------- 9 files changed, 2 insertions(+), 57 deletions(-) diff --git a/src/lib/Autodiff.hs b/src/lib/Autodiff.hs index c56670b12..02a69a8d2 100644 --- a/src/lib/Autodiff.hs +++ b/src/lib/Autodiff.hs @@ -54,10 +54,10 @@ linearizeBlock env (Block decls result) = case decls of Empty -> linearizeExpr env result Nest decl rest -> case decl of (Let _ b expr) -> linearizeBinding False [b] expr - (Unpack bs expr) -> linearizeBinding True (toList bs) expr where body = Block rest result takeWhere l m = fmap snd $ filter fst $ zip m l + -- TODO: refactor this to not have isUnpack linearizeBinding :: Bool -> [Binder] -> Expr -> LinA Atom linearizeBinding isUnpack bs expr = LinA $ do -- Don't linearize expressions with no free active variables. @@ -538,7 +538,6 @@ transposeBlock (Block decls result) ct = case decls of Empty -> transposeExpr result ct Nest decl rest -> case decl of (Let _ b expr) -> transposeBinding False [b] expr - (Unpack bs expr) -> transposeBinding True (toList bs) expr where body = Block rest result transposeBinding isUnpack bs expr = do diff --git a/src/lib/Embed.hs b/src/lib/Embed.hs index e012f525a..d21f18d2e 100644 --- a/src/lib/Embed.hs +++ b/src/lib/Embed.hs @@ -565,7 +565,6 @@ emitDecl :: MonadEmbed m => Decl -> m () emitDecl decl = embedExtend (bindings, Nest decl Empty) where bindings = case decl of Let ann b expr -> b @> (binderType b, LetBound ann expr) - Unpack bs _ -> foldMap (\b -> b @> (binderType b, PatBound)) bs scopedDecls :: MonadEmbed m => m a -> m (a, Nest Decl) scopedDecls m = do @@ -620,10 +619,6 @@ traverseDecl (_, fExpr, _) decl = case decl of Atom a | not (isGlobalBinder b) -> return $ b @> a -- TODO: Do we need to use the name hint here? _ -> (b@>) <$> emitTo (binderNameHint b) letAnn expr' - Unpack bs expr -> do - expr' <- fExpr expr - xs <- emitUnpack expr' - return $ newEnv bs xs traverseBlock :: (MonadEmbed m, MonadReader SubstEnv m) => TraversalDef m -> Block -> m Block diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index 6caa0e93d..12417e2da 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -123,13 +123,6 @@ translateDecl env (maybeDest, (Let _ b bound)) = do b' <- traverse (impSubst env) b ans <- translateExpr env (maybeDest, bound) return $ b' @> ans -translateDecl env (maybeDest, (Unpack bs bound)) = do - bs' <- mapM (traverse (impSubst env)) bs - expr <- translateExpr env (maybeDest, bound) - case expr of - DataCon _ _ _ ans -> return $ newEnv bs' ans - Record items -> return $ newEnv bs $ toList items - _ -> error "Unsupported type in an Unpack binding" translateExpr :: SubstEnv -> WithDest Expr -> ImpM Atom translateExpr env (maybeDest, expr) = case expr of @@ -750,7 +743,6 @@ splitDest (maybeDest, (Block decls ans)) = do let destDecls = flip fmap (toList decls) $ \d -> case d of Let _ b _ -> (fst <$> varDests `envLookup` b, d) - Unpack _ _ -> (Nothing, d) (destDecls, (Nothing, ans), gatherCopies ++ closureCopies) _ -> (fmap (Nothing,) $ toList decls, (maybeDest, ans), []) where @@ -781,7 +773,6 @@ splitDest (maybeDest, (Block decls ans)) = do letBoundVars :: Decl -> Env () letBoundVars (Let _ b _) = b @> () -letBoundVars (Unpack _ _) = mempty copyDest :: Maybe Dest -> Atom -> ImpM Atom copyDest maybeDest atom = case maybeDest of diff --git a/src/lib/Interpreter.hs b/src/lib/Interpreter.hs index 847ec0746..4dc8eefee 100644 --- a/src/lib/Interpreter.hs +++ b/src/lib/Interpreter.hs @@ -46,14 +46,6 @@ evalBlock env (Block decls result) = do evalDecl :: SubstEnv -> Decl -> InterpM SubstEnv evalDecl env (Let _ v rhs) = liftM (v @>) $ evalExpr env rhs' where rhs' = subst (env, mempty) rhs -evalDecl env (Unpack vs rhs) = do - let rhs' = subst (env, mempty) rhs - ans <- evalExpr env rhs' - let atoms = case ans of - DataCon _ _ _ atoms' -> atoms' - Record atoms' -> toList atoms' - _ -> error $ "Can't unpack: " <> pprint rhs' - return $ fold $ map (uncurry (@>)) $ zip (toList vs) atoms evalExpr :: SubstEnv -> Expr -> InterpM Atom evalExpr env expr = case expr of diff --git a/src/lib/Optimize.hs b/src/lib/Optimize.hs index 616db19fd..dd88dac68 100644 --- a/src/lib/Optimize.hs +++ b/src/lib/Optimize.hs @@ -52,7 +52,6 @@ dceDecl :: Decl -> DceM (Maybe Decl) dceDecl decl = do newDecl <- case decl of Let ann b expr -> go [b] expr $ Let ann b - Unpack bs expr -> go bs expr $ Unpack bs modify (<> freeVars newDecl) return newDecl where @@ -155,7 +154,6 @@ computeInlineHints m@(Module _ _ bindings) = hintDecl decl = case decl of Let ann b expr -> go [b] expr $ Let ann . head - Unpack bs expr -> go bs expr $ Unpack where go bs expr mkDecl = do void $ noInlineFree bs diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index a76df6bd1..c61241e11 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -296,7 +296,6 @@ instance Pretty Decl where -- This is just to reduce clutter a bit. We can comment it out when needed. -- Let (v:>Pi _) bound -> p v <+> "=" <+> p bound Let _ b rhs -> align $ p b <+> "=" <> (nest 2 $ group $ line <> pLowest rhs) - Unpack bs rhs -> align $ p (toList bs) <+> "=" <> (nest 2 $ group $ line <> pLowest rhs) prettyPiTypeHelper :: PiType -> Doc ann prettyPiTypeHelper (Abs binder (arr, body)) = let diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index 502cbf817..d82d8484c 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -82,13 +82,6 @@ simplifyDecl (Let ann b expr) = do if isGlobalBinder b then emitTo name ann (Atom x) $> mempty else return $ b @> x -simplifyDecl (Unpack bs expr) = do - x <- simplifyExpr expr - xs <- case x of - DataCon _ _ _ xs -> return xs - Record items -> return $ toList items - _ -> emitUnpack $ Atom x - return $ newEnv bs xs simplifyBlock :: Block -> SimplifyM Atom simplifyBlock (Block decls result) = do diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index 4f50876ae..6d65c44bd 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -114,8 +114,7 @@ data Expr = App Atom Atom | Hof Hof deriving (Show, Generic) -data Decl = Let LetAnn Binder Expr - | Unpack (Nest Binder) Expr deriving (Show, Generic) +data Decl = Let LetAnn Binder Expr deriving (Show, Generic) data DataConRefBinding = DataConRefBinding Binder Atom deriving (Show, Generic) @@ -1033,25 +1032,19 @@ instance Subst Expr where instance HasVars Decl where freeVars decl = case decl of Let _ b expr -> freeVars expr <> freeVars b - Unpack bs expr -> freeVars expr <> freeVars bs instance Subst Decl where subst env decl = case decl of Let ann b expr -> Let ann (fmap (subst env) b) $ subst env expr - Unpack bs expr -> Unpack (subst env bs) $ subst env expr instance BindsVars Decl where boundVars decl = case decl of Let ann b expr -> b @> (binderType b, LetBound ann expr) - Unpack bs _ -> boundVars bs renamingSubst env decl = case decl of Let ann b expr -> (Let ann b' expr', env') where expr' = subst env expr (b', env') = renamingSubst env b - Unpack bs expr -> (Unpack bs' expr', env') - where expr' = subst env expr - (bs', env') = renamingSubst env bs instance HasVars Block where freeVars (Block decls result) = freeVars $ Abs decls result diff --git a/src/lib/Type.hs b/src/lib/Type.hs index 2e51acf77..87377e158 100644 --- a/src/lib/Type.hs +++ b/src/lib/Type.hs @@ -260,7 +260,6 @@ blockEffs :: Block -> EffectSummary blockEffs (Block decls result) = foldMap declEffs decls <> exprEffs result where declEffs (Let _ _ expr) = exprEffs expr - declEffs (Unpack _ expr) = exprEffs expr isPure :: Expr -> Bool isPure expr = exprEffs expr == mempty @@ -324,19 +323,6 @@ checkDecl env decl = withTypeEnv env $ addContext ctxStr $ case decl of ty' <- typeCheck rhs checkEq ty ty' return $ boundVars b - Unpack bs rhs -> do - void $ checkNestedBinders bs - ty <- typeCheck rhs - bs' <- case ty of - TypeCon def params -> do - [DataConDef _ bs'] <- return $ applyDataDefParams def params - return bs' - RecordTy (NoExt types) -> - return $ toNest $ map Ignore $ toList types - RecordTy _ -> throw CompilerErr "Can't unpack partially-known records" - _ -> throw TypeErr $ "Only single-member ADTs and record types can be unpacked in let bindings" - checkEq bs bs' - return $ foldMap boundVars bs where ctxStr = "checking decl:\n" ++ pprint decl checkNestedBinders :: Nest Binder -> TypeM (Nest Type) @@ -419,7 +405,6 @@ instance CoreVariant Expr where instance CoreVariant Decl where -- let annotation restrictions? checkVariant (Let _ b e) = checkVariant b >> checkVariant e - checkVariant (Unpack bs e) = mapM checkVariant bs >> checkVariant e instance CoreVariant Block where checkVariant (Block ds e) = mapM_ checkVariant ds >> checkVariant e From a11aca7948e0bfbcb7d7df134c06290e3545bc3c Mon Sep 17 00:00:00 2001 From: Daniel Johnson Date: Thu, 19 Nov 2020 21:55:33 -0500 Subject: [PATCH 06/13] Modify implicit args, reduce projections, get list length example working. Modify implicit args: To allow function calls inside type annotations, this change makes it so that any lowercase name that is used in function position of an application (e.g. `f x`) will NOT be added as an automatic implicit type parameter. This is an improvement because such a type application would never typecheck anyway if `f` was inferred to have type `Type`. This also matches the behavior of Idris. Reduce projections: We assume during typechecking that `getType` always returns a fully reduced type. But `getType` of a projection may produce another projection with the same root variable. Thus, whenver we create a new projection, we have to reduce the variable. (It's not clear that this is the best way to do this, but it seems to work for now.) Get list extraction example working: With these two changes, it becomes possible to construct functions that do dependent projections, for instance by converting a `List a` into a table. The syntax is a bit unwieldy for this, but that should be easy to fix. --- examples/adt-tests.dx | 17 +++++++++ src/lib/Embed.hs | 82 ++++++++++++++++++++++++++++++++++++++----- src/lib/Inference.hs | 58 ------------------------------ src/lib/Parser.hs | 49 +++++++++++++++++++++++--- src/lib/Syntax.hs | 5 ++- 5 files changed, 139 insertions(+), 72 deletions(-) diff --git a/examples/adt-tests.dx b/examples/adt-tests.dx index f59e18d7b..59073bb90 100644 --- a/examples/adt-tests.dx +++ b/examples/adt-tests.dx @@ -209,3 +209,20 @@ def catLists (xs:List a) (ys:List a) : List a = n = 1 + 4 AsList _ (for i:(Fin n). ordinal i) > (AsList 5 [0, 1, 2, 3, 4]) + +def listLength ((AsList length xs):List a) : Int = length + +-- TODO: not yet supported +-- def listToTable1 ((AsList n xs): List a) : (Fin n)=>a = xs + +def listToTable2 (l: List a) : (Fin (listLength l))=>a = + (AsList _ xs) = l + xs + +:t listToTable2 +> ((a:Type) ?-> (l:(List a)) -> (Fin (ProjectElt [0] l)) => a) + +:p + l = AsList _ [1, 2, 3] + sum $ listToTable2 l +> 6 diff --git a/src/lib/Embed.hs b/src/lib/Embed.hs index d21f18d2e..33d9a0787 100644 --- a/src/lib/Embed.hs +++ b/src/lib/Embed.hs @@ -30,7 +30,8 @@ module Embed (emit, emitTo, emitAnn, emitOp, buildDepEffLam, buildLamAux, buildP TraversalDef, traverseDecls, traverseDecl, traverseBlock, traverseExpr, clampPositive, buildNAbs, buildNAbsAux, buildNestedLam, zeroAt, transformModuleAsBlock, dropSub, appReduceTraversalDef, - indexSetSizeE, indexToIntE, intToIndexE, freshVarE) where + indexSetSizeE, indexToIntE, intToIndexE, freshVarE, + reduceScoped, reduceBlock, reduceAtom, reduceExpr) where import Control.Applicative import Control.Monad @@ -50,6 +51,7 @@ import Cat import Type import PPrint import Util (bindM2, scanM, restructure) +import Data.Maybe (fromMaybe) newtype EmbedT m a = EmbedT (ReaderT EmbedEnvR (CatT EmbedEnvC m) a) deriving (Functor, Applicative, Monad, MonadIO, MonadFail, Alternative) @@ -295,15 +297,19 @@ getSndRef :: MonadEmbed m => Atom -> m Atom getSndRef r = emitOp $ SndRef r -- TODO: refactor? +-- TODO: is this the best place for the reduction? getUnpacked :: MonadEmbed m => Atom -> m [Atom] -getUnpacked atom = return res where - len = case getType atom of - TypeCon def params -> - let [DataConDef _ bs] = applyDataDefParams def params - in length bs - RecordTy (NoExt types) -> length types - ty -> error $ "Unpacking a type that doesn't support unpacking: " ++ pprint ty - res = map (\i -> getProjection [i] atom) [0..(len-1)] +getUnpacked atom = do + scope <- getScope + let len = case getType atom of + TypeCon def params -> + let [DataConDef _ bs] = applyDataDefParams def params + in length bs + RecordTy (NoExt types) -> length types + ty -> error $ "Unpacking a type that doesn't support unpacking: " ++ pprint ty + atom' = reduceAtom scope atom + res = map (\i -> getProjection [i] atom') [0..(len-1)] + return res app :: MonadEmbed m => Atom -> Atom -> m Atom app x i = emit $ App x i @@ -820,3 +826,61 @@ intToIndexE (VariantTy (NoExt types)) i = do start <- Variant (NoExt types) l0 0 <$> intToIndexE ty0 i foldM go start zs intToIndexE ty _ = error $ "Unexpected type " ++ pprint ty + +-- === Reduction === + +reduceScoped :: MonadEmbed m => m Atom -> m (Maybe Atom) +reduceScoped m = do + block <- buildScoped m + scope <- getScope + return $ reduceBlock scope block + +reduceBlock :: Scope -> Block -> Maybe Atom +reduceBlock scope (Block decls result) = do + let localScope = foldMap boundVars decls + ans <- reduceExpr (scope <> localScope) result + [] <- return $ toList $ localScope `envIntersect` freeVars ans + return ans + +-- XXX: This should handle all terms of type Type. Otherwise type equality checking +-- will get broken. +reduceAtom :: Scope -> Atom -> Atom +reduceAtom scope x = case x of + Var (Name InferenceName _ _ :> _) -> x + Var v -> case snd (scope ! v) of + -- TODO: worry about effects! + LetBound PlainLet expr -> fromMaybe x $ reduceExpr scope expr + _ -> x + TC con -> TC $ fmap (reduceAtom scope) con + Pi (Abs b (arr, ty)) -> Pi $ Abs b (arr, reduceAtom (scope <> (fmap (,PiBound) $ binderAsEnv b)) ty) + TypeCon def params -> TypeCon (reduceDataDef def) (fmap rec params) + RecordTy (Ext tys ext) -> RecordTy $ Ext (fmap rec tys) ext + VariantTy (Ext tys ext) -> VariantTy $ Ext (fmap rec tys) ext + ACase _ _ _ -> error "Not implemented" + _ -> x + where + rec = reduceAtom scope + reduceNest s n = case n of + Empty -> Empty + -- Technically this should use a more concrete type than UnknownBinder, but anything else + -- than LetBound is indistinguishable for this reduction anyway. + Nest b rest -> Nest b' $ reduceNest (s <> (fmap (,UnknownBinder) $ binderAsEnv b)) rest + where b' = fmap (reduceAtom s) b + reduceDataDef (DataDef n bs cons) = + DataDef n (reduceNest scope bs) + (fmap (reduceDataConDef (scope <> (foldMap (fmap (,UnknownBinder) . binderAsEnv) bs))) cons) + reduceDataConDef s (DataConDef n bs) = DataConDef n $ reduceNest s bs + +reduceExpr :: Scope -> Expr -> Maybe Atom +reduceExpr scope expr = case expr of + Atom val -> return $ reduceAtom scope val + App f x -> do + let f' = reduceAtom scope f + let x' = reduceAtom scope x + -- TODO: Worry about variable capture. Should really carry a substitution. + case f' of + Lam (Abs b (arr, block)) | arr == PureArrow || arr == ImplicitArrow -> + reduceBlock scope $ subst (b@>x', scope) block + TypeCon con xs -> Just $ TypeCon con $ xs ++ [x'] + _ -> Nothing + _ -> Nothing diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index d48e96b3b..d6090b062 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -907,61 +907,3 @@ instance Semigroup SolverEnv where instance Monoid SolverEnv where mempty = SolverEnv mempty mempty mappend = (<>) - --- === Inference-time reduction === - -reduceScoped :: MonadEmbed m => m Atom -> m (Maybe Atom) -reduceScoped m = do - block <- buildScoped m - scope <- getScope - return $ reduceBlock scope block - -reduceBlock :: Scope -> Block -> Maybe Atom -reduceBlock scope (Block decls result) = do - let localScope = foldMap boundVars decls - ans <- reduceExpr (scope <> localScope) result - [] <- return $ toList $ localScope `envIntersect` freeVars ans - return ans - --- XXX: This should handle all terms of type Type. Otherwise type equality checking --- will get broken. -reduceAtom :: Scope -> Atom -> Atom -reduceAtom scope x = case x of - Var (Name InferenceName _ _ :> _) -> x - Var v -> case snd (scope ! v) of - -- TODO: worry about effects! - LetBound PlainLet expr -> fromMaybe x $ reduceExpr scope expr - _ -> x - TC con -> TC $ fmap (reduceAtom scope) con - Pi (Abs b (arr, ty)) -> Pi $ Abs b (arr, reduceAtom (scope <> (fmap (,PiBound) $ binderAsEnv b)) ty) - TypeCon def params -> TypeCon (reduceDataDef def) (fmap rec params) - RecordTy (Ext tys ext) -> RecordTy $ Ext (fmap rec tys) ext - VariantTy (Ext tys ext) -> VariantTy $ Ext (fmap rec tys) ext - ACase _ _ _ -> error "Not implemented" - _ -> x - where - rec = reduceAtom scope - reduceNest s n = case n of - Empty -> Empty - -- Technically this should use a more concrete type than UnknownBinder, but anything else - -- than LetBound is indistinguishable for this reduction anyway. - Nest b rest -> Nest b' $ reduceNest (s <> (fmap (,UnknownBinder) $ binderAsEnv b)) rest - where b' = fmap (reduceAtom s) b - reduceDataDef (DataDef n bs cons) = - DataDef n (reduceNest scope bs) - (fmap (reduceDataConDef (scope <> (foldMap (fmap (,UnknownBinder) . binderAsEnv) bs))) cons) - reduceDataConDef s (DataConDef n bs) = DataConDef n $ reduceNest s bs - -reduceExpr :: Scope -> Expr -> Maybe Atom -reduceExpr scope expr = case expr of - Atom val -> return $ reduceAtom scope val - App f x -> do - let f' = reduceAtom scope f - let x' = reduceAtom scope x - -- TODO: Worry about variable capture. Should really carry a substitution. - case f' of - Lam (Abs b (PureArrow, block)) -> - reduceBlock scope $ subst (b@>x', scope) block - TypeCon con xs -> Just $ TypeCon con $ xs ++ [x'] - _ -> Nothing - _ -> Nothing diff --git a/src/lib/Parser.hs b/src/lib/Parser.hs index a9a4177eb..1d844dd3d 100644 --- a/src/lib/Parser.hs +++ b/src/lib/Parser.hs @@ -252,16 +252,57 @@ topLet = do let (ann', rhs') = addImplicitImplicitArgs pos ann rhs return $ ULet lAnn (p, ann') rhs' +-- Given a type signature, find all "implicit implicit args": lower-case +-- identifiers, not explicitly bound by Pi binders, not appearing on the left +-- hand side of an application. These identifiers are implicit in the sense +-- that they will be solved for by type inference, and also implicit in the +-- sense that the user did NOT explicitly annotate them as implicit. +findImplicitImplicitArgNames :: UType -> [Name] +findImplicitImplicitArgNames typ = filter isLowerCaseName $ envNames $ + freeUVars typ `envDiff` findFunctionVars typ + where + + isLowerCaseName :: Name -> Bool + isLowerCaseName (Name _ tag _) = isLower $ head $ tagToStr tag + isLowerCaseName _ = False + + -- Finds all variables used in function position, which should be pulled out. + findFunctionVars :: UType -> Env () + findFunctionVars (WithSrc _ typ') = case typ' of + UVar _ -> mempty + UPi b _ ty -> + findFunctionVars (binderAnn b) <> (findFunctionVars ty `envDiff` freeUVars b) + UApp _ (WithSrc _ (UVar (v:>_))) x -> (v @> ()) <> findFunctionVars x + UApp _ f x -> findFunctionVars f <> findFunctionVars x + ULam (p, ann) _ x -> + foldMap findFunctionVars ann <> (findFunctionVars x `envDiff` boundUVars p) + UDecl _ _ -> error "Unexpected let binding in type annotation" + UFor _ _ _ -> error "Unexpected for in type annotation" + UHole -> mempty + UTypeAnn v ty -> findFunctionVars v <> findFunctionVars ty + UTabCon _ -> error "Unexpected table in type annotation" + UIndexRange low high -> + foldMap findFunctionVars low <> foldMap findFunctionVars high + UPrimExpr prim -> foldMap findFunctionVars prim + UCase _ _ -> error "Unexpected case in type annotation" + URecord (Ext ulr _) -> foldMap findFunctionVars ulr + UVariant _ _ val -> findFunctionVars val + URecordTy (Ext ulr v) -> + foldMap findFunctionVars ulr <> foldMap findFunctionVars v + UVariantTy (Ext ulr v) -> + foldMap findFunctionVars ulr <> foldMap findFunctionVars v + UVariantLift _ val -> findFunctionVars val + UIntLit _ -> mempty + UCharLit _ -> mempty + UFloatLit _ -> mempty + addImplicitImplicitArgs :: SrcPos -> Maybe UType -> UExpr -> (Maybe UType, UExpr) addImplicitImplicitArgs _ Nothing e = (Nothing, e) addImplicitImplicitArgs sourcePos (Just typ) ex = let (ty', e') = foldr (addImplicitArg sourcePos) (typ, ex) implicitVars in (Just ty', e') where - implicitVars = filter isLowerCaseName $ envNames $ freeUVars typ - isLowerCaseName :: Name -> Bool - isLowerCaseName (Name _ tag _) = isLower $ head $ tagToStr tag - isLowerCaseName _ = False + implicitVars = findImplicitImplicitArgNames typ addImplicitArg :: SrcPos -> Name -> (UType, UExpr) -> (UType, UExpr) addImplicitArg pos v (ty, e) = diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index 6d65c44bd..765822f44 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -104,7 +104,10 @@ data Atom = Var Var -- single-constructor only for now | DataConRef DataDef [Atom] (Nest DataConRefBinding) | BoxedRef Binder Atom Block Atom -- binder, ptr, size, body - | ProjectElt (NE.NonEmpty Int) Var -- access a nested member of a binder + -- access a nested member of a binder + -- XXX: Variable name MUST be fully reduced, it cannot be a synonym! + -- This is because the variable name may also appear in the type. + | ProjectElt (NE.NonEmpty Int) Var deriving (Show, Generic) data Expr = App Atom Atom From 576dbf8fb6faa666bf6f7c6f56e70d8879e25242 Mon Sep 17 00:00:00 2001 From: Daniel Johnson Date: Fri, 20 Nov 2020 19:10:12 -0500 Subject: [PATCH 07/13] Allow patterns in pi types in the untyped syntax. UPi atoms now can take a pattern instead of a single binder. If the pattern is more complex than a single binder, that pattern is then bound while converting the UPi into a Pi atom. Note that the Pi representation is the same as it was; the returned type must still be reducible to an atom for the conversion to succeed. However, with the new projection atoms, unpacking of ADTs will still be reducible. The parser implementation for "def"-style functions has been modified to allow using patterns, which means it is now possible to bring values from an ADT into scope in the type for a "def"-style function. For now, the parser does not support patterns in explicit pi type expressions. This should be fairly straightforward to add but might require some care regarding ambiguity of the grammar (see #282). --- examples/adt-tests.dx | 13 ++++--------- src/lib/Embed.hs | 10 +++++++--- src/lib/Inference.hs | 16 ++++++++++++---- src/lib/PPrint.hs | 9 ++++++++- src/lib/Parser.hs | 44 ++++++++++++++++++++++--------------------- src/lib/Syntax.hs | 7 ++++--- 6 files changed, 58 insertions(+), 41 deletions(-) diff --git a/examples/adt-tests.dx b/examples/adt-tests.dx index 59073bb90..a320431b0 100644 --- a/examples/adt-tests.dx +++ b/examples/adt-tests.dx @@ -210,19 +210,14 @@ def catLists (xs:List a) (ys:List a) : List a = AsList _ (for i:(Fin n). ordinal i) > (AsList 5 [0, 1, 2, 3, 4]) -def listLength ((AsList length xs):List a) : Int = length --- TODO: not yet supported --- def listToTable1 ((AsList n xs): List a) : (Fin n)=>a = xs -def listToTable2 (l: List a) : (Fin (listLength l))=>a = - (AsList _ xs) = l - xs +def listToTable ((AsList n xs): List a) : (Fin n)=>a = xs -:t listToTable2 -> ((a:Type) ?-> (l:(List a)) -> (Fin (ProjectElt [0] l)) => a) +:t listToTable +> ((a:Type) ?-> (pat:(List a)) -> (Fin (ProjectElt [0] pat)) => a) :p l = AsList _ [1, 2, 3] - sum $ listToTable2 l + sum $ listToTable l > 6 diff --git a/src/lib/Embed.hs b/src/lib/Embed.hs index 33d9a0787..974c62674 100644 --- a/src/lib/Embed.hs +++ b/src/lib/Embed.hs @@ -139,12 +139,16 @@ freshNestedBindersRec substEnv (Nest b bs) = do buildPi :: (MonadError Err m, MonadEmbed m) => Binder -> (Atom -> m (Arrow, Type)) -> m Atom buildPi b f = do - (piTy, decls) <- scopedDecls $ do + scope <- getScope + (ans, decls) <- scopedDecls $ do v <- freshVarE PiBound b (arr, ans) <- f $ Var v return $ Pi $ makeAbs (Bind v) (arr, ans) - unless (null decls) $ throw CompilerErr $ "Unexpected decls: " ++ pprint decls - return piTy + let block = wrapDecls decls ans + case reduceBlock scope block of + Just piTy -> return piTy + Nothing -> throw CompilerErr $ + "Unexpected irreducible decls in pi type: " ++ pprint decls buildAbs :: MonadEmbed m => Binder -> (Atom -> m a) -> m (Abs Binder (Nest Decl, a)) buildAbs b f = do diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index d6090b062..b3b6c9976 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -152,12 +152,20 @@ checkOrInferRho (WithSrc pos expr) reqTy = do addEffects $ arrowEff arr' appVal <- emitZonked $ App fVal xVal' instantiateSigma appVal >>= matchRequirement - UPi b arr ty -> do + UPi (pat, kind) arr ty -> do -- TODO: make sure there's no effect if it's an implicit or table arrow -- TODO: check leaks - b' <- mapM checkUType b - piTy <- buildPi b' $ \x -> extendR (b@>x) $ - (,) <$> mapM checkUEff arr <*> checkUType ty + kind' <- checkUType kind + piTy <- case pat of + Just pat' -> withNameHint ("pat" :: Name) $ buildPi b $ \x -> + withBindPat pat' x $ (,) <$> mapM checkUEff arr <*> checkUType ty + where b = case pat' of + -- Note: must bind it by name if the user gives an explicit + -- name, since the binder name becomes part of the type. + WithSrc _ (UPatBinder (Bind (v:>()))) -> Bind (v:>kind') + _ -> Ignore kind' + Nothing -> buildPi (Ignore kind') $ const $ + (,) <$> mapM checkUEff arr <*> checkUType ty matchRequirement piTy UDecl decl body -> do env <- inferUDecl False decl diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index c61241e11..4484a05b0 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -561,7 +561,8 @@ instance PrettyPrec UExpr' where <+> nest 2 (pLowest body) where kw = case dir of Fwd -> "for" Rev -> "rof" - UPi a arr b -> atPrec LowestPrec $ p a <+> pretty arr <+> pLowest b + UPi binder arr ty -> atPrec LowestPrec $ + prettyUPiBinder binder <+> pretty arr <+> pLowest ty UDecl decl body -> atPrec LowestPrec $ align $ p decl <> hardline <> pLowest body UHole -> atPrec ArgPrec "_" @@ -623,6 +624,12 @@ prettyUBinder (pat, ann) = p pat <> annDoc where Just ty -> ":" <> pApp ty Nothing -> mempty +prettyUPiBinder :: UPiPatAnn -> Doc ann +prettyUPiBinder (pat, ann) = patDoc <> p ann where + patDoc = case pat of + Just pat' -> pApp pat' <> ":" + Nothing -> mempty + spaced :: (Foldable f, Pretty a) => f a -> Doc ann spaced xs = hsep $ map p $ toList xs diff --git a/src/lib/Parser.hs b/src/lib/Parser.hs index 1d844dd3d..1a1b622c1 100644 --- a/src/lib/Parser.hs +++ b/src/lib/Parser.hs @@ -270,8 +270,8 @@ findImplicitImplicitArgNames typ = filter isLowerCaseName $ envNames $ findFunctionVars :: UType -> Env () findFunctionVars (WithSrc _ typ') = case typ' of UVar _ -> mempty - UPi b _ ty -> - findFunctionVars (binderAnn b) <> (findFunctionVars ty `envDiff` freeUVars b) + UPi (p, ann) _ ty -> + findFunctionVars ann <> (findFunctionVars ty `envDiff` boundUVars p) UApp _ (WithSrc _ (UVar (v:>_))) x -> (v @> ()) <> findFunctionVars x UApp _ f x -> findFunctionVars f <> findFunctionVars x ULam (p, ann) _ x -> @@ -306,9 +306,10 @@ addImplicitImplicitArgs sourcePos (Just typ) ex = addImplicitArg :: SrcPos -> Name -> (UType, UExpr) -> (UType, UExpr) addImplicitArg pos v (ty, e) = - ( WithSrc (Just pos) $ UPi (Bind (v:>uTyKind)) ImplicitArrow ty - , WithSrc (Just pos) $ ULam (WithSrc (Just pos) (UPatBinder (Bind (v:>()))), Just uTyKind) ImplicitArrow e) + ( WithSrc (Just pos) $ UPi (Just uPat, uTyKind) ImplicitArrow ty + , WithSrc (Just pos) $ ULam (uPat, Just uTyKind) ImplicitArrow e) where + uPat = WithSrc (Just pos) $ UPatBinder $ Bind $ v:>() k = if v == mkName "eff" then EffectRowKind else TypeKind uTyKind = WithSrc (Just pos) $ UPrimExpr $ TCExpr k @@ -342,10 +343,10 @@ interfaceDef = do mkOneFunDef (pos, typeVarNames, interfaceName) (fLabel, fType) = ULet PlainLet (p, ann') rhs' where - uAnnBinder = Bind $ - instanceName :> (foldl mkUApp (var interfaceName) typeVarNames) + uAnnPat = ( Just $ WithSrc (Just pos) $ UPatBinder $ Bind $ instanceName :> () + , foldl mkUApp (var interfaceName) typeVarNames) p = patb fLabel - ann = Just $ ns $ UPi uAnnBinder ClassArrow fType + ann = Just $ ns $ UPi uAnnPat ClassArrow fType mkUApp func typeVarName = ns $ UApp (PlainArrow ()) func (var typeVarName) @@ -454,8 +455,7 @@ funDefLet = label "function definition" $ mayBreak $ do v <- letPat bs <- some arg (eff, ty) <- label "result type annotation" $ annot effectiveType - let piBinders = flip map bs $ \(p, ann, arr) -> (patAsBinder p ann, arr) - let funTy = buildPiType piBinders eff ty + let funTy = buildPiType bs eff ty let letBinder = (v, Just funTy) let lamBinders = flip map bs $ \(p,_, arr) -> ((p,Nothing), arr) return $ \body -> ULet PlainLet letBinder (buildLam lamBinders body) @@ -466,19 +466,15 @@ funDefLet = label "function definition" $ mayBreak $ do arr <- arrow (return ()) <|> return (PlainArrow ()) return (p, ty, arr) -patAsBinder :: UPat -> UType -> UAnnBinder -patAsBinder (WithSrc _ (UPatBinder (Bind (v:>())))) ty = Bind $ v:>ty -patAsBinder _ ty = Ignore ty - nameAsPat :: Parser Name -> Parser UPat nameAsPat p = withSrc $ (UPatBinder . Bind . (:>())) <$> p -buildPiType :: [(UAnnBinder, UArrow)] -> EffectRow -> UType -> UType +buildPiType :: [(UPat, UType, UArrow)] -> EffectRow -> UType -> UType buildPiType [] _ _ = error "shouldn't be possible" -buildPiType ((b,arr):bs) eff ty = WithSrc pos $ case bs of - [] -> UPi b (fmap (const eff ) arr) ty - _ -> UPi b (fmap (const Pure) arr) $ buildPiType bs eff ty - where WithSrc pos _ = binderAnn b +buildPiType ((p, patTy, arr):bs) eff resTy = WithSrc pos $ case bs of + [] -> UPi (Just p, patTy) (fmap (const eff ) arr) resTy + _ -> UPi (Just p, patTy) (fmap (const Pure) arr) $ buildPiType bs eff resTy + where WithSrc pos _ = patTy effectiveType :: Parser (EffectRow, UType) effectiveType = (,) <$> effects <*> uType @@ -565,7 +561,13 @@ uStatement = withPos $ liftM Left decl -- TODO: put the `try` only around the `x:` not the annotation itself uPiType :: Parser UExpr -uPiType = withSrc $ UPi <$> annBinder <*> arrow effects <*> uType +uPiType = withSrc $ UPi <$> piBinderPat <*> arrow effects <*> uType + where piBinderPat = do + b <- annBinder + return $ case b of + Bind (n:>a@(WithSrc pos _)) -> + (Just $ WithSrc pos $ UPatBinder $ Bind $ n:>(), a) + Ignore a -> (Nothing, a) annBinder :: Parser UAnnBinder annBinder = try $ namedBinder <|> anonBinder @@ -904,10 +906,10 @@ infixArrow :: Parser (UType -> UType -> UType) infixArrow = do notFollowedBy (sym "=>") -- table arrows have special fixity (arr, pos) <- withPos $ arrow effects - return $ \a b -> WithSrc (Just pos) $ UPi (Ignore a) arr b + return $ \a b -> WithSrc (Just pos) $ UPi (Nothing, a) arr b mkArrow :: Arrow -> UExpr -> UExpr -> UExpr -mkArrow arr a b = joinSrc a b $ UPi (Ignore a) arr b +mkArrow arr a b = joinSrc a b $ UPi (Nothing, a) arr b withSrc :: Parser a -> Parser (WithSrc a) withSrc p = do diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index 765822f44..076c66351 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -40,7 +40,7 @@ module Syntax ( freeVars, freeUVars, Subst, HasVars, BindsVars, Ptr, PtrType, AddressSpace (..), PtrOrigin (..), showPrimName, strToPrimName, primNameToStr, monMapSingle, monMapLookup, Direction (..), Limit (..), - UExpr, UExpr' (..), UType, UPatAnn, UAnnBinder, UVar, + UExpr, UExpr' (..), UType, UPatAnn, UPiPatAnn, UAnnBinder, UVar, UPat, UPat' (..), UModule (..), UDecl (..), UArrow, arrowEff, DataDef (..), DataConDef (..), UConDef (..), Nest (..), toNest, subst, deShadow, scopelessSubst, absArgType, applyAbs, makeAbs, @@ -213,7 +213,7 @@ prefixExtLabeledItems items (Ext items' rest) = Ext (items <> items') rest type UExpr = WithSrc UExpr' data UExpr' = UVar UVar | ULam UPatAnn UArrow UExpr - | UPi UAnnBinder Arrow UType + | UPi UPiPatAnn Arrow UType | UApp UArrow UExpr UExpr | UDecl UDecl UExpr | UFor Direction UPatAnn UExpr @@ -244,6 +244,7 @@ type UVar = VarP () type UBinder = BinderP () type UPatAnn = (UPat, Maybe UType) +type UPiPatAnn = (Maybe UPat, UType) type UAnnBinder = BinderP UType data UAlt = UAlt UPat UExpr deriving (Show, Generic) @@ -710,7 +711,7 @@ instance HasUVars UExpr' where freeUVars expr = case expr of UVar v -> v @>() ULam (pat,ty) _ body -> freeUVars ty <> freeUVars (Abs pat body) - UPi b arr ty -> freeUVars $ Abs b (arr, ty) + UPi (pat,kind) arr ty -> freeUVars kind <> freeUVars (Abs pat (arr, ty)) -- TODO: maybe distinguish table arrow application -- (otherwise `x.i` and `x i` are the same) UApp _ f x -> freeUVars f <> freeUVars x From e9b00cfed21de1a91a3b223a0c30a78d8cfb9c55 Mon Sep 17 00:00:00 2001 From: Daniel Johnson Date: Tue, 1 Dec 2020 18:35:03 -0500 Subject: [PATCH 08/13] Replace Fst/Snd with ProjectElt, get preliminary autodiff working. Support for autodiff is partial, because we currently don't seem to have a way to unpack a reference to a record or an ADT (we do have `FstRef` and `SndRef` for pairs, though). However, this is enough to get the tests to pass. --- prelude.dx | 4 ++-- src/lib/Autodiff.hs | 26 ++++++++++++++++++-------- src/lib/Embed.hs | 23 +++++++++-------------- src/lib/Imp.hs | 2 -- src/lib/Interpreter.hs | 2 -- src/lib/Simplify.hs | 2 -- src/lib/Syntax.hs | 16 +++------------- src/lib/Type.hs | 16 ++++++++++++---- 8 files changed, 44 insertions(+), 47 deletions(-) diff --git a/prelude.dx b/prelude.dx index dd4fcecee..cab138c3a 100644 --- a/prelude.dx +++ b/prelude.dx @@ -129,8 +129,8 @@ def fdiv (d:Fractional a) ?=> : a -> a -> a = case d of MkFractional div -> div def (&) (a:Type) (b:Type) : Type = %PairType a b def (,) (x:a) (y:b) : (a & b) = %pair x y -def fst (p: (a & b)) : a = %fst p -def snd (p: (a & b)) : b = %snd p +def fst ((x, _): (a & b)) : a = x +def snd ((_, y): (a & b)) : b = y def swap (p:(a&b)) : (b&a) = (snd p, fst p) def (<<<) (f: b -> c) (g: a -> b) : a -> c = \x. f (g x) diff --git a/src/lib/Autodiff.hs b/src/lib/Autodiff.hs index 02a69a8d2..20f365b78 100644 --- a/src/lib/Autodiff.hs +++ b/src/lib/Autodiff.hs @@ -13,6 +13,7 @@ import Control.Applicative import Control.Monad import Control.Monad.Reader import Control.Monad.State.Strict +import qualified Data.List.NonEmpty as NE import Data.Maybe import Data.Foldable import Data.Traversable @@ -24,7 +25,7 @@ import Syntax import PPrint import Embed import Cat -import Util (bindM2, zipWithT, enumerate) +import Util (bindM2, zipWithT, enumerate, restructure) import GHC.Stack -- === linearization === @@ -132,8 +133,6 @@ linearizeOp op = case op of MTell x -> liftA MTell $ la x MGet -> pure MGet MPut x -> liftA MPut $ la x) `bindLin` emitOp - Fst x -> (Fst <$> la x) `bindLin` emitOp - Snd x -> (Snd <$> la x) `bindLin` emitOp IndexRef ref i -> (IndexRef <$> la ref <*> pure i) `bindLin` emitOp FstRef ref -> (FstRef <$> la ref ) `bindLin` emitOp SndRef ref -> (SndRef <$> la ref ) `bindLin` emitOp @@ -341,7 +340,7 @@ linearizeAtom atom = case atom of Pi _ -> emitWithZero TC _ -> emitWithZero Eff _ -> emitWithZero - ProjectElt _ _ -> error "TODO: linearize projections" + ProjectElt idxs v -> getProjection (toList idxs) <$> linearizeAtom (Var v) -- Those should be gone after simplification Lam _ -> error "Unexpected non-table lambda" ACase _ _ _ -> error "Unexpected ACase" @@ -396,8 +395,9 @@ addTangent x y = case getType x of pack :: Type -> [Atom] -> Atom pack ty elems = case ty of + PairTy _ _ -> let [x, y] = elems in PairVal x y TypeCon def params -> DataCon def params 0 elems - RecordTy (NoExt types) -> Record $ snd $ mapAccumL (\(h:t) _ -> (t, h)) elems types + RecordTy (NoExt types) -> Record $ restructure elems types _ -> error $ "Unexpected Unpack argument type: " ++ pprint ty isTrivialForAD :: Expr -> Bool @@ -582,8 +582,6 @@ transposeExpr expr ct = case expr of transposeOp :: Op -> Atom -> TransposeM () transposeOp op ct = case op of - Fst x -> flip emitCTToRef ct =<< (traverse $ emitOp . FstRef) =<< linAtomRef x - Snd x -> flip emitCTToRef ct =<< (traverse $ emitOp . SndRef) =<< linAtomRef x ScalarUnOp FNeg x -> transposeAtom x =<< neg ct ScalarUnOp _ _ -> notLinear ScalarBinOp FAdd x y -> transposeAtom x ct >> transposeAtom y ct @@ -641,6 +639,16 @@ linAtomRef (Var x) = do case envLookup refs x of Just ref -> return ref _ -> error $ "Not a linear var: " ++ pprint (Var x) +linAtomRef (ProjectElt (i NE.:| is) x) = do + let subproj = getProjection is (Var x) + case getType subproj of + PairTy _ _ -> do + ref <- linAtomRef subproj + (traverse $ emitOp . getter) ref + where getter = case i of 0 -> FstRef + 1 -> SndRef + _ -> error "bad pair projection" + ty -> error $ "Projecting references not implemented for type " <> pprint ty linAtomRef a = error $ "Not a linear var: " ++ pprint a transposeHof :: Hof -> Atom -> TransposeM () @@ -700,7 +708,9 @@ transposeAtom atom ct = case atom of ACase _ _ _ -> error "Unexpected ACase" DataConRef _ _ _ -> error "Unexpected ref" BoxedRef _ _ _ _ -> error "Unexpected ref" - ProjectElt _ _ -> error "TODO: projection transpose types" + ProjectElt _ v -> do + lin <- isLin $ Var v + when lin $ flip emitCTToRef ct =<< linAtomRef atom where notTangent = error $ "Not a tangent atom: " ++ pprint atom transposeCon :: Con -> Atom -> TransposeM () diff --git a/src/lib/Embed.hs b/src/lib/Embed.hs index 974c62674..da7149aeb 100644 --- a/src/lib/Embed.hs +++ b/src/lib/Embed.hs @@ -285,14 +285,17 @@ ieq :: MonadEmbed m => Atom -> Atom -> m Atom ieq x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntCmpOp (==) x y ieq x y = emitOp $ ScalarBinOp (ICmp Equal) x y --- TODO: make pairs also use projection atoms? +fromPair :: MonadEmbed m => Atom -> m (Atom, Atom) +fromPair pair = do + scope <- getScope + let pair' = reduceAtom scope pair + return (getProjection [0] pair', getProjection [1] pair') + getFst :: MonadEmbed m => Atom -> m Atom -getFst (PairVal x _) = return x -getFst p = emitOp $ Fst p +getFst p = fst <$> fromPair p getSnd :: MonadEmbed m => Atom -> m Atom -getSnd (PairVal _ y) = return y -getSnd p = emitOp $ Snd p +getSnd p = snd <$> fromPair p getFstRef :: MonadEmbed m => Atom -> m Atom getFstRef r = emitOp $ FstRef r @@ -305,12 +308,7 @@ getSndRef r = emitOp $ SndRef r getUnpacked :: MonadEmbed m => Atom -> m [Atom] getUnpacked atom = do scope <- getScope - let len = case getType atom of - TypeCon def params -> - let [DataConDef _ bs] = applyDataDefParams def params - in length bs - RecordTy (NoExt types) -> length types - ty -> error $ "Unpacking a type that doesn't support unpacking: " ++ pprint ty + let len = projectLength $ getType atom atom' = reduceAtom scope atom res = map (\i -> getProjection [i] atom') [0..(len-1)] return res @@ -337,9 +335,6 @@ ptrOffset x i = emitOp $ PtrOffset x i ptrLoad :: MonadEmbed m => Atom -> m Atom ptrLoad x = emitOp $ PtrLoad x -fromPair :: MonadEmbed m => Atom -> m (Atom, Atom) -fromPair pair = (,) <$> getFst pair <*> getSnd pair - unpackConsList :: MonadEmbed m => Atom -> m [Atom] unpackConsList xs = case getType xs of UnitTy -> return [] diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index 12417e2da..ef2471aab 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -191,8 +191,6 @@ toImpOp (maybeDest, op) = case op of ithDest <- destGet dest =<< intToIndex (binderType b) (IIdxRepVal i) copyAtom ithDest row destToAtom dest - Fst ~(PairVal x _) -> returnVal x - Snd ~(PairVal _ y) -> returnVal y PrimEffect refDest m -> do case m of MAsk -> returnVal =<< destToAtom refDest diff --git a/src/lib/Interpreter.hs b/src/lib/Interpreter.hs index 4dc8eefee..8a71fc9ae 100644 --- a/src/lib/Interpreter.hs +++ b/src/lib/Interpreter.hs @@ -107,8 +107,6 @@ evalOp expr = case expr of Con (IntRangeVal _ _ i) -> return i Con (IndexRangeVal _ _ _ i) -> return i _ -> evalEmbed (indexToIntE idxArg) - Fst p -> return x where (PairVal x _) = p - Snd p -> return y where (PairVal _ y) = p _ -> error $ "Not implemented: " ++ pprint expr -- We can use this when we know we won't be dereferencing pointers. A better diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index d82d8484c..beddc38be 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -367,8 +367,6 @@ simplifyExpr expr = case expr of -- TODO: come up with a coherent strategy for ordering these various reductions simplifyOp :: Op -> SimplifyM Atom simplifyOp op = case op of - Fst (PairVal x _) -> return x - Snd (PairVal _ y) -> return y RecordCons left right -> case getType right of RecordTy (NoExt rightTys) -> do -- Unpack, then repack with new arguments (possibly in the middle). diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index 076c66351..a3302ab96 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -312,9 +312,7 @@ data PrimCon e = deriving (Show, Eq, Generic, Functor, Foldable, Traversable) data PrimOp e = - Fst e - | Snd e - | TabCon e [e] -- table type elements + TabCon e [e] -- table type elements | ScalarBinOp BinOp e e | ScalarUnOp UnOp e | Select e e e -- predicate, val-if-true, val-if-false @@ -1099,7 +1097,7 @@ instance Subst Atom where where Abs args' () = subst env $ Abs args () BoxedRef b ptr size body -> BoxedRef b' (subst env ptr) (subst env size) body' where Abs b' body' = subst env $ Abs b body - ProjectElt idxs v -> substProjectElt (fst env) idxs v + ProjectElt idxs v -> getProjection (toList idxs) $ substVar env v instance HasVars Module where freeVars (Module _ decls bindings) = freeVars $ Abs decls bindings @@ -1172,19 +1170,13 @@ substExtLabeledItemsTail env (Just v) = case envLookup env (v:>()) of Just (LabeledRow row) -> row _ -> error "Not a valid labeled row substitution" -substProjectElt :: SubstEnv -> NE.NonEmpty Int -> Var -> Atom -substProjectElt env idxs v = case envLookup env v of - Nothing -> ProjectElt idxs v - Just (Var v') -> ProjectElt idxs v' - Just atom -> getProjection (toList idxs) atom - getProjection :: [Int] -> Atom -> Atom getProjection [] a = a getProjection (i:is) a = case getProjection is a of Var v -> ProjectElt (NE.fromList [i]) v ProjectElt idxs' a' -> ProjectElt (NE.cons i idxs') a' DataCon _ _ _ xs -> xs !! i - Record items -> (toList items) !! i + Record items -> toList items !! i PairVal x _ | i == 0 -> x PairVal _ y | i == 1 -> y _ -> error $ "Not a valid projection: " ++ show i ++ " of " ++ show a @@ -1525,8 +1517,6 @@ builtinNames = M.fromList , ("LabeledRowKind", TCExpr $ LabeledRowKindTC) , ("IndexSlice", TCExpr $ IndexSlice () ()) , ("pair", ConExpr $ PairCon () ()) - , ("fst", OpExpr $ Fst ()) - , ("snd", OpExpr $ Snd ()) , ("fstRef", OpExpr $ FstRef ()) , ("sndRef", OpExpr $ SndRef ()) -- TODO: Lift vectors to constructors diff --git a/src/lib/Type.hs b/src/lib/Type.hs index 87377e158..b2fd0a6ed 100644 --- a/src/lib/Type.hs +++ b/src/lib/Type.hs @@ -13,7 +13,7 @@ module Type ( getType, checkType, HasType (..), Checkable (..), litType, isPure, functionEffs, exprEffs, blockEffs, extendEffect, isData, checkBinOp, checkUnOp, checkIntBaseType, checkFloatBaseType, withBinder, isDependent, - indexSetConcreteSize, checkNoShadow, traceCheckM, traceCheck) where + indexSetConcreteSize, checkNoShadow, traceCheckM, traceCheck, projectLength) where import Prelude hiding (pi) import Control.Monad @@ -183,7 +183,8 @@ instance HasType Atom where RecordTy _ -> throw CompilerErr "Can't project partially-known records" PairTy x _ | i == 0 -> return x PairTy _ y | i == 1 -> return y - _ -> throw TypeErr "Only single-member ADTs and record types can be projected" + Var _ -> throw CompilerErr $ "Tried to project value of unreduced type " <> pprint ty + _ -> throw TypeErr $ "Only single-member ADTs and record types can be projected. Got " <> pprint ty checkDataConRefBindings :: Nest Binder -> Nest DataConRefBinding -> TypeM () @@ -661,8 +662,6 @@ typeCheckOp op = case op of mapM_ (uncurry (|:)) $ zip xs (fmap (snd . applyAbs a) idxs) assertEq (length idxs) (length xs) "Index set size mismatch" return ty - Fst p -> do { PairTy x _ <- typeCheck p; return x} - Snd p -> do { PairTy _ y <- typeCheck p; return y} ScalarBinOp binop x y -> bindM2 (checkBinOp binop) (typeCheck x) (typeCheck y) ScalarUnOp unop x -> checkUnOp unop =<< typeCheck x Select p x y -> do @@ -971,3 +970,12 @@ isData :: Type -> Bool isData ty = case checkData ty of Left _ -> False Right _ -> True + +projectLength :: Type -> Int +projectLength ty = case ty of + TypeCon def params -> + let [DataConDef _ bs] = applyDataDefParams def params + in length bs + RecordTy (NoExt types) -> length types + PairTy _ _ -> 2 + _ -> error $ "Projecting a type that doesn't support projecting: " ++ pprint ty From 43a05b1b85db617a776001fadc68a6383b4c5e8f Mon Sep 17 00:00:00 2001 From: Daniel Johnson Date: Tue, 1 Dec 2020 19:54:02 -0500 Subject: [PATCH 09/13] Refactoring and adding tests; also fix a pattern parse bug. Refactors autodiff to not use `isUnpack` since `Unpack` decls are no longer part of the core IR. Adds tests for more complicated uses of dependent projections and deeply nestesd projections. Also fixes a bug where (,) was not respecting precedence correctly, by using `mayPair`/`mayNotPair` for patterns as well as expressions. --- examples/adt-tests.dx | 44 +++++++++++++++++++++++++++++++++++++++++++ src/lib/Autodiff.hs | 37 +++++++++++++++++------------------- src/lib/Embed.hs | 5 +++-- src/lib/Parser.hs | 13 ++++++++++--- src/lib/Type.hs | 1 - 5 files changed, 74 insertions(+), 26 deletions(-) diff --git a/examples/adt-tests.dx b/examples/adt-tests.dx index a320431b0..af1174f93 100644 --- a/examples/adt-tests.dx +++ b/examples/adt-tests.dx @@ -221,3 +221,47 @@ def listToTable ((AsList n xs): List a) : (Fin n)=>a = xs l = AsList _ [1, 2, 3] sum $ listToTable l > 6 + +def listLength ((AsList length xs):List a) : Int = length +def listToTable2 (l: List a) : (Fin (listLength l))=>a = + (AsList _ xs) = l + xs + +:t listToTable2 +> ((a:Type) ?-> (l:(List a)) -> (Fin (ProjectElt [0] l)) => a) + +:p + l = AsList _ [1, 2, 3] + sum $ listToTable2 l +> 6 + +data Graph a:Type = + MkGraph n:Type nodes:(n=>a) m:Type edges:(m=>(n & n)) + +def graphToAdjacencyMatrix ((MkGraph n nodes m edges):Graph a) : n=>n=>Bool = + init = for i j. False + snd $ withState init \mRef. + for i:m. + (from, to) = edges.i + mRef!from!to := True + +:t graphToAdjacencyMatrix +> ((a:Type) +> ?-> (pat:(Graph a)) -> (ProjectElt [0] pat) => (ProjectElt [0] pat) => Bool) + +:p + g : Graph Int = MkGraph (Fin 3) [5, 6, 7] (Fin 4) [(0@_, 1@_), (0@_, 2@_), (2@_, 0@_), (1@_, 1@_)] + graphToAdjacencyMatrix g +> [[False, True, True], [False, True, False], [True, False, False]] + + +def deepUnpack (x:MyPair a (MyPair (MyIntish & b) c)) : Int = + (MkMyPair _ (MkMyPair (MkIntish y, _) _)) = x + y + +:t deepUnpack +> ((a:Type) +> ?-> (b:Type) ?-> (c:Type) ?-> (MyPair a (MyPair (MyIntish & b) c)) -> Int32) + +:p deepUnpack (MkMyPair 3 (MkMyPair (MkIntish 4, 5) 6)) +> 4 diff --git a/src/lib/Autodiff.hs b/src/lib/Autodiff.hs index 20f365b78..14cd54909 100644 --- a/src/lib/Autodiff.hs +++ b/src/lib/Autodiff.hs @@ -54,13 +54,11 @@ linearizeBlock :: SubstEnv -> Block -> LinA Atom linearizeBlock env (Block decls result) = case decls of Empty -> linearizeExpr env result Nest decl rest -> case decl of - (Let _ b expr) -> linearizeBinding False [b] expr + (Let _ b expr) -> linearizeBinding b expr where body = Block rest result - takeWhere l m = fmap snd $ filter fst $ zip m l - -- TODO: refactor this to not have isUnpack - linearizeBinding :: Bool -> [Binder] -> Expr -> LinA Atom - linearizeBinding isUnpack bs expr = LinA $ do + linearizeBinding :: Binder -> Expr -> LinA Atom + linearizeBinding b expr = LinA $ do -- Don't linearize expressions with no free active variables. -- Technically, we could do this and later run the code through a simplification -- pass that would eliminate a bunch of multiplications with zeros, but this seems @@ -70,8 +68,7 @@ linearizeBlock env (Block decls result) = case decls of if any id varsAreActive then do (x, boundLin) <- runLinA $ linearizeExpr env expr - xs <- if isUnpack then emitUnpack (Atom x) else (:[]) <$> emit (Atom x) - let vs = fmap (\(Var v) -> v) xs + ~(Var v) <- emit $ Atom x -- NB: This can still overestimate the set of active variables (e.g. -- when multiple values are returned from a case statement). -- Don't mark variables with trivial tangent types as active. This lets us avoid @@ -81,20 +78,20 @@ linearizeBlock env (Block decls result) = case decls of -- variables, but I don't think that we want to define them to have tangents. -- We should delete this check, but to do that we would have to support differentiation -- through case statements with active scrutinees. - let nontrivialVsMask = [not $ isSingletonType $ tangentType $ varType v | v <- vs] - let nontrivialVs = vs `takeWhere` nontrivialVsMask - (ans, bodyLin) <- extendWrt nontrivialVs [] $ runLinA $ linearizeBlock (env <> newEnv bs xs) body + let vIsTrivial = isSingletonType $ tangentType $ varType v + let nontrivialVs = if vIsTrivial then [] else [v] + (ans, bodyLin) <- extendWrt nontrivialVs [] $ runLinA $ + linearizeBlock (env <> b @> Var v) body return (ans, do t <- boundLin - ts <- if isUnpack then emitUnpack (Atom t) else return [t] -- Tangent environment needs to be synced between the primal and tangent -- monads (tangentFunAsLambda and applyLinToTangents need that). - let nontrivialTs = ts `takeWhere` nontrivialVsMask + let nontrivialTs = if vIsTrivial then [] else [t] extendTangentEnv (newEnv nontrivialVs nontrivialTs) [] bodyLin) else do expr' <- substEmbed env expr - xs <- if isUnpack then emitUnpack expr' else (:[]) <$> emit expr' - runLinA $ linearizeBlock (env <> newEnv bs xs) body + x <- emit expr' + runLinA $ linearizeBlock (env <> b @> x) body linearizeExpr :: SubstEnv -> Expr -> LinA Atom linearizeExpr env expr = case expr of @@ -537,19 +534,19 @@ transposeBlock :: Block -> Atom -> TransposeM () transposeBlock (Block decls result) ct = case decls of Empty -> transposeExpr result ct Nest decl rest -> case decl of - (Let _ b expr) -> transposeBinding False [b] expr + (Let _ b expr) -> transposeBinding b expr where body = Block rest result - transposeBinding isUnpack bs expr = do + transposeBinding b expr = do isLinearExpr <- (||) <$> isLinEff (exprEffs expr) <*> isLin expr if isLinearExpr then do - cts <- withLinVars bs $ transposeBlock body ct - transposeExpr expr $ if isUnpack then pack (getType expr) cts else head cts + cts <- withLinVars [b] $ transposeBlock body ct + transposeExpr expr $ head cts else do expr' <- substNonlin expr - xs <- if isUnpack then emitUnpack expr' else (:[]) <$> emit expr' - localNonlinSubst (newEnv bs xs) $ transposeBlock body ct + x <- emit expr' + localNonlinSubst (b @> x) $ transposeBlock body ct withLinVars :: [Binder] -> TransposeM () -> TransposeM [Atom] withLinVars [] m = m >> return [] diff --git a/src/lib/Embed.hs b/src/lib/Embed.hs index da7149aeb..d85e19884 100644 --- a/src/lib/Embed.hs +++ b/src/lib/Embed.hs @@ -303,8 +303,9 @@ getFstRef r = emitOp $ FstRef r getSndRef :: MonadEmbed m => Atom -> m Atom getSndRef r = emitOp $ SndRef r --- TODO: refactor? --- TODO: is this the best place for the reduction? +-- XXX: getUnpacked must reduce its argument to enforce the invariant that +-- ProjectElt atoms are always fully reduced (to avoid type errors between two +-- equivalent types spelled differently). getUnpacked :: MonadEmbed m => Atom -> m [Atom] getUnpacked atom = do scope <- getScope diff --git a/src/lib/Parser.hs b/src/lib/Parser.hs index 1a1b622c1..3cfe2f016 100644 --- a/src/lib/Parser.hs +++ b/src/lib/Parser.hs @@ -603,12 +603,12 @@ onePerLine p = liftM (:[]) p <|> (withIndent $ mayNotBreak $ p `sepBy1` try nextLine) pat :: Parser UPat -pat = makeExprParser leafPat patOps +pat = mayNotPair $ makeExprParser leafPat patOps leafPat :: Parser UPat leafPat = (withSrc (symbol "()" $> UPatUnit)) - <|> parens pat + <|> parens (mayPair $ makeExprParser leafPat patOps) <|> (withSrc $ (UPatBinder <$> ( (Bind <$> (:>()) <$> lowerName) <|> (underscore $> Ignore ()))) @@ -624,7 +624,14 @@ leafPat = -- TODO: add user-defined patterns patOps :: [[Operator Parser UPat]] -patOps = [[InfixR $ sym "," $> \x y -> joinSrc x y $ UPatPair x y]] +patOps = [[InfixR patPairOp]] + +patPairOp :: Parser (UPat -> UPat -> UPat) +patPairOp = do + allowed <- asks canPair + if allowed + then sym "," $> \x y -> joinSrc x y $ UPatPair x y + else fail "pair pattern not allowed outside parentheses" annot :: Parser a -> Parser a annot p = label "type annotation" $ sym ":" >> p diff --git a/src/lib/Type.hs b/src/lib/Type.hs index b2fd0a6ed..acee16bc8 100644 --- a/src/lib/Type.hs +++ b/src/lib/Type.hs @@ -174,7 +174,6 @@ instance HasType Atom where -- use projections. let go :: Int -> Nest Binder -> Type go j (Nest b _) | i == j = binderAnn b - -- TODO: is scopelessSubst correct here? go j (Nest b rest) = go (j+1) (scopelessSubst (b @> proj) rest) where proj = ProjectElt (j NE.:| is) v go _ _ = error "Bad projection index" From 0acecb3fc567ca8a6e10dc77c86ec9892caa525d Mon Sep 17 00:00:00 2001 From: Daniel Johnson Date: Tue, 1 Dec 2020 19:58:14 -0500 Subject: [PATCH 10/13] Modify pretty printing of ProjectElt atoms (shown as %projectElt) Since %projectElt can show up in the type of ordinary expressions, the prefix of % should make it clear that this isn't a user-defined ADT but instead an internal type. --- examples/adt-tests.dx | 6 +++--- src/lib/PPrint.hs | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/adt-tests.dx b/examples/adt-tests.dx index af1174f93..c5e203b9d 100644 --- a/examples/adt-tests.dx +++ b/examples/adt-tests.dx @@ -215,7 +215,7 @@ def catLists (xs:List a) (ys:List a) : List a = def listToTable ((AsList n xs): List a) : (Fin n)=>a = xs :t listToTable -> ((a:Type) ?-> (pat:(List a)) -> (Fin (ProjectElt [0] pat)) => a) +> ((a:Type) ?-> (pat:(List a)) -> (Fin (%projectElt [0] pat)) => a) :p l = AsList _ [1, 2, 3] @@ -228,7 +228,7 @@ def listToTable2 (l: List a) : (Fin (listLength l))=>a = xs :t listToTable2 -> ((a:Type) ?-> (l:(List a)) -> (Fin (ProjectElt [0] l)) => a) +> ((a:Type) ?-> (l:(List a)) -> (Fin (%projectElt [0] l)) => a) :p l = AsList _ [1, 2, 3] @@ -247,7 +247,7 @@ def graphToAdjacencyMatrix ((MkGraph n nodes m edges):Graph a) : n=>n=>Bool = :t graphToAdjacencyMatrix > ((a:Type) -> ?-> (pat:(Graph a)) -> (ProjectElt [0] pat) => (ProjectElt [0] pat) => Bool) +> ?-> (pat:(Graph a)) -> (%projectElt [0] pat) => (%projectElt [0] pat) => Bool) :p g : Graph Int = MkGraph (Fin 3) [5, 6, 7] (Fin 4) [(0@_, 1@_), (0@_, 2@_), (2@_, 0@_), (1@_, 1@_)] diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index 4484a05b0..5385e226f 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -370,7 +370,7 @@ instance PrettyPrec Atom where BoxedRef b ptr size body -> atPrec AppPrec $ "Box" <+> p b <+> "<-" <+> p ptr <+> "[" <> p size <> "]" <+> hardline <> "in" <+> p body ProjectElt idxs (x:>_) -> atPrec AppPrec $ - "ProjectElt" <+> p idxs <+> p x + "%projectElt" <+> p idxs <+> p x instance Pretty DataConRefBinding where pretty = prettyFromPrettyPrec instance PrettyPrec DataConRefBinding where From 0023cbc5a6be637accd304a1fc47a1333c8339f2 Mon Sep 17 00:00:00 2001 From: Daniel Johnson Date: Mon, 14 Dec 2020 18:29:58 -0500 Subject: [PATCH 11/13] Pretty-print projections as immediately-applied lambdas. In the IR, projections are represented with simple integers. But since projections can show up in user expressions, we rewrite them during printing to instead be of the form `(\pat. elt) x` where `pat` is a pattern that does the unpacking. Also changes the way patterns are pretty-printed to remove some redundant visual noise. --- examples/adt-tests.dx | 30 ++++++++++++++++------- examples/isomorphisms.dx | 51 ++++++++++++++++++++-------------------- src/lib/PPrint.hs | 47 ++++++++++++++++++++++++++++++++---- 3 files changed, 90 insertions(+), 38 deletions(-) diff --git a/examples/adt-tests.dx b/examples/adt-tests.dx index 7f795df3d..ca64b7bba 100644 --- a/examples/adt-tests.dx +++ b/examples/adt-tests.dx @@ -215,7 +215,7 @@ def catLists (xs:List a) (ys:List a) : List a = def listToTable ((AsList n xs): List a) : (Fin n)=>a = xs :t listToTable -> ((a:Type) ?-> (pat:(List a)) -> (Fin (%projectElt [0] pat)) => a) +> ((a:Type) ?-> (pat:(List a)) -> (Fin ((\((AsList n _)). n) pat)) => a) :p l = AsList _ [1, 2, 3] @@ -227,7 +227,7 @@ def listToTable2 (l: List a) : (Fin (listLength l))=>a = xs :t listToTable2 -> ((a:Type) ?-> (l:(List a)) -> (Fin (%projectElt [0] l)) => a) +> ((a:Type) ?-> (l:(List a)) -> (Fin ((\((AsList n _)). n) l)) => a) :p l = AsList _ [1, 2, 3] @@ -246,21 +246,35 @@ def graphToAdjacencyMatrix ((MkGraph n nodes m edges):Graph a) : n=>n=>Bool = :t graphToAdjacencyMatrix > ((a:Type) -> ?-> (pat:(Graph a)) -> (%projectElt [0] pat) => (%projectElt [0] pat) => Bool) +> ?-> (pat:(Graph a)) +> -> ((\((MkGraph n _ _ _)). n) pat) => ((\((MkGraph n _ _ _)). n) pat) => Bool) :p g : Graph Int = MkGraph (Fin 3) [5, 6, 7] (Fin 4) [(0@_, 1@_), (0@_, 2@_), (2@_, 0@_), (1@_, 1@_)] graphToAdjacencyMatrix g > [[False, True, True], [False, True, False], [True, False, False]] +-- Test how (nested) projections are handled and pretty-printed in the IR. -def deepUnpack (x:MyPair a (MyPair (MyIntish & b) c)) : Int = +def pairUnpack ((v, _):(Int & Float)) : Int = v +:p pairUnpack +> \pat:(Int32 & Float32). (\(a, _). a) pat + +def adtUnpack ((MkMyPair v _):MyPair Int Float) : Int = v +:p adtUnpack +> \pat:(MyPair Int32 Float32). (\((MkMyPair elt _)). elt) pat + +def recordUnpack ({a=v, b=_}:{a:Int & b:Float}) : Int = v +:p recordUnpack +> \pat:{a: Int32 & b: Float32}. (\{a = a, b = _}. a) pat + +def nestedUnpack (x:MyPair Int (MyPair (MyIntish & Int) Int)) : Int = (MkMyPair _ (MkMyPair (MkIntish y, _) _)) = x y -:t deepUnpack -> ((a:Type) -> ?-> (b:Type) ?-> (c:Type) ?-> (MyPair a (MyPair (MyIntish & b) c)) -> Int32) +:p nestedUnpack +> \x:(MyPair Int32 (MyPair (MyIntish & Int32) Int32)). +> (\((MkIntish (((MkMyPair ((MkMyPair _ elt)) _)), _))). elt) x -:p deepUnpack (MkMyPair 3 (MkMyPair (MkIntish 4, 5) 6)) +:p nestedUnpack (MkMyPair 3 (MkMyPair (MkIntish 4, 5) 6)) > 4 diff --git a/examples/isomorphisms.dx b/examples/isomorphisms.dx index bd7ea2daf..b127c16a2 100644 --- a/examples/isomorphisms.dx +++ b/examples/isomorphisms.dx @@ -46,24 +46,23 @@ that produce isos. We will start with the first two: :t #b : Iso {a:Int & b:Float & c:Unit} _ > (Iso {a: Int32 & b: Float32 & c: Unit} (Float32 & {a: Int32 & c: Unit})) > === parse === -> _ans_:() = -> MkIso -> {bwd = \(x:(), r:()). {b = x, ...r}, fwd = \{b = x:(), ...r:()}. (,) x r} +> _ans_ = +> MkIso {bwd = \(x, r). {b = x, ...r}, fwd = \{b = x, ...r}. (,) x r} > : Iso {a: Int & b: Float & c: Unit} _ %passes parse :t #?b : Iso {a:Int | b:Float | c:Unit} _ > (Iso {a: Int32 | b: Float32 | c: Unit} (Float32 | {a: Int32 | c: Unit})) > === parse === -> _ans_:() = +> _ans_ = > MkIso -> { bwd = \v:(). case v -> ((Left x:())) -> {| b = x |} -> ((Right r:())) -> {|b| ...r |} +> { bwd = \v. case v +> ((Left x)) -> {| b = x |} +> ((Right r)) -> {|b| ...r |} > -> , fwd = \v:(). case v -> {| b = x:() |} -> (Left x) -> {|b| ...r:() |} -> (Right r) +> , fwd = \v. case v +> {| b = x |} -> (Left x) +> {|b| ...r |} -> (Right r) > } > : Iso {a: Int | b: Float | c: Unit} _ @@ -143,10 +142,10 @@ another. For instance: > ({ &} & {a: Int32 & b: Float32 & c: Unit}) > ({a: Int32} & {b: Float32 & c: Unit})) > === parse === -> _ans_:() = +> _ans_ = > MkIso -> { bwd = \({a = x:(), ...l:()}, {, ...r:()}). (,) {, ...l} {a = x, ...r} -> , fwd = \({, ...l:()}, {a = x:(), ...r:()}). (,) {a = x, ...l} {, ...r}} +> { bwd = \({a = x, ...l}, {, ...r}). (,) {, ...l} {a = x, ...r} +> , fwd = \({, ...l}, {a = x, ...r}). (,) {a = x, ...l} {, ...r}} > : Iso ((&) { &} {a: Int & b: Float & c: Unit}) _ :t (#&a &>> #&b) : Iso ({&} & {a:Int & b:Float & c:Unit}) _ @@ -213,21 +212,21 @@ zipper isomorphisms: > ({ |} | {a: Int32 | b: Float32 | c: Unit}) > ({a: Int32} | {b: Float32 | c: Unit})) > === parse === -> _ans_:() = +> _ans_ = > MkIso -> { bwd = \v:(). case v -> ((Left w:())) -> (case w -> {| a = x:() |} -> (Right {| a = x |}) -> {|a| ...r:() |} -> (Left r) -> ) -> ((Right l:())) -> (Right {|a| ...l |}) +> { bwd = \v. case v +> ((Left w)) -> (case w +> {| a = x |} -> (Right {| a = x |}) +> {|a| ...r |} -> (Left r) +> ) +> ((Right l)) -> (Right {|a| ...l |}) > -> , fwd = \v:(). case v -> ((Left l:())) -> (Left {|a| ...l |}) -> ((Right w:())) -> (case w -> {| a = x:() |} -> (Left {| a = x |}) -> {|a| ...r:() |} -> (Right r) -> ) +> , fwd = \v. case v +> ((Left l)) -> (Left {|a| ...l |}) +> ((Right w)) -> (case w +> {| a = x |} -> (Left {| a = x |}) +> {|a| ...r |} -> (Right r) +> ) > } > : Iso ((|) { |} {a: Int | b: Float | c: Unit}) _ diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index ad4d99528..6022fd00d 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -21,6 +21,7 @@ import Data.Foldable (toList) import qualified Data.List.NonEmpty as NE import qualified Data.Map.Strict as M import qualified Data.ByteString.Lazy.Char8 as B +import Data.Maybe (fromMaybe) import Data.String (fromString) import Data.Text.Prettyprint.Doc.Render.Text import Data.Text.Prettyprint.Doc @@ -32,6 +33,7 @@ import Numeric import Env import Syntax +import Util (enumerate) -- Specifies what kinds of operations are allowed to be printed at this point. -- Printing at AppPrec level means that applications can be printed @@ -362,20 +364,57 @@ instance PrettyPrec Atom where "DataConRef" <+> p params <+> p args BoxedRef b ptr size body -> atPrec AppPrec $ "Box" <+> p b <+> "<-" <+> p ptr <+> "[" <> p size <> "]" <+> hardline <> "in" <+> p body - ProjectElt idxs (x:>_) -> atPrec AppPrec $ - "%projectElt" <+> p idxs <+> p x + ProjectElt idxs x -> prettyProjection idxs x instance Pretty DataConRefBinding where pretty = prettyFromPrettyPrec instance PrettyPrec DataConRefBinding where prettyPrec (DataConRefBinding b x) = atPrec AppPrec $ p b <+> "<-" <+> p x - fromInfix :: Text -> Maybe Text fromInfix t = do ('(', t') <- uncons t (t'', ')') <- unsnoc t' return t'' +prettyProjection :: NE.NonEmpty Int -> Var -> DocPrec ann +prettyProjection idxs (name :> ty) = prettyPrec uproj where + -- Builds a source expression that performs the given projection. + uproj = UApp (PlainArrow ()) (nosrc ulam) (nosrc uvar) + ulam = ULam (upat, Nothing) (PlainArrow ()) (nosrc $ UVar $ target :> ()) + uvar = UVar $ name :> () + (_, upat, target) = buildProj idxs + + buildProj :: NE.NonEmpty Int -> (Type, UPat, Name) + buildProj (i NE.:| is) = let + -- Lazy Haskell trick: refer to `target` even though this function is + -- responsible for setting it! + (ty', pat', eltName) = case NE.nonEmpty is of + Just is' -> let (x, y, z) = buildProj is' in (x, y, Just z) + Nothing -> (ty, nosrc $ UPatBinder $ Bind $ target :> (), Nothing) + in case ty' of + TypeCon def params -> let + [DataConDef conName bs] = applyDataDefParams def params + b = toList bs !! i + pats = (\(j,_)-> if i == j then pat' else uignore) <$> enumerate bs + hint = case b of + Bind (n :> _) -> n + Ignore _ -> Name SourceName "elt" 0 + in ( binderAnn b, nosrc $ UPatCon conName pats, fromMaybe hint eltName) + RecordTy (NoExt types) -> let + ty'' = toList types !! i + pats = (\(j,_)-> if i == j then pat' else uignore) <$> enumerate types + (fieldName, _) = toList (reflectLabels types) !! i + hint = Name SourceName (fromString fieldName) 0 + in (ty'', nosrc $ UPatRecord $ NoExt pats, fromMaybe hint eltName) + PairTy x _ | i == 0 -> + (x, nosrc $ UPatPair pat' uignore, fromMaybe "a" eltName) + PairTy _ y | i == 1 -> + (y, nosrc $ UPatPair uignore pat', fromMaybe "b" eltName) + _ -> error "Bad projection" + + nosrc = WithSrc Nothing + uignore = nosrc $ UPatBinder $ Ignore () + prettyExtLabeledItems :: (PrettyPrec a, PrettyPrec b) => ExtLabeledItems a b -> Doc ann -> Doc ann -> DocPrec ann prettyExtLabeledItems (Ext (LabeledItems row) rest) separator bindwith = @@ -599,7 +638,7 @@ instance Pretty UConDef where instance Pretty UPat' where pretty = prettyFromPrettyPrec instance PrettyPrec UPat' where prettyPrec pat = case pat of - UPatBinder x -> atPrec ArgPrec $ p x + UPatBinder x -> atPrec ArgPrec $ prettyBinderNoAnn x UPatPair x y -> atPrec ArgPrec $ parens $ p x <> ", " <> p y UPatUnit -> atPrec ArgPrec $ "()" UPatCon con pats -> atPrec AppPrec $ parens $ p con <+> spaced pats From 9b03750dc9b2a0bbdaa49a4c96d951284074c0f1 Mon Sep 17 00:00:00 2001 From: Daniel Johnson Date: Mon, 14 Dec 2020 18:54:44 -0500 Subject: [PATCH 12/13] Address PR comments for projections --- examples/adt-tests.dx | 12 +++++++++++- prelude.dx | 2 +- src/lib/Autodiff.hs | 19 ++++++++----------- src/lib/Embed.hs | 8 +++----- src/lib/Inference.hs | 4 ++-- src/lib/Parser.hs | 35 +++++++++++++++++++---------------- src/lib/Syntax.hs | 6 ++++-- src/lib/Type.hs | 10 +++++----- 8 files changed, 53 insertions(+), 43 deletions(-) diff --git a/examples/adt-tests.dx b/examples/adt-tests.dx index ca64b7bba..364b73d25 100644 --- a/examples/adt-tests.dx +++ b/examples/adt-tests.dx @@ -234,6 +234,16 @@ def listToTable2 (l: List a) : (Fin (listLength l))=>a = sum $ listToTable2 l > 6 +l2 = AsList _ [1, 2, 3] +:p sum $ listToTable2 l2 +> 6 + +def zerosLikeList (l : List a) : (Fin (listLength l))=>Float = + for i:(Fin $ listLength l). 0.0 + +:p zerosLikeList l2 +> [0.0, 0.0, 0.0] + data Graph a:Type = MkGraph n:Type nodes:(n=>a) m:Type edges:(m=>(n & n)) @@ -254,7 +264,7 @@ def graphToAdjacencyMatrix ((MkGraph n nodes m edges):Graph a) : n=>n=>Bool = graphToAdjacencyMatrix g > [[False, True, True], [False, True, False], [True, False, False]] --- Test how (nested) projections are handled and pretty-printed in the IR. +-- Test how (nested) projections are handled and pretty-printed. def pairUnpack ((v, _):(Int & Float)) : Int = v :p pairUnpack diff --git a/prelude.dx b/prelude.dx index b69b208fe..3b046f450 100644 --- a/prelude.dx +++ b/prelude.dx @@ -143,7 +143,7 @@ def (&) (a:Type) (b:Type) : Type = %PairType a b def (,) (x:a) (y:b) : (a & b) = %pair x y def fst ((x, _): (a & b)) : a = x def snd ((_, y): (a & b)) : b = y -def swap (p:(a&b)) : (b&a) = (snd p, fst p) +def swap ((x, y):(a&b)) : (b&a) = (y, x) def (<<<) (f: b -> c) (g: a -> b) : a -> c = \x. f (g x) def (>>>) (g: a -> b) (f: b -> c) : a -> c = \x. f (g x) diff --git a/src/lib/Autodiff.hs b/src/lib/Autodiff.hs index 68cead221..045fef686 100644 --- a/src/lib/Autodiff.hs +++ b/src/lib/Autodiff.hs @@ -378,7 +378,9 @@ tangentType ty = case ty of addTangent :: MonadEmbed m => Atom -> Atom -> m Atom addTangent x y = case getType x of - RecordTy _ -> pack (getType x) <$> bindM2 (zipWithT addTangent) (getUnpacked x) (getUnpacked y) + RecordTy (NoExt tys) -> do + elems <- bindM2 (zipWithT addTangent) (getUnpacked x) (getUnpacked y) + return $ Record $ restructure elems tys TabTy b _ -> buildFor Fwd b $ \i -> bindM2 addTangent (tabGet x i) (tabGet y i) TC con -> case con of BaseType (Scalar _) -> emitOp $ ScalarBinOp FAdd x y @@ -392,13 +394,6 @@ addTangent x y = case getType x of _ -> notTangent where notTangent = error $ "Not a tangent type: " ++ pprint (getType x) -pack :: Type -> [Atom] -> Atom -pack ty elems = case ty of - PairTy _ _ -> let [x, y] = elems in PairVal x y - TypeCon def params -> DataCon def params 0 elems - RecordTy (NoExt types) -> Record $ restructure elems types - _ -> error $ "Unexpected Unpack argument type: " ++ pprint ty - isTrivialForAD :: Expr -> Bool isTrivialForAD expr = isSingletonType tangentTy && exprEffs expr == mempty where tangentTy = tangentType $ getType expr @@ -640,9 +635,11 @@ linAtomRef (Var x) = do case envLookup refs x of Just ref -> return ref _ -> error $ "Not a linear var: " ++ pprint (Var x) -linAtomRef (ProjectElt (i NE.:| is) x) = do - let subproj = getProjection is (Var x) - case getType subproj of +linAtomRef (ProjectElt (i NE.:| is) x) = + let subproj = case NE.nonEmpty is of + Just is' -> ProjectElt is' x + Nothing -> Var x + in case getType subproj of PairTy _ _ -> do ref <- linAtomRef subproj (traverse $ emitOp . getter) ref diff --git a/src/lib/Embed.hs b/src/lib/Embed.hs index 380bc7aa4..4a7d75cdb 100644 --- a/src/lib/Embed.hs +++ b/src/lib/Embed.hs @@ -104,8 +104,7 @@ emitOp :: MonadEmbed m => Op -> m Atom emitOp op = emit $ Op op emitUnpack :: MonadEmbed m => Expr -> m [Atom] -emitUnpack expr = do - getUnpacked =<< emit expr +emitUnpack expr = getUnpacked =<< emit expr -- Assumes the decl binders are already fresh wrt current scope emitBlock :: MonadEmbed m => Block -> m Atom @@ -287,9 +286,8 @@ ieq x y = emitOp $ ScalarBinOp (ICmp Equal) x y fromPair :: MonadEmbed m => Atom -> m (Atom, Atom) fromPair pair = do - scope <- getScope - let pair' = reduceAtom scope pair - return (getProjection [0] pair', getProjection [1] pair') + ~[x, y] <- getUnpacked pair + return (x, y) getFst :: MonadEmbed m => Atom -> m Atom getFst p = fst <$> fromPair p diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index c31a1c3ea..78b6ae5fb 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -160,8 +160,8 @@ checkOrInferRho (WithSrc pos expr) reqTy = do Just pat' -> withNameHint ("pat" :: Name) $ buildPi b $ \x -> withBindPat pat' x $ (,) <$> mapM checkUEff arr <*> checkUType ty where b = case pat' of - -- Note: must bind it by name if the user gives an explicit - -- name, since the binder name becomes part of the type. + -- Note: The binder name becomes part of the type, so we + -- need to keep the same name used in the pattern. WithSrc _ (UPatBinder (Bind (v:>()))) -> Bind (v:>kind') _ -> Ignore kind' Nothing -> buildPi (Ignore kind') $ const $ diff --git a/src/lib/Parser.hs b/src/lib/Parser.hs index a0018a271..8e99dcb48 100644 --- a/src/lib/Parser.hs +++ b/src/lib/Parser.hs @@ -258,39 +258,42 @@ topLet = do -- sense that the user did NOT explicitly annotate them as implicit. findImplicitImplicitArgNames :: UType -> [Name] findImplicitImplicitArgNames typ = filter isLowerCaseName $ envNames $ - freeUVars typ `envDiff` findFunctionVars typ + freeUVars typ `envDiff` findVarsInAppLHS typ where isLowerCaseName :: Name -> Bool isLowerCaseName (Name _ tag _) = isLower $ head $ tagToStr tag isLowerCaseName _ = False - -- Finds all variables used in function position, which should be pulled out. - findFunctionVars :: UType -> Env () - findFunctionVars (WithSrc _ typ') = case typ' of + -- Finds all variables used in the left hand of an application, which should + -- be filtered out and not automatically inferred. + findVarsInAppLHS :: UType -> Env () + findVarsInAppLHS (WithSrc _ typ') = case typ' of + -- base case + UApp _ (WithSrc _ (UVar (v:>_))) x -> (v @> ()) <> findVarsInAppLHS x + -- recursive steps UVar _ -> mempty UPi (p, ann) _ ty -> - findFunctionVars ann <> (findFunctionVars ty `envDiff` boundUVars p) - UApp _ (WithSrc _ (UVar (v:>_))) x -> (v @> ()) <> findFunctionVars x - UApp _ f x -> findFunctionVars f <> findFunctionVars x + findVarsInAppLHS ann <> (findVarsInAppLHS ty `envDiff` boundUVars p) + UApp _ f x -> findVarsInAppLHS f <> findVarsInAppLHS x ULam (p, ann) _ x -> - foldMap findFunctionVars ann <> (findFunctionVars x `envDiff` boundUVars p) + foldMap findVarsInAppLHS ann <> (findVarsInAppLHS x `envDiff` boundUVars p) UDecl _ _ -> error "Unexpected let binding in type annotation" UFor _ _ _ -> error "Unexpected for in type annotation" UHole -> mempty - UTypeAnn v ty -> findFunctionVars v <> findFunctionVars ty + UTypeAnn v ty -> findVarsInAppLHS v <> findVarsInAppLHS ty UTabCon _ -> error "Unexpected table in type annotation" UIndexRange low high -> - foldMap findFunctionVars low <> foldMap findFunctionVars high - UPrimExpr prim -> foldMap findFunctionVars prim + foldMap findVarsInAppLHS low <> foldMap findVarsInAppLHS high + UPrimExpr prim -> foldMap findVarsInAppLHS prim UCase _ _ -> error "Unexpected case in type annotation" - URecord (Ext ulr _) -> foldMap findFunctionVars ulr - UVariant _ _ val -> findFunctionVars val + URecord (Ext ulr _) -> foldMap findVarsInAppLHS ulr + UVariant _ _ val -> findVarsInAppLHS val URecordTy (Ext ulr v) -> - foldMap findFunctionVars ulr <> foldMap findFunctionVars v + foldMap findVarsInAppLHS ulr <> foldMap findVarsInAppLHS v UVariantTy (Ext ulr v) -> - foldMap findFunctionVars ulr <> foldMap findFunctionVars v - UVariantLift _ val -> findFunctionVars val + foldMap findVarsInAppLHS ulr <> foldMap findVarsInAppLHS v + UVariantLift _ val -> findVarsInAppLHS val UIntLit _ -> mempty UFloatLit _ -> mempty diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index ed4aeafa7..6d5d04b9c 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -105,8 +105,10 @@ data Atom = Var Var | DataConRef DataDef [Atom] (Nest DataConRefBinding) | BoxedRef Binder Atom Block Atom -- binder, ptr, size, body -- access a nested member of a binder - -- XXX: Variable name MUST be fully reduced, it cannot be a synonym! - -- This is because the variable name may also appear in the type. + -- XXX: Variable name must not be an alias for another name or for + -- a statically-known atom. This is because the variable name used + -- here may also appear in the type of the atom. (We maintain this + -- invariant during substitution and in Embed.hs.) | ProjectElt (NE.NonEmpty Int) Var deriving (Show, Generic) diff --git a/src/lib/Type.hs b/src/lib/Type.hs index 59792f70d..253952e7c 100644 --- a/src/lib/Type.hs +++ b/src/lib/Type.hs @@ -171,12 +171,12 @@ instance HasType Atom where -- Users might be accessing a value whose type depends on earlier -- projected values from this constructor. Rewrite them to also -- use projections. - let go :: Int -> Nest Binder -> Type - go j (Nest b _) | i == j = binderAnn b - go j (Nest b rest) = go (j+1) (scopelessSubst (b @> proj) rest) + let go :: Int -> SubstEnv -> Nest Binder -> Type + go j env (Nest b _) | i == j = scopelessSubst env $ binderAnn b + go j env (Nest b rest) = go (j+1) (env <> (b @> proj)) rest where proj = ProjectElt (j NE.:| is) v - go _ _ = error "Bad projection index" - return $ go 0 bs' + go _ _ _ = error "Bad projection index" + return $ go 0 mempty bs' RecordTy (NoExt types) -> return $ toList types !! i RecordTy _ -> throw CompilerErr "Can't project partially-known records" PairTy x _ | i == 0 -> return x From bd4d308adaf87d3981ca8d0d9ae7b127dc76cff1 Mon Sep 17 00:00:00 2001 From: Daniel Johnson Date: Mon, 14 Dec 2020 19:11:59 -0500 Subject: [PATCH 13/13] Move and rename reduction functions. The functions now live in Type.hs and have a `typeReduce` prefix instead of just being called `reduceAtom`/`reduceExpr` etc. --- src/lib/Embed.hs | 70 ++------------------------------------------ src/lib/Inference.hs | 12 ++++++-- src/lib/Type.hs | 68 +++++++++++++++++++++++++++++++++++++++++- 3 files changed, 79 insertions(+), 71 deletions(-) diff --git a/src/lib/Embed.hs b/src/lib/Embed.hs index 4a7d75cdb..16a9af2d6 100644 --- a/src/lib/Embed.hs +++ b/src/lib/Embed.hs @@ -30,8 +30,7 @@ module Embed (emit, emitTo, emitAnn, emitOp, buildDepEffLam, buildLamAux, buildP TraversalDef, traverseDecls, traverseDecl, traverseBlock, traverseExpr, clampPositive, buildNAbs, buildNAbsAux, buildNestedLam, zeroAt, transformModuleAsBlock, dropSub, appReduceTraversalDef, - indexSetSizeE, indexToIntE, intToIndexE, freshVarE, - reduceScoped, reduceBlock, reduceAtom, reduceExpr) where + indexSetSizeE, indexToIntE, intToIndexE, freshVarE) where import Control.Applicative import Control.Monad @@ -144,7 +143,7 @@ buildPi b f = do (arr, ans) <- f $ Var v return $ Pi $ makeAbs (Bind v) (arr, ans) let block = wrapDecls decls ans - case reduceBlock scope block of + case typeReduceBlock scope block of Just piTy -> return piTy Nothing -> throw CompilerErr $ "Unexpected irreducible decls in pi type: " ++ pprint decls @@ -308,7 +307,7 @@ getUnpacked :: MonadEmbed m => Atom -> m [Atom] getUnpacked atom = do scope <- getScope let len = projectLength $ getType atom - atom' = reduceAtom scope atom + atom' = typeReduceAtom scope atom res = map (\i -> getProjection [i] atom') [0..(len-1)] return res @@ -824,66 +823,3 @@ intToIndexE (VariantTy (NoExt types)) i = do start <- Variant (NoExt types) l0 0 <$> intToIndexE ty0 i foldM go start zs intToIndexE ty _ = error $ "Unexpected type " ++ pprint ty - --- === Reduction === - -reduceScoped :: MonadEmbed m => m Atom -> m (Maybe Atom) -reduceScoped m = do - block <- buildScoped m - scope <- getScope - return $ reduceBlock scope block - -reduceBlock :: Scope -> Block -> Maybe Atom -reduceBlock scope (Block decls result) = do - let localScope = foldMap boundVars decls - ans <- reduceExpr (scope <> localScope) result - [] <- return $ toList $ localScope `envIntersect` freeVars ans - return ans - --- XXX: This should handle all terms of type Type. Otherwise type equality checking --- will get broken. -reduceAtom :: Scope -> Atom -> Atom -reduceAtom scope x = case x of - Var (Name InferenceName _ _ :> _) -> x - Var v -> case snd (scope ! v) of - -- TODO: worry about effects! - LetBound PlainLet expr -> fromMaybe x $ reduceExpr scope expr - _ -> x - TC con -> TC $ fmap (reduceAtom scope) con - Pi (Abs b (arr, ty)) -> Pi $ Abs b (arr, reduceAtom (scope <> (fmap (,PiBound) $ binderAsEnv b)) ty) - TypeCon def params -> TypeCon (reduceDataDef def) (fmap rec params) - RecordTy (Ext tys ext) -> RecordTy $ Ext (fmap rec tys) ext - VariantTy (Ext tys ext) -> VariantTy $ Ext (fmap rec tys) ext - ACase _ _ _ -> error "Not implemented" - _ -> x - where - rec = reduceAtom scope - reduceNest s n = case n of - Empty -> Empty - -- Technically this should use a more concrete type than UnknownBinder, but anything else - -- than LetBound is indistinguishable for this reduction anyway. - Nest b rest -> Nest b' $ reduceNest (s <> (fmap (,UnknownBinder) $ binderAsEnv b)) rest - where b' = fmap (reduceAtom s) b - reduceDataDef (DataDef n bs cons) = - DataDef n (reduceNest scope bs) - (fmap (reduceDataConDef (scope <> (foldMap (fmap (,UnknownBinder) . binderAsEnv) bs))) cons) - reduceDataConDef s (DataConDef n bs) = DataConDef n $ reduceNest s bs - -reduceExpr :: Scope -> Expr -> Maybe Atom -reduceExpr scope expr = case expr of - Atom val -> return $ reduceAtom scope val - App f x -> do - let f' = reduceAtom scope f - let x' = reduceAtom scope x - -- TODO: Worry about variable capture. Should really carry a substitution. - case f' of - Lam (Abs b (arr, block)) | arr == PureArrow || arr == ImplicitArrow -> - reduceBlock scope $ subst (b@>x', scope) block - TypeCon con xs -> Just $ TypeCon con $ xs ++ [x'] - _ -> Nothing - Op (MakePtrType ty) -> do - let ty' = reduceAtom scope ty - case ty' of - BaseTy b -> return $ PtrTy (AllocatedPtr, Heap CPU, b) - _ -> Nothing - _ -> Nothing diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index 78b6ae5fb..f5589d40c 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -142,7 +142,7 @@ checkOrInferRho (WithSrc pos expr) reqTy = do Abs b rhs@(arr', _) -> case b `isin` freeVars rhs of False -> embedExtend embedEnv $> (xVal, arr') True -> do - xValMaybeRed <- flip reduceBlock (Block xDecls (Atom xVal)) <$> getScope + xValMaybeRed <- flip typeReduceBlock (Block xDecls (Atom xVal)) <$> getScope case xValMaybeRed of Just xValRed -> return (xValRed, fst $ applyAbs piTy xValRed) Nothing -> addSrcContext' xPos $ do @@ -259,7 +259,7 @@ checkOrInferRho (WithSrc pos expr) reqTy = do prim' <- forM prim $ \e -> do e' <- inferRho e scope <- getScope - return $ reduceAtom scope e' + return $ typeReduceAtom scope e' val <- case prim' of TCExpr e -> return $ TC e ConExpr e -> return $ Con e @@ -553,7 +553,7 @@ checkAnn ann = case ann of checkUType :: UType -> UInferM Type checkUType ty = do - reduced <- reduceScoped $ withEffects Pure $ checkRho ty TyKind + reduced <- typeReduceScoped $ withEffects Pure $ checkRho ty TyKind case reduced of Just ty' -> return $ ty' Nothing -> throw TypeErr $ "Can't reduce type expression: " ++ pprint ty @@ -920,3 +920,9 @@ instance Semigroup SolverEnv where instance Monoid SolverEnv where mempty = SolverEnv mempty mempty mappend = (<>) + +typeReduceScoped :: MonadEmbed m => m Atom -> m (Maybe Atom) +typeReduceScoped m = do + block <- buildScoped m + scope <- getScope + return $ typeReduceBlock scope block diff --git a/src/lib/Type.hs b/src/lib/Type.hs index 253952e7c..7ef41d1c2 100644 --- a/src/lib/Type.hs +++ b/src/lib/Type.hs @@ -12,7 +12,8 @@ module Type ( getType, checkType, HasType (..), Checkable (..), litType, isPure, functionEffs, exprEffs, blockEffs, extendEffect, isData, checkBinOp, checkUnOp, checkIntBaseType, checkFloatBaseType, withBinder, isDependent, - indexSetConcreteSize, checkNoShadow, traceCheckM, traceCheck, projectLength) where + indexSetConcreteSize, checkNoShadow, traceCheckM, traceCheck, projectLength, + typeReduceBlock, typeReduceAtom, typeReduceExpr) where import Prelude hiding (pi) import Control.Monad @@ -32,6 +33,7 @@ import Env import PPrint import Cat import Util (bindM2) +import Data.Maybe (fromMaybe) type TypeEnv = Bindings -- Only care about type payload @@ -974,3 +976,67 @@ projectLength ty = case ty of RecordTy (NoExt types) -> length types PairTy _ _ -> 2 _ -> error $ "Projecting a type that doesn't support projecting: " ++ pprint ty + + +-- === Type-level reduction using variables in scope. === + +-- Note: These are simple reductions that are performed when normalizing a +-- value to use it as a type annotation. If they succeed, these functions should +-- return atoms that can be compared for equality to check whether the types +-- are equivalent. If they fail (return Nothing), this means we cannot use +-- the value as a type in the IR. + +typeReduceBlock :: Scope -> Block -> Maybe Atom +typeReduceBlock scope (Block decls result) = do + let localScope = foldMap boundVars decls + ans <- typeReduceExpr (scope <> localScope) result + [] <- return $ toList $ localScope `envIntersect` freeVars ans + return ans + +-- XXX: This should handle all terms of type Type. Otherwise type equality checking +-- will get broken. +typeReduceAtom :: Scope -> Atom -> Atom +typeReduceAtom scope x = case x of + Var (Name InferenceName _ _ :> _) -> x + Var v -> case snd (scope ! v) of + -- TODO: worry about effects! + LetBound PlainLet expr -> fromMaybe x $ typeReduceExpr scope expr + _ -> x + TC con -> TC $ fmap (typeReduceAtom scope) con + Pi (Abs b (arr, ty)) -> Pi $ Abs b (arr, typeReduceAtom (scope <> (fmap (,PiBound) $ binderAsEnv b)) ty) + TypeCon def params -> TypeCon (reduceDataDef def) (fmap rec params) + RecordTy (Ext tys ext) -> RecordTy $ Ext (fmap rec tys) ext + VariantTy (Ext tys ext) -> VariantTy $ Ext (fmap rec tys) ext + ACase _ _ _ -> error "Not implemented" + _ -> x + where + rec = typeReduceAtom scope + reduceNest s n = case n of + Empty -> Empty + -- Technically this should use a more concrete type than UnknownBinder, but anything else + -- than LetBound is indistinguishable for this reduction anyway. + Nest b rest -> Nest b' $ reduceNest (s <> (fmap (,UnknownBinder) $ binderAsEnv b)) rest + where b' = fmap (typeReduceAtom s) b + reduceDataDef (DataDef n bs cons) = + DataDef n (reduceNest scope bs) + (fmap (reduceDataConDef (scope <> (foldMap (fmap (,UnknownBinder) . binderAsEnv) bs))) cons) + reduceDataConDef s (DataConDef n bs) = DataConDef n $ reduceNest s bs + +typeReduceExpr :: Scope -> Expr -> Maybe Atom +typeReduceExpr scope expr = case expr of + Atom val -> return $ typeReduceAtom scope val + App f x -> do + let f' = typeReduceAtom scope f + let x' = typeReduceAtom scope x + -- TODO: Worry about variable capture. Should really carry a substitution. + case f' of + Lam (Abs b (arr, block)) | arr == PureArrow || arr == ImplicitArrow -> + typeReduceBlock scope $ subst (b@>x', scope) block + TypeCon con xs -> Just $ TypeCon con $ xs ++ [x'] + _ -> Nothing + Op (MakePtrType ty) -> do + let ty' = typeReduceAtom scope ty + case ty' of + BaseTy b -> return $ PtrTy (AllocatedPtr, Heap CPU, b) + _ -> Nothing + _ -> Nothing