Skip to content

Commit

Permalink
Introduce interpretAstDual
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Apr 15, 2023
1 parent 66d2033 commit 6f88617
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions simplified/HordeAd/Core/AstInterpret.hs
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,12 @@ emptyMemo :: AstMemo a
emptyMemo = ()

-- Strict environment and strict ADVal and Delta make this is hard to optimize.
-- Either the environment has to be traverse to remove the dual parts or
-- the dual part needs to be needlessly computed.
-- Either the environment has to be traversed to remove the dual parts or
-- the dual part needs to be potentially needlessly computed.
-- However, with correct sharing and large tensors, the overall cost
-- is negligible, so we optimize only minimally.
-- It helps that usually the dual part is either trivially computed
-- to be zero or is used elsewhere. It's rarely really lost and forgotten.
interpretAstPrimal
:: forall n a. (KnownNat n, Evidence a)
=> AstEnv a -> AstMemo a
Expand All @@ -150,6 +152,14 @@ interpretAstPrimal env memo (AstPrimalPart v1) = case v1 of
AstD u _-> interpretAstPrimal env memo u
_ -> second tprimalPart $ interpretAst env memo v1

interpretAstDual
:: forall n a. (KnownNat n, Evidence a)
=> AstEnv a -> AstMemo a
-> AstDualPart n (ScalarOf a) -> (AstMemo a, DualOf n a)
interpretAstDual env memo (AstDualPart v1) = case v1 of
AstD _ u'-> interpretAstDual env memo u'
_ -> second tdualPart $ interpretAst env memo v1

interpretAst
:: forall n a. (KnownNat n, Evidence a)
=> AstEnv a -> AstMemo a
Expand Down Expand Up @@ -270,9 +280,9 @@ interpretAst env memo | Dict <- evi1 @a @n Proxy = \case
-- leads to a tensor of deltas
AstConst a -> (memo, tconst a)
AstConstant a -> second tconstant $ interpretAstPrimal env memo a
AstD u (AstDualPart u') ->
AstD u u' ->
let (memo1, t1) = interpretAstPrimal env memo u
(memo2, t2) = second tdualPart $ interpretAst env memo1 u'
(memo2, t2) = interpretAstDual env memo1 u'
in (memo2, tD t1 t2)
AstLetDomains vars l v ->
let (memo2, l2) = interpretAstDomains env memo l
Expand Down

0 comments on commit 6f88617

Please sign in to comment.