Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow unpacking ADTs in types (dependent projections as atoms) #290

Merged
merged 14 commits into from
Dec 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions examples/adt-tests.dx
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,82 @@ def catLists (xs:List a) (ys:List a) : List a =
n = 1 + 4
AsList _ (for i:(Fin n). ordinal i)
> (AsList 5 [0, 1, 2, 3, 4])



def listToTable ((AsList n xs): List a) : (Fin n)=>a = xs

:t listToTable
> ((a:Type) ?-> (pat:(List a)) -> (Fin ((\((AsList n _)). n) pat)) => a)

:p
l = AsList _ [1, 2, 3]
sum $ listToTable l
> 6

def listToTable2 (l: List a) : (Fin (listLength l))=>a =
(AsList _ xs) = l
xs
danieldjohnson marked this conversation as resolved.
Show resolved Hide resolved

:t listToTable2
> ((a:Type) ?-> (l:(List a)) -> (Fin ((\((AsList n _)). n) l)) => a)

:p
l = AsList _ [1, 2, 3]
sum $ listToTable2 l
danieldjohnson marked this conversation as resolved.
Show resolved Hide resolved
> 6

l2 = AsList _ [1, 2, 3]
:p sum $ listToTable2 l2
> 6

def zerosLikeList (l : List a) : (Fin (listLength l))=>Float =
for i:(Fin $ listLength l). 0.0

:p zerosLikeList l2
> [0.0, 0.0, 0.0]

data Graph a:Type =
MkGraph n:Type nodes:(n=>a) m:Type edges:(m=>(n & n))

def graphToAdjacencyMatrix ((MkGraph n nodes m edges):Graph a) : n=>n=>Bool =
init = for i j. False
snd $ withState init \mRef.
for i:m.
(from, to) = edges.i
mRef!from!to := True

:t graphToAdjacencyMatrix
> ((a:Type)
> ?-> (pat:(Graph a))
> -> ((\((MkGraph n _ _ _)). n) pat) => ((\((MkGraph n _ _ _)). n) pat) => Bool)

:p
g : Graph Int = MkGraph (Fin 3) [5, 6, 7] (Fin 4) [(0@_, 1@_), (0@_, 2@_), (2@_, 0@_), (1@_, 1@_)]
graphToAdjacencyMatrix g
> [[False, True, True], [False, True, False], [True, False, False]]

-- Test how (nested) projections are handled and pretty-printed.

def pairUnpack ((v, _):(Int & Float)) : Int = v
:p pairUnpack
> \pat:(Int32 & Float32). (\(a, _). a) pat

def adtUnpack ((MkMyPair v _):MyPair Int Float) : Int = v
:p adtUnpack
> \pat:(MyPair Int32 Float32). (\((MkMyPair elt _)). elt) pat

def recordUnpack ({a=v, b=_}:{a:Int & b:Float}) : Int = v
:p recordUnpack
> \pat:{a: Int32 & b: Float32}. (\{a = a, b = _}. a) pat

def nestedUnpack (x:MyPair Int (MyPair (MyIntish & Int) Int)) : Int =
(MkMyPair _ (MkMyPair (MkIntish y, _) _)) = x
y
danieldjohnson marked this conversation as resolved.
Show resolved Hide resolved

:p nestedUnpack
> \x:(MyPair Int32 (MyPair (MyIntish & Int32) Int32)).
> (\((MkIntish (((MkMyPair ((MkMyPair _ elt)) _)), _))). elt) x

:p nestedUnpack (MkMyPair 3 (MkMyPair (MkIntish 4, 5) 6))
> 4
51 changes: 25 additions & 26 deletions examples/isomorphisms.dx
Original file line number Diff line number Diff line change
Expand Up @@ -46,24 +46,23 @@ that produce isos. We will start with the first two:
:t #b : Iso {a:Int & b:Float & c:Unit} _
> (Iso {a: Int32 & b: Float32 & c: Unit} (Float32 & {a: Int32 & c: Unit}))
> === parse ===
> _ans_:() =
> MkIso
> {bwd = \(x:(), r:()). {b = x, ...r}, fwd = \{b = x:(), ...r:()}. (,) x r}
> _ans_ =
> MkIso {bwd = \(x, r). {b = x, ...r}, fwd = \{b = x, ...r}. (,) x r}
> : Iso {a: Int & b: Float & c: Unit} _

