Skip to content

Commit

Permalink
[NFC] Rename "Embed" to "Builder".
Browse files Browse the repository at this point in the history
`EmbedT` is really an IR builder type.
"Builder" is standard terminology and easier to understand.

There exist both user-facing ADT builders and Imp builders.
  • Loading branch information
dan-zheng authored and dougalm committed Jan 19, 2021
1 parent 17267a2 commit 3a1d1fd
Show file tree
Hide file tree
Showing 15 changed files with 456 additions and 456 deletions.
4 changes: 2 additions & 2 deletions dex.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ library dex-resources

library
exposed-modules: Env, Syntax, Type, Inference, JIT, LLVMExec,
Parser, Util, Imp, Imp.Embed, Imp.Optimize,
Parser, Util, Imp, Imp.Builder, Imp.Optimize,
PPrint, Algebra, Parallelize, Optimize, Serialize
Actor, Cat, Embed, Export,
Actor, Builder, Cat, Export,
RenderHtml, LiveOutput, Simplify, TopLevel,
Autodiff, Interpreter, Logging, CUDA,
LLVM.JIT, LLVM.Shims
Expand Down
24 changes: 12 additions & 12 deletions src/lib/Algebra.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ import Data.Coerce
import Env
import Syntax
import PPrint
import Embed ( MonadEmbed, iadd, imul, idiv, clampPositive, ptrOffset
, indexToIntE, indexSetSizeE )
import Builder ( MonadBuilder, iadd, imul, idiv, clampPositive, ptrOffset
, indexToIntE, indexSetSizeE )

-- MVar is like Var, but it additionally defines Ord. The invariant here is that the variables
-- should never be shadowing, and so it is sufficient to only use the name for equality and
Expand Down Expand Up @@ -50,18 +50,18 @@ type ClampPolynomial = PolynomialP ClampMonomial
data SumPolynomial = SumPolynomial Polynomial Var deriving (Show, Eq)
data SumClampPolynomial = SumClampPolynomial ClampPolynomial Var deriving (Show, Eq)

applyIdxs :: MonadEmbed m => Atom -> IndexStructure -> m Atom
applyIdxs :: MonadBuilder m => Atom -> IndexStructure -> m Atom
applyIdxs ptr Empty = return ptr
applyIdxs ptr idxs@(Nest ~(Bind i) rest) = do
ordinal <- indexToIntE $ Var i
offset <- offsetToE idxs ordinal
ptr' <- ptrOffset ptr offset
applyIdxs ptr' rest

offsetToE :: MonadEmbed m => IndexStructure -> Atom -> m Atom
offsetToE :: MonadBuilder m => IndexStructure -> Atom -> m Atom
offsetToE idxs i = evalSumClampPolynomial (offsets idxs) i

elemCountE :: MonadEmbed m => IndexStructure -> m Atom
elemCountE :: MonadBuilder m => IndexStructure -> m Atom
elemCountE idxs = case idxs of
Empty -> return $ IdxRepVal 1
Nest b _ -> offsetToE idxs =<< indexSetSizeE (binderType b)
Expand Down Expand Up @@ -124,12 +124,12 @@ toPolynomial atom = case atom of
fromInt i = poly [((fromIntegral i) % 1, mono [])]
unreachable = error $ "Unsupported or invalid atom in index set: " ++ pprint atom

-- === Embedding ===
-- === Building ===

_evalClampPolynomial :: MonadEmbed m => ClampPolynomial -> m Atom
_evalClampPolynomial :: MonadBuilder m => ClampPolynomial -> m Atom
_evalClampPolynomial cp = evalPolynomialP (evalClampMonomial Var) cp

