Skip to content

Commit

Permalink
Fix static asserts in generated code
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Apr 15, 2024
1 parent c3202c1 commit 3b216a0
Show file tree
Hide file tree
Showing 12 changed files with 126 additions and 35 deletions.
5 changes: 4 additions & 1 deletion include/clad/Differentiator/BaseForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ class BaseForwardModeVisitor
StmtDiff VisitStmt(const clang::Stmt* S);
StmtDiff VisitUnaryOperator(const clang::UnaryOperator* UnOp);
// Decl is not Stmt, so it cannot be visited directly.
virtual VarDeclDiff DifferentiateVarDecl(const clang::VarDecl* VD);
virtual DeclDiff<clang::VarDecl>
DifferentiateVarDecl(const clang::VarDecl* VD);
/// Shorthand for warning on differentiation of unsupported operators
void unsupportedOpWarn(clang::SourceLocation loc,
llvm::ArrayRef<llvm::StringRef> args = {}) {
Expand Down Expand Up @@ -108,6 +109,8 @@ class BaseForwardModeVisitor
const clang::SubstNonTypeTemplateParmExpr* NTTP);
StmtDiff VisitImplicitValueInitExpr(const clang::ImplicitValueInitExpr* IVIE);
StmtDiff VisitCStyleCastExpr(const clang::CStyleCastExpr* CSCE);
static DeclDiff<clang::StaticAssertDecl>
DifferentiateStaticAssertDecl(const clang::StaticAssertDecl* SAD);