%passes parse
:t #?b : Iso {a:Int | b:Float | c:Unit} _
> (Iso {a: Int32 | b: Float32 | c: Unit} (Float32 | {a: Int32 | c: Unit}))
> === parse ===
> _ans_:() =
> _ans_ =
> MkIso
> { bwd = \v:(). case v
> ((Left x:())) -> {| b = x |}
> ((Right r:())) -> {|b| ...r |}
> { bwd = \v. case v
> ((Left x)) -> {| b = x |}
> ((Right r)) -> {|b| ...r |}
>
> , fwd = \v:(). case v
> {| b = x:() |} -> (Left x)
> {|b| ...r:() |} -> (Right r)
> , fwd = \v. case v
> {| b = x |} -> (Left x)
> {|b| ...r |} -> (Right r)
> }
> : Iso {a: Int | b: Float | c: Unit} _

Expand Down Expand Up @@ -143,10 +142,10 @@ another. For instance:
> ({ &} & {a: Int32 & b: Float32 & c: Unit})
> ({a: Int32} & {b: Float32 & c: Unit}))
> === parse ===
> _ans_:() =
> _ans_ =
> MkIso
> { bwd = \({a = x:(), ...l:()}, {, ...r:()}). (,) {, ...l} {a = x, ...r}
> , fwd = \({, ...l:()}, {a = x:(), ...r:()}). (,) {a = x, ...l} {, ...r}}
> { bwd = \({a = x, ...l}, {, ...r}). (,) {, ...l} {a = x, ...r}
> , fwd = \({, ...l}, {a = x, ...r}). (,) {a = x, ...l} {, ...r}}
> : Iso ((&) { &} {a: Int & b: Float & c: Unit}) _

