From ed986c400abddfc1611f213abf4a5d7939e52cde Mon Sep 17 00:00:00 2001 From: Nikolai Kudasov Date: Thu, 20 Jun 2024 13:33:13 +0300 Subject: [PATCH] Allow deriving scope-safe patterns separately --- .../src/Control/Monad/Foil/TH/MkFoilData.hs | 80 +++++++++++-------- .../Control/Monad/Foil/TH/MkInstancesFoil.hs | 41 ++++++---- .../src/Control/Monad/Foil/TH/MkToFoil.hs | 12 +-- .../src/Language/LambdaPi/Impl/FoilTH.hs | 2 - .../src/Language/LambdaPi/Impl/FreeFoilTH.hs | 7 ++ 5 files changed, 88 insertions(+), 54 deletions(-) diff --git a/haskell/free-foil/src/Control/Monad/Foil/TH/MkFoilData.hs b/haskell/free-foil/src/Control/Monad/Foil/TH/MkFoilData.hs index f6fcb047..d7c0097e 100644 --- a/haskell/free-foil/src/Control/Monad/Foil/TH/MkFoilData.hs +++ b/haskell/free-foil/src/Control/Monad/Foil/TH/MkFoilData.hs @@ -3,7 +3,7 @@ {-# LANGUAGE TemplateHaskellQuotes #-} {-# OPTIONS_GHC -Wno-incomplete-patterns #-} {-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-} -module Control.Monad.Foil.TH.MkFoilData (mkFoilData) where +module Control.Monad.Foil.TH.MkFoilData where import Language.Haskell.TH @@ -20,24 +20,66 @@ mkFoilData mkFoilData termT nameT scopeT patternT = do n <- newName "n" l <- newName "l" - TyConI (DataD _ctx _name patternTVars _kind patternCons _deriv) <- reify patternT TyConI (DataD _ctx _name scopeTVars _kind scopeCons _deriv) <- reify scopeT TyConI (DataD _ctx _name termTVars _kind termCons _deriv) <- reify termT - foilPatternCons <- mapM (toPatternCon patternTVars n) patternCons let foilScopeCons = map (toScopeCon scopeTVars n) scopeCons let foilTermCons = map (toTermCon termTVars n l) termCons - return + patternD <- mkFoilPattern nameT patternT + return $ [ DataD [] foilTermT (termTVars ++ [KindedTV n BndrReq (PromotedT ''Foil.S)]) Nothing foilTermCons [] , DataD [] foilScopeT (scopeTVars ++ [KindedTV n BndrReq (PromotedT ''Foil.S)]) Nothing foilScopeCons [] - , DataD [] foilPatternT (patternTVars ++ [KindedTV n BndrReq (PromotedT ''Foil.S), KindedTV l BndrReq (PromotedT ''Foil.S)]) Nothing foilPatternCons [] - ] + ] ++ patternD where foilTermT = mkName ("Foil" ++ nameBase termT) foilScopeT = mkName ("Foil" ++ nameBase scopeT) foilPatternT = mkName ("Foil" ++ nameBase patternT) + -- | Convert a constructor declaration for a raw scoped term + -- into a constructor for the scope-safe scoped term. + toScopeCon :: [TyVarBndr BndrVis] -> Name -> Con -> Con + toScopeCon _tvars n (NormalC conName params) = + NormalC foilConName (map toScopeParam params) + where + foilConName = mkName ("Foil" ++ nameBase conName) + toScopeParam (_bang, PeelConT tyName tyParams) + | tyName == termT = (_bang, PeelConT foilTermT (tyParams ++ [VarT n])) + toScopeParam _bangType = _bangType + + -- | Convert a constructor declaration for a raw term + -- into a constructor for the scope-safe term. + toTermCon :: [TyVarBndr BndrVis] -> Name -> Name -> Con -> Con + toTermCon tvars n l (NormalC conName params) = + GadtC [foilConName] (map toTermParam params) (PeelConT foilTermT (map (VarT . tvarName) tvars ++ [VarT n])) + where + foilNames = [n, l] + foilConName = mkName ("Foil" ++ nameBase conName) + toTermParam (_bang, PeelConT tyName tyParams) + | tyName == patternT = (_bang, PeelConT foilPatternT (tyParams ++ map VarT foilNames)) + | tyName == nameT = (_bang, AppT (ConT ''Foil.Name) (VarT n)) + | tyName == scopeT = (_bang, PeelConT foilScopeT (tyParams ++ [VarT l])) + | tyName == termT = (_bang, PeelConT foilTermT (tyParams ++ [VarT n])) + toTermParam _bangType = _bangType + +-- | Generate just the scope-safe patterns. +mkFoilPattern + :: Name -- ^ Type name for raw variable identifiers. + -> Name -- ^ Type name for raw patterns. + -> Q [Dec] +mkFoilPattern nameT patternT = do + n <- newName "n" + l <- newName "l" + TyConI (DataD _ctx _name patternTVars _kind patternCons _deriv) <- reify patternT + + foilPatternCons <- mapM (toPatternCon patternTVars n) patternCons + + return + [ DataD [] foilPatternT (patternTVars ++ [KindedTV n BndrReq (PromotedT ''Foil.S), KindedTV l BndrReq (PromotedT ''Foil.S)]) Nothing foilPatternCons [] + ] + where + foilPatternT = mkName ("Foil" ++ nameBase patternT) + -- | Convert a constructor declaration for a raw pattern type -- into a constructor for the scope-safe pattern type. toPatternCon @@ -79,29 +121,3 @@ mkFoilData termT nameT scopeT patternT = do _ -> do (l, conParams') <- toPatternConParams (i+1) p conParams return (l, param : conParams') - - -- | Convert a constructor declaration for a raw scoped term - -- into a constructor for the scope-safe scoped term. - toScopeCon :: [TyVarBndr BndrVis] -> Name -> Con -> Con - toScopeCon _tvars n (NormalC conName params) = - NormalC foilConName (map toScopeParam params) - where - foilConName = mkName ("Foil" ++ nameBase conName) - toScopeParam (_bang, PeelConT tyName tyParams) - | tyName == termT = (_bang, PeelConT foilTermT (tyParams ++ [VarT n])) - toScopeParam _bangType = _bangType - - -- | Convert a constructor declaration for a raw term - -- into a constructor for the scope-safe term. - toTermCon :: [TyVarBndr BndrVis] -> Name -> Name -> Con -> Con - toTermCon tvars n l (NormalC conName params) = - GadtC [foilConName] (map toTermParam params) (PeelConT foilTermT (map (VarT . tvarName) tvars ++ [VarT n])) - where - foilNames = [n, l] - foilConName = mkName ("Foil" ++ nameBase conName) - toTermParam (_bang, PeelConT tyName tyParams) - | tyName == patternT = (_bang, PeelConT foilPatternT (tyParams ++ map VarT foilNames)) - | tyName == nameT = (_bang, AppT (ConT ''Foil.Name) (VarT n)) - | tyName == scopeT = (_bang, PeelConT foilScopeT (tyParams ++ [VarT l])) - | tyName == termT = (_bang, PeelConT foilTermT (tyParams ++ [VarT n])) - toTermParam _bangType = _bangType diff --git a/haskell/free-foil/src/Control/Monad/Foil/TH/MkInstancesFoil.hs b/haskell/free-foil/src/Control/Monad/Foil/TH/MkInstancesFoil.hs index cf256719..3a33f6a3 100644 --- a/haskell/free-foil/src/Control/Monad/Foil/TH/MkInstancesFoil.hs +++ b/haskell/free-foil/src/Control/Monad/Foil/TH/MkInstancesFoil.hs @@ -3,7 +3,7 @@ {-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE TemplateHaskellQuotes #-} {-# OPTIONS_GHC -fno-warn-type-defaults #-} -module Control.Monad.Foil.TH.MkInstancesFoil (mkInstancesFoil) where +module Control.Monad.Foil.TH.MkInstancesFoil where import Language.Haskell.TH @@ -18,26 +18,22 @@ mkInstancesFoil -> Name -- ^ Type name for raw patterns. -> Q [Dec] mkInstancesFoil termT nameT scopeT patternT = do - TyConI (DataD _ctx _name patternTVars _kind patternCons _deriv) <- reify patternT TyConI (DataD _ctx _name scopeTVars _kind scopeCons _deriv) <- reify scopeT TyConI (DataD _ctx _name termTVars _kind termCons _deriv) <- reify termT - return + coSinkablePatternD <- deriveCoSinkable nameT patternT + + return $ [ InstanceD Nothing [] (AppT (ConT ''Foil.Sinkable) (PeelConT foilScopeT (map (VarT . tvarName) scopeTVars))) [ FunD 'Foil.sinkabilityProof (map clauseScopedTerm scopeCons) ] - , InstanceD Nothing [] (AppT (ConT ''Foil.CoSinkable) (PeelConT foilPatternT (map (VarT . tvarName) patternTVars))) - [ FunD 'Foil.coSinkabilityProof (map clausePattern patternCons) - , FunD 'Foil.withPattern (map clauseWithPattern patternCons) ] - , InstanceD Nothing [] (AppT (ConT ''Foil.Sinkable) (PeelConT foilTermT (map (VarT . tvarName) termTVars))) [ FunD 'Foil.sinkabilityProof (map clauseTerm termCons)] - ] + ] ++ coSinkablePatternD where foilTermT = mkName ("Foil" ++ nameBase termT) foilScopeT = mkName ("Foil" ++ nameBase scopeT) - foilPatternT = mkName ("Foil" ++ nameBase patternT) clauseScopedTerm :: Con -> Clause clauseScopedTerm = clauseTerm @@ -55,7 +51,7 @@ mkInstancesFoil termT nameT scopeT patternT = do [] where foilConName = mkName ("Foil" ++ nameBase conName) - rename = mkName "rename" + rename = mkName "_rename" conParamPatterns = zipWith mkConParamPattern params [1..] mkConParamPattern _ i = VarP (mkName ("x" ++ show i)) @@ -81,6 +77,23 @@ mkInstancesFoil termT nameT scopeT patternT = do where xi = mkName ("x" ++ show i) +-- | Generate 'Foil.Sinkable' and 'Foil.CoSinkable' instances. +deriveCoSinkable + :: Name -- ^ Type name for raw variable identifiers. + -> Name -- ^ Type name for raw patterns. + -> Q [Dec] +deriveCoSinkable nameT patternT = do + TyConI (DataD _ctx _name patternTVars _kind patternCons _deriv) <- reify patternT + + return + [ InstanceD Nothing [] (AppT (ConT ''Foil.CoSinkable) (PeelConT foilPatternT (map (VarT . tvarName) patternTVars))) + [ FunD 'Foil.coSinkabilityProof (map clausePattern patternCons) + , FunD 'Foil.withPattern (map clauseWithPattern patternCons) ] + ] + + where + foilPatternT = mkName ("Foil" ++ nameBase patternT) + clausePattern :: Con -> Clause clausePattern RecC{} = error "Record constructors (RecC) are not supported yet!" clausePattern InfixC{} = error "Infix constructors (InfixC) are not supported yet!" @@ -128,10 +141,10 @@ mkInstancesFoil termT nameT scopeT patternT = do [] where foilConName = mkName ("Foil" ++ nameBase conName) - withNameBinder = mkName "withNameBinder" + withNameBinder = mkName "_withNameBinder" id' = mkName "id'" - comp = mkName "comp" - scope = mkName "scope" + comp = mkName "_comp" + scope = mkName "_scope" cont = mkName "cont" conParamPatterns = zipWith mkConParamPattern params [1..] mkConParamPattern _ i = VarP (mkName ("x" ++ show i)) @@ -148,7 +161,7 @@ mkInstancesFoil termT nameT scopeT patternT = do xi = mkName ("x" ++ show i) xi' = mkName ("x" ++ show i ++ "'") renamei = mkName ("f" ++ show i) - scopei = mkName ("scope" ++ show i) + scopei = mkName ("_scope" ++ show i) go i scope' rename' p (_ : conPatterns) = go (i + 1) scope' rename' (AppE p (VarE xi)) conPatterns where diff --git a/haskell/free-foil/src/Control/Monad/Foil/TH/MkToFoil.hs b/haskell/free-foil/src/Control/Monad/Foil/TH/MkToFoil.hs index 8df46f3c..1c01c4fe 100644 --- a/haskell/free-foil/src/Control/Monad/Foil/TH/MkToFoil.hs +++ b/haskell/free-foil/src/Control/Monad/Foil/TH/MkToFoil.hs @@ -189,7 +189,7 @@ mkWithRefreshedFoilPattern nameT patternT = do where xi = mkName ("x" <> show i) xi' = mkName ("x" <> show i <> "'") - scopei = mkName ("scope" <> show i) + scopei = mkName ("_scope" <> show i) xsubst = mkName ("subst" <> show i) subst = mkName "subst" fi = LamE [VarP subst] @@ -295,7 +295,7 @@ mkToFoilTerm termT nameT scopeT patternT = do toMatch (NormalC conName params) = Match (ConP conName [] conParamPatterns) (NormalB conMatchBody) [toFoilVarD] where - toFoilVarFunName = mkName "lookupRawVar" + toFoilVarFunName = mkName "_lookupRawVar" toFoilVarFun = VarE toFoilVarFunName x = mkName "x" name = mkName "name" @@ -324,7 +324,7 @@ mkToFoilTerm termT nameT scopeT patternT = do where xi = mkName ("x" <> show i) xi' = mkName ("x" <> show i <> "'") - scopei = mkName ("scope" <> show i) + scopei = mkName ("_scope" <> show i) envi = mkName ("env" <> show i) go i scope' env' p (_ : conParams) = go (i + 1) scope' env' (AppE p (VarE xi)) conParams @@ -377,7 +377,7 @@ mkToFoilTerm termT nameT scopeT patternT = do where xi = mkName ("x" <> show i) xi' = mkName ("x" <> show i <> "'") - scopei = mkName ("scope" <> show i) + scopei = mkName ("_scope" <> show i) envi = mkName ("env" <> show i) go i scope' env' p (_ : conParams) = go (i + 1) scope' env' (AppE p (VarE xi)) conParams @@ -407,7 +407,7 @@ mkToFoilTerm termT nameT scopeT patternT = do toMatch (NormalC conName params) = Match (ConP conName [] conParamPatterns) (NormalB conMatchBody) [toFoilVarD] where - toFoilVarFunName = mkName "lookupRawVar" + toFoilVarFunName = mkName "_lookupRawVar" toFoilVarFun = VarE toFoilVarFunName x = mkName "x" name = mkName "name" @@ -436,7 +436,7 @@ mkToFoilTerm termT nameT scopeT patternT = do where xi = mkName ("x" <> show i) xi' = mkName ("x" <> show i <> "'") - scopei = mkName ("scope" <> show i) + scopei = mkName ("_scope" <> show i) envi = mkName ("env" <> show i) go i scope' env' p (_ : conParams) = go (i + 1) scope' env' (AppE p (VarE xi)) conParams diff --git a/haskell/lambda-pi/src/Language/LambdaPi/Impl/FoilTH.hs b/haskell/lambda-pi/src/Language/LambdaPi/Impl/FoilTH.hs index f36828c2..2b4d06d8 100644 --- a/haskell/lambda-pi/src/Language/LambdaPi/Impl/FoilTH.hs +++ b/haskell/lambda-pi/src/Language/LambdaPi/Impl/FoilTH.hs @@ -12,8 +12,6 @@ {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} -{-# OPTIONS_GHC -fno-warn-unused-binds #-} -{-# OPTIONS_GHC -fno-warn-unused-matches #-} -- | Foil implementation of the \(\lambda\Pi\)-calculus (with pairs) -- using Template Haskell to reduce boilerplate. -- diff --git a/haskell/lambda-pi/src/Language/LambdaPi/Impl/FreeFoilTH.hs b/haskell/lambda-pi/src/Language/LambdaPi/Impl/FreeFoilTH.hs index d5e87418..eeaed499 100644 --- a/haskell/lambda-pi/src/Language/LambdaPi/Impl/FreeFoilTH.hs +++ b/haskell/lambda-pi/src/Language/LambdaPi/Impl/FreeFoilTH.hs @@ -1,4 +1,5 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} {-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE KindSignatures #-} @@ -31,6 +32,7 @@ module Language.LambdaPi.Impl.FreeFoilTH where import qualified Control.Monad.Foil as Foil import Control.Monad.Free.Foil +import Control.Monad.Foil.TH import Control.Monad.Free.Foil.TH import Data.Bifunctor.TH import Data.Map (Map) @@ -67,6 +69,11 @@ mkPatternSynonyms ''Term'Sig mkConvertToFreeFoil ''Raw.Term' ''Raw.VarIdent ''Raw.ScopedTerm' ''Raw.Pattern' mkConvertFromFreeFoil ''Raw.Term' ''Raw.VarIdent ''Raw.ScopedTerm' ''Raw.Pattern' +-- ** Scope-safe patterns + +mkFoilPattern ''Raw.VarIdent ''Raw.Pattern' +deriveCoSinkable ''Raw.VarIdent ''Raw.Pattern' + -- * User-defined code type Term' a = AST (Term'Sig a)