Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow unpacking ADTs in types (dependent projections as atoms) #290

Merged
merged 14 commits into from
Dec 15, 2020
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions examples/adt-tests.dx
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,59 @@ def catLists (xs:List a) (ys:List a) : List a =
n = 1 + 4
AsList _ (for i:(Fin n). ordinal i)
> (AsList 5 [0, 1, 2, 3, 4])



def listToTable ((AsList n xs): List a) : (Fin n)=>a = xs

:t listToTable
> ((a:Type) ?-> (pat:(List a)) -> (Fin (%projectElt [0] pat)) => a)
danieldjohnson marked this conversation as resolved.
Show resolved Hide resolved

:p
l = AsList _ [1, 2, 3]
sum $ listToTable l
> 6

def listLength ((AsList length xs):List a) : Int = length
def listToTable2 (l: List a) : (Fin (listLength l))=>a =
(AsList _ xs) = l
xs
danieldjohnson marked this conversation as resolved.
Show resolved Hide resolved

:t listToTable2
> ((a:Type) ?-> (l:(List a)) -> (Fin (%projectElt [0] l)) => a)

:p
l = AsList _ [1, 2, 3]
sum $ listToTable2 l
danieldjohnson marked this conversation as resolved.
Show resolved Hide resolved
> 6

data Graph a:Type =
MkGraph n:Type nodes:(n=>a) m:Type edges:(m=>(n & n))

def graphToAdjacencyMatrix ((MkGraph n nodes m edges):Graph a) : n=>n=>Bool =
init = for i j. False
snd $ withState init \mRef.
for i:m.
(from, to) = edges.i
mRef!from!to := True

:t graphToAdjacencyMatrix
> ((a:Type)
> ?-> (pat:(Graph a)) -> (%projectElt [0] pat) => (%projectElt [0] pat) => Bool)

:p
g : Graph Int = MkGraph (Fin 3) [5, 6, 7] (Fin 4) [(0@_, 1@_), (0@_, 2@_), (2@_, 0@_), (1@_, 1@_)]
graphToAdjacencyMatrix g
> [[False, True, True], [False, True, False], [True, False, False]]


def deepUnpack (x:MyPair a (MyPair (MyIntish & b) c)) : Int =
(MkMyPair _ (MkMyPair (MkIntish y, _) _)) = x
y
danieldjohnson marked this conversation as resolved.
Show resolved Hide resolved

:t deepUnpack
> ((a:Type)
> ?-> (b:Type) ?-> (c:Type) ?-> (MyPair a (MyPair (MyIntish & b) c)) -> Int32)

:p deepUnpack (MkMyPair 3 (MkMyPair (MkIntish 4, 5) 6))
> 4
4 changes: 2 additions & 2 deletions prelude.dx
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ def fdiv (d:Fractional a) ?=> : a -> a -> a = case d of MkFractional div -> div

def (&) (a:Type) (b:Type) : Type = %PairType a b
def (,) (x:a) (y:b) : (a & b) = %pair x y
def fst (p: (a & b)) : a = %fst p
def snd (p: (a & b)) : b = %snd p
def fst ((x, _): (a & b)) : a = x
def snd ((_, y): (a & b)) : b = y
def swap (p:(a&b)) : (b&a) = (snd p, fst p)
danieldjohnson marked this conversation as resolved.
Show resolved Hide resolved

def (<<<) (f: b -> c) (g: a -> b) : a -> c = \x. f (g x)
Expand Down
62 changes: 35 additions & 27 deletions src/lib/Autodiff.hs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import Control.Applicative
import Control.Monad
import Control.Monad.Reader
import Control.Monad.State.Strict
import qualified Data.List.NonEmpty as NE
import Data.Maybe
import Data.Foldable
import Data.Traversable
Expand All @@ -24,7 +25,7 @@ import Syntax
import PPrint
import Embed
import Cat
import Util (bindM2, zipWithT, enumerate)
import Util (bindM2, zipWithT, enumerate, restructure)
import GHC.Stack