:t (#&a &>> #&b) : Iso ({&} & {a:Int & b:Float & c:Unit}) _
Expand Down Expand Up @@ -213,21 +212,21 @@ zipper isomorphisms:
> ({ |} | {a: Int32 | b: Float32 | c: Unit})
> ({a: Int32} | {b: Float32 | c: Unit}))
> === parse ===
> _ans_:() =
> _ans_ =
> MkIso
> { bwd = \v:(). case v
> ((Left w:())) -> (case w
> {| a = x:() |} -> (Right {| a = x |})
> {|a| ...r:() |} -> (Left r)
> )
> ((Right l:())) -> (Right {|a| ...l |})
> { bwd = \v. case v
> ((Left w)) -> (case w
> {| a = x |} -> (Right {| a = x |})
> {|a| ...r |} -> (Left r)
> )
> ((Right l)) -> (Right {|a| ...l |})
>
> , fwd = \v:(). case v
> ((Left l:())) -> (Left {|a| ...l |})
> ((Right w:())) -> (case w
> {| a = x:() |} -> (Left {| a = x |})
> {|a| ...r:() |} -> (Right r)
> )
> , fwd = \v. case v
> ((Left l)) -> (Left {|a| ...l |})
> ((Right w)) -> (case w
> {| a = x |} -> (Left {| a = x |})
> {|a| ...r |} -> (Right r)
> )
> }
> : Iso ((|) { |} {a: Int | b: Float | c: Unit}) _

Expand Down
6 changes: 3 additions & 3 deletions prelude.dx
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@ instance float32Fractional : Fractional Float32 where

def (&) (a:Type) (b:Type) : Type = %PairType a b
def (,) (x:a) (y:b) : (a & b) = %pair x y
def fst (p: (a & b)) : a = %fst p
def snd (p: (a & b)) : b = %snd p
def swap (p:(a&b)) : (b&a) = (snd p, fst p)
def fst ((x, _): (a & b)) : a = x
def snd ((_, y): (a & b)) : b = y
def swap ((x, y):(a&b)) : (b&a) = (y, x)

def (<<<) (f: b -> c) (g: a -> b) : a -> c = \x. f (g x)
def (>>>) (g: a -> b) (f: b -> c) : a -> c = \x. f (g x)
Expand Down
71 changes: 38 additions & 33 deletions src/lib/Autodiff.hs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import Control.Applicative
import Control.Monad
import Control.Monad.Reader
import Control.Monad.State.Strict
import qualified Data.List.NonEmpty as NE
import Data.Maybe
import Data.Foldable
import Data.Traversable
Expand All @@ -24,7 +25,7 @@ import Syntax
import PPrint
import Embed
import Cat
import Util (bindM2, zipWithT, enumerate)
import Util (bindM2, zipWithT, enumerate, restructure)
import GHC.Stack

-- === linearization ===
Expand Down Expand Up @@ -53,13 +54,11 @@ linearizeBlock :: SubstEnv -> Block -> LinA Atom
linearizeBlock env (Block decls result) = case decls of
Empty -> linearizeExpr env result
Nest decl rest -> case decl of
(Let _ b expr) -> linearizeBinding False [b] expr
(Unpack bs expr) -> linearizeBinding True (toList bs) expr
(Let _ b expr) -> linearizeBinding b expr
where
body = Block rest result
takeWhere l m = fmap snd $ filter fst $ zip m l
linearizeBinding :: Bool -> [Binder] -> Expr -> LinA Atom
linearizeBinding isUnpack bs expr = LinA $ do
linearizeBinding :: Binder -> Expr -> LinA Atom
linearizeBinding b expr = LinA $ do
-- Don't linearize expressions with no free active variables.
-- Technically, we could do this and later run the code through a simplification
-- pass that would eliminate a bunch of multiplications with zeros, but this seems
Expand All @@ -69,8 +68,7 @@ linearizeBlock env (Block decls result) = case decls of
if any id varsAreActive
then do
(x, boundLin) <- runLinA $ linearizeExpr env expr
xs <- if isUnpack then emitUnpack (Atom x) else (:[]) <$> emit (Atom x)
let vs = fmap (\(Var v) -> v) xs
~(Var v) <- emit $ Atom x
-- NB: This can still overestimate the set of active variables (e.g.
-- when multiple values are returned from a case statement).
-- Don't mark variables with trivial tangent types as active. This lets us avoid
Expand All @@ -80,20 +78,20 @@ linearizeBlock env (Block decls result) = case decls of
-- variables, but I don't think that we want to define them to have tangents.
-- We should delete this check, but to do that we would have to support differentiation
-- through case statements with active scrutinees.
let nontrivialVsMask = [not $ isSingletonType $ tangentType $ varType v | v <- vs]
let nontrivialVs = vs `takeWhere` nontrivialVsMask
(ans, bodyLin) <- extendWrt nontrivialVs [] $ runLinA $ linearizeBlock (env <> newEnv bs xs) body
let vIsTrivial = isSingletonType $ tangentType $ varType v
let nontrivialVs = if vIsTrivial then [] else [v]
(ans, bodyLin) <- extendWrt nontrivialVs [] $ runLinA $
linearizeBlock (env <> b @> Var v) body
return (ans, do
t <- boundLin
ts <- if isUnpack then emitUnpack (Atom t) else return [t]
-- Tangent environment needs to be synced between the primal and tangent
-- monads (tangentFunAsLambda and applyLinToTangents need that).
let nontrivialTs = ts `takeWhere` nontrivialVsMask
let nontrivialTs = if vIsTrivial then [] else [t]
extendTangentEnv (newEnv nontrivialVs nontrivialTs) [] bodyLin)
else do
expr' <- substEmbed env expr
xs <- if isUnpack then emitUnpack expr' else (:[]) <$> emit expr'
runLinA $ linearizeBlock (env <> newEnv bs xs) body
x <- emit expr'
runLinA $ linearizeBlock (env <> b @> x) body

linearizeExpr :: SubstEnv -> Expr -> LinA Atom
linearizeExpr env expr = case expr of
Expand Down Expand Up @@ -132,8 +130,6 @@ linearizeOp op = case op of
MTell x -> liftA MTell $ la x
MGet -> pure MGet
MPut x -> liftA MPut $ la x) `bindLin` emitOp
Fst x -> (Fst <$> la x) `bindLin` emitOp
Snd x -> (Snd <$> la x) `bindLin` emitOp
IndexRef ref i -> (IndexRef <$> la ref <*> pure i) `bindLin` emitOp
FstRef ref -> (FstRef <$> la ref ) `bindLin` emitOp
SndRef ref -> (SndRef <$> la ref ) `bindLin` emitOp
Expand Down Expand Up @@ -344,6 +340,7 @@ linearizeAtom atom = case atom of
Pi _ -> emitWithZero
TC _ -> emitWithZero
Eff _ -> emitWithZero
ProjectElt idxs v -> getProjection (toList idxs) <$> linearizeAtom (Var v)
-- Those should be gone after simplification
Lam _ -> error "Unexpected non-table lambda"
ACase _ _ _ -> error "Unexpected ACase"
Expand Down Expand Up @@ -381,7 +378,9 @@ tangentType ty = case ty of

