diff --git a/horde-ad.cabal b/horde-ad.cabal index ea3fd114f..58325b810 100644 --- a/horde-ad.cabal +++ b/horde-ad.cabal @@ -734,6 +734,7 @@ test-suite simplifiedOnlyTest , ghc-typelits-natnormalise , hmatrix , ilist + , mono-traversable , orthotope , random , strict-containers diff --git a/test/simplified/TestMnistFCNNR.hs b/test/simplified/TestMnistFCNNR.hs index 3c809966a..076cfce3b 100644 --- a/test/simplified/TestMnistFCNNR.hs +++ b/test/simplified/TestMnistFCNNR.hs @@ -9,8 +9,11 @@ import qualified Data.Array.DynamicS as OD import qualified Data.Array.RankedS as OR import qualified Data.EnumMap.Strict as EM import Data.List.Index (imap) +import Data.MonoTraversable (Element) import qualified Data.Strict.IntMap as IM +import qualified Data.Strict.Vector as Data.Vector import qualified Data.Vector.Generic as V +import Numeric.LinearAlgebra (Vector) import qualified Numeric.LinearAlgebra as LA import System.IO (hPutStrLn, stderr) import System.Random @@ -21,9 +24,10 @@ import Text.Printf import HordeAd.Core.Ast import HordeAd.Core.AstInterpret import HordeAd.Core.AstSimplify -import HordeAd.Core.DualNumber +import HordeAd.Core.DualNumber (ADVal, dDnotShared) import HordeAd.Core.Engine import HordeAd.Core.SizedIndex +import HordeAd.Core.TensorADVal (ADTensor) import HordeAd.Core.TensorClass import HordeAd.External.Adaptor import HordeAd.External.CommonRankedOps @@ -51,12 +55,15 @@ testTrees = [ tensorADValMnistTests -- POPL differentiation, straight via the ADVal instance of Tensor mnistTestCase2VTA :: forall r. - ( ADReady r, ADReady (ADVal r), ADNum r, PrintfArg r - , Primal r ~ r, ScalarOf r ~ r, AssertEqualUpToEpsilon r + ( ADReady r, ADReady (ADVal r), ScalarOf r ~ r, ScalarOf (ADVal r) ~ r , TensorOf 0 (ADVal r) ~ ADVal (OR.Array 0 r) , TensorOf 1 (ADVal r) ~ ADVal (OR.Array 1 r) , DTensorOf (ADVal r) ~ ADVal (OD.Array r) - , Primal (ADVal r) ~ r, ScalarOf (ADVal r) ~ r ) + , PrintfArg r, AssertEqualUpToEpsilon r + , Floating (Vector r), ADTensor r + , DynamicTensor r, DomainsTensor r, Element r ~ r + , DTensorOf r ~ OD.Array r, TensorOf 1 r ~ OR.Array 1 r + , DomainsOf r ~ Data.Vector.Vector (OD.Array r) ) => String -> Int -> Int -> Int -> Int -> r -> Int -> r -> TestTree @@ -141,12 +148,15 @@ tensorADValMnistTests = testGroup "ShortRanked ADVal MNIST tests" -- POPL differentiation, Ast term defined only once but differentiated each time mnistTestCase2VTI :: forall r. - ( ADReady r, ADReady (ADVal r), ADReady (Ast0 r), ADNum r, PrintfArg r - , Primal r ~ r, ScalarOf r ~ r, AssertEqualUpToEpsilon r - , TensorOf 0 (ADVal r) ~ ADVal (OR.Array 0 r) + ( ADReady r, ADReady (ADVal r), ScalarOf r ~ r, ScalarOf (ADVal r) ~ r + , TensorOf 1 (ADVal r) ~ ADVal (OR.Array 1 r) , DTensorOf (ADVal r) ~ ADVal (OD.Array r) - , ScalarOf (ADVal r) ~ r - , InterpretAst (ADVal r) ) + , InterpretAst (ADVal r) + , PrintfArg r, AssertEqualUpToEpsilon r + , Floating (Vector r), ADTensor r + , DynamicTensor r, DomainsTensor r, Element r ~ r + , DTensorOf r ~ OD.Array r, TensorOf 1 r ~ OR.Array 1 r + , DomainsOf r ~ Data.Vector.Vector (OD.Array r) ) => String -> Int -> Int -> Int -> Int -> r -> Int -> r -> TestTree @@ -256,9 +266,10 @@ tensorIntermediateMnistTests = testGroup "ShortRankedIntermediate MNIST tests" -- JAX differentiation, Ast term built and differentiated only once mnistTestCase2VTO :: forall r. - ( ADReady r, ADReady (ADVal r), ADReady (Ast0 r), ADNum r, PrintfArg r - , Primal r ~ r, ScalarOf r ~ r, AssertEqualUpToEpsilon r - , TensorOf 0 (ADVal r) ~ ADVal (OR.Array 0 r), InterpretAst r ) + ( ADReady r, ScalarOf r ~ r, InterpretAst r + , PrintfArg r, AssertEqualUpToEpsilon r + , Floating (Vector r), ADTensor r, DomainsTensor r + , DTensorOf r ~ OD.Array r, TensorOf 1 r ~ OR.Array 1 r ) => String -> Int -> Int -> Int -> Int -> r -> Int -> r -> TestTree @@ -387,13 +398,16 @@ tensorADOnceMnistTests = testGroup "ShortRankedOnce MNIST tests" -- POPL differentiation, straight via the ADVal instance of Tensor mnistTestCase2VT2A :: forall r. - ( ADReady r, ADReady (ADVal r), ADNum r, PrintfArg r - , Primal r ~ r, ScalarOf r ~ r, AssertEqualUpToEpsilon r + ( ADReady r, ADReady (ADVal r), ScalarOf r ~ r, ScalarOf (ADVal r) ~ r , TensorOf 0 (ADVal r) ~ ADVal (OR.Array 0 r) , TensorOf 1 (ADVal r) ~ ADVal (OR.Array 1 r) , TensorOf 2 (ADVal r) ~ ADVal (OR.Array 2 r) , DTensorOf (ADVal r) ~ ADVal (OD.Array r) - , Primal (ADVal r) ~ r, ScalarOf (ADVal r) ~ r ) + , PrintfArg r, AssertEqualUpToEpsilon r + , Floating (Vector r), ADTensor r + , DynamicTensor r, DomainsTensor r, Element r ~ r + , DTensorOf r ~ OD.Array r, DomainsOf r ~ Data.Vector.Vector (OD.Array r) + , TensorOf 1 r ~ OR.Array 1 r, TensorOf 2 r ~ OR.Array 2 r ) => String -> Int -> Int -> Int -> Int -> r -> Int -> r -> TestTree @@ -479,12 +493,15 @@ tensorADValMnistTests2 = testGroup "ShortRanked ADVal MNIST tests" -- POPL differentiation, Ast term defined only once but differentiated each time mnistTestCase2VT2I :: forall r. - ( ADReady r, ADReady (ADVal r), ADReady (Ast0 r), ADNum r, PrintfArg r - , Primal r ~ r, ScalarOf r ~ r, AssertEqualUpToEpsilon r - , TensorOf 0 (ADVal r) ~ ADVal (OR.Array 0 r) + ( ADReady r, ADReady (ADVal r), ScalarOf r ~ r, ScalarOf (ADVal r) ~ r + , TensorOf 1 (ADVal r) ~ ADVal (OR.Array 1 r) , DTensorOf (ADVal r) ~ ADVal (OD.Array r) - , ScalarOf (ADVal r) ~ r - , InterpretAst (ADVal r) ) + , InterpretAst (ADVal r) + , PrintfArg r, AssertEqualUpToEpsilon r + , Floating (Vector r), ADTensor r + , DynamicTensor r, DomainsTensor r, Element r ~ r + , DTensorOf r ~ OD.Array r, DomainsOf r ~ Data.Vector.Vector (OD.Array r) + , TensorOf 1 r ~ OR.Array 1 r, TensorOf 2 r ~ OR.Array 2 r ) => String -> Int -> Int -> Int -> Int -> r -> Int -> r -> TestTree @@ -595,9 +612,11 @@ tensorIntermediateMnistTests2 = testGroup "ShortRankedIntermediate MNIST tests" -- JAX differentiation, Ast term built and differentiated only once mnistTestCase2VT2O :: forall r. - ( ADReady r, ADReady (ADVal r), ADReady (Ast0 r), ADNum r, PrintfArg r - , Primal r ~ r, ScalarOf r ~ r, AssertEqualUpToEpsilon r - , TensorOf 0 (ADVal r) ~ ADVal (OR.Array 0 r), InterpretAst r ) + ( ADReady r, ScalarOf r ~ r, InterpretAst r + , PrintfArg r, AssertEqualUpToEpsilon r + , Floating (Vector r), ADTensor r, DomainsTensor r + , DTensorOf r ~ OD.Array r + , TensorOf 1 r ~ OR.Array 1 r, TensorOf 2 r ~ OR.Array 2 r ) => String -> Int -> Int -> Int -> Int -> r -> Int -> r -> TestTree