Skip to content

Commit

Permalink
Make integral type variables non-differentiable in the forward mode
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Apr 30, 2024
1 parent 53387b9 commit 685ad18
Show file tree
Hide file tree
Showing 33 changed files with 610 additions and 777 deletions.
15 changes: 7 additions & 8 deletions include/clad/Differentiator/BuiltinDerivatives.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ namespace custom_derivatives {
#ifdef __CUDACC__
template <typename T>
ValueAndPushforward<cudaError_t, cudaError_t>
cudaMalloc_pushforward(T** devPtr, size_t sz, T** d_devPtr, size_t d_sz)
cudaMalloc_pushforward(T** devPtr, size_t sz, T** d_devPtr)
__attribute__((host)) {
return {cudaMalloc(devPtr, sz), cudaMalloc(d_devPtr, sz)};
}

ValueAndPushforward<cudaError_t, cudaError_t>
cudaMemcpy_pushforward(void* destPtr, void* srcPtr, size_t count,
cudaMemcpyKind kind, void* d_destPtr, void* d_srcPtr,
size_t d_count) __attribute__((host)) {
cudaMemcpyKind kind, void* d_destPtr, void* d_srcPtr)
__attribute__((host)) {
return {cudaMemcpy(destPtr, srcPtr, count, kind),
cudaMemcpy(d_destPtr, d_srcPtr, count, kind)};
}
Expand Down Expand Up @@ -199,18 +199,17 @@ CUDA_HOST_DEVICE void clamp_pullback(const T& v, const T& lo, const T& hi,

// NOLINTBEGIN(cppcoreguidelines-no-malloc)
// NOLINTBEGIN(cppcoreguidelines-owning-memory)
inline ValueAndPushforward<void*, void*> malloc_pushforward(size_t sz,
size_t d_sz) {
inline ValueAndPushforward<void*, void*> malloc_pushforward(size_t sz) {
return {malloc(sz), malloc(sz)};
}

inline ValueAndPushforward<void*, void*>
calloc_pushforward(size_t n, size_t sz, size_t d_n, size_t d_sz) {
inline ValueAndPushforward<void*, void*> calloc_pushforward(size_t n,
size_t sz) {
return {calloc(n, sz), calloc(n, sz)};
}

inline ValueAndPushforward<void*, void*>
realloc_pushforward(void* ptr, size_t sz, void* d_ptr, size_t d_sz) {
realloc_pushforward(void* ptr, size_t sz, void* d_ptr) {
return {realloc(ptr, sz), realloc(d_ptr, sz)};
}

Expand Down
29 changes: 18 additions & 11 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1172,7 +1172,6 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
}
}
}

// If clad failed to derive it, try finding its derivative using
// numerical diff.
if (!callDiff) {
Expand Down Expand Up @@ -1298,9 +1297,12 @@ BaseForwardModeVisitor::VisitBinaryOperator(const BinaryOperator* BinOp) {
opDiff = BuildOp(opCode, derivedL, derivedR);
} else if (BinOp->isAssignmentOp()) {
if (Ldiff.getExpr_dx()->isModifiableLvalue(m_Context) != Expr::MLV_Valid) {
diag(DiagnosticsEngine::Warning, BinOp->getEndLoc(),
"derivative of an assignment attempts to assign to unassignable "
"expr, assignment ignored");
// If the LHS has a non-differentiable type, Ldiff.getExpr_dx() will be 0.
// Don't create a warning then.
if (IsDifferentiableType(BinOp->getLHS()->getType()))
diag(DiagnosticsEngine::Warning, BinOp->getEndLoc(),
"derivative of an assignment attempts to assign to unassignable "
"expr, assignment ignored");
opDiff = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);
} else if (opCode == BO_Assign || opCode == BO_AddAssign ||
opCode == BO_SubAssign) {
Expand Down Expand Up @@ -1377,10 +1379,13 @@ BaseForwardModeVisitor::DifferentiateVarDecl(const VarDecl* VD) {
BuildVarDecl(VD->getType(), VD->getNameAsString(), initDiff.getExpr(),
VD->isDirectInit(), nullptr, VD->getInitStyle());
// FIXME: Create unique identifier for derivative.
VarDecl* VDDerived = BuildVarDecl(
VD->getType(), "_d_" + VD->getNameAsString(), initDiff.getExpr_dx(),
VD->isDirectInit(), nullptr, VD->getInitStyle());
m_Variables.emplace(VDClone, BuildDeclRef(VDDerived));
VarDecl* VDDerived = nullptr;
if (IsDifferentiableType(VD->getType())) {
VDDerived = BuildVarDecl(VD->getType(), "_d_" + VD->getNameAsString(),
initDiff.getExpr_dx(), VD->isDirectInit(), nullptr,
VD->getInitStyle());
m_Variables.emplace(VDClone, BuildDeclRef(VDDerived));
}
return DeclDiff<VarDecl>(VDClone, VDDerived);
}

Expand Down Expand Up @@ -1442,7 +1447,8 @@ StmtDiff BaseForwardModeVisitor::VisitDeclStmt(const DeclStmt* DS) {
if (VDDiff.getDecl()->getDeclName() != VD->getDeclName())
m_DeclReplacements[VD] = VDDiff.getDecl();
decls.push_back(VDDiff.getDecl());
declsDiff.push_back(VDDiff.getDecl_dx());
if (VDDiff.getDecl_dx())
declsDiff.push_back(VDDiff.getDecl_dx());
} else if (auto* SAD = dyn_cast<StaticAssertDecl>(D)) {
DeclDiff<StaticAssertDecl> SADDiff = DifferentiateStaticAssertDecl(SAD);
if (SADDiff.getDecl())
Expand Down Expand Up @@ -1581,7 +1587,7 @@ StmtDiff BaseForwardModeVisitor::VisitWhileStmt(const WhileStmt* WS) {
// ...
// ...
// }
if (condVarClone) {
if (condVarRes.getDecl_dx()) {
bodyResult = utils::PrependAndCreateCompoundStmt(
m_Sema.getASTContext(), cast<CompoundStmt>(bodyResult),
BuildDeclStmt(condVarRes.getDecl_dx()));
Expand Down Expand Up @@ -1660,7 +1666,8 @@ StmtDiff BaseForwardModeVisitor::VisitSwitchStmt(const SwitchStmt* SS) {
if (condVarDecl) {
DeclDiff<VarDecl> condVarDeclDiff = DifferentiateVarDecl(condVarDecl);
condVarClone = condVarDeclDiff.getDecl();
addToCurrentBlock(BuildDeclStmt(condVarDeclDiff.getDecl_dx()));
if (condVarDeclDiff.getDecl_dx())
addToCurrentBlock(BuildDeclStmt(condVarDeclDiff.getDecl_dx()));
}

StmtDiff initVarRes = (SS->getInit() ? Visit(SS->getInit()) : StmtDiff());
Expand Down
64 changes: 42 additions & 22 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -892,30 +892,45 @@ namespace clad {
Expr* UnresolvedLookup =
m_Sema.BuildDeclarationNameExpr(SS, R, /*ADL*/ false).get();

llvm::SmallVector<Expr*, 16> ExtendedCallArgs(CallArgs.begin(),
CallArgs.end());
llvm::SmallVector<Expr*, 16> ExtendedCallArgs;
llvm::SmallVector<Stmt*, 16> DeclStmts;
// FIXME: for now, integer types are considered differentiable in the
// forward mode.
if (m_Mode != DiffMode::forward &&
m_Mode != DiffMode::vector_forward_mode &&
m_Mode != DiffMode::experimental_pushforward)
for (size_t i = 0, e = originalFD->getNumParams(); i < e; ++i) {
QualType paramTy = originalFD->getParamDecl(i)->getType();
if (!IsDifferentiableType(paramTy)) {
QualType argTy = utils::getNonConstType(paramTy, m_Context, m_Sema);
VarDecl* argDecl = BuildVarDecl(argTy, "_r", getZeroInit(argTy));
Expr* arg = BuildDeclRef(argDecl);
if (!utils::isArrayOrPointerType(argTy))
arg = BuildOp(UO_AddrOf, arg);
ExtendedCallArgs.insert(ExtendedCallArgs.begin() + e + i + 1, arg);
DeclStmts.push_back(BuildDeclStmt(argDecl));
auto MARargs = llvm::MutableArrayRef<Expr*>(CallArgs);
if (noOverloadExists(UnresolvedLookup, MARargs)) {
bool isMethodCall = isa<CXXMethodDecl>(originalFD);
ExtendedCallArgs =
llvm::SmallVector<Expr*, 16>(CallArgs.begin(), CallArgs.end());
if (m_Mode != DiffMode::forward &&
m_Mode != DiffMode::vector_forward_mode &&
m_Mode != DiffMode::experimental_pushforward)
for (size_t i = 0, e = originalFD->getNumParams(); i < e; ++i) {
QualType paramTy = originalFD->getParamDecl(i)->getType();
if (!IsDifferentiableType(paramTy)) {
QualType argTy =
utils::getNonConstType(paramTy, m_Context, m_Sema);
VarDecl* argDecl = BuildVarDecl(argTy, "_r", getZeroInit(argTy));
Expr* arg = BuildDeclRef(argDecl);
if (!utils::isArrayOrPointerType(argTy))
arg = BuildOp(UO_AddrOf, arg);
ExtendedCallArgs.insert(
ExtendedCallArgs.begin() + e + i + 1 + 2 * isMethodCall, arg);
DeclStmts.push_back(BuildDeclStmt(argDecl));
}
}
}
auto MARargs = llvm::MutableArrayRef<Expr*>(ExtendedCallArgs);

if (noOverloadExists(UnresolvedLookup, MARargs))
return nullptr;
else
for (size_t i = 0, e = originalFD->getNumParams(); i < e; ++i) {
QualType paramTy = originalFD->getParamDecl(i)->getType();
if (!IsDifferentiableType(paramTy)) {
QualType argTy =
utils::getNonConstType(paramTy, m_Context, m_Sema);
Expr* zero = getZeroInit(argTy);
ExtendedCallArgs.insert(
ExtendedCallArgs.begin() + e + i + 2 * isMethodCall, zero);
}
}
MARargs = llvm::MutableArrayRef<Expr*>(ExtendedCallArgs);
if (noOverloadExists(UnresolvedLookup, MARargs))
return nullptr;
}

OverloadedFn =
m_Sema.ActOnCallExpr(S, UnresolvedLookup, noLoc, MARargs, noLoc)
Expand Down Expand Up @@ -958,6 +973,11 @@ namespace clad {
return true;
}
}
return false;
}
if (const auto* DRE = dyn_cast<DeclRefExpr>(UnresolvedLookup)) {
const auto* FD = cast<FunctionDecl>(DRE->getDecl());
return FD->getNumParams() != ARargs.size();
}
return false;
}
Expand Down
30 changes: 9 additions & 21 deletions test/Arrays/ArrayInputsForwardMode.C
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,11 @@ double addArr(const double *arr, int n) {
}

//CHECK: double addArr_darg0_1(const double *arr, int n) {
//CHECK-NEXT: int _d_n = 0;
//CHECK-NEXT: double _d_ret = 0;
//CHECK-NEXT: double ret = 0;
//CHECK-NEXT: {
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: for (int i = 0; i < n; i++) {
//CHECK-NEXT: _d_ret += (i == 1.);
//CHECK-NEXT: ret += arr[i];
//CHECK-NEXT: }
//CHECK-NEXT: for (int i = 0; i < n; i++) {
//CHECK-NEXT: _d_ret += (i == 1.);
//CHECK-NEXT: ret += arr[i];
//CHECK-NEXT: }
//CHECK-NEXT: return _d_ret;
//CHECK-NEXT: }
Expand All @@ -59,25 +55,17 @@ double numMultIndex(double* arr, size_t n, double x) {
}

// CHECK: double numMultIndex_darg2(double *arr, size_t n, double x) {
// CHECK-NEXT: size_t _d_n = 0;
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: bool _d_flag = 0;
// CHECK-NEXT: bool flag = false;
// CHECK-NEXT: size_t _d_idx = 0;
// CHECK-NEXT: size_t idx = 0;
// CHECK-NEXT: {
// CHECK-NEXT: size_t _d_i = 0;
// CHECK-NEXT: for (size_t i = 0; i < n; ++i) {
// CHECK-NEXT: if (arr[i] == x) {
// CHECK-NEXT: _d_flag = 0;
// CHECK-NEXT: flag = true;
// CHECK-NEXT: _d_idx = _d_i;
// CHECK-NEXT: idx = i;
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: for (size_t i = 0; i < n; ++i) {
// CHECK-NEXT: if (arr[i] == x) {
// CHECK-NEXT: flag = true;
// CHECK-NEXT: idx = i;
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return flag ? _d_idx * x + idx * _d_x : 0;
// CHECK-NEXT: return flag ? 0 * x + idx * _d_x : 0;
// CHECK-NEXT: }

int main() {
Expand Down
27 changes: 9 additions & 18 deletions test/Arrays/Arrays.C
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,9 @@ double sum(double x, double y, double z) {
//CHECK-NEXT: double vars[3] = {x, y, z};
//CHECK-NEXT: double _d_s = 0;
//CHECK-NEXT: double s = 0;
//CHECK-NEXT: {
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: _d_s = _d_s + _d_vars[i];
//CHECK-NEXT: s = s + vars[i];
//CHECK-NEXT: }
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: _d_s = _d_s + _d_vars[i];
//CHECK-NEXT: s = s + vars[i];
//CHECK-NEXT: }
//CHECK-NEXT: return _d_s;
//CHECK-NEXT: }
Expand All @@ -55,21 +52,15 @@ double sum_squares(double x, double y, double z) {
//CHECK-NEXT: double vars[3] = {x, y, z};
//CHECK-NEXT: double _d_squares[3];
//CHECK-NEXT: double squares[3];
//CHECK-NEXT: {
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: _d_squares[i] = _d_vars[i] * vars[i] + vars[i] * _d_vars[i];
//CHECK-NEXT: squares[i] = vars[i] * vars[i];
//CHECK-NEXT: }
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: _d_squares[i] = _d_vars[i] * vars[i] + vars[i] * _d_vars[i];
//CHECK-NEXT: squares[i] = vars[i] * vars[i];
//CHECK-NEXT: }
//CHECK-NEXT: double _d_s = 0;
//CHECK-NEXT: double s = 0;
//CHECK-NEXT: {
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: _d_s = _d_s + _d_squares[i];
//CHECK-NEXT: s = s + squares[i];
//CHECK-NEXT: }
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: _d_s = _d_s + _d_squares[i];
//CHECK-NEXT: s = s + squares[i];
//CHECK-NEXT: }
//CHECK-NEXT: return _d_s;
//CHECK-NEXT: }
Expand Down
42 changes: 17 additions & 25 deletions test/CUDA/ForwardMode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ __global__ void add(double *a, double *b, double *c, int n) {
c[idx] = a[idx] + b[idx];
}

// CHECK: void add_pushforward(double *a, double *b, double *c, int n, double *_d_a, double *_d_b, double *_d_c, int _d_n) __attribute__((global)) {
// CHECK-NEXT: int _d_idx = 0;
// CHECK: void add_pushforward(double *a, double *b, double *c, int n, double *_d_a, double *_d_b, double *_d_c) __attribute__((global)) {
// CHECK-NEXT: int idx = threadIdx.x;
// CHECK-NEXT: if (idx < n) {
// CHECK-NEXT: _d_c[idx] = _d_a[idx] + _d_b[idx];
Expand Down Expand Up @@ -71,16 +70,12 @@ double fn1(double i, double j) {
// CHECK-NEXT: double b[500] = {};
// CHECK-NEXT: double _d_c[500] = {};
// CHECK-NEXT: double c[500] = {};
// CHECK-NEXT: int _d_n = 0;
// CHECK-NEXT: int n = 500;
// CHECK-NEXT: {
// CHECK-NEXT: int _d_idx = 0;
// CHECK-NEXT: for (int idx = 0; idx < 500; ++idx) {
// CHECK-NEXT: _d_a[idx] = 0;
// CHECK-NEXT: a[idx] = 7;
// CHECK-NEXT: _d_b[idx] = 0;
// CHECK-NEXT: b[idx] = 9;
// CHECK-NEXT: }
// CHECK-NEXT: for (int idx = 0; idx < 500; ++idx) {
// CHECK-NEXT: _d_a[idx] = 0;
// CHECK-NEXT: a[idx] = 7;
// CHECK-NEXT: _d_b[idx] = 0;
// CHECK-NEXT: b[idx] = 9;
// CHECK-NEXT: }
// CHECK-NEXT: double *_d_device_a = nullptr;
// CHECK-NEXT: double *device_a = nullptr;
Expand All @@ -89,29 +84,26 @@ double fn1(double i, double j) {
// CHECK-NEXT: double *_d_device_c = nullptr;
// CHECK-NEXT: double *device_c = nullptr;
// CHECK-NEXT: unsigned long _t0 = sizeof(double);
// CHECK-NEXT: ValueAndPushforward<cudaError_t, cudaError_t> _t1 = clad::custom_derivatives::cudaMalloc_pushforward(&device_a, n * _t0, &_d_device_a, _d_n * _t0 + n * sizeof(double));
// CHECK-NEXT: ValueAndPushforward<cudaError_t, cudaError_t> _t1 = clad::custom_derivatives::cudaMalloc_pushforward(&device_a, n * _t0, &_d_device_a);
// CHECK-NEXT: unsigned long _t2 = sizeof(double);
// CHECK-NEXT: ValueAndPushforward<cudaError_t, cudaError_t> _t3 = clad::custom_derivatives::cudaMalloc_pushforward(&device_b, n * _t2, &_d_device_b, _d_n * _t2 + n * sizeof(double));
// CHECK-NEXT: ValueAndPushforward<cudaError_t, cudaError_t> _t3 = clad::custom_derivatives::cudaMalloc_pushforward(&device_b, n * _t2, &_d_device_b);
// CHECK-NEXT: unsigned long _t4 = sizeof(double);
// CHECK-NEXT: ValueAndPushforward<cudaError_t, cudaError_t> _t5 = clad::custom_derivatives::cudaMalloc_pushforward(&device_c, n * _t4, &_d_device_c, _d_n * _t4 + n * sizeof(double));
// CHECK-NEXT: ValueAndPushforward<cudaError_t, cudaError_t> _t5 = clad::custom_derivatives::cudaMalloc_pushforward(&device_c, n * _t4, &_d_device_c);
// CHECK-NEXT: unsigned long _t6 = sizeof(double);
// CHECK-NEXT: ValueAndPushforward<cudaError_t, cudaError_t> _t7 = clad::custom_derivatives::cudaMemcpy_pushforward(device_a, a, n * _t6, cudaMemcpyHostToDevice, _d_device_a, _d_a, _d_n * _t6 + n * sizeof(double));
// CHECK-NEXT: ValueAndPushforward<cudaError_t, cudaError_t> _t7 = clad::custom_derivatives::cudaMemcpy_pushforward(device_a, a, n * _t6, cudaMemcpyHostToDevice, _d_device_a, _d_a);
// CHECK-NEXT: unsigned long _t8 = sizeof(double);
// CHECK-NEXT: ValueAndPushforward<cudaError_t, cudaError_t> _t9 = clad::custom_derivatives::cudaMemcpy_pushforward(device_b, b, n * _t8, cudaMemcpyHostToDevice, _d_device_b, _d_b, _d_n * _t8 + n * sizeof(double));
// CHECK-NEXT: ValueAndPushforward<cudaError_t, cudaError_t> _t9 = clad::custom_derivatives::cudaMemcpy_pushforward(device_b, b, n * _t8, cudaMemcpyHostToDevice, _d_device_b, _d_b);
// CHECK-NEXT: unsigned long _t10 = sizeof(double);
// CHECK-NEXT: ValueAndPushforward<cudaError_t, cudaError_t> _t11 = clad::custom_derivatives::cudaMemcpy_pushforward(device_c, c, n * _t10, cudaMemcpyHostToDevice, _d_device_c, _d_c, _d_n * _t10 + n * sizeof(double));
// CHECK-NEXT: add_pushforward<<<1, 700>>>(device_a, device_b, device_c, n, _d_device_a, _d_device_b, _d_device_c, _d_n);
// CHECK-NEXT: ValueAndPushforward<cudaError_t, cudaError_t> _t11 = clad::custom_derivatives::cudaMemcpy_pushforward(device_c, c, n * _t10, cudaMemcpyHostToDevice, _d_device_c, _d_c);
// CHECK-NEXT: add_pushforward<<<1, 700>>>(device_a, device_b, device_c, n, _d_device_a, _d_device_b, _d_device_c);
// CHECK-NEXT: ValueAndPushforward<int, int> _t12 = clad::custom_derivatives::cudaDeviceSynchronize_pushforward();
// CHECK-NEXT: unsigned long _t13 = sizeof(double);
// CHECK-NEXT: ValueAndPushforward<cudaError_t, cudaError_t> _t14 = clad::custom_derivatives::cudaMemcpy_pushforward(c, device_c, n * _t13, cudaMemcpyDeviceToHost, _d_c, _d_device_c, _d_n * _t13 + n * sizeof(double));
// CHECK-NEXT: ValueAndPushforward<cudaError_t, cudaError_t> _t14 = clad::custom_derivatives::cudaMemcpy_pushforward(c, device_c, n * _t13, cudaMemcpyDeviceToHost, _d_c, _d_device_c);
// CHECK-NEXT: double _d_sum = 0;
// CHECK-NEXT: double sum = 0;
// CHECK-NEXT: {
// CHECK-NEXT: int _d_idx = 0;
// CHECK-NEXT: for (int idx = 0; idx < n; ++idx) {
// CHECK-NEXT: _d_sum += _d_c[idx];
// CHECK-NEXT: sum += c[idx];
// CHECK-NEXT: }
// CHECK-NEXT: for (int idx = 0; idx < n; ++idx) {
// CHECK-NEXT: _d_sum += _d_c[idx];
// CHECK-NEXT: sum += c[idx];
// CHECK-NEXT: }
// CHECK-NEXT: double _t15 = 2 * sum;
// CHECK-NEXT: return _d_sum * i + sum * _d_i + (0 * sum + 2 * _d_sum) * j + _t15 * _d_j;
Expand Down
Loading

0 comments on commit 685ad18

Please sign in to comment.