Skip to content

Commit

Permalink
[tmva][sofie] Add Sin/Cos operators
Browse files Browse the repository at this point in the history
Add Sin and Cos operators as new Unary operators.
Add also tests, taken from Vedant's PR  #16809
  • Loading branch information
lmoneta committed Dec 13, 2024
1 parent d462aa0 commit 4b12e44
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 2 deletions.
14 changes: 13 additions & 1 deletion tmva/sofie/inc/TMVA/ROperator_BasicUnary.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace TMVA {
namespace Experimental {
namespace SOFIE {

enum class EBasicUnaryOperator { kReciprocal, kSqrt , kNeg, kExp, kLog };
enum class EBasicUnaryOperator { kReciprocal, kSqrt , kNeg, kExp, kLog, kSin, kCos };

template <typename T, EBasicUnaryOperator Op>
struct UnaryOpTraits {
Expand Down Expand Up @@ -45,6 +45,18 @@ struct UnaryOpTraits<T, EBasicUnaryOperator::kLog> {
static std::string Op(const std::string &X) { return "std::log(" + X + ")"; }
};

template <typename T>
struct UnaryOpTraits<T, EBasicUnaryOperator::kSin> {
static std::string Name() { return "Sin"; }
static std::string Op(const std::string &X) { return "std::sin(" + X + ")"; }
};

template <typename T>
struct UnaryOpTraits<T, EBasicUnaryOperator::kCos> {
static std::string Name() { return "Cos"; }
static std::string Op(const std::string &X) { return "std::cos(" + X + ")"; }
};

template <typename T, EBasicUnaryOperator Op>
class ROperator_BasicUnary final : public ROperator {
private:
Expand Down
2 changes: 1 addition & 1 deletion tmva/sofie/inc/TMVA/SOFIE_common.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ void BroadcastTensor(ContT data, const std::vector<size_t>& shape, const std::ve
if (shape.front() == targetShape.front() && shape.back() == 1 && size > 1) {
size_t bsize = targetShape.back();
// compute the size of the data to broadcast
for (size_t k = size-2; k >=0; k--) {
for (int k = int(size)-2; k >=0; k--) {
if (shape[k] != 1) break;
bsize *= targetShape[k];
}
Expand Down
50 changes: 50 additions & 0 deletions tmva/sofie/test/TestCustomModelsFromONNX.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,10 @@

#include "Where_FromONNX.hxx"

#include "Sin_FromONNX.hxx"

#include "Cos_FromONNX.hxx"

#include "gtest/gtest.h"

constexpr float DEFAULT_TOLERANCE = 1e-3f;
Expand Down Expand Up @@ -2937,4 +2941,50 @@ TEST(ONNX, Where) {
for (size_t i = 0; i < output.size(); i++) {
EXPECT_EQ(output[i], correct[i]);
}
}
float outputs[] = {0.406200, 0.111242, 0.770231, 0.940162, 0.260436, -0.258742,
0.304129, 0.999899, 0.256423, 0.410855, 0.843406, 0.862500};

TEST(ONNX, Sin)
{
constexpr float TOLERANCE = DEFAULT_TOLERANCE;

// Preparing some random input
std::vector<float> input({
-0.786738,-0.197796,-0.187787,0.142758,0.876096,-0.653239,0.145444,-1.107658,2.259171,-0.947054,-0.506689,1.801250
});

TMVA_SOFIE_Sin::Session s("Sin_FromONNX.dat");

std::vector<float> output = s.infer(input.data());

// Checking output size
EXPECT_EQ(output.size(), input.size());

// Checking every output value, one by one
for (size_t i = 0; i < output.size(); ++i) {
EXPECT_LE(std::abs(output[i] - std::sin(input[i])), TOLERANCE);
}
}

TEST(ONNX, Cos)
{
constexpr float TOLERANCE = DEFAULT_TOLERANCE;

// Preparing the random input
std::vector<float> input({
1.152504,-1.459324,0.691594,0.347690,-1.307323,1.832516,-1.261772,0.014224,1.311477,1.147405,-0.567206,-0.530606
});

TMVA_SOFIE_Cos::Session s("Cos_FromONNX.dat");

std::vector<float> output = s.infer(input.data());

// Checking output size
EXPECT_EQ(output.size(), input.size());

// Checking every output value, one by one
for (size_t i = 0; i < output.size(); ++i) {
EXPECT_LE(std::abs(output[i] - std::cos(input[i])), TOLERANCE);
}
}
12 changes: 12 additions & 0 deletions tmva/sofie/test/input_models/Cos.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@

 cos_example:S

inputoutput"CosCosGraphZ
input


b
output


B
12 changes: 12 additions & 0 deletions tmva/sofie/test/input_models/Sin.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@

 onnx-example:S

inputoutput"Sinsin_testZ
input


b
output


B
10 changes: 10 additions & 0 deletions tmva/sofie_parsers/src/ParseBasicUnary.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,16 @@ ParserFuncSignature ParseLog = [](RModelParser_ONNX &parser, const onnx::NodePro
return ParseBasicUnary<EBasicUnaryOperator::kLog>(parser, nodeproto);
};

// Parse Sin
ParserFuncSignature ParseSin = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
return ParseBasicUnary<EBasicUnaryOperator::kSin>(parser, nodeproto);
};

// Parse Cos
ParserFuncSignature ParseCos = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
return ParseBasicUnary<EBasicUnaryOperator::kCos>(parser, nodeproto);
};

} // namespace SOFIE
} // namespace Experimental
} // namespace TMVA
4 changes: 4 additions & 0 deletions tmva/sofie_parsers/src/RModelParser_ONNX.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ extern ParserFuncSignature ParseReciprocal;
extern ParserFuncSignature ParseNeg;
extern ParserFuncSignature ParseExp;
extern ParserFuncSignature ParseLog;
extern ParserFuncSignature ParseSin;
extern ParserFuncSignature ParseCos;
// Binary operators
extern ParserFuncSignature ParseAdd;
extern ParserFuncSignature ParseSub;
Expand Down Expand Up @@ -152,6 +154,8 @@ RModelParser_ONNX::RModelParser_ONNX() noexcept : fOperatorsMapImpl(std::make_un
RegisterOperator("Neg", ParseNeg);
RegisterOperator("Exp", ParseExp);
RegisterOperator("Log", ParseLog);
RegisterOperator("Sin", ParseSin);
RegisterOperator("Cos", ParseCos);
// Binary operators
RegisterOperator("Add", ParseAdd);
RegisterOperator("Sub", ParseSub);
Expand Down

0 comments on commit 4b12e44

Please sign in to comment.