Skip to content

Commit

Permalink
Fix rvalue references decl name in pullbacks and support make_float2
Browse files Browse the repository at this point in the history
  • Loading branch information
kchristin22 committed Nov 17, 2024
1 parent 59129ad commit 6bb0572
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 16 deletions.
8 changes: 3 additions & 5 deletions demos/CUDA/BlackScholes/BlackScholes_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,7 @@ __global__ void BlackScholesGPU(float2* __restrict d_CallResult,
BlackScholesBodyGPU(callResult2, putResult2, d_StockPrice[opt].y,
d_OptionStrike[opt].y, d_OptionYears[opt].y, Riskfree,
Volatility);
d_CallResult[opt].x = callResult1;
d_CallResult[opt].y = callResult2;
d_PutResult[opt].x = putResult1;
d_PutResult[opt].y = putResult2;
d_CallResult[opt] = make_float2(callResult1, callResult2);
d_PutResult[opt] = make_float2(putResult1, putResult2);
}
}
}
9 changes: 9 additions & 0 deletions include/clad/Differentiator/BuiltinDerivatives.h
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,15 @@ CUDA_HOST_DEVICE inline void sqrtf_pullback(float a, float d_y, float* d_a) {
*d_a += (1.F / (2.F * sqrtf(a))) * d_y;
}


#ifdef __CUDACC__
CUDA_HOST_DEVICE inline void make_float2_pullback(float a, float b, float2 d_y,
float* d_a, float* d_b) {
*d_a += d_y.x;
*d_b += d_y.y;
}
#endif

// These are required because C variants of mathematical functions are
// defined in global namespace.
using std::abs_pushforward;
Expand Down
8 changes: 6 additions & 2 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1891,8 +1891,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
e = CE->getNumArgs();
i != e; ++i) {
const Expr* arg = CE->getArg(i);
const auto* PVD = FD->getParamDecl(
i - static_cast<unsigned long>(isMethodOperatorCall));
auto* PVD = const_cast<ParmVarDecl*>(FD->getParamDecl(
i - static_cast<unsigned long>(isMethodOperatorCall)));
if (PVD->getType()->isRValueReferenceType()) {
IdentifierInfo* RValueName = CreateUniqueIdentifier("_r");
PVD->setDeclName(RValueName);
}
StmtDiff argDiff{};
// We do not need to create result arg for arguments passed by reference
// because the derivatives of arguments passed by reference are directly
Expand Down
18 changes: 9 additions & 9 deletions test/Gradient/STLCustomDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -817,28 +817,28 @@ int main() {
// CHECK-NEXT: std::vector<double> _d_a({});
// CHECK-NEXT: std::vector<double> a;
// CHECK-NEXT: std::vector<double> _t0 = a;
// CHECK-NEXT: {{.*}}push_back_reverse_forw(&a, 0{{.*}}, &_d_a, _r0);
// CHECK-NEXT: {{.*}}push_back_reverse_forw(&a, 0{{.*}}, &_d_a, _r1);
// CHECK-NEXT: std::vector<double> _t1 = a;
// CHECK-NEXT: {{.*}}ValueAndAdjoint<double &, double &> _t2 = {{.*}}operator_subscript_reverse_forw(&a, 0, &_d_a, _r1);
// CHECK-NEXT: {{.*}}ValueAndAdjoint<double &, double &> _t2 = {{.*}}operator_subscript_reverse_forw(&a, 0, &_d_a, _r2);
// CHECK-NEXT: double _t3 = _t2.value;
// CHECK-NEXT: _t2.value = x * x;
// CHECK-NEXT: std::vector<double> _t4 = a;
// CHECK-NEXT: {{.*}}ValueAndAdjoint<double &, double &> _t5 = {{.*}}operator_subscript_reverse_forw(&a, 0, &_d_a, _r2);
// CHECK-NEXT: {{.*}}ValueAndAdjoint<double &, double &> _t5 = {{.*}}operator_subscript_reverse_forw(&a, 0, &_d_a, _r3);
// CHECK-NEXT: {
// CHECK-NEXT: {{.*}}size_type _r2 = 0{{.*}};
// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t4, 0, 1, &_d_a, &_r2);
// CHECK-NEXT: {{.*}}size_type _r3 = 0{{.*}};
// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t4, 0, 1, &_d_a, &_r3);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: _t2.value = _t3;
// CHECK-NEXT: double _r_d0 = _t2.adjoint;
// CHECK-NEXT: _t2.adjoint = 0{{.*}};
// CHECK-NEXT: *_d_x += _r_d0 * x;
// CHECK-NEXT: *_d_x += x * _r_d0;
// CHECK-NEXT: {{.*}}size_type _r1 = 0{{.*}};
// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t1, 0, 0{{.*}}, &_d_a, &_r1);
// CHECK-NEXT: {{.*}}size_type _r2 = 0{{.*}};
// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t1, 0, 0{{.*}}, &_d_a, &_r2);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: {{.*}}value_type _r0 = 0.;
// CHECK-NEXT: {{.*}}push_back_pullback(&_t0, 0{{.*}}, &_d_a, &_r0);
// CHECK-NEXT: {{.*}}value_type _r1 = 0.;
// CHECK-NEXT: {{.*}}push_back_pullback(&_t0, 0{{.*}}, &_d_a, &_r1);
// CHECK-NEXT: }
// CHECK-NEXT: }

0 comments on commit 6bb0572

Please sign in to comment.