Skip to content

Commit

Permalink
Test that simplification now works after AD thanks to only local sharing
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Apr 15, 2023
1 parent 62d53f0 commit 66d2033
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
2 changes: 1 addition & 1 deletion simplified/HordeAd/Core/AstSimplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ module HordeAd.Core.AstSimplify
, astConstant, astSum, astScatter, astFromList, astFromVector, astKonst
, astAppend, astSlice, astReverse, astFromDynamic
, astIntCond
, ShowAstSimplify, simplifyAst
, ShowAstSimplify, simplifyAst, simplifyAstDomains
, substituteAst, substituteAstInt, substituteAstBool
, resetVarCounter
) where
Expand Down
4 changes: 4 additions & 0 deletions test/simplified/TestAdaptorSimplified.hs
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,10 @@ testReluPP2 = do
@?= length "\\s0 dt x3 -> dlet (tkonst 5 (s0 ! [0])) (\\x6 -> dlet (x3 * x6) (\\x7 -> dlet (tgather [5] (tconst (fromList [2] [0.0,1.0])) (\\[i5] -> [ifB (tlet (s0 ! [0]) (\\x12 -> tlet (x3 ! [i5]) (\\x13 -> x13 * x12)) <=* tconst 0.0) 0 1])) (\\x8 -> dlet (x8 * dt) (\\x9 -> dlet (tscatter [1] (tfromList [tsum (x3 * x9)]) (\\[i10] -> [0])) (\\x11 -> dmkDomains (fromList [dfromR (tfromList [tconst 0.0 + x11 ! [0]]), dfromR (x6 * x9)]))))))"
length ("\\" ++ varsPP ++ " -> " ++ printAstSimple renames vAst)
@?= length "\\s0 x3 -> tlet (tkonst 5 (s0 ! [0])) (\\x6 -> tlet (x3 * x6) (\\x7 -> tlet (tgather [5] (tconst (fromList [2] [0.0,1.0])) (\\[i5] -> [ifB (tlet (s0 ! [0]) (\\x12 -> tlet (x3 ! [i5]) (\\x13 -> x13 * x12)) <=* tconst 0.0) 0 1])) (\\x8 -> x8 * x7)))"
length ("\\" ++ varsPPD
++ " -> " ++ printAstDomainsPretty renames
(simplifyAstDomains letGradientAst))
@?= length "\\s0 dt x3 -> let x6 = tkonst 5 (s0 ! [0]) in let x7 = x3 * x6 in let x8 = tconstant (tgather [5] (tconst (fromList [2] [0.0,1.0])) (\\[i5] -> [ifB ((let x12 = s0 ! [0] in let x13 = x3 ! [i5] in x13 * x12) <=* tconst 0.0) 0 1])) in let x9 = x8 * dt in let x11 = tscatter [1] (tkonst 1 (tsum (x3 * x9))) (\\[i10] -> [0]) in (tkonst 1 (tconst 0.0 + x11 ! [0]), x6 * x9)"

reluMax :: forall n r. (ADReady r, KnownNat n)
=> TensorOf n r -> TensorOf n r
Expand Down

0 comments on commit 66d2033

Please sign in to comment.