diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index ebe153fd1..dea966744 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -2125,7 +2125,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return {cloneE, derivedE}; } else { if (opCode != UO_LNot) - // We should not output any warning on visiting boolean conditions + // We should only output warnings on visiting boolean conditions + // when it is related to some indepdendent variable and causes + // discontinuity in the function space. // FIXME: We should support boolean differentiation or ignore it // completely unsupportedOpWarn(UnOp->getEndLoc()); @@ -2329,7 +2331,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, utils::GetInnermostReturnExpr(Ldiff.getExpr(), ExprsToStore); // We need to store values of derivative pointer variables in forward pass - // and restore them in reverese pass. + // and restore them in reverse pass. if (isPointerOp) { Expr* Edx = Ldiff.getExpr_dx(); ExprsToStore.push_back(Edx); @@ -2595,8 +2597,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // expression of the corresponding pointer type. else if (isPointerType && VD->getInit()) { initDiff = Visit(VD->getInit()); - if (initDiff.getExpr_dx()) - VDDerivedInit = initDiff.getExpr_dx(); + VDDerivedType = getNonConstType(VDDerivedType, m_Context, m_Sema); + // If it's a pointer to a constant type, then remove the constness. + if (VD->getType()->getPointeeType().isConstQualified()) { + // first extract the pointee type + auto pointeeType = VD->getType()->getPointeeType(); + // then remove the constness + pointeeType.removeLocalConst(); + // then create a new pointer type with the new pointee type + VDDerivedType = m_Context.getPointerType(pointeeType); + } + VDDerivedInit = getZeroInit(VDDerivedType); } // Here separate behaviour for record and non-record types is only // necessary to preserve the old tests. @@ -2681,6 +2692,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, derivedVDE = BuildOp(UnaryOperatorKind::UO_Deref, derivedVDE); } } + if (isPointerType) { + Expr* assignDerivativeE = BuildOp(BinaryOperatorKind::BO_Assign, + derivedVDE, initDiff.getExpr_dx()); + addToCurrentBlock(assignDerivativeE, direction::forward); + } m_Variables.emplace(VDClone, derivedVDE); return VarDeclDiff(VDClone, VDDerived); diff --git a/test/Gradient/Pointers.C b/test/Gradient/Pointers.C index 7cbb3922f..86c9602fb 100644 --- a/test/Gradient/Pointers.C +++ b/test/Gradient/Pointers.C @@ -21,29 +21,23 @@ double nonMemFn(double i) { // CHECK-NEXT: } double minimalPointer(double x) { - double *p; - p = &x; + double* const p = &x; *p = (*p)*(*p); return *p; // x*x } // CHECK: void minimalPointer_grad(double x, clad::array_ref _d_x) { // CHECK-NEXT: double *_d_p = 0; -// CHECK-NEXT: double *_t0; -// CHECK-NEXT: double *_t1; -// CHECK-NEXT: double _t2; -// CHECK-NEXT: double *p; -// CHECK-NEXT: _t0 = p; -// CHECK-NEXT: _t1 = _d_p; +// CHECK-NEXT: double _t0; // CHECK-NEXT: _d_p = &* _d_x; -// CHECK-NEXT: p = &x; -// CHECK-NEXT: _t2 = *p; +// CHECK-NEXT: double *const p = &x; +// CHECK-NEXT: _t0 = *p; // CHECK-NEXT: *p = *p * (*p); // CHECK-NEXT: goto _label0; // CHECK-NEXT: _label0: // CHECK-NEXT: *_d_p += 1; // CHECK-NEXT: { -// CHECK-NEXT: *p = _t2; +// CHECK-NEXT: *p = _t0; // CHECK-NEXT: double _r_d0 = *_d_p; // CHECK-NEXT: double _r0 = _r_d0 * (*p); // CHECK-NEXT: *_d_p += _r0; @@ -51,14 +45,10 @@ double minimalPointer(double x) { // CHECK-NEXT: *_d_p += _r1; // CHECK-NEXT: *_d_p -= _r_d0; // CHECK-NEXT: } -// CHECK-NEXT: { -// CHECK-NEXT: p = _t0; -// CHECK-NEXT: _d_p = _t1; -// CHECK-NEXT: } // CHECK-NEXT: } -double arrayPointer(double* arr) { - double *p = arr; +double arrayPointer(const double* arr) { + const double *p = arr; p = p + 1; double sum = *p; p++; @@ -73,8 +63,8 @@ double arrayPointer(double* arr) { return sum; // 5*arr[0] + arr[1] + 2*arr[2] + 4*arr[3] + 3*arr[4] } -// CHECK: void arrayPointer_grad(double *arr, clad::array_ref _d_arr) { -// CHECK-NEXT: double *_d_p = _d_arr; +// CHECK: void arrayPointer_grad(const double *arr, clad::array_ref _d_arr) { +// CHECK-NEXT: double *_d_p = 0; // CHECK-NEXT: double *_t0; // CHECK-NEXT: double *_t1; // CHECK-NEXT: double _d_sum = 0; @@ -88,6 +78,7 @@ double arrayPointer(double* arr) { // CHECK-NEXT: double *_t9; // CHECK-NEXT: double *_t10; // CHECK-NEXT: double _t11; +// CHECK-NEXT: _d_p = _d_arr; // CHECK-NEXT: double *p = arr; // CHECK-NEXT: _t0 = p; // CHECK-NEXT: _t1 = _d_p;