Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Consider array parameters differentiable in forward mode
Browse files Browse the repository at this point in the history
PetroZarytskyi committed Aug 28, 2024
1 parent e2f4638 commit 20b2a28
Showing 9 changed files with 148 additions and 23 deletions.
14 changes: 14 additions & 0 deletions include/clad/Differentiator/Array.h
Original file line number Diff line number Diff line change
@@ -108,6 +108,20 @@ template <typename T> class array {

/// Returns the size of the underlying array
CUDA_HOST_DEVICE std::size_t size() const { return m_size; }
/// Extends the size of array to `size` and default-initializer the new
/// elements if the current array size is less than `size`.
CUDA_HOST_DEVICE void extend(std::size_t size) {
if (size > m_size) {
T* extendedArr = new T[size];
for (std::size_t i = 0; i < m_size; ++i)
extendedArr[i] = m_arr[i];
for (std::size_t i = m_size; i < size; ++i)
extendedArr[i] = T();
delete m_arr;
m_arr = extendedArr;
m_size = size;
}
}
/// Iterator functions
CUDA_HOST_DEVICE T* begin() { return m_arr; }
CUDA_HOST_DEVICE const T* begin() const { return m_arr; }
4 changes: 4 additions & 0 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
@@ -607,6 +607,10 @@ namespace clad {
clang::SourceLocation srcLoc);

clang::QualType DetermineCladArrayValueType(clang::QualType T);
/// Extend the size of `arr` to safely access the element corresponding to
/// `idx`. Works only for clad::array when handling array parameters in
/// forward mode.
void EmitCladArrayExtend(StmtDiff arr, clang::Expr* idx);

/// Returns clad::Identify template declaration.
clang::TemplateDecl* GetCladConstructorPushforwardTag();
45 changes: 33 additions & 12 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
@@ -234,15 +234,32 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD,
// non-reference type for creating the derivatives.
QualType dParamType = param->getType().getNonReferenceType();
// We do not create derived variable for array/pointer parameters.
if (!BaseForwardModeVisitor::IsDifferentiableType(dParamType) ||
utils::isArrayOrPointerType(dParamType))
if (!BaseForwardModeVisitor::IsDifferentiableType(dParamType))
continue;
Expr* dParam = nullptr;
bool isArrayTy = utils::isArrayOrPointerType(dParamType);
if (dParamType->isRealType()) {
// If param is independent variable, its derivative is 1, otherwise 0.
int dValue = (param == m_IndependentVar);
dParam = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context,
dValue);
} else if (isArrayTy) {
if (param == m_IndependentVar)
continue;

if (auto* DT = dyn_cast<DecayedType>(dParamType)) {
if (auto* CAT = dyn_cast<ConstantArrayType>(DT->getOriginalType())) {
if (param == m_IndependentVar)
continue;
Expr* zero = ConstantFolder::synthesizeLiteral(m_Context.IntTy,
m_Context, 0);
dParam = m_Sema.ActOnInitList(noLoc, {zero}, noLoc).get();
dParamType = QualType::getFromOpaquePtr(CAT);
}
} else {
dParamType = GetCladArrayOfType(utils::GetValueType(dParamType));
dParam = getZeroInit(dParamType);
}
}
// For each function arg, create a variable _d_arg to store derivatives
// of potential reassignments, e.g.:
@@ -254,6 +271,10 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD,
BuildVarDecl(dParamType, "_d_" + param->getNameAsString(), dParam);
addToCurrentBlock(BuildDeclStmt(dParamDecl));
dParam = BuildDeclRef(dParamDecl);
if (!isa<ConstantArrayType>(dParamType) && isArrayTy) {
llvm::SmallVector<Expr*, 0> noParams{};
dParam = BuildCallExprToMemFn(dParam, "ptr", noParams);
}
if (dParamType->isRecordType() && param == m_IndependentVar) {
llvm::SmallVector<llvm::StringRef, 4> ref(diffVarInfo.fields.begin(),
diffVarInfo.fields.end());
@@ -992,7 +1013,6 @@ BaseForwardModeVisitor::VisitArraySubscriptExpr(const ArraySubscriptExpr* ASE) {
VD = DRE->getDecl();
}
if (VD == m_IndependentVar) {
llvm::APSInt index;
Expr* diffExpr = nullptr;
Expr::EvalResult res;
Expr::SideEffectsKind AllowSideEffects =
@@ -1017,12 +1037,10 @@ BaseForwardModeVisitor::VisitArraySubscriptExpr(const ArraySubscriptExpr* ASE) {
return StmtDiff(cloned, zero);

Expr* target = it->second;
// FIXME: fix when adding array inputs
if (!isArrayOrPointerType(target->getType()))
return StmtDiff(cloned, zero);
// llvm::APSInt IVal;
// if (!I->EvaluateAsInt(IVal, m_Context))
// return;
// The size of array parameters is unknown
// so we need to always extend the adjoint size before accessing the element.
if (utils::isArrayOrPointerType(clonedBase->getType()))
EmitCladArrayExtend({clonedBase, target}, clonedIndices.back());
// Create the _result[idx] expression.
auto result_at_is = BuildArraySubscript(target, clonedIndices);
return StmtDiff(cloned, result_at_is);
@@ -1372,10 +1390,13 @@ StmtDiff BaseForwardModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) {
opKind == UnaryOperatorKind::UO_Imag) {
return StmtDiff(op, BuildOp(opKind, diff.getExpr_dx()));
} else if (opKind == UnaryOperatorKind::UO_Deref) {
if (Expr* dx = diff.getExpr_dx())
Expr* zero = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context,
/*val=*/0);
if (Expr* dx = diff.getExpr_dx()) {
EmitCladArrayExtend(diff, zero);
return StmtDiff(op, BuildOp(opKind, dx));
return StmtDiff(op, ConstantFolder::synthesizeLiteral(
m_Context.IntTy, m_Context, /*val=*/0));
}
return StmtDiff(op, zero);
} else if (opKind == UnaryOperatorKind::UO_AddrOf) {
return StmtDiff(op, BuildOp(opKind, diff.getExpr_dx()));
} else if (opKind == UnaryOperatorKind::UO_LNot) {
23 changes: 23 additions & 0 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
@@ -1601,6 +1601,29 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return {call, callDiff};
}

// FIXME: There is nothing special about the `clad::array::extend` function
// but we need to be able to differentiate it to prevent segfaults in
// hessians. Once we add support for methods that change their objects, this
// section should be removed.
if (FDName == "extend")
if (auto* MCE = dyn_cast<CXXMemberCallExpr>(CE->IgnoreImplicit())) {
const Expr* cladArr =
MCE->getImplicitObjectArgument()->IgnoreImplicit();
if (isCladArrayType(cladArr->getType())) {
Expr* size = Clone(CE->getArg(0));
llvm::SmallVector<Expr*, 1> param{size};
StmtDiff arrDiff = Visit(cladArr);
Expr* extendCall_dx =
BuildCallExprToMemFn(arrDiff.getExpr_dx(), "extend", param);
Expr* extendCall =
BuildCallExprToMemFn(arrDiff.getExpr(), "extend", param);
beginBlock(direction::forward);
addToCurrentBlock(extendCall_dx);
addToCurrentBlock(extendCall);
return endBlock(direction::forward);
}
}

auto NArgs = FD->getNumParams();
// If the function has no args and is not a member function call then we
// assume that it is not related to independent variables and does not
46 changes: 43 additions & 3 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
@@ -822,11 +822,14 @@ namespace clad {
derivedL = LDiff.getExpr_dx();
derivedR = RDiff.getExpr_dx();
if (utils::isArrayOrPointerType(LDiff.getExpr()->getType()) &&
!utils::isArrayOrPointerType(RDiff.getExpr()->getType()))
!utils::isArrayOrPointerType(RDiff.getExpr()->getType())) {
derivedR = RDiff.getExpr();
else if (utils::isArrayOrPointerType(RDiff.getExpr()->getType()) &&
!utils::isArrayOrPointerType(LDiff.getExpr()->getType()))
EmitCladArrayExtend(LDiff, derivedR);
} else if (utils::isArrayOrPointerType(RDiff.getExpr()->getType()) &&
!utils::isArrayOrPointerType(LDiff.getExpr()->getType())) {
derivedL = LDiff.getExpr();
EmitCladArrayExtend(RDiff, derivedL);
}
}

