Skip to content

Commit

Permalink
Add support for C-style memory alloc and free in reverse mode AD
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Feb 22, 2024
1 parent 1cb590e commit dbcf6f7
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 0 deletions.
3 changes: 3 additions & 0 deletions include/clad/Differentiator/CladUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,9 @@ namespace clad {
void SetSwitchCaseSubStmt(clang::SwitchCase* SC, clang::Stmt* subStmt);

bool IsLiteral(const clang::Expr* E);

bool IsMemoryAllocationFunction(const clang::FunctionDecl* FD);
bool IsMemoryDeallocationFunction(const clang::FunctionDecl* FD);
} // namespace utils
} // namespace clad

Expand Down
16 changes: 16 additions & 0 deletions lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -641,5 +641,21 @@ namespace clad {
isa<ObjCBoolLiteralExpr>(E) || isa<CXXBoolLiteralExpr>(E) ||
isa<GNUNullExpr>(E);
}

bool IsMemoryAllocationFunction(const clang::FunctionDecl* FD) {
if (FD->getNameAsString() == "malloc")
return true;
if (FD->getNameAsString() == "calloc")
return true;
if (FD->getNameAsString() == "realloc")
return true;
return false;
}

bool IsMemoryDeallocationFunction(const clang::FunctionDecl* FD) {
if (FD->getNameAsString() == "free")
return true;
return false;
}
} // namespace utils
} // namespace clad
42 changes: 42 additions & 0 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1441,6 +1441,48 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// Stores tape decl and pushes for multiarg numerically differentiated
// calls.
llvm::SmallVector<Stmt*, 16> NumericalDiffMultiArg{};

// For calls to C-style memory allocation functions, we do not need to
// differentiate the call. We just need to visit the arguments to the
// function.
if (utils::IsMemoryAllocationFunction(FD)) {
for (const Expr* Arg : CE->arguments()) {
StmtDiff ArgDiff = Visit(Arg, dfdx());
CallArgs.push_back(ArgDiff.getExpr());
}
Expr* call =
m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), noLoc,
llvm::MutableArrayRef<Expr*>(CallArgs), noLoc)
.get();
return StmtDiff(call, call);
}
// For calls to C-style memory deallocation functions, we do not need to
// differentiate the call. We just need to visit the arguments to the
// function. Also, don't add any statements either in forward or reverse
// pass. Instead, add it in m_DeallocExprs.
if (utils::IsMemoryDeallocationFunction(FD)) {
for (const Expr* Arg : CE->arguments()) {
StmtDiff ArgDiff = Visit(Arg, dfdx());
CallArgs.push_back(ArgDiff.getExpr());
DerivedCallArgs.push_back(ArgDiff.getExpr_dx());
}
Expr* call =
m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), noLoc,
llvm::MutableArrayRef<Expr*>(CallArgs), noLoc)
.get();
Expr* call_dx =
m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), noLoc,
llvm::MutableArrayRef<Expr*>(DerivedCallArgs),
noLoc)
.get();
m_DeallocExprs.push_back(call);
m_DeallocExprs.push_back(call_dx);
return StmtDiff();
}

// If the result does not depend on the result of the call, just clone
// the call and visit arguments (since they may contain side-effects like
// f(x = y))
Expand Down
38 changes: 38 additions & 0 deletions test/Gradient/Pointers.C
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,38 @@ double structPointer (double x) {
// CHECK-NEXT: delete _d_t;
// CHECK-NEXT: }

double cStyleMemoryAlloc(double x, size_t n) {
T* t = (T*)malloc(n * sizeof(T));
t->x = x;
double res = t->x;
free(t);
return res;
}

// CHECK: void cStyleMemoryAlloc_grad_0(double x, size_t n, clad::array_ref<double> _d_x) {
// CHECK-NEXT: size_t _d_n = 0;
// CHECK-NEXT: T *_d_t = 0;
// CHECK-NEXT: double _t0;
// CHECK-NEXT: double _d_res = 0;
// CHECK-NEXT: _d_t = (T *)malloc(n * sizeof(T));
// CHECK-NEXT: T *t = (T *)malloc(n * sizeof(T));
// CHECK-NEXT: _t0 = t->x;
// CHECK-NEXT: t->x = x;
// CHECK-NEXT: double res = t->x;
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: _d_res += 1;
// CHECK-NEXT: _d_t->x += _d_res;
// CHECK-NEXT: {
// CHECK-NEXT: t->x = _t0;
// CHECK-NEXT: double _r_d0 = _d_t->x;
// CHECK-NEXT: _d_t->x -= _r_d0;
// CHECK-NEXT: * _d_x += _r_d0;
// CHECK-NEXT: }
// CHECK-NEXT: free(t);
// CHECK-NEXT: free(_d_t);
// CHECK-NEXT: }

#define NON_MEM_FN_TEST(var)\
res[0]=0;\
var.execute(5,res);\
Expand Down Expand Up @@ -533,4 +565,10 @@ int main() {
auto d_structPointer = clad::gradient(structPointer);
double d_x = 0;
d_structPointer.execute(5, &d_x);
printf("%.2f\n", d_x); // CHECK-EXEC: 1.00

auto d_cStyleMemoryAlloc = clad::gradient(cStyleMemoryAlloc, "x");
d_x = 0;
d_cStyleMemoryAlloc.execute(5, 7, &d_x);
printf("%.2f\n", d_x); // CHECK-EXEC: 1.00
}

0 comments on commit dbcf6f7

Please sign in to comment.