-- === linearization ===
Expand Down Expand Up @@ -53,13 +54,11 @@ linearizeBlock :: SubstEnv -> Block -> LinA Atom
linearizeBlock env (Block decls result) = case decls of
Empty -> linearizeExpr env result
Nest decl rest -> case decl of
(Let _ b expr) -> linearizeBinding False [b] expr
(Unpack bs expr) -> linearizeBinding True (toList bs) expr
(Let _ b expr) -> linearizeBinding b expr
where
body = Block rest result
takeWhere l m = fmap snd $ filter fst $ zip m l
linearizeBinding :: Bool -> [Binder] -> Expr -> LinA Atom
linearizeBinding isUnpack bs expr = LinA $ do
linearizeBinding :: Binder -> Expr -> LinA Atom
linearizeBinding b expr = LinA $ do
-- Don't linearize expressions with no free active variables.
-- Technically, we could do this and later run the code through a simplification
-- pass that would eliminate a bunch of multiplications with zeros, but this seems
Expand All @@ -69,8 +68,7 @@ linearizeBlock env (Block decls result) = case decls of
if any id varsAreActive
then do
(x, boundLin) <- runLinA $ linearizeExpr env expr
xs <- if isUnpack then emitUnpack (Atom x) else (:[]) <$> emit (Atom x)
let vs = fmap (\(Var v) -> v) xs
~(Var v) <- emit $ Atom x
-- NB: This can still overestimate the set of active variables (e.g.
-- when multiple values are returned from a case statement).
-- Don't mark variables with trivial tangent types as active. This lets us avoid
Expand All @@ -80,20 +78,20 @@ linearizeBlock env (Block decls result) = case decls of
-- variables, but I don't think that we want to define them to have tangents.
-- We should delete this check, but to do that we would have to support differentiation
-- through case statements with active scrutinees.
let nontrivialVsMask = [not $ isSingletonType $ tangentType $ varType v | v <- vs]
let nontrivialVs = vs `takeWhere` nontrivialVsMask
(ans, bodyLin) <- extendWrt nontrivialVs [] $ runLinA $ linearizeBlock (env <> newEnv bs xs) body
let vIsTrivial = isSingletonType $ tangentType $ varType v
let nontrivialVs = if vIsTrivial then [] else [v]
(ans, bodyLin) <- extendWrt nontrivialVs [] $ runLinA $
linearizeBlock (env <> b @> Var v) body
return (ans, do
t <- boundLin
ts <- if isUnpack then emitUnpack (Atom t) else return [t]
-- Tangent environment needs to be synced between the primal and tangent
-- monads (tangentFunAsLambda and applyLinToTangents need that).
let nontrivialTs = ts `takeWhere` nontrivialVsMask
let nontrivialTs = if vIsTrivial then [] else [t]
extendTangentEnv (newEnv nontrivialVs nontrivialTs) [] bodyLin)
else do
expr' <- substEmbed env expr
xs <- if isUnpack then emitUnpack expr' else (:[]) <$> emit expr'
runLinA $ linearizeBlock (env <> newEnv bs xs) body
x <- emit expr'
runLinA $ linearizeBlock (env <> b @> x) body

linearizeExpr :: SubstEnv -> Expr -> LinA Atom
linearizeExpr env expr = case expr of
Expand Down Expand Up @@ -132,8 +130,6 @@ linearizeOp op = case op of
MTell x -> liftA MTell $ la x
MGet -> pure MGet
MPut x -> liftA MPut $ la x) `bindLin` emitOp
Fst x -> (Fst <$> la x) `bindLin` emitOp
Snd x -> (Snd <$> la x) `bindLin` emitOp
IndexRef ref i -> (IndexRef <$> la ref <*> pure i) `bindLin` emitOp
FstRef ref -> (FstRef <$> la ref ) `bindLin` emitOp
SndRef ref -> (SndRef <$> la ref ) `bindLin` emitOp
Expand Down Expand Up @@ -341,6 +337,7 @@ linearizeAtom atom = case atom of
Pi _ -> emitWithZero
TC _ -> emitWithZero
Eff _ -> emitWithZero
ProjectElt idxs v -> getProjection (toList idxs) <$> linearizeAtom (Var v)
-- Those should be gone after simplification
Lam _ -> error "Unexpected non-table lambda"
ACase _ _ _ -> error "Unexpected ACase"
Expand Down Expand Up @@ -395,8 +392,9 @@ addTangent x y = case getType x of

