Skip to content

Commit

Permalink
Remove ADNum constraint from simplified tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Apr 18, 2023
1 parent 5f21ebe commit 35ea918
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 23 deletions.
1 change: 1 addition & 0 deletions horde-ad.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,7 @@ test-suite simplifiedOnlyTest
, ghc-typelits-natnormalise
, hmatrix
, ilist
, mono-traversable
, orthotope
, random
, strict-containers
Expand Down
65 changes: 42 additions & 23 deletions test/simplified/TestMnistFCNNR.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@ import qualified Data.Array.DynamicS as OD
import qualified Data.Array.RankedS as OR
import qualified Data.EnumMap.Strict as EM
import Data.List.Index (imap)
import Data.MonoTraversable (Element)
import qualified Data.Strict.IntMap as IM
import qualified Data.Strict.Vector as Data.Vector
import qualified Data.Vector.Generic as V
import Numeric.LinearAlgebra (Vector)
import qualified Numeric.LinearAlgebra as LA
import System.IO (hPutStrLn, stderr)
import System.Random
Expand All @@ -21,9 +24,10 @@ import Text.Printf
import HordeAd.Core.Ast
import HordeAd.Core.AstInterpret
import HordeAd.Core.AstSimplify
import HordeAd.Core.DualNumber
import HordeAd.Core.DualNumber (ADVal, dDnotShared)
import HordeAd.Core.Engine
import HordeAd.Core.SizedIndex
import HordeAd.Core.TensorADVal (ADTensor)
import HordeAd.Core.TensorClass
import HordeAd.External.Adaptor
import HordeAd.External.CommonRankedOps
Expand Down Expand Up @@ -51,12 +55,15 @@ testTrees = [ tensorADValMnistTests
-- POPL differentiation, straight via the ADVal instance of Tensor
mnistTestCase2VTA
:: forall r.
( ADReady r, ADReady (ADVal r), ADNum r, PrintfArg r
, Primal r ~ r, ScalarOf r ~ r, AssertEqualUpToEpsilon r
( ADReady r, ADReady (ADVal r), ScalarOf r ~ r, ScalarOf (ADVal r) ~ r
, TensorOf 0 (ADVal r) ~ ADVal (OR.Array 0 r)
, TensorOf 1 (ADVal r) ~ ADVal (OR.Array 1 r)
, DTensorOf (ADVal r) ~ ADVal (OD.Array r)
, Primal (ADVal r) ~ r, ScalarOf (ADVal r) ~ r )
, PrintfArg r, AssertEqualUpToEpsilon r
, Floating (Vector r), ADTensor r
, DynamicTensor r, DomainsTensor r, Element r ~ r
, DTensorOf r ~ OD.Array r, TensorOf 1 r ~ OR.Array 1 r
, DomainsOf r ~ Data.Vector.Vector (OD.Array r) )
=> String
-> Int -> Int -> Int -> Int -> r -> Int -> r
-> TestTree
Expand Down Expand Up @@ -141,12 +148,15 @@ tensorADValMnistTests = testGroup "ShortRanked ADVal MNIST tests"
-- POPL differentiation, Ast term defined only once but differentiated each time
mnistTestCase2VTI
:: forall r.
( ADReady r, ADReady (ADVal r), ADReady (Ast0 r), ADNum r, PrintfArg r
, Primal r ~ r, ScalarOf r ~ r, AssertEqualUpToEpsilon r
, TensorOf 0 (ADVal r) ~ ADVal (OR.Array 0 r)
( ADReady r, ADReady (ADVal r), ScalarOf r ~ r, ScalarOf (ADVal r) ~ r
, TensorOf 1 (ADVal r) ~ ADVal (OR.Array 1 r)
, DTensorOf (ADVal r) ~ ADVal (OD.Array r)
, ScalarOf (ADVal r) ~ r
, InterpretAst (ADVal r) )
, InterpretAst (ADVal r)
, PrintfArg r, AssertEqualUpToEpsilon r
, Floating (Vector r), ADTensor r
, DynamicTensor r, DomainsTensor r, Element r ~ r
, DTensorOf r ~ OD.Array r, TensorOf 1 r ~ OR.Array 1 r
, DomainsOf r ~ Data.Vector.Vector (OD.Array r) )
=> String
-> Int -> Int -> Int -> Int -> r -> Int -> r
-> TestTree
Expand Down Expand Up @@ -256,9 +266,10 @@ tensorIntermediateMnistTests = testGroup "ShortRankedIntermediate MNIST tests"
-- JAX differentiation, Ast term built and differentiated only once
mnistTestCase2VTO
:: forall r.
( ADReady r, ADReady (ADVal r), ADReady (Ast0 r), ADNum r, PrintfArg r
, Primal r ~ r, ScalarOf r ~ r, AssertEqualUpToEpsilon r
, TensorOf 0 (ADVal r) ~ ADVal (OR.Array 0 r), InterpretAst r )
( ADReady r, ScalarOf r ~ r, InterpretAst r
, PrintfArg r, AssertEqualUpToEpsilon r
, Floating (Vector r), ADTensor r, DomainsTensor r
, DTensorOf r ~ OD.Array r, TensorOf 1 r ~ OR.Array 1 r )
=> String
-> Int -> Int -> Int -> Int -> r -> Int -> r
-> TestTree
Expand Down Expand Up @@ -387,13 +398,16 @@ tensorADOnceMnistTests = testGroup "ShortRankedOnce MNIST tests"
-- POPL differentiation, straight via the ADVal instance of Tensor
mnistTestCase2VT2A
:: forall r.
( ADReady r, ADReady (ADVal r), ADNum r, PrintfArg r
, Primal r ~ r, ScalarOf r ~ r, AssertEqualUpToEpsilon r
( ADReady r, ADReady (ADVal r), ScalarOf r ~ r, ScalarOf (ADVal r) ~ r
, TensorOf 0 (ADVal r) ~ ADVal (OR.Array 0 r)
, TensorOf 1 (ADVal r) ~ ADVal (OR.Array 1 r)
, TensorOf 2 (ADVal r) ~ ADVal (OR.Array 2 r)
, DTensorOf (ADVal r) ~ ADVal (OD.Array r)
, Primal (ADVal r) ~ r, ScalarOf (ADVal r) ~ r )
, PrintfArg r, AssertEqualUpToEpsilon r
, Floating (Vector r), ADTensor r
, DynamicTensor r, DomainsTensor r, Element r ~ r
, DTensorOf r ~ OD.Array r, DomainsOf r ~ Data.Vector.Vector (OD.Array r)
, TensorOf 1 r ~ OR.Array 1 r, TensorOf 2 r ~ OR.Array 2 r )
=> String
-> Int -> Int -> Int -> Int -> r -> Int -> r
-> TestTree
Expand Down Expand Up @@ -479,12 +493,15 @@ tensorADValMnistTests2 = testGroup "ShortRanked ADVal MNIST tests"
-- POPL differentiation, Ast term defined only once but differentiated each time
mnistTestCase2VT2I
:: forall r.
( ADReady r, ADReady (ADVal r), ADReady (Ast0 r), ADNum r, PrintfArg r
, Primal r ~ r, ScalarOf r ~ r, AssertEqualUpToEpsilon r
, TensorOf 0 (ADVal r) ~ ADVal (OR.Array 0 r)
( ADReady r, ADReady (ADVal r), ScalarOf r ~ r, ScalarOf (ADVal r) ~ r
, TensorOf 1 (ADVal r) ~ ADVal (OR.Array 1 r)
, DTensorOf (ADVal r) ~ ADVal (OD.Array r)
, ScalarOf (ADVal r) ~ r
, InterpretAst (ADVal r) )
, InterpretAst (ADVal r)
, PrintfArg r, AssertEqualUpToEpsilon r
, Floating (Vector r), ADTensor r
, DynamicTensor r, DomainsTensor r, Element r ~ r
, DTensorOf r ~ OD.Array r, DomainsOf r ~ Data.Vector.Vector (OD.Array r)
, TensorOf 1 r ~ OR.Array 1 r, TensorOf 2 r ~ OR.Array 2 r )
=> String
-> Int -> Int -> Int -> Int -> r -> Int -> r
-> TestTree
Expand Down Expand Up @@ -595,9 +612,11 @@ tensorIntermediateMnistTests2 = testGroup "ShortRankedIntermediate MNIST tests"
-- JAX differentiation, Ast term built and differentiated only once
mnistTestCase2VT2O
:: forall r.
( ADReady r, ADReady (ADVal r), ADReady (Ast0 r), ADNum r, PrintfArg r
, Primal r ~ r, ScalarOf r ~ r, AssertEqualUpToEpsilon r
, TensorOf 0 (ADVal r) ~ ADVal (OR.Array 0 r), InterpretAst r )
( ADReady r, ScalarOf r ~ r, InterpretAst r
, PrintfArg r, AssertEqualUpToEpsilon r
, Floating (Vector r), ADTensor r, DomainsTensor r
, DTensorOf r ~ OD.Array r
, TensorOf 1 r ~ OR.Array 1 r, TensorOf 2 r ~ OR.Array 2 r )
=> String
-> Int -> Int -> Int -> Int -> r -> Int -> r
-> TestTree
Expand Down

0 comments on commit 35ea918

Please sign in to comment.