diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 1b494ce82..155e1e6e4 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1314,7 +1314,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, const Expr* value = RS->getRetValue(); QualType type = value->getType(); auto* dfdf = m_Pullback; - if (isa(dfdf) || isa(dfdf)) { + if (dfdf && (isa(dfdf) || isa(dfdf))) { ExprResult tmp = dfdf; dfdf = m_Sema .ImpCastExprToType(tmp.get(), type, @@ -2016,7 +2016,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (Expr* customForwardPassCE = BuildCallToCustomForwPassFn(FD, CallArgs, CallArgDx, baseExpr)) { - if (!utils::isNonConstReferenceType(returnType)) + if (!utils::isNonConstReferenceType(returnType) && + !returnType->isPointerType()) return StmtDiff{customForwardPassCE}; auto* callRes = StoreAndRef(customForwardPassCE); auto* resValue = @@ -2025,7 +2026,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "adjoint"); return StmtDiff(resValue, nullptr, resAdjoint); } - if (utils::isNonConstReferenceType(returnType)) { + if (utils::isNonConstReferenceType(returnType) || + returnType->isPointerType()) { DiffRequest calleeFnForwPassReq; calleeFnForwPassReq.Function = FD; calleeFnForwPassReq.Mode = DiffMode::reverse_mode_forward_pass; @@ -2083,7 +2085,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "value"); auto* resAdjoint = utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "adjoint"); - return StmtDiff(resValue, nullptr, resAdjoint); + return StmtDiff(resValue, resAdjoint, resAdjoint); } // Recreate the original call expression. call = m_Sema .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc, @@ -4114,7 +4116,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // respect to variable of type X, then the derivative should be of type // X. Check this related issue for more details: // https://github.com/vgvassilev/clad/issues/385 - if (effectiveReturnType->isVoidType()) + if (effectiveReturnType->isVoidType() || + effectiveReturnType->isPointerType()) effectiveReturnType = m_Context.DoubleTy; else paramTypes.push_back(effectiveReturnType); @@ -4154,7 +4157,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, std::size_t dParamTypesIdx = m_DiffReq->getNumParams(); if (m_DiffReq.Mode == DiffMode::experimental_pullback && - !m_DiffReq->getReturnType()->isVoidType()) { + !m_DiffReq->getReturnType()->isVoidType() && + !m_DiffReq->getReturnType()->isPointerType()) { ++dParamTypesIdx; } @@ -4225,7 +4229,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } if (m_DiffReq.Mode == DiffMode::experimental_pullback && - !m_DiffReq->getReturnType()->isVoidType()) { + !m_DiffReq->getReturnType()->isVoidType() && + !m_DiffReq->getReturnType()->isPointerType()) { IdentifierInfo* pullbackParamII = CreateUniqueIdentifier("_d_y"); QualType pullbackType = derivativeFnType->getParamType(m_DiffReq->getNumParams()); diff --git a/test/Gradient/Pointers.C b/test/Gradient/Pointers.C index ee31cdda8..8a57d88cd 100644 --- a/test/Gradient/Pointers.C +++ b/test/Gradient/Pointers.C @@ -466,6 +466,37 @@ double cStyleMemoryAlloc(double x, size_t n) { // CHECK-NEXT: free(_d_t); // CHECK-NEXT: } +double* ptrValFn (double* x, int n) { + x += n; + return x; +} + +// CHECK: void ptrValFn_pullback(double *x, int n, double *_d_x, int *_d_n); +// CHECK: clad::ValueAndAdjoint ptrValFn_forw(double *x, int n, double *_d_x, int _d_n); + +double nestedPtrFn (double x, double y) { + double arr[] = {x, y}; + double* z = ptrValFn(arr, 1); + return *z; +} + +// CHECK: void nestedPtrFn_grad(double x, double y, double *_d_x, double *_d_y) { +// CHECK-NEXT: double _d_arr[2] = {0}; +// CHECK-NEXT: double arr[2] = {x, y}; +// CHECK-NEXT: clad::ValueAndAdjoint _t0 = ptrValFn_forw(arr, 1, _d_arr, 0); +// CHECK-NEXT: double *_d_z = _t0.adjoint; +// CHECK-NEXT: double *z = _t0.value; +// CHECK-NEXT: *_d_z += 1; +// CHECK-NEXT: { +// CHECK-NEXT: int _r0 = 0; +// CHECK-NEXT: ptrValFn_pullback(arr, 1, _d_arr, &_r0); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: *_d_x += _d_arr[0]; +// CHECK-NEXT: *_d_y += _d_arr[1]; +// CHECK-NEXT: } +// CHECK-NEXT: } + #define NON_MEM_FN_TEST(var)\ res[0]=0;\ var.execute(5,res);\ @@ -573,4 +604,28 @@ int main() { d_x = 0; d_cStyleMemoryAlloc.execute(5, 7, &d_x); printf("%.2f\n", d_x); // CHECK-EXEC: 4.00 + + auto d_nestedPtrFn = clad::gradient(nestedPtrFn); + d_i = 0; d_j = 0; + d_nestedPtrFn.execute(5, 7, &d_i, &d_j); + printf("%.2f %.2f\n", d_i, d_j); // CHECK-EXEC: 0.00 1.00 } + +// CHECK: void ptrValFn_pullback(double *x, int n, double *_d_x, int *_d_n) { +// CHECK-NEXT: double *_t0 = x; +// CHECK-NEXT: double *_t1 = _d_x; +// CHECK-NEXT: _d_x += n; +// CHECK-NEXT: x += n; +// CHECK-NEXT: { +// CHECK-NEXT: x = _t0; +// CHECK-NEXT: _d_x = _t1; +// CHECK-NEXT: } +// CHECK-NEXT: } + +// CHECK: clad::ValueAndAdjoint ptrValFn_forw(double *x, int n, double *_d_x, int _d_n) { +// CHECK-NEXT: double *_t0 = x; +// CHECK-NEXT: double *_t1 = _d_x; +// CHECK-NEXT: _d_x += n; +// CHECK-NEXT: x += n; +// CHECK-NEXT: return {x, _d_x}; +// CHECK-NEXT: }