Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Differentiate array parameters in forward mode #1062

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions include/clad/Differentiator/Array.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,21 @@ template <typename T> class array {

/// Returns the size of the underlying array
CUDA_HOST_DEVICE std::size_t size() const { return m_size; }
/// Extends the size of array to `size` and default-initializer the new
/// elements if the current array size is less than `size`.
CUDA_HOST_DEVICE void extend(std::size_t size) {
if (size > m_size) {
// NOLINTNEXTLINE(cppcoreguidelines-owning-memory)
T* extendedArr = new T[size];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: initializing non-owner 'T *' with a newly created 'gsl::owner<>' [cppcoreguidelines-owning-memory]

      T* extendedArr = new T[size];
      ^

for (std::size_t i = 0; i < m_size; ++i)
extendedArr[i] = m_arr[i];
for (std::size_t i = m_size; i < size; ++i)
extendedArr[i] = T();
delete m_arr;
m_arr = extendedArr;
m_size = size;
}
}
/// Iterator functions
CUDA_HOST_DEVICE T* begin() { return m_arr; }
CUDA_HOST_DEVICE const T* begin() const { return m_arr; }
Expand Down Expand Up @@ -446,6 +461,20 @@ operator/(const array<T>& arr1, const array<U>& arr2) {
arr2);
}

namespace custom_derivatives {
namespace class_functions {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: nested namespaces can be concatenated [modernize-concat-nested-namespaces]

Suggested change
namespace class_functions {
namespace custom_derivatives::class_functions {

include/clad/Differentiator/Array.h:474:

- } // namespace class_functions
- } // namespace custom_derivatives
+ } // namespace custom_derivatives

template <typename T>
void extend_reverse_forw(array<T>* arr, std::size_t size, array<T>* d_arr,
std::size_t d_size) {
arr->extend(size);
d_arr->extend(size);
}
template <typename T>
void extend_pullback(array<T>* arr, std::size_t size, array<T>* d_arr,
std::size_t* d_size) {}
} // namespace class_functions
} // namespace custom_derivatives

} // namespace clad

#endif // CLAD_ARRAY_H
4 changes: 4 additions & 0 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,10 @@ namespace clad {
clang::SourceLocation srcLoc);

clang::QualType DetermineCladArrayValueType(clang::QualType T);
/// Extend the size of `arr` to safely access the element corresponding to
/// `idx`. Works only for clad::array when handling array parameters in
/// forward mode.
void EmitCladArrayExtend(StmtDiff arr, clang::Expr* idx);

