Skip to content

Commit

Permalink
Allow deriving scope-safe patterns separately
Browse files Browse the repository at this point in the history
  • Loading branch information
fizruk committed Jun 20, 2024
1 parent 400f338 commit ed986c4
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 54 deletions.
80 changes: 48 additions & 32 deletions haskell/free-foil/src/Control/Monad/Foil/TH/MkFoilData.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
{-# LANGUAGE TemplateHaskellQuotes #-}
{-# OPTIONS_GHC -Wno-incomplete-patterns #-}
{-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-}
module Control.Monad.Foil.TH.MkFoilData (mkFoilData) where
module Control.Monad.Foil.TH.MkFoilData where

import Language.Haskell.TH

Expand All @@ -20,24 +20,66 @@ mkFoilData
mkFoilData termT nameT scopeT patternT = do
n <- newName "n"
l <- newName "l"
TyConI (DataD _ctx _name patternTVars _kind patternCons _deriv) <- reify patternT
TyConI (DataD _ctx _name scopeTVars _kind scopeCons _deriv) <- reify scopeT
TyConI (DataD _ctx _name termTVars _kind termCons _deriv) <- reify termT

foilPatternCons <- mapM (toPatternCon patternTVars n) patternCons
let foilScopeCons = map (toScopeCon scopeTVars n) scopeCons
let foilTermCons = map (toTermCon termTVars n l) termCons

return
patternD <- mkFoilPattern nameT patternT
return $
[ DataD [] foilTermT (termTVars ++ [KindedTV n BndrReq (PromotedT ''Foil.S)]) Nothing foilTermCons []
, DataD [] foilScopeT (scopeTVars ++ [KindedTV n BndrReq (PromotedT ''Foil.S)]) Nothing foilScopeCons []
, DataD [] foilPatternT (patternTVars ++ [KindedTV n BndrReq (PromotedT ''Foil.S), KindedTV l BndrReq (PromotedT ''Foil.S)]) Nothing foilPatternCons []
]
] ++ patternD
where
foilTermT = mkName ("Foil" ++ nameBase termT)
foilScopeT = mkName ("Foil" ++ nameBase scopeT)
foilPatternT = mkName ("Foil" ++ nameBase patternT)

-- | Convert a constructor declaration for a raw scoped term
-- into a constructor for the scope-safe scoped term.
toScopeCon :: [TyVarBndr BndrVis] -> Name -> Con -> Con
toScopeCon _tvars n (NormalC conName params) =
NormalC foilConName (map toScopeParam params)
where
foilConName = mkName ("Foil" ++ nameBase conName)
toScopeParam (_bang, PeelConT tyName tyParams)
| tyName == termT = (_bang, PeelConT foilTermT (tyParams ++ [VarT n]))
toScopeParam _bangType = _bangType

-- | Convert a constructor declaration for a raw term
-- into a constructor for the scope-safe term.
toTermCon :: [TyVarBndr BndrVis] -> Name -> Name -> Con -> Con
toTermCon tvars n l (NormalC conName params) =
GadtC [foilConName] (map toTermParam params) (PeelConT foilTermT (map (VarT . tvarName) tvars ++ [VarT n]))
where
foilNames = [n, l]
foilConName = mkName ("Foil" ++ nameBase conName)
toTermParam (_bang, PeelConT tyName tyParams)
| tyName == patternT = (_bang, PeelConT foilPatternT (tyParams ++ map VarT foilNames))
| tyName == nameT = (_bang, AppT (ConT ''Foil.Name) (VarT n))
| tyName == scopeT = (_bang, PeelConT foilScopeT (tyParams ++ [VarT l]))
| tyName == termT = (_bang, PeelConT foilTermT (tyParams ++ [VarT n]))
toTermParam _bangType = _bangType

-- | Generate just the scope-safe patterns.
mkFoilPattern
:: Name -- ^ Type name for raw variable identifiers.
-> Name -- ^ Type name for raw patterns.
-> Q [Dec]
mkFoilPattern nameT patternT = do
n <- newName "n"
l <- newName "l"
TyConI (DataD _ctx _name patternTVars _kind patternCons _deriv) <- reify patternT

foilPatternCons <- mapM (toPatternCon patternTVars n) patternCons

return
[ DataD [] foilPatternT (patternTVars ++ [KindedTV n BndrReq (PromotedT ''Foil.S), KindedTV l BndrReq (PromotedT ''Foil.S)]) Nothing foilPatternCons []
]
where
foilPatternT = mkName ("Foil" ++ nameBase patternT)

-- | Convert a constructor declaration for a raw pattern type
-- into a constructor for the scope-safe pattern type.
toPatternCon
Expand Down Expand Up @@ -79,29 +121,3 @@ mkFoilData termT nameT scopeT patternT = do
_ -> do
(l, conParams') <- toPatternConParams (i+1) p conParams
return (l, param : conParams')

-- | Convert a constructor declaration for a raw scoped term
-- into a constructor for the scope-safe scoped term.
toScopeCon :: [TyVarBndr BndrVis] -> Name -> Con -> Con
toScopeCon _tvars n (NormalC conName params) =
NormalC foilConName (map toScopeParam params)
where
foilConName = mkName ("Foil" ++ nameBase conName)
toScopeParam (_bang, PeelConT tyName tyParams)
| tyName == termT = (_bang, PeelConT foilTermT (tyParams ++ [VarT n]))
toScopeParam _bangType = _bangType

-- | Convert a constructor declaration for a raw term
-- into a constructor for the scope-safe term.
toTermCon :: [TyVarBndr BndrVis] -> Name -> Name -> Con -> Con
toTermCon tvars n l (NormalC conName params) =
GadtC [foilConName] (map toTermParam params) (PeelConT foilTermT (map (VarT . tvarName) tvars ++ [VarT n]))
where
foilNames = [n, l]
foilConName = mkName ("Foil" ++ nameBase conName)
toTermParam (_bang, PeelConT tyName tyParams)
| tyName == patternT = (_bang, PeelConT foilPatternT (tyParams ++ map VarT foilNames))
| tyName == nameT = (_bang, AppT (ConT ''Foil.Name) (VarT n))
| tyName == scopeT = (_bang, PeelConT foilScopeT (tyParams ++ [VarT l]))
| tyName == termT = (_bang, PeelConT foilTermT (tyParams ++ [VarT n]))
toTermParam _bangType = _bangType
41 changes: 27 additions & 14 deletions haskell/free-foil/src/Control/Monad/Foil/TH/MkInstancesFoil.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TemplateHaskellQuotes #-}
{-# OPTIONS_GHC -fno-warn-type-defaults #-}
module Control.Monad.Foil.TH.MkInstancesFoil (mkInstancesFoil) where
module Control.Monad.Foil.TH.MkInstancesFoil where

import Language.Haskell.TH

Expand All @@ -18,26 +18,22 @@ mkInstancesFoil
-> Name -- ^ Type name for raw patterns.
-> Q [Dec]
mkInstancesFoil termT nameT scopeT patternT = do
TyConI (DataD _ctx _name patternTVars _kind patternCons _deriv) <- reify patternT
TyConI (DataD _ctx _name scopeTVars _kind scopeCons _deriv) <- reify scopeT
TyConI (DataD _ctx _name termTVars _kind termCons _deriv) <- reify termT

return
coSinkablePatternD <- deriveCoSinkable nameT patternT

return $
[ InstanceD Nothing [] (AppT (ConT ''Foil.Sinkable) (PeelConT foilScopeT (map (VarT . tvarName) scopeTVars)))
[ FunD 'Foil.sinkabilityProof (map clauseScopedTerm scopeCons) ]

, InstanceD Nothing [] (AppT (ConT ''Foil.CoSinkable) (PeelConT foilPatternT (map (VarT . tvarName) patternTVars)))
[ FunD 'Foil.coSinkabilityProof (map clausePattern patternCons)
, FunD 'Foil.withPattern (map clauseWithPattern patternCons) ]

, InstanceD Nothing [] (AppT (ConT ''Foil.Sinkable) (PeelConT foilTermT (map (VarT . tvarName) termTVars)))
[ FunD 'Foil.sinkabilityProof (map clauseTerm termCons)]
]
] ++ coSinkablePatternD

where
foilTermT = mkName ("Foil" ++ nameBase termT)
foilScopeT = mkName ("Foil" ++ nameBase scopeT)
foilPatternT = mkName ("Foil" ++ nameBase patternT)

clauseScopedTerm :: Con -> Clause
clauseScopedTerm = clauseTerm
Expand All @@ -55,7 +51,7 @@ mkInstancesFoil termT nameT scopeT patternT = do
[]
where
foilConName = mkName ("Foil" ++ nameBase conName)
rename = mkName "rename"
rename = mkName "_rename"
conParamPatterns = zipWith mkConParamPattern params [1..]
mkConParamPattern _ i = VarP (mkName ("x" ++ show i))

Expand All @@ -81,6 +77,23 @@ mkInstancesFoil termT nameT scopeT patternT = do
where
xi = mkName ("x" ++ show i)

-- | Generate 'Foil.Sinkable' and 'Foil.CoSinkable' instances.
deriveCoSinkable
:: Name -- ^ Type name for raw variable identifiers.
-> Name -- ^ Type name for raw patterns.
-> Q [Dec]
deriveCoSinkable nameT patternT = do
TyConI (DataD _ctx _name patternTVars _kind patternCons _deriv) <- reify patternT

return
[ InstanceD Nothing [] (AppT (ConT ''Foil.CoSinkable) (PeelConT foilPatternT (map (VarT . tvarName) patternTVars)))
[ FunD 'Foil.coSinkabilityProof (map clausePattern patternCons)
, FunD 'Foil.withPattern (map clauseWithPattern patternCons) ]
]

where
foilPatternT = mkName ("Foil" ++ nameBase patternT)

clausePattern :: Con -> Clause
clausePattern RecC{} = error "Record constructors (RecC) are not supported yet!"
clausePattern InfixC{} = error "Infix constructors (InfixC) are not supported yet!"
Expand Down Expand Up @@ -128,10 +141,10 @@ mkInstancesFoil termT nameT scopeT patternT = do
[]
where
foilConName = mkName ("Foil" ++ nameBase conName)
withNameBinder = mkName "withNameBinder"
withNameBinder = mkName "_withNameBinder"
id' = mkName "id'"
comp = mkName "comp"
scope = mkName "scope"
comp = mkName "_comp"
scope = mkName "_scope"
cont = mkName "cont"
conParamPatterns = zipWith mkConParamPattern params [1..]
mkConParamPattern _ i = VarP (mkName ("x" ++ show i))
Expand All @@ -148,7 +161,7 @@ mkInstancesFoil termT nameT scopeT patternT = do
xi = mkName ("x" ++ show i)
xi' = mkName ("x" ++ show i ++ "'")
renamei = mkName ("f" ++ show i)
scopei = mkName ("scope" ++ show i)
scopei = mkName ("_scope" ++ show i)
go i scope' rename' p (_ : conPatterns) =
go (i + 1) scope' rename' (AppE p (VarE xi)) conPatterns
where
Expand Down
12 changes: 6 additions & 6 deletions haskell/free-foil/src/Control/Monad/Foil/TH/MkToFoil.hs
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ mkWithRefreshedFoilPattern nameT patternT = do
where
xi = mkName ("x" <> show i)
xi' = mkName ("x" <> show i <> "'")
scopei = mkName ("scope" <> show i)
scopei = mkName ("_scope" <> show i)
xsubst = mkName ("subst" <> show i)
subst = mkName "subst"
fi = LamE [VarP subst]
Expand Down Expand Up @@ -295,7 +295,7 @@ mkToFoilTerm termT nameT scopeT patternT = do
toMatch (NormalC conName params) =
Match (ConP conName [] conParamPatterns) (NormalB conMatchBody) [toFoilVarD]
where
toFoilVarFunName = mkName "lookupRawVar"
toFoilVarFunName = mkName "_lookupRawVar"
toFoilVarFun = VarE toFoilVarFunName
x = mkName "x"
name = mkName "name"
Expand Down Expand Up @@ -324,7 +324,7 @@ mkToFoilTerm termT nameT scopeT patternT = do
where
xi = mkName ("x" <> show i)
xi' = mkName ("x" <> show i <> "'")
scopei = mkName ("scope" <> show i)
scopei = mkName ("_scope" <> show i)
envi = mkName ("env" <> show i)
go i scope' env' p (_ : conParams) =
go (i + 1) scope' env' (AppE p (VarE xi)) conParams
Expand Down Expand Up @@ -377,7 +377,7 @@ mkToFoilTerm termT nameT scopeT patternT = do
where
xi = mkName ("x" <> show i)
xi' = mkName ("x" <> show i <> "'")
scopei = mkName ("scope" <> show i)
scopei = mkName ("_scope" <> show i)
envi = mkName ("env" <> show i)
go i scope' env' p (_ : conParams) =
go (i + 1) scope' env' (AppE p (VarE xi)) conParams
Expand Down Expand Up @@ -407,7 +407,7 @@ mkToFoilTerm termT nameT scopeT patternT = do
toMatch (NormalC conName params) =
Match (ConP conName [] conParamPatterns) (NormalB conMatchBody) [toFoilVarD]
where
toFoilVarFunName = mkName "lookupRawVar"
toFoilVarFunName = mkName "_lookupRawVar"
toFoilVarFun = VarE toFoilVarFunName
x = mkName "x"
name = mkName "name"
Expand Down Expand Up @@ -436,7 +436,7 @@ mkToFoilTerm termT nameT scopeT patternT = do
where
xi = mkName ("x" <> show i)
xi' = mkName ("x" <> show i <> "'")
scopei = mkName ("scope" <> show i)
scopei = mkName ("_scope" <> show i)
envi = mkName ("env" <> show i)
go i scope' env' p (_ : conParams) =
go (i + 1) scope' env' (AppE p (VarE xi)) conParams
Expand Down
2 changes: 0 additions & 2 deletions haskell/lambda-pi/src/Language/LambdaPi/Impl/FoilTH.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# OPTIONS_GHC -fno-warn-unused-binds #-}
{-# OPTIONS_GHC -fno-warn-unused-matches #-}
-- | Foil implementation of the \(\lambda\Pi\)-calculus (with pairs)
-- using Template Haskell to reduce boilerplate.
--
Expand Down
7 changes: 7 additions & 0 deletions haskell/lambda-pi/src/Language/LambdaPi/Impl/FreeFoilTH.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE KindSignatures #-}
Expand Down Expand Up @@ -31,6 +32,7 @@ module Language.LambdaPi.Impl.FreeFoilTH where

import qualified Control.Monad.Foil as Foil
import Control.Monad.Free.Foil
import Control.Monad.Foil.TH
import Control.Monad.Free.Foil.TH
import Data.Bifunctor.TH
import Data.Map (Map)
Expand Down Expand Up @@ -67,6 +69,11 @@ mkPatternSynonyms ''Term'Sig
mkConvertToFreeFoil ''Raw.Term' ''Raw.VarIdent ''Raw.ScopedTerm' ''Raw.Pattern'
mkConvertFromFreeFoil ''Raw.Term' ''Raw.VarIdent ''Raw.ScopedTerm' ''Raw.Pattern'

-- ** Scope-safe patterns

mkFoilPattern ''Raw.VarIdent ''Raw.Pattern'
deriveCoSinkable ''Raw.VarIdent ''Raw.Pattern'

-- * User-defined code

type Term' a = AST (Term'Sig a)
Expand Down

0 comments on commit ed986c4

Please sign in to comment.