Skip to content

Commit

Permalink
Fix type casting of ValueAndPushForward types vgvassilev#986
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak authored and vgvassilev committed Jul 18, 2024
1 parent e04d04e commit 3daeb51
Show file tree
Hide file tree
Showing 10 changed files with 139 additions and 44 deletions.
8 changes: 8 additions & 0 deletions include/clad/Differentiator/BuiltinDerivatives.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ namespace clad {
template <typename T, typename U> struct ValueAndPushforward {
T value;
U pushforward;

// Define the cast operator from ValueAndPushforward<T, U> to
// ValueAndPushforward<V, w> where V is convertible to T and W is
// convertible to U.
template <typename V = T, typename W = U>
operator ValueAndPushforward<V, W>() const {
return {static_cast<V>(value), static_cast<W>(pushforward)};
}
};

/// It is used to identify constructor custom pushforwards. For
Expand Down
1 change: 1 addition & 0 deletions include/clad/Differentiator/CladUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ namespace clad {
void SetSwitchCaseSubStmt(clang::SwitchCase* SC, clang::Stmt* subStmt);

bool IsLiteral(const clang::Expr* E);
bool IsZeroOrNullValue(const clang::Expr* E);

bool IsMemoryFunction(const clang::FunctionDecl* FD);
bool IsMemoryDeallocationFunction(const clang::FunctionDecl* FD);
Expand Down
17 changes: 7 additions & 10 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1071,7 +1071,6 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
// Returning the function call and zero derivative
return StmtDiff(Call, zero);
}

// Find the built-in derivatives namespace.
std::string s = std::to_string(m_DerivativeOrder);
if (m_DerivativeOrder == 1)
Expand Down Expand Up @@ -1200,19 +1199,17 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
// FIXME: revert this when this is integrated in the activity analysis pass.
if (!callDiff) {
if (!isa<CXXOperatorCallExpr>(CE) && !isa<CXXMemberCallExpr>(CE)) {
bool allArgsAreConstantLiterals = true;
bool allArgsHaveZeroDerivatives = true;
for (unsigned i = 0, e = CE->getNumArgs(); i < e; ++i) {
const Expr* arg = CE->getArg(i);
// if it's of type MaterializeTemporaryExpr, then check its
// subexpression.
if (const auto* MTE = dyn_cast<MaterializeTemporaryExpr>(arg))
arg = clad_compat::GetSubExpr(MTE);
if (!arg->isEvaluatable(m_Context)) {
allArgsAreConstantLiterals = false;
Expr* dArg = diffArgs[i];
// If argDiff.expr_dx is nullptr or is a constant 0, then the derivative
// of the function call is 0.
if (!clad::utils::IsZeroOrNullValue(dArg)) {
allArgsHaveZeroDerivatives = false;
break;
}
}
if (allArgsAreConstantLiterals) {
if (allArgsHaveZeroDerivatives) {
Expr* call =
m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()),
Expand Down
16 changes: 16 additions & 0 deletions lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,22 @@ namespace clad {
isa<GNUNullExpr>(E);
}

bool IsZeroOrNullValue(const clang::Expr* E) {
if (!E)
return true;
if (const auto* ICE = dyn_cast<ImplicitCastExpr>(E))
return IsZeroOrNullValue(ICE->getSubExpr());
if (isa<CXXNullPtrLiteralExpr>(E))
return true;
if (const auto* FL = dyn_cast<FloatingLiteral>(E))
return FL->getValue().isZero();
if (const auto* IL = dyn_cast<IntegerLiteral>(E))
return IL->getValue() == 0;
if (const auto* SL = dyn_cast<StringLiteral>(E))
return SL->getLength() == 0;
return false;
}

bool IsMemoryFunction(const clang::FunctionDecl* FD) {

#if CLANG_VERSION_MAJOR > 12
Expand Down
21 changes: 19 additions & 2 deletions lib/Differentiator/PushForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,25 @@ StmtDiff PushForwardModeVisitor::VisitReturnStmt(const ReturnStmt* RS) {
return nullptr;

StmtDiff retValDiff = Visit(RS->getRetValue());
llvm::SmallVector<Expr*, 2> returnValues = {retValDiff.getExpr(),
retValDiff.getExpr_dx()};
Expr* retVal = retValDiff.getExpr();
Expr* retVal_dx = retValDiff.getExpr_dx();
if (!m_Context.hasSameUnqualifiedType(retVal->getType(),
m_DiffReq->getReturnType())) {
// Check if implficit cast would work.
// Add a cast to the return type.
TypeSourceInfo* TSI =
m_Context.getTrivialTypeSourceInfo(m_DiffReq->getReturnType());
retVal = m_Sema
.BuildCStyleCastExpr(RS->getBeginLoc(), TSI, RS->getEndLoc(),
BuildParens(retVal))
.get();
retVal_dx =
m_Sema
.BuildCStyleCastExpr(RS->getBeginLoc(), TSI, RS->getEndLoc(),
BuildParens(retVal_dx))
.get();
}
llvm::SmallVector<Expr*, 2> returnValues = {retVal, retVal_dx};
// This can instantiate as part of the move or copy initialization and
// needs a fake source location.
SourceLocation fakeLoc = utils::GetValidSLoc(m_Sema);
Expand Down
56 changes: 56 additions & 0 deletions test/FirstDerivative/CallArguments.C
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,44 @@ float f_const_args_func_8(const float x, float y) {
// CHECK-NEXT: return _t0.pushforward + _t1.pushforward - _d_y;
// CHECK-NEXT: }

float f_literal_helper(float x, char ch, float* p, float* q) {
if (ch == 'a')
return x * x;
return -x * x;
}

float f_literal_args_func(float x, float y, float *z) {
printf("hello world ");
return x * f_literal_helper(0.5, 'a', z, nullptr);
}

// CHECK: float f_literal_args_func_darg0(float x, float y, float *z) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: float _d_y = 0;
// CHECK-NEXT: printf("hello world ");
// CHECK-NEXT: float _t0 = f_literal_helper(0.5, 'a', z, nullptr);
// CHECK-NEXT: return _d_x * _t0 + x * 0.F;
// CHECK-NEXT: }

inline unsigned int getBin(double low, double high, double val, unsigned int numBins) {
double binWidth = (high - low) / numBins;
return val >= high ? numBins - 1 : std::abs((val - low) / binWidth);
}

float f_call_inline_fxn(float *params, float const *obs, float const *xlArr) {
const float t116 = *(xlArr + getBin(0., 1., params[0], 1));
return t116 * params[0];
}

// CHECK: inline clad::ValueAndPushforward<unsigned int, unsigned int> getBin_pushforward(double low, double high, double val, unsigned int numBins, double _d_low, double _d_high, double _d_val, unsigned int _d_numBins);

// CHECK: float f_call_inline_fxn_darg0_0(float *params, const float *obs, const float *xlArr) {
// CHECK-NEXT: clad::ValueAndPushforward<unsigned int, unsigned int> _t0 = getBin_pushforward(0., 1., params[0], 1, 0., 0., 1.F, 0);
// CHECK-NEXT: const float _d_t116 = 0;
// CHECK-NEXT: const float t116 = *(xlArr + _t0.value);
// CHECK-NEXT: return _d_t116 * params[0] + t116 * 1.F;
// CHECK-NEXT: }

extern "C" int printf(const char* fmt, ...);
int main () { // expected-no-diagnostics
auto f = clad::differentiate(g, 0);
Expand Down Expand Up @@ -165,9 +203,27 @@ int main () { // expected-no-diagnostics
const float f8x = 1.F;
printf("f8_darg0=%f\n", f8.execute(f8x,2.F));
//CHECK-EXEC: f8_darg0=2.000000
auto f9 = clad::differentiate(f_literal_args_func, 0);
float z = 3.0;
printf("f9_darg0=%.2f\n", f9.execute(1.F, 2.F, &z));
//CHECK-EXEC: hello world f9_darg0=0.25
auto f10 = clad::differentiate(f_call_inline_fxn, "params[0]");
float params = 1.0, obs = 5.0, xlArr = 7.0;
printf("f10_darg0_0=%.2f\n", f10.execute(&params, &obs, &xlArr));
//CHECK-EXEC: f10_darg0_0=7.00

// CHECK: clad::ValueAndPushforward<float, float> f_const_helper_pushforward(const float x, const float _d_x) {
// CHECK-NEXT: return {x * x, _d_x * x + x * _d_x};
// CHECK-NEXT: }

// CHECK: inline clad::ValueAndPushforward<unsigned int, unsigned int> getBin_pushforward(double low, double high, double val, unsigned int numBins, double _d_low, double _d_high, double _d_val, unsigned int _d_numBins) {
// CHECK-NEXT: double _t0 = (high - low);
// CHECK-NEXT: double _d_binWidth = ((_d_high - _d_low) * numBins - _t0 * _d_numBins) / (numBins * numBins);
// CHECK-NEXT: double binWidth = _t0 / numBins;
// CHECK-NEXT: double _t1 = (val - low);
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<double, double> _t2 = clad::custom_derivatives{{(::std)?}}::abs_pushforward(_t1 / binWidth, ((_d_val - _d_low) * binWidth - _t1 * _d_binWidth) / (binWidth * binWidth));
// CHECK-NEXT: bool _t3 = val >= high;
// CHECK-NEXT: return {(unsigned int)(_t3 ? numBins - 1 : _t2.value), (unsigned int)(_t3 ? _d_numBins - 0 : _t2.pushforward)};
// CHECK-NEXT: }

return 0;
Expand Down
4 changes: 2 additions & 2 deletions test/FirstDerivative/FunctionCallsWithResults.C
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ int main () {

// CHECK: clad::ValueAndPushforward<double, double> fn6_pushforward(double i, double j, double k, double _d_i, double _d_j, double _d_k) {
// CHECK-NEXT: if (i < 0.5)
// CHECK-NEXT: return {0, 0};
// CHECK-NEXT: return {(double)0, (double)0};
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t0 = fn6_pushforward(i - 1, j - 1, k - 1, _d_i - 0, _d_j - 0, _d_k - 0);
// CHECK-NEXT: return {i + j + k + _t0.value, _d_i + _d_j + _d_k + _t0.pushforward};
// CHECK-NEXT: }
Expand Down Expand Up @@ -420,7 +420,7 @@ int main () {
// CHECK: clad::ValueAndPushforward<double, double> check_and_return_pushforward(double x, char c, double _d_x, char _d_c) {
// CHECK-NEXT: if (c == 'a')
// CHECK-NEXT: return {x, _d_x};
// CHECK-NEXT: return {1, 0};
// CHECK-NEXT: return {(double)1, (double)0};
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<double, double> g_pushforward(double x, double _d_x) {
Expand Down
6 changes: 3 additions & 3 deletions test/FirstDerivative/FunctionsInNamespaces.C
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ int main () {

// CHECK: clad::ValueAndPushforward<double, double> someFn_pushforward(double &i, double &j, double &_d_i, double &_d_j) {
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t0 = someFn_1_pushforward(i, j, _d_i, _d_j);
// CHECK-NEXT: return {3, 0};
// CHECK-NEXT: return {(double)3, (double)0};
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<int, int> func4_pushforward(int x, int y, int _d_x, int _d_y) {
Expand All @@ -118,13 +118,13 @@ int main () {

// CHECK: clad::ValueAndPushforward<double, double> someFn_1_pushforward(double &i, double j, double &_d_i, double _d_j) {
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t0 = someFn_1_pushforward(i, j, j, _d_i, _d_j, _d_j);
// CHECK-NEXT: return {2, 0};
// CHECK-NEXT: return {(double)2, (double)0};
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<double, double> someFn_1_pushforward(double &i, double j, double k, double &_d_i, double _d_j, double _d_k) {
// CHECK-NEXT: _d_i = _d_j;
// CHECK-NEXT: i = j;
// CHECK-NEXT: return {1, 0};
// CHECK-NEXT: return {(double)1, (double)0};
// CHECK-NEXT: }

return 0;
Expand Down
Loading

0 comments on commit 3daeb51

Please sign in to comment.