/// Returns clad::Identify template declaration.
clang::TemplateDecl* GetCladConstructorPushforwardTag();
Expand Down
39 changes: 29 additions & 10 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,15 +229,31 @@ DerivativeAndOverload BaseForwardModeVisitor::Derive() {
// non-reference type for creating the derivatives.
QualType dParamType = param->getType().getNonReferenceType();
// We do not create derived variable for array/pointer parameters.
if (!BaseForwardModeVisitor::IsDifferentiableType(dParamType) ||
utils::isArrayOrPointerType(dParamType))
if (!BaseForwardModeVisitor::IsDifferentiableType(dParamType))
continue;
Expr* dParam = nullptr;
bool isArrayTy = utils::isArrayOrPointerType(dParamType);
if (dParamType->isRealType()) {
// If param is independent variable, its derivative is 1, otherwise 0.
int dValue = (param == m_IndependentVar);
dParam = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context,
dValue);
} else if (isArrayTy) {
if (param == m_IndependentVar)
continue;

if (const auto* DT = dyn_cast<DecayedType>(dParamType)) {
if (const auto* CAT =
dyn_cast<ConstantArrayType>(DT->getOriginalType())) {
Expr* zero = ConstantFolder::synthesizeLiteral(
m_Context.IntTy, m_Context, /*val=*/0);
dParam = m_Sema.ActOnInitList(noLoc, {zero}, noLoc).get();
dParamType = QualType::getFromOpaquePtr(CAT);
}
} else {
dParamType = GetCladArrayOfType(utils::GetValueType(dParamType));
dParam = getZeroInit(dParamType);
}
}
// For each function arg, create a variable _d_arg to store derivatives
// of potential reassignments, e.g.:
Expand All @@ -249,6 +265,10 @@ DerivativeAndOverload BaseForwardModeVisitor::Derive() {
BuildVarDecl(dParamType, "_d_" + param->getNameAsString(), dParam);
addToCurrentBlock(BuildDeclStmt(dParamDecl));
dParam = BuildDeclRef(dParamDecl);
if (!isa<ConstantArrayType>(dParamType) && isArrayTy) {
llvm::SmallVector<Expr*, 0> noParams{};
dParam = BuildCallExprToMemFn(dParam, "ptr", noParams);
}
if (dParamType->isRecordType() && param == m_IndependentVar) {
llvm::SmallVector<llvm::StringRef, 4> ref(diffVarInfo.fields.begin(),
diffVarInfo.fields.end());
Expand Down Expand Up @@ -984,7 +1004,6 @@ BaseForwardModeVisitor::VisitArraySubscriptExpr(const ArraySubscriptExpr* ASE) {
VD = DRE->getDecl();
}
if (VD == m_IndependentVar) {
llvm::APSInt index;
Expr* diffExpr = nullptr;
Expr::EvalResult res;
Expr::SideEffectsKind AllowSideEffects =
Expand All @@ -1009,12 +1028,10 @@ BaseForwardModeVisitor::VisitArraySubscriptExpr(const ArraySubscriptExpr* ASE) {
return StmtDiff(cloned, zero);

Expr* target = it->second;
// FIXME: fix when adding array inputs
if (!isArrayOrPointerType(target->getType()))
return StmtDiff(cloned, zero);
// llvm::APSInt IVal;
// if (!I->EvaluateAsInt(IVal, m_Context))
// return;
// The size of array parameters is unknown
// so we need to always extend the adjoint size before accessing the element.
if (utils::isArrayOrPointerType(clonedBase->getType()))
EmitCladArrayExtend({clonedBase, target}, clonedIndices.back());
// Create the _result[idx] expression.
auto result_at_is = BuildArraySubscript(target, clonedIndices);
return StmtDiff(cloned, result_at_is);
Expand Down Expand Up @@ -1366,8 +1383,10 @@ StmtDiff BaseForwardModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) {
opKind == UnaryOperatorKind::UO_Imag) {
return StmtDiff(op, BuildOp(opKind, diff.getExpr_dx()));
} else if (opKind == UnaryOperatorKind::UO_Deref) {
if (Expr* dx = diff.getExpr_dx())
if (Expr* dx = diff.getExpr_dx()) {
EmitCladArrayExtend(diff, getZeroInit(m_Context.IntTy));
return StmtDiff(op, BuildOp(opKind, dx));
}
QualType literalTy =
utils::GetValueType(UnOp->getSubExpr()->getType()->getPointeeType());
return StmtDiff(
Expand Down
46 changes: 43 additions & 3 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -853,11 +853,14 @@ namespace clad {
derivedL = LDiff.getExpr_dx();
derivedR = RDiff.getExpr_dx();
if (utils::isArrayOrPointerType(LDiff.getExpr()->getType()) &&
!utils::isArrayOrPointerType(RDiff.getExpr()->getType()))
!utils::isArrayOrPointerType(RDiff.getExpr()->getType())) {
derivedR = RDiff.getExpr();
else if (utils::isArrayOrPointerType(RDiff.getExpr()->getType()) &&
!utils::isArrayOrPointerType(LDiff.getExpr()->getType()))
EmitCladArrayExtend(LDiff, derivedR);
} else if (utils::isArrayOrPointerType(RDiff.getExpr()->getType()) &&
!utils::isArrayOrPointerType(LDiff.getExpr()->getType())) {
derivedL = LDiff.getExpr();
EmitCladArrayExtend(RDiff, derivedL);
}
}

Stmt* VisitorBase::GetCladZeroInit(llvm::MutableArrayRef<Expr*> args) {
Expand Down Expand Up @@ -896,4 +899,41 @@ namespace clad {
VisitorBase::GetCladConstructorReverseForwTagOfType(clang::QualType T) {
return InstantiateTemplate(GetCladConstructorReverseForwTag(), {T});
}

void VisitorBase::EmitCladArrayExtend(StmtDiff arr, Expr* idx) {
// FIXME: For now, only forward mode supports not differentiating w.r.t.
// array parameters.
if (m_DiffReq.Mode != DiffMode::forward)
return;
if (isa<DeclRefExpr>(arr.getExpr()->IgnoreImplicit()))
if (auto* MCE =
dyn_cast<CXXMemberCallExpr>(arr.getExpr_dx()->IgnoreImplicit())) {
Expr* cladArr = MCE->getImplicitObjectArgument()->IgnoreImplicit();
if (isCladArrayType(cladArr->getType()) &&
MCE->getDirectCallee()->getNameAsString() == "ptr") {
Expr* size = nullptr;
Expr::EvalResult index;
// If it's possible to determine the index at compile time, generate
// the `extend` argument as a literal. This will help us avoid ugly
// code like
// ```
// _d_arr.extend(0 + 1);
// _d_arr[0] = ...;
// ```
if (idx->EvaluateAsInt(index, m_Context,
Expr::SideEffectsKind::SE_NoSideEffects)) {
size = ConstantFolder::synthesizeLiteral(
m_Context.IntTy, m_Context,
index.Val.getInt().getExtValue() + 1);
} else {
Expr* one = ConstantFolder::synthesizeLiteral(m_Context.IntTy,
m_Context, /*val=*/1);
size = BuildOp(BO_Add, idx, one);
}
llvm::SmallVector<Expr*, 1> param{size};
Expr* extendCall = BuildCallExprToMemFn(cladArr, "extend", param);
addToCurrentBlock(extendCall);
}
}
}
} // end namespace clad
58 changes: 58 additions & 0 deletions test/Arrays/ArrayInputsForwardMode.C
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ double numMultIndex(double* arr, size_t n, double x) {
}

// CHECK: double numMultIndex_darg2(double *arr, size_t n, double x) {
// CHECK-NEXT: clad::array<double> _d_arr = {};
// CHECK-NEXT: size_t _d_n = 0;
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: bool _d_flag = 0;
Expand All @@ -78,6 +79,55 @@ double numMultIndex(double* arr, size_t n, double x) {
// CHECK-NEXT: return flag ? _d_idx * x + idx * _d_x : 0;
// CHECK-NEXT: }

double modifyArr(double* arr, double x) {
arr[3] *= x;
for (int i = 0; i < 5; ++i)
arr[i] /= 2;
return *(arr + 3);
}

// CHECK: double modifyArr_darg1(double *arr, double x) {
// CHECK-NEXT: clad::array<double> _d_arr = {};
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: _d_arr.extend(4);
// CHECK-NEXT: double &_t0 = _d_arr.ptr()[3];
// CHECK-NEXT: _t0 = _t0 * x + arr[3] * _d_x;
// CHECK-NEXT: arr[3] *= x;
// CHECK-NEXT: {
// CHECK-NEXT: int _d_i = 0;
// CHECK-NEXT: for (int i = 0; i < 5; ++i) {
// CHECK-NEXT: _d_arr.extend(i + 1);
// CHECK-NEXT: double &_t1 = _d_arr.ptr()[i];
// CHECK-NEXT: _t1 = (_t1 * 2 - arr[i] * 0) / (2 * 2);
// CHECK-NEXT: arr[i] /= 2;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: _d_arr.extend(4);
// CHECK-NEXT: return *(_d_arr.ptr() + 3);
// CHECK-NEXT: }

double modifyConstArr(double arr[5], double x) {
arr[3] *= x;
for (int i = 0; i < 5; ++i)
arr[i] /= 2;
return *(arr + 3);
}

// CHECK: double modifyConstArr_darg1(double arr[5], double x) {
// CHECK-NEXT: double _d_arr[5] = {0};
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: _d_arr[3] = _d_arr[3] * x + arr[3] * _d_x;
// CHECK-NEXT: arr[3] *= x;
// CHECK-NEXT: {
// CHECK-NEXT: int _d_i = 0;
// CHECK-NEXT: for (int i = 0; i < 5; ++i) {
// CHECK-NEXT: _d_arr[i] = (_d_arr[i] * 2 - arr[i] * 0) / (2 * 2);
// CHECK-NEXT: arr[i] /= 2;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return *(_d_arr + 3);
// CHECK-NEXT: }

int main() {
double arr[] = {1, 2, 3, 4, 5};
auto multiply_dx = clad::differentiate(multiply, "arr[1]");
Expand All @@ -91,4 +141,12 @@ int main() {

auto numMultIndex_dx = clad::differentiate(numMultIndex, "x");
printf("Result = {%.2f}\n", numMultIndex_dx.execute(arr, 5, 4)); // CHECK-EXEC: Result = {3.00}

auto modifyArr_dx = clad::differentiate(modifyArr, "x");
printf("Result = {%.2f}\n", modifyArr_dx.execute(arr, 5)); // CHECK-EXEC: Result = {2.00}
arr[0] = 1; arr[1] = 2; arr[2] = 3; arr[3] = 4; arr[4] = 5;

auto modifyConstArr_dx = clad::differentiate(modifyConstArr, "x");
printf("Result = {%.2f}\n", modifyConstArr_dx.execute(arr, 5)); // CHECK-EXEC: Result = {2.00}
arr[0] = 1; arr[1] = 2; arr[2] = 3; arr[3] = 4; arr[4] = 5;
}
18 changes: 15 additions & 3 deletions test/FirstDerivative/CallArguments.C
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,16 @@ float f_literal_args_func(float x, float y, float *z) {
printf("hello world ");
return x * f_literal_helper(0.5, 'a', z, nullptr);
}
// CHECK: clad::ValueAndPushforward<float, float> f_literal_helper_pushforward(float x, char ch, float *p, float *q, float _d_x, char _d_ch, float *_d_p, float *_d_q);

// CHECK: float f_literal_args_func_darg0(float x, float y, float *z) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: float _d_y = 0;
// CHECK-NEXT: clad::array<float> _d_z = {};
// CHECK-NEXT: printf("hello world ");
// CHECK-NEXT: float _t0 = f_literal_helper(0.5, 'a', z, nullptr);
// CHECK-NEXT: return _d_x * _t0 + x * 0.F;
// CHECK-NEXT: clad::ValueAndPushforward<float, float> _t0 = f_literal_helper_pushforward(0.5, 'a', z, nullptr, 0., 0, _d_z.ptr(), nullptr);
// CHECK-NEXT: float &_t1 = _t0.value;
// CHECK-NEXT: return _d_x * _t1 + x * _t0.pushforward;
// CHECK-NEXT: }

inline unsigned int getBin(double low, double high, double val, unsigned int numBins) {
Expand All @@ -162,8 +165,11 @@ float f_call_inline_fxn(float *params, float const *obs, float const *xlArr) {
// CHECK: inline clad::ValueAndPushforward<unsigned int, unsigned int> getBin_pushforward(double low, double high, double val, unsigned int numBins, double _d_low, double _d_high, double _d_val, unsigned int _d_numBins);

// CHECK: float f_call_inline_fxn_darg0_0(float *params, const float *obs, const float *xlArr) {
// CHECK-NEXT: clad::array<float> _d_obs = {};
// CHECK-NEXT: clad::array<float> _d_xlArr = {};
// CHECK-NEXT: clad::ValueAndPushforward<unsigned int, unsigned int> _t0 = getBin_pushforward(0., 1., params[0], 1, 0., 0., 1.F, 0);
// CHECK-NEXT: const float _d_t116 = 0.F;
// CHECK-NEXT: _d_xlArr.extend(_t0.value + 1);
// CHECK-NEXT: const float _d_t116 = *(_d_xlArr.ptr() + _t0.value);
// CHECK-NEXT: const float t116 = *(xlArr + _t0.value);
// CHECK-NEXT: return _d_t116 * params[0] + t116 * 1.F;
// CHECK-NEXT: }
Expand Down Expand Up @@ -214,6 +220,12 @@ int main () { // expected-no-diagnostics
// CHECK-NEXT: return {x * x, _d_x * x + x * _d_x};
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<float, float> f_literal_helper_pushforward(float x, char ch, float *p, float *q, float _d_x, char _d_ch, float *_d_p, float *_d_q) {
// CHECK-NEXT: if (ch == 'a')
// CHECK-NEXT: return {x * x, _d_x * x + x * _d_x};
// CHECK-NEXT: return {-x * x, -_d_x * x + -x * _d_x};
// CHECK-NEXT: }

// CHECK: inline clad::ValueAndPushforward<unsigned int, unsigned int> getBin_pushforward(double low, double high, double val, unsigned int numBins, double _d_low, double _d_high, double _d_val, unsigned int _d_numBins) {
// CHECK-NEXT: double _t0 = (high - low);
// CHECK-NEXT: double _d_binWidth = ((_d_high - _d_low) * numBins - _t0 * _d_numBins) / (numBins * numBins);
Expand Down
8 changes: 6 additions & 2 deletions test/ForwardMode/Pointer.C
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,9 @@ double fn9(double* params, const double *constants) {
}

// CHECK: double fn9_darg0_0(double *params, const double *constants) {
// CHECK-NEXT: double _d_c0 = 0.;
// CHECK-NEXT: clad::array<double> _d_constants = {};
// CHECK-NEXT: _d_constants.extend(1);
// CHECK-NEXT: double _d_c0 = *_d_constants.ptr();
// CHECK-NEXT: double c0 = *constants;
// CHECK-NEXT: return 1. * c0 + params[0] * _d_c0;
// CHECK-NEXT: }
Expand All @@ -208,7 +210,9 @@ double fn10(double *params, const double *constants) {
}

// CHECK: double fn10_darg0_0(double *params, const double *constants) {
// CHECK-NEXT: double _d_c0 = 0.;
// CHECK-NEXT: clad::array<double> _d_constants = {};
// CHECK-NEXT: _d_constants.extend(1);
// CHECK-NEXT: double _d_c0 = *(_d_constants.ptr() + 0);
// CHECK-NEXT: double c0 = *(constants + 0);
// CHECK-NEXT: return 1. * c0 + params[0] * _d_c0;
// CHECK-NEXT: }
Expand Down
12 changes: 9 additions & 3 deletions test/ROOT/TFormula.C
Original file line number Diff line number Diff line change
Expand Up @@ -53,24 +53,30 @@ void TFormula_example_grad_1(Double_t* x, Double_t* p, Double_t* _d_p);
//CHECK-NEXT: }

//CHECK: Double_t TFormula_example_darg1_0(Double_t *x, Double_t *p) {
//CHECK-NEXT: clad::array<Double_t> _d_x = {};
//CHECK-NEXT: _d_x.extend(1);
//CHECK-NEXT: {{double|Double_t}} _t0 = (p[0] + p[1] + p[2]);
//CHECK-NEXT: clad::ValueAndPushforward<Double_t, Double_t> _t1 = clad::custom_derivatives{{(::std)?}}::TMath::Exp_pushforward(-p[0], -1.);
//CHECK-NEXT: clad::ValueAndPushforward<Double_t, Double_t> _t2 = clad::custom_derivatives{{(::std)?}}::TMath::Abs_pushforward(p[1], 0.);
//CHECK-NEXT: return 0. * _t0 + x[0] * (1. + 0. + 0.) + _t1.pushforward + _t2.pushforward;
//CHECK-NEXT: return _d_x.ptr()[0] * _t0 + x[0] * (1. + 0. + 0.) + _t1.pushforward + _t2.pushforward;
//CHECK-NEXT: }

//CHECK: Double_t TFormula_example_darg1_1(Double_t *x, Double_t *p) {
//CHECK-NEXT: clad::array<Double_t> _d_x = {};
//CHECK-NEXT: _d_x.extend(1);
//CHECK-NEXT: {{double|Double_t}} _t0 = (p[0] + p[1] + p[2]);
//CHECK-NEXT: clad::ValueAndPushforward<Double_t, Double_t> _t1 = clad::custom_derivatives{{(::std)?}}::TMath::Exp_pushforward(-p[0], -0.);
//CHECK-NEXT: clad::ValueAndPushforward<Double_t, Double_t> _t2 = clad::custom_derivatives{{(::std)?}}::TMath::Abs_pushforward(p[1], 1.);
//CHECK-NEXT: return 0. * _t0 + x[0] * (0. + 1. + 0.) + _t1.pushforward + _t2.pushforward;
//CHECK-NEXT: return _d_x.ptr()[0] * _t0 + x[0] * (0. + 1. + 0.) + _t1.pushforward + _t2.pushforward;
//CHECK-NEXT: }

//CHECK: Double_t TFormula_example_darg1_2(Double_t *x, Double_t *p) {
//CHECK-NEXT: clad::array<Double_t> _d_x = {};
//CHECK-NEXT: _d_x.extend(1);
//CHECK-NEXT: {{double|Double_t}} _t0 = (p[0] + p[1] + p[2]);
//CHECK-NEXT: clad::ValueAndPushforward<Double_t, Double_t> _t1 = clad::custom_derivatives{{(::std)?}}::TMath::Exp_pushforward(-p[0], -0.);
//CHECK-NEXT: clad::ValueAndPushforward<Double_t, Double_t> _t2 = clad::custom_derivatives{{(::std)?}}::TMath::Abs_pushforward(p[1], 0.);
//CHECK-NEXT: return 0. * _t0 + x[0] * (0. + 0. + 1.) + _t1.pushforward + _t2.pushforward;
//CHECK-NEXT: return _d_x.ptr()[0] * _t0 + x[0] * (0. + 0. + 1.) + _t1.pushforward + _t2.pushforward;
//CHECK-NEXT: }

Double_t TFormula_hess1(Double_t *x, Double_t *p) {
Expand Down
Loading