Stmt* VisitorBase::GetCladZeroInit(llvm::MutableArrayRef<Expr*> args) {
@@ -853,4 +856,41 @@ namespace clad {
VisitorBase::GetCladConstructorPushforwardTagOfType(clang::QualType T) {
return InstantiateTemplate(GetCladConstructorPushforwardTag(), {T});
}

void VisitorBase::EmitCladArrayExtend(StmtDiff arr, Expr* idx) {
// FIXME: For now, only forward mode supports not differentiating w.r.t.
// array parameters.
if (m_DiffReq.Mode != DiffMode::forward)
return;
if (isa<DeclRefExpr>(arr.getExpr()->IgnoreImplicit()))
if (auto* MCE =
dyn_cast<CXXMemberCallExpr>(arr.getExpr_dx()->IgnoreImplicit())) {
Expr* cladArr = MCE->getImplicitObjectArgument()->IgnoreImplicit();
if (isCladArrayType(cladArr->getType()) &&
MCE->getDirectCallee()->getNameAsString() == "ptr") {
Expr* size = nullptr;
Expr::EvalResult index;
// If it's possible to determine the index at compile time, generate
// the `extend` argument as a literal. This will help us avoid ugly
// code like
// ```
// _d_arr.extend(0 + 1);
// _d_arr[0] = ...;
// ```
if (idx->EvaluateAsInt(index, m_Context,
Expr::SideEffectsKind::SE_NoSideEffects)) {
size = ConstantFolder::synthesizeLiteral(
m_Context.IntTy, m_Context,
index.Val.getInt().getExtValue() + 1);
} else {
Expr* one = ConstantFolder::synthesizeLiteral(m_Context.IntTy,
m_Context, 1);
size = BuildOp(BO_Add, idx, one);
}
llvm::SmallVector<Expr*, 1> param{size};
Expr* extendCall = BuildCallExprToMemFn(cladArr, "extend", param);
addToCurrentBlock(extendCall);
}
}
}
} // end namespace clad
1 change: 1 addition & 0 deletions test/Arrays/ArrayInputsForwardMode.C
Original file line number Diff line number Diff line change
@@ -59,6 +59,7 @@ double numMultIndex(double* arr, size_t n, double x) {
}

// CHECK: double numMultIndex_darg2(double *arr, size_t n, double x) {
// CHECK-NEXT: clad::array<double> _d_arr = {};
// CHECK-NEXT: size_t _d_n = 0;
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: bool _d_flag = 0;
18 changes: 15 additions & 3 deletions test/FirstDerivative/CallArguments.C
Original file line number Diff line number Diff line change
@@ -142,13 +142,16 @@ float f_literal_args_func(float x, float y, float *z) {
printf("hello world ");
return x * f_literal_helper(0.5, 'a', z, nullptr);
}
// CHECK: clad::ValueAndPushforward<float, float> f_literal_helper_pushforward(float x, char ch, float *p, float *q, float _d_x, char _d_ch, float *_d_p, float *_d_q);

// CHECK: float f_literal_args_func_darg0(float x, float y, float *z) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: float _d_y = 0;
// CHECK-NEXT: clad::array<float> _d_z = {};
// CHECK-NEXT: printf("hello world ");
// CHECK-NEXT: float _t0 = f_literal_helper(0.5, 'a', z, nullptr);
// CHECK-NEXT: return _d_x * _t0 + x * 0;
// CHECK-NEXT: clad::ValueAndPushforward<float, float> _t0 = f_literal_helper_pushforward(0.5, 'a', z, nullptr, 0., 0, _d_z.ptr(), nullptr);
// CHECK-NEXT: float &_t1 = _t0.value;
// CHECK-NEXT: return _d_x * _t1 + x * _t0.pushforward;
// CHECK-NEXT: }

inline unsigned int getBin(double low, double high, double val, unsigned int numBins) {
@@ -164,8 +167,11 @@ float f_call_inline_fxn(float *params, float const *obs, float const *xlArr) {
// CHECK: inline clad::ValueAndPushforward<unsigned int, unsigned int> getBin_pushforward(double low, double high, double val, unsigned int numBins, double _d_low, double _d_high, double _d_val, unsigned int _d_numBins);

// CHECK: float f_call_inline_fxn_darg0_0(float *params, const float *obs, const float *xlArr) {
// CHECK-NEXT: clad::array<float> _d_obs = {};
// CHECK-NEXT: clad::array<float> _d_xlArr = {};
// CHECK-NEXT: clad::ValueAndPushforward<unsigned int, unsigned int> _t0 = getBin_pushforward(0., 1., params[0], 1, 0., 0., 1.F, 0);
// CHECK-NEXT: const float _d_t116 = 0;
// CHECK-NEXT: _d_xlArr.extend(_t0.value + 1);
// CHECK-NEXT: const float _d_t116 = *(_d_xlArr.ptr() + _t0.value);
// CHECK-NEXT: const float t116 = *(xlArr + _t0.value);
// CHECK-NEXT: return _d_t116 * params[0] + t116 * 1.F;
// CHECK-NEXT: }
@@ -216,6 +222,12 @@ int main () { // expected-no-diagnostics
// CHECK-NEXT: return {x * x, _d_x * x + x * _d_x};
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<float, float> f_literal_helper_pushforward(float x, char ch, float *p, float *q, float _d_x, char _d_ch, float *_d_p, float *_d_q) {
// CHECK-NEXT: if (ch == 'a')
// CHECK-NEXT: return {x * x, _d_x * x + x * _d_x};
// CHECK-NEXT: return {-x * x, -_d_x * x + -x * _d_x};
// CHECK-NEXT: }

// CHECK: inline clad::ValueAndPushforward<unsigned int, unsigned int> getBin_pushforward(double low, double high, double val, unsigned int numBins, double _d_low, double _d_high, double _d_val, unsigned int _d_numBins) {
// CHECK-NEXT: double _t0 = (high - low);
// CHECK-NEXT: double _d_binWidth = ((_d_high - _d_low) * numBins - _t0 * _d_numBins) / (numBins * numBins);
8 changes: 6 additions & 2 deletions test/ForwardMode/Pointer.C
Original file line number Diff line number Diff line change
@@ -198,7 +198,9 @@ double fn9(double* params, const double *constants) {
}

// CHECK: double fn9_darg0_0(double *params, const double *constants) {
// CHECK-NEXT: double _d_c0 = 0;
// CHECK-NEXT: clad::array<double> _d_constants = {};
// CHECK-NEXT: _d_constants.extend(1);
// CHECK-NEXT: double _d_c0 = *_d_constants.ptr();
// CHECK-NEXT: double c0 = *constants;
// CHECK-NEXT: return 1. * c0 + params[0] * _d_c0;
// CHECK-NEXT: }
@@ -209,7 +211,9 @@ double fn10(double *params, const double *constants) {
}

// CHECK: double fn10_darg0_0(double *params, const double *constants) {
// CHECK-NEXT: double _d_c0 = 0;
// CHECK-NEXT: clad::array<double> _d_constants = {};
// CHECK-NEXT: _d_constants.extend(1);
// CHECK-NEXT: double _d_c0 = *(_d_constants.ptr() + 0);
// CHECK-NEXT: double c0 = *(constants + 0);
// CHECK-NEXT: return 1. * c0 + params[0] * _d_c0;
// CHECK-NEXT: }
12 changes: 9 additions & 3 deletions test/ROOT/TFormula.C
Original file line number Diff line number Diff line change
@@ -55,24 +55,30 @@ void TFormula_example_grad_1(Double_t* x, Double_t* p, Double_t* _d_p);
//CHECK-NEXT: }

//CHECK: Double_t TFormula_example_darg1_0(Double_t *x, Double_t *p) {
//CHECK-NEXT: clad::array<Double_t> _d_x = {};
//CHECK-NEXT: _d_x.extend(1);
//CHECK-NEXT: {{double|Double_t}} _t0 = (p[0] + p[1] + p[2]);
//CHECK-NEXT: clad::ValueAndPushforward<Double_t, Double_t> _t1 = clad::custom_derivatives{{(::std)?}}::TMath::Exp_pushforward(-p[0], -1.);
//CHECK-NEXT: clad::ValueAndPushforward<Double_t, Double_t> _t2 = clad::custom_derivatives{{(::std)?}}::TMath::Abs_pushforward(p[1], 0);
//CHECK-NEXT: return 0 * _t0 + x[0] * (1. + 0 + 0) + _t1.pushforward + _t2.pushforward;
//CHECK-NEXT: return _d_x.ptr()[0] * _t0 + x[0] * (1. + 0 + 0) + _t1.pushforward + _t2.pushforward;
//CHECK-NEXT: }

//CHECK: Double_t TFormula_example_darg1_1(Double_t *x, Double_t *p) {
//CHECK-NEXT: clad::array<Double_t> _d_x = {};
//CHECK-NEXT: _d_x.extend(1);
//CHECK-NEXT: {{double|Double_t}} _t0 = (p[0] + p[1] + p[2]);
//CHECK-NEXT: clad::ValueAndPushforward<Double_t, Double_t> _t1 = clad::custom_derivatives{{(::std)?}}::TMath::Exp_pushforward(-p[0], -0);
//CHECK-NEXT: clad::ValueAndPushforward<Double_t, Double_t> _t2 = clad::custom_derivatives{{(::std)?}}::TMath::Abs_pushforward(p[1], 1.);
//CHECK-NEXT: return 0 * _t0 + x[0] * (0 + 1. + 0) + _t1.pushforward + _t2.pushforward;
//CHECK-NEXT: return _d_x.ptr()[0] * _t0 + x[0] * (0 + 1. + 0) + _t1.pushforward + _t2.pushforward;
//CHECK-NEXT: }

//CHECK: Double_t TFormula_example_darg1_2(Double_t *x, Double_t *p) {
//CHECK-NEXT: clad::array<Double_t> _d_x = {};
//CHECK-NEXT: _d_x.extend(1);
//CHECK-NEXT: {{double|Double_t}} _t0 = (p[0] + p[1] + p[2]);
//CHECK-NEXT: clad::ValueAndPushforward<Double_t, Double_t> _t1 = clad::custom_derivatives{{(::std)?}}::TMath::Exp_pushforward(-p[0], -0);
//CHECK-NEXT: clad::ValueAndPushforward<Double_t, Double_t> _t2 = clad::custom_derivatives{{(::std)?}}::TMath::Abs_pushforward(p[1], 0);
//CHECK-NEXT: return 0 * _t0 + x[0] * (0 + 0 + 1.) + _t1.pushforward + _t2.pushforward;
//CHECK-NEXT: return _d_x.ptr()[0] * _t0 + x[0] * (0 + 0 + 1.) + _t1.pushforward + _t2.pushforward;
//CHECK-NEXT: }

Double_t TFormula_hess1(Double_t *x, Double_t *p) {

0 comments on commit 20b2a28

Please sign in to comment.