From 4b12e441b86de1e07ff8f792f424d54dbf7a2bde Mon Sep 17 00:00:00 2001 From: moneta Date: Thu, 12 Dec 2024 17:21:24 +0100 Subject: [PATCH] [tmva][sofie] Add Sin/Cos operators Add Sin and Cos operators as new Unary operators. Add also tests, taken from Vedant's PR #16809 --- tmva/sofie/inc/TMVA/ROperator_BasicUnary.hxx | 14 +++++- tmva/sofie/inc/TMVA/SOFIE_common.hxx | 2 +- tmva/sofie/test/TestCustomModelsFromONNX.cxx | 50 ++++++++++++++++++++ tmva/sofie/test/input_models/Cos.onnx | 12 +++++ tmva/sofie/test/input_models/Sin.onnx | 12 +++++ tmva/sofie_parsers/src/ParseBasicUnary.cxx | 10 ++++ tmva/sofie_parsers/src/RModelParser_ONNX.cxx | 4 ++ 7 files changed, 102 insertions(+), 2 deletions(-) create mode 100644 tmva/sofie/test/input_models/Cos.onnx create mode 100644 tmva/sofie/test/input_models/Sin.onnx diff --git a/tmva/sofie/inc/TMVA/ROperator_BasicUnary.hxx b/tmva/sofie/inc/TMVA/ROperator_BasicUnary.hxx index f23686e92de69..3393421f18473 100644 --- a/tmva/sofie/inc/TMVA/ROperator_BasicUnary.hxx +++ b/tmva/sofie/inc/TMVA/ROperator_BasicUnary.hxx @@ -9,7 +9,7 @@ namespace TMVA { namespace Experimental { namespace SOFIE { -enum class EBasicUnaryOperator { kReciprocal, kSqrt , kNeg, kExp, kLog }; +enum class EBasicUnaryOperator { kReciprocal, kSqrt , kNeg, kExp, kLog, kSin, kCos }; template struct UnaryOpTraits { @@ -45,6 +45,18 @@ struct UnaryOpTraits { static std::string Op(const std::string &X) { return "std::log(" + X + ")"; } }; +template +struct UnaryOpTraits { + static std::string Name() { return "Sin"; } + static std::string Op(const std::string &X) { return "std::sin(" + X + ")"; } +}; + +template +struct UnaryOpTraits { + static std::string Name() { return "Cos"; } + static std::string Op(const std::string &X) { return "std::cos(" + X + ")"; } +}; + template class ROperator_BasicUnary final : public ROperator { private: diff --git a/tmva/sofie/inc/TMVA/SOFIE_common.hxx b/tmva/sofie/inc/TMVA/SOFIE_common.hxx index 44cbc00230ca9..32ee288aa2ed3 100644 --- a/tmva/sofie/inc/TMVA/SOFIE_common.hxx +++ b/tmva/sofie/inc/TMVA/SOFIE_common.hxx @@ -292,7 +292,7 @@ void BroadcastTensor(ContT data, const std::vector& shape, const std::ve if (shape.front() == targetShape.front() && shape.back() == 1 && size > 1) { size_t bsize = targetShape.back(); // compute the size of the data to broadcast - for (size_t k = size-2; k >=0; k--) { + for (int k = int(size)-2; k >=0; k--) { if (shape[k] != 1) break; bsize *= targetShape[k]; } diff --git a/tmva/sofie/test/TestCustomModelsFromONNX.cxx b/tmva/sofie/test/TestCustomModelsFromONNX.cxx index f95143d9c1122..20e0d39378de9 100644 --- a/tmva/sofie/test/TestCustomModelsFromONNX.cxx +++ b/tmva/sofie/test/TestCustomModelsFromONNX.cxx @@ -304,6 +304,10 @@ #include "Where_FromONNX.hxx" +#include "Sin_FromONNX.hxx" + +#include "Cos_FromONNX.hxx" + #include "gtest/gtest.h" constexpr float DEFAULT_TOLERANCE = 1e-3f; @@ -2937,4 +2941,50 @@ TEST(ONNX, Where) { for (size_t i = 0; i < output.size(); i++) { EXPECT_EQ(output[i], correct[i]); } +} +float outputs[] = {0.406200, 0.111242, 0.770231, 0.940162, 0.260436, -0.258742, + 0.304129, 0.999899, 0.256423, 0.410855, 0.843406, 0.862500}; + +TEST(ONNX, Sin) +{ + constexpr float TOLERANCE = DEFAULT_TOLERANCE; + + // Preparing some random input + std::vector input({ + -0.786738,-0.197796,-0.187787,0.142758,0.876096,-0.653239,0.145444,-1.107658,2.259171,-0.947054,-0.506689,1.801250 + }); + + TMVA_SOFIE_Sin::Session s("Sin_FromONNX.dat"); + + std::vector output = s.infer(input.data()); + + // Checking output size + EXPECT_EQ(output.size(), input.size()); + + // Checking every output value, one by one + for (size_t i = 0; i < output.size(); ++i) { + EXPECT_LE(std::abs(output[i] - std::sin(input[i])), TOLERANCE); + } +} + +TEST(ONNX, Cos) +{ + constexpr float TOLERANCE = DEFAULT_TOLERANCE; + + // Preparing the random input + std::vector input({ + 1.152504,-1.459324,0.691594,0.347690,-1.307323,1.832516,-1.261772,0.014224,1.311477,1.147405,-0.567206,-0.530606 + }); + + TMVA_SOFIE_Cos::Session s("Cos_FromONNX.dat"); + + std::vector output = s.infer(input.data()); + + // Checking output size + EXPECT_EQ(output.size(), input.size()); + + // Checking every output value, one by one + for (size_t i = 0; i < output.size(); ++i) { + EXPECT_LE(std::abs(output[i] - std::cos(input[i])), TOLERANCE); + } } \ No newline at end of file diff --git a/tmva/sofie/test/input_models/Cos.onnx b/tmva/sofie/test/input_models/Cos.onnx new file mode 100644 index 0000000000000..31877b49bba14 --- /dev/null +++ b/tmva/sofie/test/input_models/Cos.onnx @@ -0,0 +1,12 @@ + + cos_example:S + +inputoutput"CosCosGraphZ +input +  + +b +output +  + +B \ No newline at end of file diff --git a/tmva/sofie/test/input_models/Sin.onnx b/tmva/sofie/test/input_models/Sin.onnx new file mode 100644 index 0000000000000..dedfc39623d0a --- /dev/null +++ b/tmva/sofie/test/input_models/Sin.onnx @@ -0,0 +1,12 @@ + + onnx-example:S + +inputoutput"Sinsin_testZ +input +  + +b +output +  + +B \ No newline at end of file diff --git a/tmva/sofie_parsers/src/ParseBasicUnary.cxx b/tmva/sofie_parsers/src/ParseBasicUnary.cxx index 5521dc08f8cc0..1cbed78e00b7f 100644 --- a/tmva/sofie_parsers/src/ParseBasicUnary.cxx +++ b/tmva/sofie_parsers/src/ParseBasicUnary.cxx @@ -65,6 +65,16 @@ ParserFuncSignature ParseLog = [](RModelParser_ONNX &parser, const onnx::NodePro return ParseBasicUnary(parser, nodeproto); }; +// Parse Sin +ParserFuncSignature ParseSin = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) { + return ParseBasicUnary(parser, nodeproto); +}; + +// Parse Cos +ParserFuncSignature ParseCos = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) { + return ParseBasicUnary(parser, nodeproto); +}; + } // namespace SOFIE } // namespace Experimental } // namespace TMVA diff --git a/tmva/sofie_parsers/src/RModelParser_ONNX.cxx b/tmva/sofie_parsers/src/RModelParser_ONNX.cxx index 25dd5eb10c8dc..c83b92bc9a640 100644 --- a/tmva/sofie_parsers/src/RModelParser_ONNX.cxx +++ b/tmva/sofie_parsers/src/RModelParser_ONNX.cxx @@ -22,6 +22,8 @@ extern ParserFuncSignature ParseReciprocal; extern ParserFuncSignature ParseNeg; extern ParserFuncSignature ParseExp; extern ParserFuncSignature ParseLog; +extern ParserFuncSignature ParseSin; +extern ParserFuncSignature ParseCos; // Binary operators extern ParserFuncSignature ParseAdd; extern ParserFuncSignature ParseSub; @@ -152,6 +154,8 @@ RModelParser_ONNX::RModelParser_ONNX() noexcept : fOperatorsMapImpl(std::make_un RegisterOperator("Neg", ParseNeg); RegisterOperator("Exp", ParseExp); RegisterOperator("Log", ParseLog); + RegisterOperator("Sin", ParseSin); + RegisterOperator("Cos", ParseCos); // Binary operators RegisterOperator("Add", ParseAdd); RegisterOperator("Sub", ParseSub);