evalSumClampPolynomial :: MonadEmbed m => SumClampPolynomial -> Atom -> m Atom
evalSumClampPolynomial :: MonadBuilder m => SumClampPolynomial -> Atom -> m Atom
evalSumClampPolynomial (SumClampPolynomial cp summedVar) a =
evalPolynomialP (evalClampMonomial varVal) cp
where varVal v = if MVar v == sumVar summedVar then a else Var v
Expand All @@ -139,7 +139,7 @@ evalSumClampPolynomial (SumClampPolynomial cp summedVar) a =
-- coefficients. This is why we have to find the least common multiples and do the
-- accumulation over numbers multiplied by that LCM. We essentially do fixed point
-- fractional math here.
evalPolynomialP :: MonadEmbed m => (mono -> m Atom) -> PolynomialP mono -> m Atom
evalPolynomialP :: MonadBuilder m => (mono -> m Atom) -> PolynomialP mono -> m Atom
evalPolynomialP evalMono p = do
let constLCM = asAtom $ foldl lcm 1 $ fmap (denominator . snd) $ toList p
monoAtoms <- flip traverse (toList p) $ \(m, c) -> do
Expand All @@ -153,19 +153,19 @@ evalPolynomialP evalMono p = do
-- because it might be causing overflows due to all arithmetic being shifted.
asAtom = IdxRepVal . fromInteger

evalMonomial :: MonadEmbed m => (Var -> Atom) -> Monomial -> m Atom
evalMonomial :: MonadBuilder m => (Var -> Atom) -> Monomial -> m Atom
evalMonomial varVal m = do
varAtoms <- traverse (\(MVar v, e) -> ipow (varVal v) e) $ toList m
foldM imul (IdxRepVal 1) varAtoms

evalClampMonomial :: MonadEmbed m => (Var -> Atom) -> ClampMonomial -> m Atom
evalClampMonomial :: MonadBuilder m => (Var -> Atom) -> ClampMonomial -> m Atom
evalClampMonomial varVal (ClampMonomial clamps m) = do
valuesToClamp <- traverse (evalPolynomialP (evalMonomial varVal) . coerce) clamps
clampsProduct <- foldM imul (IdxRepVal 1) =<< traverse clampPositive valuesToClamp
mval <- evalMonomial varVal m
imul clampsProduct mval

ipow :: MonadEmbed m => Atom -> Int -> m Atom
ipow :: MonadBuilder m => Atom -> Int -> m Atom
ipow x i = foldM imul (IdxRepVal 1) (replicate i x)

-- === Polynomial math ===
Expand Down
48 changes: 24 additions & 24 deletions src/lib/Autodiff.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import Type
import Env
import Syntax
import PPrint
import Embed
import Builder
import Cat
import Util (bindM2, zipWithT, enumerate, restructure)
import GHC.Stack
Expand All @@ -38,17 +38,17 @@ data DerivWrt = DerivWrt { activeVars :: Env Type
-- arguments to the linearized function.
data TangentEnv = TangentEnv { tangentVals :: SubstEnv, activeRefs :: [Name], rematVals :: SubstEnv }

type PrimalM = ReaderT DerivWrt Embed
type TangentM = ReaderT TangentEnv Embed
type PrimalM = ReaderT DerivWrt Builder
type TangentM = ReaderT TangentEnv Builder
newtype LinA a = LinA { runLinA :: PrimalM (a, TangentM a) }

linearize :: Scope -> Atom -> Atom
linearize scope ~(Lam (Abs b (_, block))) = fst $ flip runEmbed scope $ do
linearize scope ~(Lam (Abs b (_, block))) = fst $ flip runBuilder scope $ do
buildLam b PureArrow \x@(Var v) -> do
(y, yt) <- flip runReaderT (DerivWrt (varAsEnv v) [] mempty) $ runLinA $ linearizeBlock (b@>x) block
-- TODO: check linearity
fLin <- buildLam (fmap tangentType b) LinArrow \xt -> runReaderT yt $ TangentEnv (v @> xt) [] mempty
fLinChecked <- checkEmbed fLin
fLinChecked <- checkBuilder fLin
return $ PairVal y fLinChecked

linearizeBlock :: SubstEnv -> Block -> LinA Atom
Expand All @@ -64,7 +64,7 @@ linearizeBlock env (Block decls result) = case decls of
-- Technically, we could do this and later run the code through a simplification
-- pass that would eliminate a bunch of multiplications with zeros, but this seems
-- simpler to do for now.
freeAtoms <- traverse (substEmbed env . Var) $ bindingsAsVars $ freeVars expr
freeAtoms <- traverse (substBuilder env . Var) $ bindingsAsVars $ freeVars expr
varsAreActive <- traverse isActive $ bindingsAsVars $ freeVars freeAtoms
if any id varsAreActive
then do
Expand All @@ -90,15 +90,15 @@ linearizeBlock env (Block decls result) = case decls of
let nontrivialTs = if vIsTrivial then [] else [t]
extendTangentEnv (newEnv nontrivialVs nontrivialTs) [] bodyLin)
else do
expr' <- substEmbed env expr
expr' <- substBuilder env expr
x <- emit expr'
runLinA $ linearizeBlock (env <> b @> x) body

