Skip to content

Commit

Permalink
Remove static asserts from generated code
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Apr 12, 2024
1 parent 4f8292c commit 11a5aa3
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 1 deletion.
1 change: 1 addition & 0 deletions include/clad/Differentiator/BaseForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ class BaseForwardModeVisitor
const clang::SubstNonTypeTemplateParmExpr* NTTP);
StmtDiff VisitImplicitValueInitExpr(const clang::ImplicitValueInitExpr* IVIE);
StmtDiff VisitCStyleCastExpr(const clang::CStyleCastExpr* CSCE);
StmtDiff VisitStaticAssertDecl(const clang::StaticAssertDecl* SAD);

virtual clang::QualType
GetPushForwardDerivativeType(clang::QualType ParamType);
Expand Down
1 change: 1 addition & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ namespace clad {
const clang::SubstNonTypeTemplateParmExpr* NTTP);
StmtDiff
VisitCXXNullPtrLiteralExpr(const clang::CXXNullPtrLiteralExpr* NPE);
StmtDiff VisitStaticAssertDecl(const clang::StaticAssertDecl* SAD);

/// A helper method to differentiate a single Stmt in the reverse mode.
/// Internally, calls Visit(S, expr). Its result is wrapped into a
Expand Down
8 changes: 7 additions & 1 deletion lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include "clad/Differentiator/CladUtils.h"
#include "clad/Differentiator/DiffPlanner.h"
#include "clad/Differentiator/ErrorEstimator.h"
#include "clad/Differentiator/StmtClone.h"

#include "clang/AST/ASTContext.h"
#include "clang/AST/Expr.h"
Expand Down Expand Up @@ -1454,6 +1453,8 @@ StmtDiff BaseForwardModeVisitor::VisitDeclStmt(const DeclStmt* DS) {
m_DeclReplacements[VD] = VDDiff.getDecl();
decls.push_back(VDDiff.getDecl());
declsDiff.push_back(VDDiff.getDecl_dx());
} else if (auto SAD = dyn_cast<StaticAssertDecl>(D)) {
return VisitStaticAssertDecl(SAD);
} else {
diag(DiagnosticsEngine::Warning, D->getEndLoc(),
"Unsupported declaration");
Expand Down Expand Up @@ -2022,4 +2023,9 @@ StmtDiff BaseForwardModeVisitor::VisitSubstNonTypeTemplateParmExpr(
const clang::SubstNonTypeTemplateParmExpr* NTTP) {
return Visit(NTTP->getReplacement());
}

StmtDiff BaseForwardModeVisitor::VisitStaticAssertDecl(
const clang::StaticAssertDecl* SAD) {
return nullptr;
}
} // end namespace clad
7 changes: 7 additions & 0 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2878,6 +2878,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
localDeclsDiff.push_back(VDDiff.getDecl_dx());
else
declsDiff.push_back(VDDiff.getDecl_dx());
} else if (auto SAD = dyn_cast<StaticAssertDecl>(D)) {
return VisitStaticAssertDecl(SAD);
} else {
diag(DiagnosticsEngine::Warning,
D->getEndLoc(),
Expand Down Expand Up @@ -3831,6 +3833,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return Visit(NTTP->getReplacement());
}

StmtDiff ReverseModeVisitor::VisitStaticAssertDecl(
const clang::StaticAssertDecl* SAD) {
return nullptr;
}

QualType ReverseModeVisitor::GetParameterDerivativeType(QualType yType,
QualType xType) {

Expand Down
22 changes: 22 additions & 0 deletions test/FirstDerivative/FunctionCallsWithResults.C
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// RUN: ./FunctionCallsWithResults.out | FileCheck -check-prefix=CHECK-EXEC %s

#include "clad/Differentiator/Differentiator.h"
#include <random>

int printf(const char* fmt, ...);

Expand Down Expand Up @@ -289,6 +290,25 @@ double fn9 (double i, double j) {
// CHECK-NEXT: return _t0.pushforward * _t3 + _t2 * _t1.pushforward;
// CHECK-NEXT: }

double fn10(double x) {
std::mt19937 gen64;
std::uniform_real_distribution<double> distribution(0.0,1.0);
double rand = distribution(gen64);
return x+rand;
}

// CHECK: double f_darg0(double x) {
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: std::mt19937 _d_gen64;
// CHECK-NEXT: std::mt19937 gen64;
// CHECK-NEXT: std::uniform_real_distribution<double> _d_distribution(0., 0.);
// CHECK-NEXT: std::uniform_real_distribution<double> distribution(0., 1.);
// CHECK-NEXT: clad::ValueAndPushforward<result_type, result_type> _t0 = distribution.operator_call_pushforward(gen64, &_d_distribution, _d_gen64);
// CHECK-NEXT: double _d_rand = _t0.pushforward;
// CHECK-NEXT: double rand0 = _t0.value;
// CHECK-NEXT: return _d_x * x + x * _d_x + _d_rand;
// CHECK-NEXT: }

float test_1_darg0(float x);
float test_2_darg0(float x);
float test_4_darg0(float x);
Expand Down Expand Up @@ -318,6 +338,7 @@ int main () {
INIT(fn7, "i");
INIT(fn8, "i");
INIT(fn9, "i");
INIT(fn10, "x");

TEST(fn1, 3, 5); // CHECK-EXEC: {12.00}
TEST(fn2, 3, 5); // CHECK-EXEC: {181.00}
Expand All @@ -328,6 +349,7 @@ int main () {
TEST(fn7, 3, 5); // CHECK-EXEC: {8.00}
TEST(fn8, 3, 5); // CHECK-EXEC: {19.04}
TEST(fn9, 3, 5); // CHECK-EXEC: {5.00}
TEST(fn10, 3); // CHECK-EXEC: {1.00}
return 0;

// CHECK: clad::ValueAndPushforward<double, double> sum_of_squares_pushforward(double u, double v, double _d_u, double _d_v) {
Expand Down

0 comments on commit 11a5aa3

Please sign in to comment.