virtual clang::QualType
GetPushForwardDerivativeType(clang::QualType ParamType);
Expand Down
2 changes: 1 addition & 1 deletion include/clad/Differentiator/ErrorEstimator.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ class ErrorEstimationHandler : public ExternalRMVSource {
/// \param[in] VDDiff The variable declaration to calculate the error in.
/// \param[in] isInsideLoop A flag to keep track of if we are inside a
/// loop.
void EmitDeclErrorStmts(VarDeclDiff VDDiff, bool isInsideLoop);
void EmitDeclErrorStmts(DeclDiff<clang::VarDecl> VDDiff, bool isInsideLoop);

/// This function returns the size expression for a given variable
/// (`var.size()` for clad::array/clad::array_ref
Expand Down
3 changes: 2 additions & 1 deletion include/clad/Differentiator/ExternalRMVSource.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ namespace clad {

struct DiffRequest;
class StmtDiff;
class VarDeclDiff;

template <typename T> class DeclDiff;

using direction = rmv::direction;

Expand Down
4 changes: 3 additions & 1 deletion include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -406,11 +406,13 @@ namespace clad {
StmtDiff VisitSwitchStmt(const clang::SwitchStmt* SS);
StmtDiff VisitCaseStmt(const clang::CaseStmt* CS);
StmtDiff VisitDefaultStmt(const clang::DefaultStmt* DS);
VarDeclDiff DifferentiateVarDecl(const clang::VarDecl* VD);
DeclDiff<clang::VarDecl> DifferentiateVarDecl(const clang::VarDecl* VD);
StmtDiff VisitSubstNonTypeTemplateParmExpr(
const clang::SubstNonTypeTemplateParmExpr* NTTP);
StmtDiff
VisitCXXNullPtrLiteralExpr(const clang::CXXNullPtrLiteralExpr* NPE);
static DeclDiff<clang::StaticAssertDecl>
DifferentiateStaticAssertDecl(const clang::StaticAssertDecl* SAD);

/// A helper method to differentiate a single Stmt in the reverse mode.
/// Internally, calls Visit(S, expr). Its result is wrapped into a
Expand Down
3 changes: 2 additions & 1 deletion include/clad/Differentiator/VectorForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ class VectorForwardModeVisitor : public BaseForwardModeVisitor {
VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE) override;
StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override;
// Decl is not Stmt, so it cannot be visited directly.
VarDeclDiff DifferentiateVarDecl(const clang::VarDecl* VD) override;
DeclDiff<clang::VarDecl>
DifferentiateVarDecl(const clang::VarDecl* VD) override;

clang::QualType
GetPushForwardDerivativeType(clang::QualType ParamType) override;
Expand Down
17 changes: 8 additions & 9 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,21 +80,20 @@ namespace clad {
void setForwSweepStmt_dx(clang::Stmt* S) { m_DerivativeForForwSweep = S; }
};

class VarDeclDiff {
template <typename T> class DeclDiff {
private:
std::array<clang::VarDecl*, 2> data;
std::array<T*, 2> m_data;

public:
VarDeclDiff(clang::VarDecl* orig = nullptr,
clang::VarDecl* diff = nullptr) {
data[1] = orig;
data[0] = diff;
DeclDiff(T* orig = nullptr, T* diff = nullptr) {
m_data[1] = orig;
m_data[0] = diff;
}

clang::VarDecl* getDecl() { return data[1]; }
clang::VarDecl* getDecl_dx() { return data[0]; }
T* getDecl() { return m_data[1]; }
T* getDecl_dx() { return m_data[0]; }
// Decl_dx goes first!
std::array<clang::VarDecl*, 2>& getBothDecls() { return data; }
std::array<T*, 2>& getBothDecls() { return m_data; }
};

/// A base class for all common functionality for visitors
Expand Down
36 changes: 26 additions & 10 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include "clad/Differentiator/CladUtils.h"
#include "clad/Differentiator/DiffPlanner.h"
#include "clad/Differentiator/ErrorEstimator.h"
#include "clad/Differentiator/StmtClone.h"

#include "clang/AST/ASTContext.h"
#include "clang/AST/Expr.h"
Expand Down Expand Up @@ -574,7 +573,7 @@ StmtDiff BaseForwardModeVisitor::VisitIfStmt(const IfStmt* If) {

VarDecl* condVarClone = nullptr;
if (const VarDecl* condVarDecl = If->getConditionVariable()) {
VarDeclDiff condVarDeclDiff = DifferentiateVarDecl(condVarDecl);
DeclDiff<VarDecl> condVarDeclDiff = DifferentiateVarDecl(condVarDecl);
condVarClone = condVarDeclDiff.getDecl();
if (condVarDeclDiff.getDecl_dx())
addToCurrentBlock(BuildDeclStmt(condVarDeclDiff.getDecl_dx()));
Expand Down Expand Up @@ -672,7 +671,7 @@ StmtDiff BaseForwardModeVisitor::VisitForStmt(const ForStmt* FS) {
VarDecl* condVarDecl = FS->getConditionVariable();
VarDecl* condVarClone = nullptr;
if (condVarDecl) {
VarDeclDiff condVarResult = DifferentiateVarDecl(condVarDecl);
DeclDiff<VarDecl> condVarResult = DifferentiateVarDecl(condVarDecl);
condVarClone = condVarResult.getDecl();
if (condVarResult.getDecl_dx())
addToCurrentBlock(BuildDeclStmt(condVarResult.getDecl_dx()));
Expand Down Expand Up @@ -1380,7 +1379,8 @@ BaseForwardModeVisitor::VisitBinaryOperator(const BinaryOperator* BinOp) {
return StmtDiff(op, opDiff);
}

VarDeclDiff BaseForwardModeVisitor::DifferentiateVarDecl(const VarDecl* VD) {
DeclDiff<VarDecl>
BaseForwardModeVisitor::DifferentiateVarDecl(const VarDecl* VD) {
StmtDiff initDiff = VD->getInit() ? Visit(VD->getInit()) : StmtDiff{};
// Here we are assuming that derived type and the original type are same.
// This may not necessarily be true in the future.
Expand All @@ -1392,7 +1392,7 @@ VarDeclDiff BaseForwardModeVisitor::DifferentiateVarDecl(const VarDecl* VD) {
VD->getType(), "_d_" + VD->getNameAsString(), initDiff.getExpr_dx(),
VD->isDirectInit(), nullptr, VD->getInitStyle());
m_Variables.emplace(VDClone, BuildDeclRef(VDDerived));
return VarDeclDiff(VDClone, VDDerived);
return DeclDiff<VarDecl>(VDClone, VDDerived);
}

StmtDiff BaseForwardModeVisitor::VisitDeclStmt(const DeclStmt* DS) {
Expand Down Expand Up @@ -1431,7 +1431,7 @@ StmtDiff BaseForwardModeVisitor::VisitDeclStmt(const DeclStmt* DS) {
// double _d_y = _d_x; double y = x;
for (auto D : DS->decls()) {
if (auto VD = dyn_cast<VarDecl>(D)) {
VarDeclDiff VDDiff = DifferentiateVarDecl(VD);
DeclDiff<VarDecl> VDDiff = DifferentiateVarDecl(VD);
// Check if decl's name is the same as before. The name may be changed
// if decl name collides with something in the derivative body.
// This can happen in rare cases, e.g. when the original function
Expand All @@ -1454,14 +1454,24 @@ StmtDiff BaseForwardModeVisitor::VisitDeclStmt(const DeclStmt* DS) {
m_DeclReplacements[VD] = VDDiff.getDecl();
decls.push_back(VDDiff.getDecl());
declsDiff.push_back(VDDiff.getDecl_dx());
} else if (auto* SAD = dyn_cast<StaticAssertDecl>(D)) {
DeclDiff<StaticAssertDecl> SADDiff = DifferentiateStaticAssertDecl(SAD);
if (SADDiff.getDecl())
decls.push_back(SADDiff.getDecl());
if (SADDiff.getDecl_dx())
declsDiff.push_back(SADDiff.getDecl_dx());
} else {
diag(DiagnosticsEngine::Warning, D->getEndLoc(),
"Unsupported declaration");
}
}

Stmt* DSClone = BuildDeclStmt(decls);
Stmt* DSDiff = BuildDeclStmt(declsDiff);
Stmt* DSClone = nullptr;
Stmt* DSDiff = nullptr;
if (!decls.empty())
DSClone = BuildDeclStmt(decls);
if (!declsDiff.empty())
DSDiff = BuildDeclStmt(declsDiff);
return StmtDiff(DSClone, DSDiff);
}

Expand Down Expand Up @@ -1534,7 +1544,7 @@ StmtDiff BaseForwardModeVisitor::VisitWhileStmt(const WhileStmt* WS) {

const VarDecl* condVar = WS->getConditionVariable();
VarDecl* condVarClone = nullptr;
VarDeclDiff condVarRes;
DeclDiff<VarDecl> condVarRes;
if (condVar) {
condVarRes = DifferentiateVarDecl(condVar);
condVarClone = condVarRes.getDecl();
Expand Down Expand Up @@ -1659,7 +1669,7 @@ StmtDiff BaseForwardModeVisitor::VisitSwitchStmt(const SwitchStmt* SS) {
const VarDecl* condVarDecl = SS->getConditionVariable();
VarDecl* condVarClone = nullptr;
if (condVarDecl) {
VarDeclDiff condVarDeclDiff = DifferentiateVarDecl(condVarDecl);
DeclDiff<VarDecl> condVarDeclDiff = DifferentiateVarDecl(condVarDecl);
condVarClone = condVarDeclDiff.getDecl();
addToCurrentBlock(BuildDeclStmt(condVarDeclDiff.getDecl_dx()));
}
Expand Down Expand Up @@ -2022,4 +2032,10 @@ StmtDiff BaseForwardModeVisitor::VisitSubstNonTypeTemplateParmExpr(
const clang::SubstNonTypeTemplateParmExpr* NTTP) {
return Visit(NTTP->getReplacement());
}

DeclDiff<StaticAssertDecl>
BaseForwardModeVisitor::DifferentiateStaticAssertDecl(
const clang::StaticAssertDecl* SAD) {
return DeclDiff<StaticAssertDecl>();
}
} // end namespace clad
6 changes: 3 additions & 3 deletions lib/Differentiator/ErrorEstimator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ void ErrorEstimationHandler::EmitBinaryOpErrorStmts(Expr* LExpr,
EmitErrorEstimationStmts(direction::reverse);
}

void ErrorEstimationHandler::EmitDeclErrorStmts(VarDeclDiff VDDiff,
void ErrorEstimationHandler::EmitDeclErrorStmts(DeclDiff<VarDecl> VDDiff,
bool isInsideLoop) {
auto VD = VDDiff.getDecl();
if (!ShouldEstimateErrorFor(VD))
Expand Down Expand Up @@ -481,8 +481,8 @@ void ErrorEstimationHandler::ActBeforeFinalizingVisitDeclStmt(
// For all dependent variables, we register them for estimation
// here.
for (size_t i = 0; i < decls.size(); i++) {
VarDeclDiff VDDiff(static_cast<VarDecl*>(decls[0]),
static_cast<VarDecl*>(declsDiff[0]));
DeclDiff<VarDecl> VDDiff(cast<VarDecl>(decls[0]),
cast<VarDecl>(declsDiff[0]));
EmitDeclErrorStmts(VDDiff, m_RMV->isInsideLoop);
}
}
Expand Down
24 changes: 19 additions & 5 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

VarDecl* condVarClone = nullptr;
if (const VarDecl* condVarDecl = If->getConditionVariable()) {
VarDeclDiff condVarDeclDiff = DifferentiateVarDecl(condVarDecl);
DeclDiff<VarDecl> condVarDeclDiff = DifferentiateVarDecl(condVarDecl);
condVarClone = condVarDeclDiff.getDecl();
if (condVarDeclDiff.getDecl_dx())
addToBlock(BuildDeclStmt(condVarDeclDiff.getDecl_dx()), m_Globals);
Expand Down Expand Up @@ -2549,7 +2549,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return StmtDiff(op, ResultRef, nullptr, valueForRevPass);
}

VarDeclDiff ReverseModeVisitor::DifferentiateVarDecl(const VarDecl* VD) {
DeclDiff<VarDecl>
ReverseModeVisitor::DifferentiateVarDecl(const VarDecl* VD) {
StmtDiff initDiff;
Expr* VDDerivedInit = nullptr;
// Local declarations are promoted to the function global scope. This
Expand Down Expand Up @@ -2745,7 +2746,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
m_Variables.emplace(VDClone, derivedVDE);

return VarDeclDiff(VDClone, VDDerived);
return DeclDiff<VarDecl>(VDClone, VDDerived);
}

// TODO: 'shouldEmit' parameter should be removed after converting
Expand Down Expand Up @@ -2812,7 +2813,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// double _d_y = _d_x; double y = x;
for (auto* D : DS->decls()) {
if (auto* VD = dyn_cast<VarDecl>(D)) {
VarDeclDiff VDDiff = DifferentiateVarDecl(VD);
DeclDiff<VarDecl> VDDiff = DifferentiateVarDecl(VD);

// Check if decl's name is the same as before. The name may be changed
// if decl name collides with something in the derivative body.
Expand Down Expand Up @@ -2878,14 +2879,22 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
localDeclsDiff.push_back(VDDiff.getDecl_dx());
else
declsDiff.push_back(VDDiff.getDecl_dx());
} else if (auto* SAD = dyn_cast<StaticAssertDecl>(D)) {
DeclDiff<StaticAssertDecl> SADDiff = DifferentiateStaticAssertDecl(SAD);
if (SADDiff.getDecl())
decls.push_back(SADDiff.getDecl());
if (SADDiff.getDecl_dx())
declsDiff.push_back(SADDiff.getDecl_dx());
} else {
diag(DiagnosticsEngine::Warning,
D->getEndLoc(),
"Unsupported declaration");
}
}

Stmt* DSClone = BuildDeclStmt(decls);
Stmt* DSClone = nullptr;
if (!decls.empty())
DSClone = BuildDeclStmt(decls);

if (!localDeclsDiff.empty()) {
Stmt* localDSDIff = BuildDeclStmt(localDeclsDiff);
Expand Down Expand Up @@ -3831,6 +3840,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return Visit(NTTP->getReplacement());
}

DeclDiff<StaticAssertDecl> ReverseModeVisitor::DifferentiateStaticAssertDecl(
const clang::StaticAssertDecl* SAD) {
return DeclDiff<StaticAssertDecl>(nullptr, nullptr);
}

QualType ReverseModeVisitor::GetParameterDerivativeType(QualType yType,
QualType xType) {

Expand Down
5 changes: 3 additions & 2 deletions lib/Differentiator/VectorForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,8 @@ StmtDiff VectorForwardModeVisitor::VisitReturnStmt(const ReturnStmt* RS) {
return StmtDiff(returnStmt);
}

VarDeclDiff VectorForwardModeVisitor::DifferentiateVarDecl(const VarDecl* VD) {
DeclDiff<VarDecl>
VectorForwardModeVisitor::DifferentiateVarDecl(const VarDecl* VD) {
StmtDiff initDiff = VD->getInit() ? Visit(VD->getInit()) : StmtDiff{};
// Here we are assuming that derived type and the original type are same.
// This may not necessarily be true in the future.
Expand Down Expand Up @@ -610,7 +611,7 @@ VarDeclDiff VectorForwardModeVisitor::DifferentiateVarDecl(const VarDecl* VD) {
false, nullptr, VarDecl::InitializationStyle::CallInit);

m_Variables.emplace(VDClone, BuildDeclRef(VDDerived));
return VarDeclDiff(VDClone, VDDerived);
return DeclDiff<VarDecl>(VDClone, VDDerived);
}

} // namespace clad
22 changes: 22 additions & 0 deletions test/FirstDerivative/FunctionCallsWithResults.C
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// RUN: ./FunctionCallsWithResults.out | FileCheck -check-prefix=CHECK-EXEC %s

#include "clad/Differentiator/Differentiator.h"
#include <random>

int printf(const char* fmt, ...);

Expand Down Expand Up @@ -289,6 +290,25 @@ double fn9 (double i, double j) {
// CHECK-NEXT: return _t0.pushforward * _t3 + _t2 * _t1.pushforward;
// CHECK-NEXT: }

double fn10(double x) {
std::mt19937 gen64;
std::uniform_real_distribution<double> distribution(0.0,1.0);
double rand = distribution(gen64);
return x+rand;
}

// CHECK: double fn10_darg0(double x) {
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: std::mt19937 _d_gen64;
// CHECK-NEXT: std::mt19937 gen64;
// CHECK-NEXT: std::uniform_real_distribution<{{(double)?}}> _d_distribution(0., 0.);
// CHECK-NEXT: std::uniform_real_distribution<{{(double)?}}> distribution(0., 1.);
// CHECK-NEXT: clad::ValueAndPushforward<result_type, result_type> _t0 = distribution.operator_call_pushforward(gen64, &_d_distribution, _d_gen64);
// CHECK-NEXT: double _d_rand = _t0.pushforward;
// CHECK-NEXT: double rand0 = _t0.value;
// CHECK-NEXT: return _d_x + _d_rand;
// CHECK-NEXT: }

float test_1_darg0(float x);
float test_2_darg0(float x);
float test_4_darg0(float x);
Expand Down Expand Up @@ -318,6 +338,7 @@ int main () {
INIT(fn7, "i");
INIT(fn8, "i");
INIT(fn9, "i");
INIT(fn10, "x");

TEST(fn1, 3, 5); // CHECK-EXEC: {12.00}
TEST(fn2, 3, 5); // CHECK-EXEC: {181.00}
Expand All @@ -328,6 +349,7 @@ int main () {
TEST(fn7, 3, 5); // CHECK-EXEC: {8.00}
TEST(fn8, 3, 5); // CHECK-EXEC: {19.04}
TEST(fn9, 3, 5); // CHECK-EXEC: {5.00}
TEST(fn10, 3); // CHECK-EXEC: {1.00}
return 0;

// CHECK: clad::ValueAndPushforward<double, double> sum_of_squares_pushforward(double u, double v, double _d_u, double _d_v) {
Expand Down
Loading

0 comments on commit 3b216a0

Please sign in to comment.