Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up some code smells in D.S.{Promote,Single}.Defun #424

Merged
merged 1 commit into from
Nov 25, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 47 additions & 16 deletions src/Data/Singletons/Promote/Defun.hs
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ defunctionalize :: Name
defunctionalize name m_fixity m_arg_tvbs' m_res_kind' = do
(m_arg_tvbs, m_res_kind) <- eta_expand (noExactTyVars m_arg_tvbs')
(noExactTyVars m_res_kind')
extra_name <- qNewName "arg"

let -- Implements part (2)(i) from Note [Defunctionalization and dependent quantification]
tvb_to_type_map :: Map Name DType
Expand All @@ -224,20 +225,52 @@ defunctionalize name m_fixity m_arg_tvbs' m_res_kind' = do
map dTyVarBndrToDType m_arg_tvbs
++ maybeToList m_res_kind -- (2)(i)(a)

go :: Int -> [DTyVarBndr] -> Maybe DKind
-> PrM [DDec]
go _ [] _ = return []
go n (m_arg : m_args) m_result = do
extra_name <- qNewName "arg"
let tyfun_name = extractTvbName m_arg
-- The inner loop. @go n arg_tvbs res_tvbs@ returns @(m_result, decls)@.
-- Using one particular example:
--
-- @
-- data ExampleSym2 (x :: a) (y :: b) :: c ~> d ~> Type where ...
-- type instance Apply (ExampleSym2 x y) z = ExampleSym3 x y z
-- ...
-- @
--
-- We have:
--
-- * @n@ is 2. This is incremented in each iteration of `go`.
--
-- * @arg_tvbs@ is [(x :: a), (y :: b)].
--
-- * @res_tvbs@ is [(z :: c), (w :: d)]. The kinds of these type variable
-- binders appear in the result kind.
--
-- * @m_result@ is `Just (c ~> d ~> Type)`. @m_result@ is returned so
-- that earlier defunctionalization symbols can build on the result
-- kinds of later symbols. For instance, ExampleSym1 would get the
-- result kind `b ~> c ~> d ~> Type` by prepending `b` to ExampleSym2's
-- result kind `c ~> d ~> Type`.
--
-- * @decls@ are all of the declarations corresponding to ExampleSym2
-- and later defunctionalization symbols. This is the main payload of
-- the function.
--
-- This function is quadratic because it appends a variable at the end of
-- the @arg_tvbs@ list at each iteration. In practice, this is unlikely
-- to be a performance bottleneck since the number of arguments rarely
-- gets to be that large.
go :: Int -> [DTyVarBndr] -> [DTyVarBndr]
-> (Maybe DKind, [DDec])
go _ _ [] = (m_res_kind, [])
go n arg_tvbs (res_tvb:res_tvbs) =
let (m_result, decls) = go (n+1) (arg_tvbs ++ [res_tvb]) res_tvbs

tyfun_name = extractTvbName res_tvb
data_name = promoteTySym name n
next_name = promoteTySym name (n+1)
con_name = prefixName "" ":" $ suffixName "KindInference" "###" data_name
m_tyfun = buildTyFunArrow_maybe (extractTvbKind m_arg) m_result
m_tyfun = buildTyFunArrow_maybe (extractTvbKind res_tvb) m_result
arg_params = -- Implements part (2)(ii) from
-- Note [Defunctionalization and dependent quantification]
map (map_tvb_kind (substType tvb_to_type_map)) $
reverse m_args
map (map_tvb_kind (substType tvb_to_type_map)) arg_tvbs
arg_names = map extractTvbName arg_params
params = arg_params ++ [DPlainTV tyfun_name]
con_eq_ct = DConT sameKindName `DAppT` lhs `DAppT` rhs
Expand Down Expand Up @@ -270,12 +303,12 @@ defunctionalize name m_fixity m_arg_tvbs' m_res_kind' = do
map tvb_to_type $ -- (2)(iii)(b)
toList $ fvDType tyfun -- (2)(iii)(a)
in (arg_params, Just (DForallT ForallInvis tyfun_tvbs tyfun))
app_data_ty = foldTypeTvbs (DConT data_name) m_args
app_data_ty = foldTypeTvbs (DConT data_name) arg_tvbs
app_eqn = DTySynEqn Nothing
(DConT applyName `DAppT` app_data_ty
`DAppT` DVarT tyfun_name)
(foldTypeTvbs (DConT next_name)
(m_args ++ [DPlainTV tyfun_name]))
(arg_tvbs ++ [DPlainTV tyfun_name]))
app_decl = DTySynInstD app_eqn
suppress = DInstanceD Nothing Nothing []
(DConT suppressClassName `DAppT` app_data_ty)
Expand All @@ -287,17 +320,15 @@ defunctionalize name m_fixity m_arg_tvbs' m_res_kind' = do

-- See Note [Fixity declarations for defunctionalization symbols]
fixity_decl = maybeToList $ fmap (mk_fix_decl data_name) m_fixity

decls <- go (n - 1) m_args m_tyfun
return $ suppress : data_decl : app_decl : fixity_decl ++ decls
in (m_tyfun, suppress : data_decl : app_decl : fixity_decl ++ decls)

let num_args = length m_arg_tvbs
sat_name = promoteTySym name num_args
sat_dec = DTySynD sat_name m_arg_tvbs $ foldTypeTvbs (DConT name) m_arg_tvbs
sat_fixity_dec = maybeToList $ fmap (mk_fix_decl sat_name) m_fixity

other_decs <- go (num_args - 1) (reverse m_arg_tvbs) m_res_kind
return $ sat_dec : sat_fixity_dec ++ other_decs
(_, other_decs) = go 0 [] m_arg_tvbs
return $ other_decs ++ sat_dec : sat_fixity_dec
where
eta_expand :: [DTyVarBndr] -> Maybe DKind -> PrM ([DTyVarBndr], Maybe DKind)
eta_expand m_arg_tvbs Nothing = pure (m_arg_tvbs, Nothing)
Expand Down
99 changes: 65 additions & 34 deletions src/Data/Singletons/Single/Defun.hs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

module Data.Singletons.Single.Defun (singDefuns) where

import Control.Monad
import Data.List
import Data.Singletons.Names
import Data.Singletons.Promote.Defun
Expand Down Expand Up @@ -57,71 +58,101 @@ singDefuns n ns ty_ctxt mb_ty_args mb_ty_res =
[] -> pure [] -- If a function has no arguments, then it has no
-- defunctionalization symbols, so there's nothing to be done.
_ -> do sty_ctxt <- mapM singPred ty_ctxt
go 0 sty_ctxt [] mb_ty_args
names <- replicateM (length mb_ty_args) $ qNewName "d"
let tvbs = zipWith inferMaybeKindTV names mb_ty_args
(_, insts) = go 0 sty_ctxt [] tvbs
pure insts
where
num_ty_args :: Int
num_ty_args = length mb_ty_args

-- Sadly, this algorithm is quadratic, because in each iteration of the loop
-- we must:
-- The inner loop. @go n ctxt arg_tvbs res_tvbs@ returns @(m_result, insts)@.
-- Using one particular example:
--
-- * Construct an arrow type of the form (a ~> ... ~> z), using a suffix of
-- the promoted argument types.
-- * Append a new type variable to the end of an ordered list.
-- @
-- instance (SingI a, SingI b, SEq c, SEq d) =>
-- SingI (ExampleSym2 (x :: a) (y :: b) :: c ~> d ~> Type) where ...
-- @
--
-- In practice, this is unlikely to be a bottleneck, as singletons does not
-- support functions with more than 7 or so arguments anyways.
go :: Int -> DCxt -> [DTyVarBndr] -> [Maybe DKind] -> SgM [DDec]
go sym_num sty_ctxt tvbs mb_tyss
| sym_num < num_ty_args
, mb_ty:mb_tys <- mb_tyss
= do new_tvb_name <- qNewName "d"
let new_tvb = inferMaybeKindTV new_tvb_name mb_ty
insts <- go (sym_num + 1) sty_ctxt (tvbs ++ [new_tvb]) mb_tys
pure $ new_insts ++ insts
| otherwise
= pure []
-- We have:
--
-- * @n@ is 2. This is incremented in each iteration of `go`.
--
-- * @ctxt@ is (SEq c, SEq d). The (SingI a, SingI b) part of the instance
-- context is added separately.
--
-- * @arg_tvbs@ is [(x :: a), (y :: b)].
--
-- * @res_tvbs@ is [(z :: c), (w :: d)]. The kinds of these type variable
-- binders appear in the result kind.
--
-- * @m_result@ is `Just (c ~> d ~> Type)`. @m_result@ is returned so
-- that earlier defunctionalization symbols can build on the result
-- kinds of later symbols. For instance, ExampleSym1 would get the
-- result kind `b ~> c ~> d ~> Type` by prepending `b` to ExampleSym2's
-- result kind `c ~> d ~> Type`.
--
-- * @insts@ are all of the instance declarations corresponding to
-- ExampleSym2 and later defunctionalization symbols. This is the main
-- payload of the function.
--
-- This function is quadratic because it appends a variable at the end of
-- the @arg_tvbs@ list at each iteration. In practice, this is unlikely
-- to be a performance bottleneck since the number of arguments rarely
-- gets to be that large.
go :: Int -> DCxt -> [DTyVarBndr] -> [DTyVarBndr]
-> (Maybe DKind, [DDec])
go _ _ _ [] = (mb_ty_res, [])
go sym_num sty_ctxt arg_tvbs (res_tvb:res_tvbs) =
(mb_new_res, new_inst:insts)
where
mb_res :: Maybe DKind
insts :: [DDec]
(mb_res, insts) = go (sym_num + 1) sty_ctxt (arg_tvbs ++ [res_tvb]) res_tvbs

mb_new_res :: Maybe DKind
mb_new_res = mk_inst_kind res_tvb mb_res

sing_fun_num :: Int
sing_fun_num = num_ty_args - sym_num

mk_sing_fun_expr :: DExp -> DExp
mk_sing_fun_expr sing_expr =
foldl' (\f tvb_n -> f `DAppE` (DVarE singMethName `DAppTypeE` DVarT tvb_n))
sing_expr
(map extractTvbName tvbs)
(map extractTvbName arg_tvbs)

singI_ctxt :: DCxt
singI_ctxt = map (DAppT (DConT singIName) . tvbToType) tvbs
singI_ctxt = map (DAppT (DConT singIName) . tvbToType) arg_tvbs

mk_inst_ty :: DType -> DType
mk_inst_ty inst_head
= case mb_inst_kind of
= case mb_new_res of
Just inst_kind -> inst_head `DSigT` inst_kind
Nothing -> inst_head

tvb_tys :: [DType]
tvb_tys = map dTyVarBndrToDType tvbs
arg_tvb_tys :: [DType]
arg_tvb_tys = map dTyVarBndrToDType arg_tvbs

-- Construct the arrow kind used to annotate the defunctionalization
-- symbol (e.g., the `a ~> a ~> Bool` in
-- `SingI (FooSym0 :: a ~> a ~> Bool)`).
-- If any of the argument kinds or result kind isn't known (i.e., is
-- Nothing), then we opt not to construct this arrow kind altogether.
-- See Note [singDefuns and type inference]
mb_inst_kind :: Maybe DType
mb_inst_kind = foldr buildTyFunArrow_maybe mb_ty_res mb_tyss

new_insts :: [DDec]
new_insts = [DInstanceD Nothing Nothing
(sty_ctxt ++ singI_ctxt)
(DConT singIName `DAppT` mk_inst_ty defun_inst_ty)
[DLetDec $ DValD (DVarP singMethName)
$ wrapSingFun sing_fun_num defun_inst_ty
$ mk_sing_fun_expr sing_exp ]]
mk_inst_kind :: DTyVarBndr -> Maybe DKind -> Maybe DKind
mk_inst_kind tvb' = buildTyFunArrow_maybe (extractTvbKind tvb')

new_inst :: DDec
new_inst = DInstanceD Nothing Nothing
(sty_ctxt ++ singI_ctxt)
(DConT singIName `DAppT` mk_inst_ty defun_inst_ty)
[DLetDec $ DValD (DVarP singMethName)
$ wrapSingFun sing_fun_num defun_inst_ty
$ mk_sing_fun_expr sing_exp ]
where
defun_inst_ty :: DType
defun_inst_ty = foldType (DConT (promoteTySym n sym_num)) tvb_tys
defun_inst_ty = foldType (DConT (promoteTySym n sym_num)) arg_tvb_tys

sing_exp :: DExp
sing_exp = case ns of
Expand Down
Loading