Skip to content

Commit

Permalink
Support pointer-valued functions in the reverse mode
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi authored and vgvassilev committed Aug 21, 2024
1 parent a1fed2a commit e201c1c
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 7 deletions.
19 changes: 12 additions & 7 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1314,7 +1314,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
const Expr* value = RS->getRetValue();
QualType type = value->getType();
auto* dfdf = m_Pullback;
if (isa<FloatingLiteral>(dfdf) || isa<IntegerLiteral>(dfdf)) {
if (dfdf && (isa<FloatingLiteral>(dfdf) || isa<IntegerLiteral>(dfdf))) {
ExprResult tmp = dfdf;
dfdf = m_Sema
.ImpCastExprToType(tmp.get(), type,
Expand Down Expand Up @@ -2016,7 +2016,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

if (Expr* customForwardPassCE =
BuildCallToCustomForwPassFn(FD, CallArgs, CallArgDx, baseExpr)) {
if (!utils::isNonConstReferenceType(returnType))
if (!utils::isNonConstReferenceType(returnType) &&
!returnType->isPointerType())
return StmtDiff{customForwardPassCE};
auto* callRes = StoreAndRef(customForwardPassCE);
auto* resValue =
Expand All @@ -2025,7 +2026,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "adjoint");
return StmtDiff(resValue, nullptr, resAdjoint);
}
if (utils::isNonConstReferenceType(returnType)) {
if (utils::isNonConstReferenceType(returnType) ||
returnType->isPointerType()) {
DiffRequest calleeFnForwPassReq;
calleeFnForwPassReq.Function = FD;
calleeFnForwPassReq.Mode = DiffMode::reverse_mode_forward_pass;
Expand Down Expand Up @@ -2083,7 +2085,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "value");
auto* resAdjoint =
utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "adjoint");
return StmtDiff(resValue, nullptr, resAdjoint);
return StmtDiff(resValue, resAdjoint, resAdjoint);
} // Recreate the original call expression.
call = m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc,
Expand Down Expand Up @@ -4114,7 +4116,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// respect to variable of type X, then the derivative should be of type
// X. Check this related issue for more details:
// https://github.com/vgvassilev/clad/issues/385
if (effectiveReturnType->isVoidType())
if (effectiveReturnType->isVoidType() ||
effectiveReturnType->isPointerType())
effectiveReturnType = m_Context.DoubleTy;
else
paramTypes.push_back(effectiveReturnType);
Expand Down Expand Up @@ -4154,7 +4157,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
std::size_t dParamTypesIdx = m_DiffReq->getNumParams();

if (m_DiffReq.Mode == DiffMode::experimental_pullback &&
!m_DiffReq->getReturnType()->isVoidType()) {
!m_DiffReq->getReturnType()->isVoidType() &&
!m_DiffReq->getReturnType()->isPointerType()) {
++dParamTypesIdx;
}

Expand Down Expand Up @@ -4225,7 +4229,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

if (m_DiffReq.Mode == DiffMode::experimental_pullback &&
!m_DiffReq->getReturnType()->isVoidType()) {
!m_DiffReq->getReturnType()->isVoidType() &&
!m_DiffReq->getReturnType()->isPointerType()) {
IdentifierInfo* pullbackParamII = CreateUniqueIdentifier("_d_y");
QualType pullbackType =
derivativeFnType->getParamType(m_DiffReq->getNumParams());
Expand Down
55 changes: 55 additions & 0 deletions test/Gradient/Pointers.C
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,37 @@ double cStyleMemoryAlloc(double x, size_t n) {
// CHECK-NEXT: free(_d_t);
// CHECK-NEXT: }

double* ptrValFn (double* x, int n) {
x += n;
return x;
}

// CHECK: void ptrValFn_pullback(double *x, int n, double *_d_x, int *_d_n);
// CHECK: clad::ValueAndAdjoint<double *, double *> ptrValFn_forw(double *x, int n, double *_d_x, int _d_n);

double nestedPtrFn (double x, double y) {
double arr[] = {x, y};
double* z = ptrValFn(arr, 1);
return *z;
}

// CHECK: void nestedPtrFn_grad(double x, double y, double *_d_x, double *_d_y) {
// CHECK-NEXT: double _d_arr[2] = {0};
// CHECK-NEXT: double arr[2] = {x, y};
// CHECK-NEXT: clad::ValueAndAdjoint<double *, double *> _t0 = ptrValFn_forw(arr, 1, _d_arr, 0);
// CHECK-NEXT: double *_d_z = _t0.adjoint;
// CHECK-NEXT: double *z = _t0.value;
// CHECK-NEXT: *_d_z += 1;
// CHECK-NEXT: {
// CHECK-NEXT: int _r0 = 0;
// CHECK-NEXT: ptrValFn_pullback(arr, 1, _d_arr, &_r0);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: *_d_x += _d_arr[0];
// CHECK-NEXT: *_d_y += _d_arr[1];
// CHECK-NEXT: }
// CHECK-NEXT: }

#define NON_MEM_FN_TEST(var)\
res[0]=0;\
var.execute(5,res);\
Expand Down Expand Up @@ -573,4 +604,28 @@ int main() {
d_x = 0;
d_cStyleMemoryAlloc.execute(5, 7, &d_x);
printf("%.2f\n", d_x); // CHECK-EXEC: 4.00

auto d_nestedPtrFn = clad::gradient(nestedPtrFn);
d_i = 0; d_j = 0;
d_nestedPtrFn.execute(5, 7, &d_i, &d_j);
printf("%.2f %.2f\n", d_i, d_j); // CHECK-EXEC: 0.00 1.00
}

// CHECK: void ptrValFn_pullback(double *x, int n, double *_d_x, int *_d_n) {
// CHECK-NEXT: double *_t0 = x;
// CHECK-NEXT: double *_t1 = _d_x;
// CHECK-NEXT: _d_x += n;
// CHECK-NEXT: x += n;
// CHECK-NEXT: {
// CHECK-NEXT: x = _t0;
// CHECK-NEXT: _d_x = _t1;
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: clad::ValueAndAdjoint<double *, double *> ptrValFn_forw(double *x, int n, double *_d_x, int _d_n) {
// CHECK-NEXT: double *_t0 = x;
// CHECK-NEXT: double *_t1 = _d_x;
// CHECK-NEXT: _d_x += n;
// CHECK-NEXT: x += n;
// CHECK-NEXT: return {x, _d_x};
// CHECK-NEXT: }

0 comments on commit e201c1c

Please sign in to comment.