pack :: Type -> [Atom] -> Atom
pack ty elems = case ty of
danieldjohnson marked this conversation as resolved.
Show resolved Hide resolved
PairTy _ _ -> let [x, y] = elems in PairVal x y
TypeCon def params -> DataCon def params 0 elems
RecordTy (NoExt types) -> Record $ snd $ mapAccumL (\(h:t) _ -> (t, h)) elems types
RecordTy (NoExt types) -> Record $ restructure elems types
_ -> error $ "Unexpected Unpack argument type: " ++ pprint ty

isTrivialForAD :: Expr -> Bool
Expand Down Expand Up @@ -536,20 +534,19 @@ transposeBlock :: Block -> Atom -> TransposeM ()
transposeBlock (Block decls result) ct = case decls of
Empty -> transposeExpr result ct
Nest decl rest -> case decl of
(Let _ b expr) -> transposeBinding False [b] expr
(Unpack bs expr) -> transposeBinding True (toList bs) expr
(Let _ b expr) -> transposeBinding b expr
where
body = Block rest result
transposeBinding isUnpack bs expr = do
transposeBinding b expr = do
isLinearExpr <- (||) <$> isLinEff (exprEffs expr) <*> isLin expr
if isLinearExpr
then do
cts <- withLinVars bs $ transposeBlock body ct
transposeExpr expr $ if isUnpack then pack (getType expr) cts else head cts
cts <- withLinVars [b] $ transposeBlock body ct
transposeExpr expr $ head cts
else do
expr' <- substNonlin expr
xs <- if isUnpack then emitUnpack expr' else (:[]) <$> emit expr'
localNonlinSubst (newEnv bs xs) $ transposeBlock body ct
x <- emit expr'
localNonlinSubst (b @> x) $ transposeBlock body ct

withLinVars :: [Binder] -> TransposeM () -> TransposeM [Atom]
withLinVars [] m = m >> return []
Expand Down Expand Up @@ -582,8 +579,6 @@ transposeExpr expr ct = case expr of

transposeOp :: Op -> Atom -> TransposeM ()
transposeOp op ct = case op of
Fst x -> flip emitCTToRef ct =<< (traverse $ emitOp . FstRef) =<< linAtomRef x
Snd x -> flip emitCTToRef ct =<< (traverse $ emitOp . SndRef) =<< linAtomRef x
ScalarUnOp FNeg x -> transposeAtom x =<< neg ct
ScalarUnOp _ _ -> notLinear
ScalarBinOp FAdd x y -> transposeAtom x ct >> transposeAtom y ct
Expand Down Expand Up @@ -641,6 +636,16 @@ linAtomRef (Var x) = do
case envLookup refs x of
Just ref -> return ref
_ -> error $ "Not a linear var: " ++ pprint (Var x)
linAtomRef (ProjectElt (i NE.:| is) x) = do
let subproj = getProjection is (Var x)
case getType subproj of
PairTy _ _ -> do
ref <- linAtomRef subproj
(traverse $ emitOp . getter) ref
where getter = case i of 0 -> FstRef
1 -> SndRef
_ -> error "bad pair projection"
ty -> error $ "Projecting references not implemented for type " <> pprint ty
danieldjohnson marked this conversation as resolved.
Show resolved Hide resolved
linAtomRef a = error $ "Not a linear var: " ++ pprint a

transposeHof :: Hof -> Atom -> TransposeM ()
Expand Down Expand Up @@ -700,6 +705,9 @@ transposeAtom atom ct = case atom of
ACase _ _ _ -> error "Unexpected ACase"
DataConRef _ _ _ -> error "Unexpected ref"
BoxedRef _ _ _ _ -> error "Unexpected ref"
ProjectElt _ v -> do
lin <- isLin $ Var v
when lin $ flip emitCTToRef ct =<< linAtomRef atom
where notTangent = error $ "Not a tangent atom: " ++ pprint atom

transposeCon :: Con -> Atom -> TransposeM ()
Expand Down
Loading