Skip to content

Commit

Permalink
Fix #323 by generating fixity declarations for defunctionalization sy…
Browse files Browse the repository at this point in the history
…mbols
  • Loading branch information
RyanGlScott committed May 6, 2018
1 parent 21f73d7 commit e1ab884
Show file tree
Hide file tree
Showing 17 changed files with 120 additions and 38 deletions.
3 changes: 3 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ Changelog for singletons project
synonym did nothing whatsoever, and promoting or singling a type family
produced an error.)

* `singletons` now produces fixity declarations for defunctionalization
symbols when appropriate.

* Add `(%<=?)`, a singled version of `(<=?)` from `GHC.TypeNats`, as well as
defunctionalization symbols for `(<=?)`, to `Data.Singletons.TypeLits`.

Expand Down
44 changes: 25 additions & 19 deletions src/Data/Singletons/Promote.hs
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,8 @@ promoteClassDec decl@(ClassDecl { cd_cxt = cxt
(default_decs, ann_rhss, prom_rhss)
<- mapAndUnzip3M (promoteMethod Nothing meth_sigs) defaults_list

let infix_decls' = catMaybes $ map (uncurry promoteInfixDecl) infix_decls
let infix_decls' = catMaybes $ map (uncurry promoteInfixDecl)
$ Map.toList infix_decls

-- no need to do anything to the fundeps. They work as is!
emitDecs [DClassD pCxt pClsName tvbs fundeps
Expand All @@ -312,7 +313,7 @@ promoteClassDec decl@(ClassDecl { cd_cxt = cxt
let proName = promoteValNameLhs name
(argKs, resK) <- promoteUnraveled ty
args <- mapM (const $ qNewName "arg") argKs
emitDecsM $ defunctionalize proName (map Just argKs) (Just resK)
emitDecsM $ defunReifyFixity proName (map Just argKs) (Just resK)

return $ DOpenTypeFamilyD (DTypeFamilyHead proName
(zipWith DKindedTV args argKs)
Expand Down Expand Up @@ -409,7 +410,8 @@ promoteMethod :: Maybe (Map Name DKind)
promoteMethod m_subst sigs_map (meth_name, meth_rhs) = do
(arg_kis, res_ki) <- lookup_meth_ty
((_, _, _, eqns), _defuns, ann_rhs)
<- promoteLetDecRHS (Just (arg_kis, res_ki)) sigs_map noPrefix meth_name meth_rhs
<- promoteLetDecRHS (Just (arg_kis, res_ki)) sigs_map Map.empty
noPrefix meth_name meth_rhs
meth_arg_tvs <- mapM (const $ qNewName "a") arg_kis
let -- If we're dealing with an associated type family instance, substitute
-- in the kind of the instance for better kind information in the RHS
Expand Down Expand Up @@ -443,7 +445,7 @@ promoteMethod m_subst sigs_map (meth_name, meth_rhs) = do
(DKindSig meth_res_ki')
Nothing)
eqns]
emitDecsM (defunctionalize helperName (map Just meth_arg_kis') (Just meth_res_ki'))
emitDecsM (defunctionalize helperName Nothing (map Just meth_arg_kis') (Just meth_res_ki'))
return ( DTySynInstD
proName
(DTySynEqn family_args
Expand Down Expand Up @@ -504,22 +506,23 @@ promoted method implementations like MHelper2.
promoteLetDecEnv :: (String, String) -> ULetDecEnv -> PrM ([DDec], ALetDecEnv)
promoteLetDecEnv prefixes (LetDecEnv { lde_defns = value_env
, lde_types = type_env
, lde_infix = infix_decls }) = do
let infix_decls' = catMaybes $ map (uncurry promoteInfixDecl) infix_decls
, lde_infix = fix_env }) = do
let infix_decls = catMaybes $ map (uncurry promoteInfixDecl)
$ Map.toList fix_env

-- promote all the declarations, producing annotated declarations
let (names, rhss) = unzip $ Map.toList value_env
(payloads, defun_decss, ann_rhss)
<- fmap unzip3 $ zipWithM (promoteLetDecRHS Nothing type_env prefixes) names rhss
<- fmap unzip3 $ zipWithM (promoteLetDecRHS Nothing type_env fix_env prefixes) names rhss

emitDecs $ concat defun_decss
bound_kvs <- allBoundKindVars
let decs = map payload_to_dec payloads ++ infix_decls'
let decs = map payload_to_dec payloads ++ infix_decls

-- build the ALetDecEnv
let let_dec_env' = LetDecEnv { lde_defns = Map.fromList $ zip names ann_rhss
, lde_types = type_env
, lde_infix = infix_decls
, lde_infix = fix_env
, lde_proms = Map.empty -- filled in promoteLetDecs
, lde_bound_kvs = Map.fromList $ map (, bound_kvs) names }

Expand All @@ -531,8 +534,8 @@ promoteLetDecEnv prefixes (LetDecEnv { lde_defns = value_env
where
sig = maybe DNoSig DKindSig m_ki

promoteInfixDecl :: Fixity -> Name -> Maybe DDec
promoteInfixDecl fixity name
promoteInfixDecl :: Name -> Fixity -> Maybe DDec
promoteInfixDecl name fixity
| nameBase name == nameBase promoted_name
-- If a name and its promoted counterpart are the same (modulo module
-- prefixes), then there's no need to promote a fixity declaration for
Expand All @@ -554,13 +557,14 @@ promoteInfixDecl fixity name
promoteLetDecRHS :: Maybe ([DKind], DKind) -- the promoted type of the RHS (if known)
-- needed to fix #136
-> Map Name DType -- local type env't
-> Map Name Fixity -- local fixity env't
-> (String, String) -- let-binding prefixes
-> Name -- name of the thing being promoted
-> ULetDecRHS -- body of the thing
-> PrM ( (Name, [DTyVarBndr], Maybe DKind, [DTySynEqn]) -- "type family"
, [DDec] -- defunctionalization
, ALetDecRHS ) -- annotated RHS
promoteLetDecRHS m_rhs_ki type_env prefixes name (UValue exp) = do
promoteLetDecRHS m_rhs_ki type_env fix_env prefixes name (UValue exp) = do
(res_kind, num_arrows)
<- case m_rhs_ki of
Just (arg_kis, res_ki) -> return ( Just (ravelTyFun (arg_kis ++ [res_ki]))
Expand All @@ -575,8 +579,9 @@ promoteLetDecRHS m_rhs_ki type_env prefixes name (UValue exp) = do
all_locals <- allLocals
let lde_kvs_to_bind = foldMap fvDType res_kind
(exp', ann_exp) <- forallBind lde_kvs_to_bind $ promoteExp exp
let proName = promoteValNameLhsPrefix prefixes name
defuns <- defunctionalize proName (map (const Nothing) all_locals) res_kind
let proName = promoteValNameLhsPrefix prefixes name
m_fixity = Map.lookup name fix_env
defuns <- defunctionalize proName m_fixity (map (const Nothing) all_locals) res_kind
return ( ( proName, map DPlainTV all_locals, res_kind
, [DTySynEqn (map DVarT all_locals) exp'] )
, defuns
Expand All @@ -586,10 +591,10 @@ promoteLetDecRHS m_rhs_ki type_env prefixes name (UValue exp) = do
names <- replicateM num_arrows (newUniqueName "a")
let pats = map DVarPa names
newArgs = map DVarE names
promoteLetDecRHS m_rhs_ki type_env prefixes name
promoteLetDecRHS m_rhs_ki type_env fix_env prefixes name
(UFunction [DClause pats (foldExp exp newArgs)])

promoteLetDecRHS m_rhs_ki type_env prefixes name (UFunction clauses) = do
promoteLetDecRHS m_rhs_ki type_env fix_env prefixes name (UFunction clauses) = do
numArgs <- count_args clauses
(m_argKs, m_resK, ty_num_args) <- case m_rhs_ki of
Just (arg_kis, res_ki) -> return (map Just arg_kis, Just res_ki, length arg_kis)
Expand All @@ -604,9 +609,10 @@ promoteLetDecRHS m_rhs_ki type_env prefixes name (UFunction clauses) = do

| otherwise
-> return (replicate numArgs Nothing, Nothing, numArgs)
let proName = promoteValNameLhsPrefix prefixes name
let proName = promoteValNameLhsPrefix prefixes name
m_fixity = Map.lookup name fix_env
all_locals <- allLocals
defun_decs <- defunctionalize proName
defun_decs <- defunctionalize proName m_fixity
(map (const Nothing) all_locals ++ m_argKs) m_resK
let local_tvbs = map DPlainTV all_locals
tyvarNames <- mapM (const $ qNewName "a") m_argKs
Expand Down Expand Up @@ -728,7 +734,7 @@ promoteExp (DLamE names exp) = do
Nothing)
[DTySynEqn (map DVarT (all_locals ++ tyNames))
rhs]]
emitDecsM $ defunctionalize lambdaName (map (const Nothing) all_args) Nothing
emitDecsM $ defunctionalize lambdaName Nothing (map (const Nothing) all_args) Nothing
let promLambda = foldl apply (DConT (promoteTySym lambdaName 0))
(map DVarT all_locals)
return (promLambda, ADLamE tyNames promLambda names ann_exp)
Expand Down
51 changes: 44 additions & 7 deletions src/Data/Singletons/Promote/Defun.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import Language.Haskell.TH.Syntax
import Data.Singletons.Syntax
import Data.Singletons.Util
import Control.Monad
import Data.Maybe

defunInfo :: DInfo -> PrM [DDec]
defunInfo (DTyConI dec _instances) = buildDefunSyms dec
Expand Down Expand Up @@ -55,7 +56,7 @@ buildDefunSyms (DTySynD name tvbs _type) =
buildDefunSymsTySynD name tvbs
buildDefunSyms (DClassD _cxt name tvbs _fundeps _members) = do
let arg_m_kinds = map extractTvbKind tvbs
defunctionalize name arg_m_kinds (Just (DConT constraintName))
defunReifyFixity name arg_m_kinds (Just (DConT constraintName))
buildDefunSyms _ = fail $ "Defunctionalization symbols can only be built for " ++
"type families and data declarations"

Expand All @@ -74,12 +75,12 @@ buildDefunSymsTypeFamilyHead
buildDefunSymsTypeFamilyHead default_kind (DTypeFamilyHead name tvbs result_sig _) = do
let arg_kinds = map (default_kind . extractTvbKind) tvbs
res_kind = default_kind (resultSigToMaybeKind result_sig)
defunctionalize name arg_kinds res_kind
defunReifyFixity name arg_kinds res_kind

buildDefunSymsTySynD :: Name -> [DTyVarBndr] -> PrM [DDec]
buildDefunSymsTySynD name tvbs = do
let arg_m_kinds = map extractTvbKind tvbs
defunctionalize name arg_m_kinds Nothing
defunReifyFixity name arg_m_kinds Nothing

buildDefunSymsDataD :: [DCon] -> PrM [DDec]
buildDefunSymsDataD ctors =
Expand All @@ -90,7 +91,15 @@ buildDefunSymsDataD ctors =
let (name, arg_tys) = extractNameTypes ctor
arg_kis <- mapM promoteType arg_tys
res_ki <- promoteType res_ty
defunctionalize name (map Just arg_kis) (Just res_ki)
defunReifyFixity name (map Just arg_kis) (Just res_ki)

-- Generate defunctionalization symbols for a name, using reifyFixityWithLocals
-- to determine what the fixity of each symbol should be.
-- See Note [Fixity declarations for defunctionalization symbols]
defunReifyFixity :: Name -> [Maybe DKind] -> Maybe DKind -> PrM [DDec]
defunReifyFixity name m_arg_kinds m_res_kind = do
m_fixity <- reifyFixityWithLocals name
defunctionalize name m_fixity m_arg_kinds m_res_kind

-- Generate data declarations and apply instances
-- required for defunctionalization.
Expand Down Expand Up @@ -125,8 +134,10 @@ buildDefunSymsDataD ctors =
--
-- The defunctionalize function takes Maybe DKinds so that the caller can
-- indicate which kinds are known and which need to be inferred.
defunctionalize :: Name -> [Maybe DKind] -> Maybe DKind -> PrM [DDec]
defunctionalize name m_arg_kinds' m_res_kind' = do
defunctionalize :: Name
-> Maybe Fixity -- The name's fixity, if one was declared.
-> [Maybe DKind] -> Maybe DKind -> PrM [DDec]
defunctionalize name m_fixity m_arg_kinds' m_res_kind' = do
let (m_arg_kinds, m_res_kind) = eta_expand (noExactTyVars m_arg_kinds')
(noExactTyVars m_res_kind')
num_args = length m_arg_kinds
Expand Down Expand Up @@ -188,8 +199,12 @@ defunctionalize name m_arg_kinds' m_res_kind' = do

mk_rhs' ns = foldType (DConT data_name) (map DVarT ns)

-- See Note [Fixity declarations for defunctionalization symbols]
mk_fix_decl f = DLetDec $ DInfixD f data_name
fixity_decl = maybeToList $ fmap mk_fix_decl m_fixity

decls <- go (n - 1) m_args (buildTyFunArrow_maybe m_arg m_result) mk_rhs'
return $ suppress : data_decl : app_decl : decls
return $ suppress : data_decl : app_decl : fixity_decl ++ decls

-- This is a small function with large importance. When generating
-- defunctionalization data types, we often need to fill in the blank in the
Expand Down Expand Up @@ -264,3 +279,25 @@ ravelTyFun kinds = go tailK (buildTyFunArrow k2 k1)
where (k1 : k2 : tailK) = reverse kinds
go [] acc = acc
go (k:ks) acc = go ks (buildTyFunArrow k acc)

{-
Note [Fixity declarations for defunctionalization symbols]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Just like we promote fixity declarations, we should also generate fixity
declarations for defunctionaliztion symbols. A primary use case is the
following scenario:
(.) :: (b -> c) -> (a -> b) -> (a -> c)
(f . g) x = f (g x)
infixr 9 .
One often writes (f . g . h) at the value level, but because (.) is promoted
to a type family with three arguments, this doesn't directly translate to the
type level. Instead, one must write this:
f .@#@$$$ g .@#@$$$ h
But in order to ensure that this associates to the right as expected, one must
generate an `infixr 9 .@#@#$$$` declaration. This is why defunctionalize accepts
a Maybe Fixity argument.
-}
4 changes: 2 additions & 2 deletions src/Data/Singletons/Single.hs
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ singClassD (ClassDecl { cd_cxt = cls_cxt
sing_meths <- mapM (uncurry (singLetDecRHS (Map.fromList tyvar_names)
res_ki_map))
(Map.toList default_defns)
fixities' <- traverse (uncurry singInfixDecl) fixities
fixities' <- traverse (uncurry singInfixDecl) $ Map.toList fixities
cls_cxt' <- mapM singPred cls_cxt
return $ DClassD cls_cxt'
(singClassName cls_name)
Expand Down Expand Up @@ -430,7 +430,7 @@ singLetDecEnv (LetDecEnv { lde_defns = defns
let prom_list = Map.toList proms
(typeSigs, letBinds, tyvarNames, res_kis)
<- unzip4 <$> mapM (uncurry (singTySig defns types bound_kvs)) prom_list
infix_decls' <- traverse (uncurry singInfixDecl) infix_decls
infix_decls' <- traverse (uncurry singInfixDecl) $ Map.toList infix_decls
let res_ki_map = Map.fromList [ (name, res_ki) | ((name, _), Just res_ki)
<- zip prom_list res_kis ]
bindLets letBinds $ do
Expand Down
6 changes: 3 additions & 3 deletions src/Data/Singletons/Single/Fixity.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ import Data.Singletons.Util
import Data.Singletons.Names
import Language.Haskell.TH.Desugar

singInfixDecl :: DsMonad q => Fixity -> Name -> q DLetDec
singInfixDecl fixity name = do
singInfixDecl :: DsMonad q => Name -> Fixity -> q DLetDec
singInfixDecl name fixity = do
mb_ns <- reifyNameSpace name
pure $ DInfixD fixity
$ case mb_ns of
Expand All @@ -24,7 +24,7 @@ singFixityDeclaration name = do
mFixity <- qReifyFixity name
case mFixity of
Nothing -> pure []
Just fixity -> sequenceA [DLetDec <$> singInfixDecl fixity name]
Just fixity -> sequenceA [DLetDec <$> singInfixDecl name fixity]

singFixityDeclarations :: DsMonad q => [Name] -> q [DDec]
singFixityDeclarations = concatMapM trySingFixityDeclaration
Expand Down
6 changes: 3 additions & 3 deletions src/Data/Singletons/Syntax.hs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ type ULetDecRHS = LetDecRHS Unannotated
data LetDecEnv ann = LetDecEnv
{ lde_defns :: Map Name (LetDecRHS ann)
, lde_types :: Map Name DType -- type signatures
, lde_infix :: [(Fixity, Name)] -- infix declarations
, lde_infix :: Map Name Fixity -- infix declarations
, lde_proms :: IfAnn ann (Map Name DType) () -- possibly, promotions
, lde_bound_kvs :: IfAnn ann (Map Name (Set Name)) ()
-- The set of bound variables in scope.
Expand All @@ -161,7 +161,7 @@ instance Semigroup ULetDecEnv where
LetDecEnv (defns1 <> defns2) (types1 <> types2) (infx1 <> infx2) () ()

instance Monoid ULetDecEnv where
mempty = LetDecEnv Map.empty Map.empty [] () ()
mempty = LetDecEnv Map.empty Map.empty Map.empty () ()

valueBinding :: Name -> ULetDecRHS -> ULetDecEnv
valueBinding n v = emptyLetDecEnv { lde_defns = Map.singleton n v }
Expand All @@ -170,7 +170,7 @@ typeBinding :: Name -> DType -> ULetDecEnv
typeBinding n t = emptyLetDecEnv { lde_types = Map.singleton n t }

infixDecl :: Fixity -> Name -> ULetDecEnv
infixDecl f n = emptyLetDecEnv { lde_infix = [(f,n)] }
infixDecl f n = emptyLetDecEnv { lde_infix = Map.singleton n f }

emptyLetDecEnv :: ULetDecEnv
emptyLetDecEnv = mempty
Expand Down
1 change: 1 addition & 0 deletions tests/SingletonsTestSuite.hs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ tests =
, compileAndDumpStdTest "T313"
, compileAndDumpStdTest "T316"
, compileAndDumpStdTest "T322"
, compileAndDumpStdTest "T323"
],
testCompileAndDumpGroup "Promote"
[ compileAndDumpStdTest "Constructors"
Expand Down
2 changes: 2 additions & 0 deletions tests/compile-and-dump/Singletons/Classes.ghc84.template
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ Singletons/Classes.hs:(0,0)-(0,0): Splicing declarations
SameKind (Apply ((<=>@#@$$) l) arg) ((<=>@#@$$$) l arg) =>
(<=>@#@$$) l l
type instance Apply ((<=>@#@$$) l) l = (<=>) l l
infix 4 <=>@#@$$
instance SuppressUnusedWarnings (<=>@#@$) where
suppressUnusedWarnings
= snd ((GHC.Tuple.(,) (:<=>@#@$###)) GHC.Tuple.())
Expand All @@ -149,6 +150,7 @@ Singletons/Classes.hs:(0,0)-(0,0): Splicing declarations
(:<=>@#@$###) :: forall l arg.
SameKind (Apply (<=>@#@$) arg) ((<=>@#@$$) arg) => (<=>@#@$) l
type instance Apply (<=>@#@$) l = (<=>@#@$$) l
infix 4 <=>@#@$
type family TFHelper_0123456789876543210 (a :: a) (a :: a) :: Ordering where
TFHelper_0123456789876543210 a_0123456789876543210 a_0123456789876543210 = Apply (Apply MycompareSym0 a_0123456789876543210) a_0123456789876543210
type TFHelper_0123456789876543210Sym2 (t :: a0123456789876543210) (t :: a0123456789876543210) =
Expand Down
4 changes: 4 additions & 0 deletions tests/compile-and-dump/Singletons/Fixity.ghc84.template
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Singletons/Fixity.hs:(0,0)-(0,0): Splicing declarations
SameKind (Apply ((====@#@$$) l) arg) ((====@#@$$$) l arg) =>
(====@#@$$) l l
type instance Apply ((====@#@$$) l) l = (====) l l
infix 4 ====@#@$$
instance SuppressUnusedWarnings (====@#@$) where
suppressUnusedWarnings
= snd ((GHC.Tuple.(,) (:====@#@$###)) GHC.Tuple.())
Expand All @@ -35,6 +36,7 @@ Singletons/Fixity.hs:(0,0)-(0,0): Splicing declarations
(:====@#@$###) :: forall l arg.
SameKind (Apply (====@#@$) arg) ((====@#@$$) arg) => (====@#@$) l
type instance Apply (====@#@$) l = (====@#@$$) l
infix 4 ====@#@$
type family (====) (a :: a) (a :: a) :: a where
(====) a _ = a
type (<=>@#@$$$) (t :: a0123456789876543210) (t :: a0123456789876543210) =
Expand All @@ -48,6 +50,7 @@ Singletons/Fixity.hs:(0,0)-(0,0): Splicing declarations
SameKind (Apply ((<=>@#@$$) l) arg) ((<=>@#@$$$) l arg) =>
(<=>@#@$$) l l
type instance Apply ((<=>@#@$$) l) l = (<=>) l l
infix 4 <=>@#@$$
instance SuppressUnusedWarnings (<=>@#@$) where
suppressUnusedWarnings
= snd ((GHC.Tuple.(,) (:<=>@#@$###)) GHC.Tuple.())
Expand All @@ -56,6 +59,7 @@ Singletons/Fixity.hs:(0,0)-(0,0): Splicing declarations
(:<=>@#@$###) :: forall l arg.
SameKind (Apply (<=>@#@$) arg) ((<=>@#@$$) arg) => (<=>@#@$) l
type instance Apply (<=>@#@$) l = (<=>@#@$$) l
infix 4 <=>@#@$
class PMyOrd (a :: GHC.Types.Type) where
type (<=>) (arg :: a) (arg :: a) :: Ordering
infix 4 %====
Expand Down
Loading

0 comments on commit e1ab884

Please sign in to comment.