linearizeExpr :: SubstEnv -> Expr -> LinA Atom
linearizeExpr env expr = case expr of
Hof e -> linearizeHof env e
Case e alts _ -> LinA $ do
e' <- substEmbed env e
e' <- substBuilder env e
hasActiveScrutinee <- any id <$> (mapM isActive $ bindingsAsVars $ freeVars e')
case hasActiveScrutinee of
True -> notImplemented
Expand All @@ -111,7 +111,7 @@ linearizeExpr env expr = case expr of
linearizeInactiveAlt (Abs bs body) = do
buildNAbs bs \xs -> tangentFunAsLambda $ linearizeBlock (env <> newEnv bs xs) body
_ -> LinA $ do
expr' <- substEmbed env expr
expr' <- substBuilder env expr
runLinA $ case expr' of
App x i | isTabTy (getType x) -> liftA (flip App i) (linearizeAtom x) `bindLin` emit
Op e -> linearizeOp e
Expand Down Expand Up @@ -194,7 +194,7 @@ linearizeOp op = case op of
emitWithZero :: LinA Atom
emitWithZero = LinA $ withZeroTangent <$> emitOp op

emitUnOp :: MonadEmbed m => UnOp -> Atom -> m Atom
emitUnOp :: MonadBuilder m => UnOp -> Atom -> m Atom
emitUnOp op x = emitOp $ ScalarUnOp op x

linearizeUnOp :: UnOp -> Atom -> LinA Atom
Expand Down Expand Up @@ -268,7 +268,7 @@ linearizeBinOp op x' y' = LinA $ do
linearizeHof :: SubstEnv -> Hof -> LinA Atom
linearizeHof env hof = case hof of
For ~(RegularFor d) ~(LamVal i body) -> LinA $ do
i' <- mapM (substEmbed env) i
i' <- mapM (substBuilder env) i
(ansWithLinTab, vi'') <- buildForAux d i' \i''@(Var vi'') ->
(,vi'') <$> (willRemat vi'' $ tangentFunAsLambda $ linearizeBlock (env <> i@>i'') body)
(ans, linTab) <- unzipTab ansWithLinTab
Expand All @@ -280,13 +280,13 @@ linearizeHof env hof = case hof of
let RefTy _ accTy = binderType refBinder
linearizeEff lam Writer (RunWriter bm) (emitRunWriter "r" accTy bm)
RunReader val lam -> LinA $ do
(val', mkLinInit) <- runLinA <$> linearizeAtom =<< substEmbed env val
(val', mkLinInit) <- runLinA <$> linearizeAtom =<< substBuilder env val
linearizeEff lam Reader (RunReader val') \f -> mkLinInit >>= emitRunReader "r" `flip` f
RunState val lam -> LinA $ do
(val', mkLinInit) <- runLinA <$> linearizeAtom =<< substEmbed env val
(val', mkLinInit) <- runLinA <$> linearizeAtom =<< substBuilder env val
linearizeEff lam State (RunState val') \f -> mkLinInit >>= emitRunState "r" `flip` f
RunIO ~(Lam (Abs _ (arrow, body))) -> LinA $ do
arrow' <- substEmbed env arrow
arrow' <- substBuilder env arrow
-- TODO: consider the possibility of other effects here besides IO
lam <- buildLam (Ignore UnitTy) arrow' \_ ->
tangentFunAsLambda $ linearizeBlock env body
Expand Down Expand Up @@ -317,11 +317,11 @@ linearizeHof env hof = case hof of

linearizeEffectFun :: RWS -> Atom -> PrimalM (Atom, Var)
linearizeEffectFun rws ~(BinaryFunVal h ref eff body) = do
h' <- mapM (substEmbed env) h
h' <- mapM (substBuilder env) h
buildLamAux h' (const $ return PureArrow) \h''@(Var hVar) -> do
let env' = env <> h@>h''
eff' <- substEmbed env' eff
ref' <- mapM (substEmbed env') ref
eff' <- substBuilder env' eff
ref' <- mapM (substBuilder env') ref
buildLamAux ref' (const $ return $ PlainArrow eff') \ref''@(Var refVar) ->
extendWrt [refVar] [RWSEffect rws (varName hVar)] $
(,refVar) <$> (tangentFunAsLambda $ linearizeBlock (env' <> ref@>ref'') body)
Expand Down Expand Up @@ -437,7 +437,7 @@ tangentFunAsLambda m = do
-- Like buildLam, but doesn't try to deshadow the binder.
makeLambda v f = do
block <- buildScoped $ do
embedExtend $ asFst $ v @> (varType v, LamBound (void PureArrow))
builderExtend $ asFst $ v @> (varType v, LamBound (void PureArrow))
f v
return $ Lam $ makeAbs (Bind v) (PureArrow, block)

Expand Down Expand Up @@ -465,7 +465,7 @@ applyLinToTangents f = do
let args = (toList rematVals) ++ hs' ++ tangents ++ [UnitVal]
naryApp f args

bindLin :: LinA a -> (a -> Embed b) -> LinA b
bindLin :: LinA a -> (a -> Builder b) -> LinA b
bindLin (LinA m) f = LinA $ do
(e, t) <- m
x <- lift $ f e
Expand Down Expand Up @@ -535,10 +535,10 @@ instance Semigroup TransposeEnv where
instance Monoid TransposeEnv where
mempty = TransposeEnv mempty mempty mempty mempty

type TransposeM a = ReaderT TransposeEnv Embed a
type TransposeM a = ReaderT TransposeEnv Builder a

transpose :: Scope -> Atom -> Atom
transpose scope ~(Lam (Abs b (_, block))) = fst $ flip runEmbed scope $ do
transpose scope ~(Lam (Abs b (_, block))) = fst $ flip runBuilder scope $ do
buildLam (Bind $ "ct" :> getType block) LinArrow \ct -> do
snd <$> (flip runReaderT mempty $ withLinVar b $ transposeBlock block ct)

Expand Down Expand Up @@ -830,10 +830,10 @@ zeroAt ty = case ty of
Vector st -> VecLit $ replicate vectorWidth $ zeroLit $ Scalar st
_ -> unreachable

updateAddAt :: MonadEmbed m => Atom -> m Atom
updateAddAt :: MonadBuilder m => Atom -> m Atom
updateAddAt x = buildLam (Bind ("t":>getType x)) PureArrow $ addTangent x

addTangent :: MonadEmbed m => Atom -> Atom -> m Atom
addTangent :: MonadBuilder m => Atom -> Atom -> m Atom
addTangent x y = case getType x of
RecordTy (NoExt tys) -> do
elems <- bindM2 (zipWithT addTangent) (getUnpacked x) (getUnpacked y)
Expand All @@ -851,7 +851,7 @@ addTangent x y = case getType x of
_ -> notTangent
where notTangent = error $ "Not a tangent type: " ++ pprint (getType x)

tangentBaseMonoidFor :: MonadEmbed m => Type -> m BaseMonoid
tangentBaseMonoidFor :: MonadBuilder m => Type -> m BaseMonoid
tangentBaseMonoidFor ty = BaseMonoid (zeroAt ty) <$> buildLam (Bind ("t":>ty)) PureArrow updateAddAt

checkZeroPlusFloatMonoid :: BaseMonoid -> Bool
Expand Down
Loading

0 comments on commit 3a1d1fd

Please sign in to comment.