addTangent :: MonadEmbed m => Atom -> Atom -> m Atom
addTangent x y = case getType x of
RecordTy _ -> pack (getType x) <$> bindM2 (zipWithT addTangent) (getUnpacked x) (getUnpacked y)
RecordTy (NoExt tys) -> do
elems <- bindM2 (zipWithT addTangent) (getUnpacked x) (getUnpacked y)
return $ Record $ restructure elems tys
TabTy b _ -> buildFor Fwd b $ \i -> bindM2 addTangent (tabGet x i) (tabGet y i)
TC con -> case con of
BaseType (Scalar _) -> emitOp $ ScalarBinOp FAdd x y
Expand All @@ -395,12 +394,6 @@ addTangent x y = case getType x of
_ -> notTangent
where notTangent = error $ "Not a tangent type: " ++ pprint (getType x)

pack :: Type -> [Atom] -> Atom
pack ty elems = case ty of
TypeCon def params -> DataCon def params 0 elems
RecordTy (NoExt types) -> Record $ snd $ mapAccumL (\(h:t) _ -> (t, h)) elems types
_ -> error $ "Unexpected Unpack argument type: " ++ pprint ty

isTrivialForAD :: Expr -> Bool
isTrivialForAD expr = isSingletonType tangentTy && exprEffs expr == mempty
where tangentTy = tangentType $ getType expr
Expand Down Expand Up @@ -538,20 +531,19 @@ transposeBlock :: Block -> Atom -> TransposeM ()
transposeBlock (Block decls result) ct = case decls of
Empty -> transposeExpr result ct
Nest decl rest -> case decl of
(Let _ b expr) -> transposeBinding False [b] expr
(Unpack bs expr) -> transposeBinding True (toList bs) expr
(Let _ b expr) -> transposeBinding b expr
where
body = Block rest result
transposeBinding isUnpack bs expr = do
transposeBinding b expr = do
isLinearExpr <- (||) <$> isLinEff (exprEffs expr) <*> isLin expr
if isLinearExpr
then do
cts <- withLinVars bs $ transposeBlock body ct
transposeExpr expr $ if isUnpack then pack (getType expr) cts else head cts
cts <- withLinVars [b] $ transposeBlock body ct
transposeExpr expr $ head cts
else do
expr' <- substNonlin expr
xs <- if isUnpack then emitUnpack expr' else (:[]) <$> emit expr'
localNonlinSubst (newEnv bs xs) $ transposeBlock body ct
x <- emit expr'
localNonlinSubst (b @> x) $ transposeBlock body ct

withLinVars :: [Binder] -> TransposeM () -> TransposeM [Atom]
withLinVars [] m = m >> return []
Expand Down Expand Up @@ -584,8 +576,6 @@ transposeExpr expr ct = case expr of

transposeOp :: Op -> Atom -> TransposeM ()
transposeOp op ct = case op of
Fst x -> flip emitCTToRef ct =<< (traverse $ emitOp . FstRef) =<< linAtomRef x
Snd x -> flip emitCTToRef ct =<< (traverse $ emitOp . SndRef) =<< linAtomRef x
ScalarUnOp FNeg x -> transposeAtom x =<< neg ct
ScalarUnOp _ _ -> notLinear
ScalarBinOp FAdd x y -> transposeAtom x ct >> transposeAtom y ct
Expand Down Expand Up @@ -645,6 +635,18 @@ linAtomRef (Var x) = do
case envLookup refs x of
Just ref -> return ref
_ -> error $ "Not a linear var: " ++ pprint (Var x)
linAtomRef (ProjectElt (i NE.:| is) x) =
let subproj = case NE.nonEmpty is of
Just is' -> ProjectElt is' x
Nothing -> Var x
in case getType subproj of
PairTy _ _ -> do
ref <- linAtomRef subproj
(traverse $ emitOp . getter) ref
where getter = case i of 0 -> FstRef
1 -> SndRef
_ -> error "bad pair projection"
ty -> error $ "Projecting references not implemented for type " <> pprint ty
danieldjohnson marked this conversation as resolved.
Show resolved Hide resolved
linAtomRef a = error $ "Not a linear var: " ++ pprint a

transposeHof :: Hof -> Atom -> TransposeM ()
Expand Down Expand Up @@ -704,6 +706,9 @@ transposeAtom atom ct = case atom of
ACase _ _ _ -> error "Unexpected ACase"
DataConRef _ _ _ -> error "Unexpected ref"
BoxedRef _ _ _ _ -> error "Unexpected ref"
ProjectElt _ v -> do
lin <- isLin $ Var v
when lin $ flip emitCTToRef ct =<< linAtomRef atom
where notTangent = error $ "Not a tangent atom: " ++ pprint atom

transposeCon :: Con -> Atom -> TransposeM ()
Expand Down
Loading