From 20b2a28de487d90e2ef8b3ab18b854912ff1ac3e Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Tue, 27 Aug 2024 13:15:17 +0300 Subject: [PATCH] Consider array parameters differentiable in forward mode --- include/clad/Differentiator/Array.h | 14 ++++++ include/clad/Differentiator/VisitorBase.h | 4 ++ lib/Differentiator/BaseForwardModeVisitor.cpp | 45 +++++++++++++----- lib/Differentiator/ReverseModeVisitor.cpp | 23 ++++++++++ lib/Differentiator/VisitorBase.cpp | 46 +++++++++++++++++-- test/Arrays/ArrayInputsForwardMode.C | 1 + test/FirstDerivative/CallArguments.C | 18 ++++++-- test/ForwardMode/Pointer.C | 8 +++- test/ROOT/TFormula.C | 12 +++-- 9 files changed, 148 insertions(+), 23 deletions(-) diff --git a/include/clad/Differentiator/Array.h b/include/clad/Differentiator/Array.h index eef7de54e..0b840878e 100644 --- a/include/clad/Differentiator/Array.h +++ b/include/clad/Differentiator/Array.h @@ -108,6 +108,20 @@ template class array { /// Returns the size of the underlying array CUDA_HOST_DEVICE std::size_t size() const { return m_size; } + /// Extends the size of array to `size` and default-initializer the new + /// elements if the current array size is less than `size`. + CUDA_HOST_DEVICE void extend(std::size_t size) { + if (size > m_size) { + T* extendedArr = new T[size]; + for (std::size_t i = 0; i < m_size; ++i) + extendedArr[i] = m_arr[i]; + for (std::size_t i = m_size; i < size; ++i) + extendedArr[i] = T(); + delete m_arr; + m_arr = extendedArr; + m_size = size; + } + } /// Iterator functions CUDA_HOST_DEVICE T* begin() { return m_arr; } CUDA_HOST_DEVICE const T* begin() const { return m_arr; } diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index c209c8cee..5e02daef4 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -607,6 +607,10 @@ namespace clad { clang::SourceLocation srcLoc); clang::QualType DetermineCladArrayValueType(clang::QualType T); + /// Extend the size of `arr` to safely access the element corresponding to + /// `idx`. Works only for clad::array when handling array parameters in + /// forward mode. + void EmitCladArrayExtend(StmtDiff arr, clang::Expr* idx); /// Returns clad::Identify template declaration. clang::TemplateDecl* GetCladConstructorPushforwardTag(); diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 7a70f3bcb..0a6341fa7 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -234,15 +234,32 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD, // non-reference type for creating the derivatives. QualType dParamType = param->getType().getNonReferenceType(); // We do not create derived variable for array/pointer parameters. - if (!BaseForwardModeVisitor::IsDifferentiableType(dParamType) || - utils::isArrayOrPointerType(dParamType)) + if (!BaseForwardModeVisitor::IsDifferentiableType(dParamType)) continue; Expr* dParam = nullptr; + bool isArrayTy = utils::isArrayOrPointerType(dParamType); if (dParamType->isRealType()) { // If param is independent variable, its derivative is 1, otherwise 0. int dValue = (param == m_IndependentVar); dParam = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, dValue); + } else if (isArrayTy) { + if (param == m_IndependentVar) + continue; + + if (auto* DT = dyn_cast(dParamType)) { + if (auto* CAT = dyn_cast(DT->getOriginalType())) { + if (param == m_IndependentVar) + continue; + Expr* zero = ConstantFolder::synthesizeLiteral(m_Context.IntTy, + m_Context, 0); + dParam = m_Sema.ActOnInitList(noLoc, {zero}, noLoc).get(); + dParamType = QualType::getFromOpaquePtr(CAT); + } + } else { + dParamType = GetCladArrayOfType(utils::GetValueType(dParamType)); + dParam = getZeroInit(dParamType); + } } // For each function arg, create a variable _d_arg to store derivatives // of potential reassignments, e.g.: @@ -254,6 +271,10 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD, BuildVarDecl(dParamType, "_d_" + param->getNameAsString(), dParam); addToCurrentBlock(BuildDeclStmt(dParamDecl)); dParam = BuildDeclRef(dParamDecl); + if (!isa(dParamType) && isArrayTy) { + llvm::SmallVector noParams{}; + dParam = BuildCallExprToMemFn(dParam, "ptr", noParams); + } if (dParamType->isRecordType() && param == m_IndependentVar) { llvm::SmallVector ref(diffVarInfo.fields.begin(), diffVarInfo.fields.end()); @@ -992,7 +1013,6 @@ BaseForwardModeVisitor::VisitArraySubscriptExpr(const ArraySubscriptExpr* ASE) { VD = DRE->getDecl(); } if (VD == m_IndependentVar) { - llvm::APSInt index; Expr* diffExpr = nullptr; Expr::EvalResult res; Expr::SideEffectsKind AllowSideEffects = @@ -1017,12 +1037,10 @@ BaseForwardModeVisitor::VisitArraySubscriptExpr(const ArraySubscriptExpr* ASE) { return StmtDiff(cloned, zero); Expr* target = it->second; - // FIXME: fix when adding array inputs - if (!isArrayOrPointerType(target->getType())) - return StmtDiff(cloned, zero); - // llvm::APSInt IVal; - // if (!I->EvaluateAsInt(IVal, m_Context)) - // return; + // The size of array parameters is unknown + // so we need to always extend the adjoint size before accessing the element. + if (utils::isArrayOrPointerType(clonedBase->getType())) + EmitCladArrayExtend({clonedBase, target}, clonedIndices.back()); // Create the _result[idx] expression. auto result_at_is = BuildArraySubscript(target, clonedIndices); return StmtDiff(cloned, result_at_is); @@ -1372,10 +1390,13 @@ StmtDiff BaseForwardModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) { opKind == UnaryOperatorKind::UO_Imag) { return StmtDiff(op, BuildOp(opKind, diff.getExpr_dx())); } else if (opKind == UnaryOperatorKind::UO_Deref) { - if (Expr* dx = diff.getExpr_dx()) + Expr* zero = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, + /*val=*/0); + if (Expr* dx = diff.getExpr_dx()) { + EmitCladArrayExtend(diff, zero); return StmtDiff(op, BuildOp(opKind, dx)); - return StmtDiff(op, ConstantFolder::synthesizeLiteral( - m_Context.IntTy, m_Context, /*val=*/0)); + } + return StmtDiff(op, zero); } else if (opKind == UnaryOperatorKind::UO_AddrOf) { return StmtDiff(op, BuildOp(opKind, diff.getExpr_dx())); } else if (opKind == UnaryOperatorKind::UO_LNot) { diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 637b027cc..007375100 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1601,6 +1601,29 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return {call, callDiff}; } + // FIXME: There is nothing special about the `clad::array::extend` function + // but we need to be able to differentiate it to prevent segfaults in + // hessians. Once we add support for methods that change their objects, this + // section should be removed. + if (FDName == "extend") + if (auto* MCE = dyn_cast(CE->IgnoreImplicit())) { + const Expr* cladArr = + MCE->getImplicitObjectArgument()->IgnoreImplicit(); + if (isCladArrayType(cladArr->getType())) { + Expr* size = Clone(CE->getArg(0)); + llvm::SmallVector param{size}; + StmtDiff arrDiff = Visit(cladArr); + Expr* extendCall_dx = + BuildCallExprToMemFn(arrDiff.getExpr_dx(), "extend", param); + Expr* extendCall = + BuildCallExprToMemFn(arrDiff.getExpr(), "extend", param); + beginBlock(direction::forward); + addToCurrentBlock(extendCall_dx); + addToCurrentBlock(extendCall); + return endBlock(direction::forward); + } + } + auto NArgs = FD->getNumParams(); // If the function has no args and is not a member function call then we // assume that it is not related to independent variables and does not diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index 113e01e4d..6be6623be 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -822,11 +822,14 @@ namespace clad { derivedL = LDiff.getExpr_dx(); derivedR = RDiff.getExpr_dx(); if (utils::isArrayOrPointerType(LDiff.getExpr()->getType()) && - !utils::isArrayOrPointerType(RDiff.getExpr()->getType())) + !utils::isArrayOrPointerType(RDiff.getExpr()->getType())) { derivedR = RDiff.getExpr(); - else if (utils::isArrayOrPointerType(RDiff.getExpr()->getType()) && - !utils::isArrayOrPointerType(LDiff.getExpr()->getType())) + EmitCladArrayExtend(LDiff, derivedR); + } else if (utils::isArrayOrPointerType(RDiff.getExpr()->getType()) && + !utils::isArrayOrPointerType(LDiff.getExpr()->getType())) { derivedL = LDiff.getExpr(); + EmitCladArrayExtend(RDiff, derivedL); + } } Stmt* VisitorBase::GetCladZeroInit(llvm::MutableArrayRef args) { @@ -853,4 +856,41 @@ namespace clad { VisitorBase::GetCladConstructorPushforwardTagOfType(clang::QualType T) { return InstantiateTemplate(GetCladConstructorPushforwardTag(), {T}); } + + void VisitorBase::EmitCladArrayExtend(StmtDiff arr, Expr* idx) { + // FIXME: For now, only forward mode supports not differentiating w.r.t. + // array parameters. + if (m_DiffReq.Mode != DiffMode::forward) + return; + if (isa(arr.getExpr()->IgnoreImplicit())) + if (auto* MCE = + dyn_cast(arr.getExpr_dx()->IgnoreImplicit())) { + Expr* cladArr = MCE->getImplicitObjectArgument()->IgnoreImplicit(); + if (isCladArrayType(cladArr->getType()) && + MCE->getDirectCallee()->getNameAsString() == "ptr") { + Expr* size = nullptr; + Expr::EvalResult index; + // If it's possible to determine the index at compile time, generate + // the `extend` argument as a literal. This will help us avoid ugly + // code like + // ``` + // _d_arr.extend(0 + 1); + // _d_arr[0] = ...; + // ``` + if (idx->EvaluateAsInt(index, m_Context, + Expr::SideEffectsKind::SE_NoSideEffects)) { + size = ConstantFolder::synthesizeLiteral( + m_Context.IntTy, m_Context, + index.Val.getInt().getExtValue() + 1); + } else { + Expr* one = ConstantFolder::synthesizeLiteral(m_Context.IntTy, + m_Context, 1); + size = BuildOp(BO_Add, idx, one); + } + llvm::SmallVector param{size}; + Expr* extendCall = BuildCallExprToMemFn(cladArr, "extend", param); + addToCurrentBlock(extendCall); + } + } + } } // end namespace clad diff --git a/test/Arrays/ArrayInputsForwardMode.C b/test/Arrays/ArrayInputsForwardMode.C index 633b969aa..e39ea5b2b 100644 --- a/test/Arrays/ArrayInputsForwardMode.C +++ b/test/Arrays/ArrayInputsForwardMode.C @@ -59,6 +59,7 @@ double numMultIndex(double* arr, size_t n, double x) { } // CHECK: double numMultIndex_darg2(double *arr, size_t n, double x) { +// CHECK-NEXT: clad::array _d_arr = {}; // CHECK-NEXT: size_t _d_n = 0; // CHECK-NEXT: double _d_x = 1; // CHECK-NEXT: bool _d_flag = 0; diff --git a/test/FirstDerivative/CallArguments.C b/test/FirstDerivative/CallArguments.C index efe05909e..2932d1c8c 100644 --- a/test/FirstDerivative/CallArguments.C +++ b/test/FirstDerivative/CallArguments.C @@ -142,13 +142,16 @@ float f_literal_args_func(float x, float y, float *z) { printf("hello world "); return x * f_literal_helper(0.5, 'a', z, nullptr); } +// CHECK: clad::ValueAndPushforward f_literal_helper_pushforward(float x, char ch, float *p, float *q, float _d_x, char _d_ch, float *_d_p, float *_d_q); // CHECK: float f_literal_args_func_darg0(float x, float y, float *z) { // CHECK-NEXT: float _d_x = 1; // CHECK-NEXT: float _d_y = 0; +// CHECK-NEXT: clad::array _d_z = {}; // CHECK-NEXT: printf("hello world "); -// CHECK-NEXT: float _t0 = f_literal_helper(0.5, 'a', z, nullptr); -// CHECK-NEXT: return _d_x * _t0 + x * 0; +// CHECK-NEXT: clad::ValueAndPushforward _t0 = f_literal_helper_pushforward(0.5, 'a', z, nullptr, 0., 0, _d_z.ptr(), nullptr); +// CHECK-NEXT: float &_t1 = _t0.value; +// CHECK-NEXT: return _d_x * _t1 + x * _t0.pushforward; // CHECK-NEXT: } inline unsigned int getBin(double low, double high, double val, unsigned int numBins) { @@ -164,8 +167,11 @@ float f_call_inline_fxn(float *params, float const *obs, float const *xlArr) { // CHECK: inline clad::ValueAndPushforward getBin_pushforward(double low, double high, double val, unsigned int numBins, double _d_low, double _d_high, double _d_val, unsigned int _d_numBins); // CHECK: float f_call_inline_fxn_darg0_0(float *params, const float *obs, const float *xlArr) { +// CHECK-NEXT: clad::array _d_obs = {}; +// CHECK-NEXT: clad::array _d_xlArr = {}; // CHECK-NEXT: clad::ValueAndPushforward _t0 = getBin_pushforward(0., 1., params[0], 1, 0., 0., 1.F, 0); -// CHECK-NEXT: const float _d_t116 = 0; +// CHECK-NEXT: _d_xlArr.extend(_t0.value + 1); +// CHECK-NEXT: const float _d_t116 = *(_d_xlArr.ptr() + _t0.value); // CHECK-NEXT: const float t116 = *(xlArr + _t0.value); // CHECK-NEXT: return _d_t116 * params[0] + t116 * 1.F; // CHECK-NEXT: } @@ -216,6 +222,12 @@ int main () { // expected-no-diagnostics // CHECK-NEXT: return {x * x, _d_x * x + x * _d_x}; // CHECK-NEXT: } +// CHECK: clad::ValueAndPushforward f_literal_helper_pushforward(float x, char ch, float *p, float *q, float _d_x, char _d_ch, float *_d_p, float *_d_q) { +// CHECK-NEXT: if (ch == 'a') +// CHECK-NEXT: return {x * x, _d_x * x + x * _d_x}; +// CHECK-NEXT: return {-x * x, -_d_x * x + -x * _d_x}; +// CHECK-NEXT: } + // CHECK: inline clad::ValueAndPushforward getBin_pushforward(double low, double high, double val, unsigned int numBins, double _d_low, double _d_high, double _d_val, unsigned int _d_numBins) { // CHECK-NEXT: double _t0 = (high - low); // CHECK-NEXT: double _d_binWidth = ((_d_high - _d_low) * numBins - _t0 * _d_numBins) / (numBins * numBins); diff --git a/test/ForwardMode/Pointer.C b/test/ForwardMode/Pointer.C index 7d98cd2f8..3afa50d2c 100644 --- a/test/ForwardMode/Pointer.C +++ b/test/ForwardMode/Pointer.C @@ -198,7 +198,9 @@ double fn9(double* params, const double *constants) { } // CHECK: double fn9_darg0_0(double *params, const double *constants) { -// CHECK-NEXT: double _d_c0 = 0; +// CHECK-NEXT: clad::array _d_constants = {}; +// CHECK-NEXT: _d_constants.extend(1); +// CHECK-NEXT: double _d_c0 = *_d_constants.ptr(); // CHECK-NEXT: double c0 = *constants; // CHECK-NEXT: return 1. * c0 + params[0] * _d_c0; // CHECK-NEXT: } @@ -209,7 +211,9 @@ double fn10(double *params, const double *constants) { } // CHECK: double fn10_darg0_0(double *params, const double *constants) { -// CHECK-NEXT: double _d_c0 = 0; +// CHECK-NEXT: clad::array _d_constants = {}; +// CHECK-NEXT: _d_constants.extend(1); +// CHECK-NEXT: double _d_c0 = *(_d_constants.ptr() + 0); // CHECK-NEXT: double c0 = *(constants + 0); // CHECK-NEXT: return 1. * c0 + params[0] * _d_c0; // CHECK-NEXT: } diff --git a/test/ROOT/TFormula.C b/test/ROOT/TFormula.C index 6d44fc921..d734599dd 100644 --- a/test/ROOT/TFormula.C +++ b/test/ROOT/TFormula.C @@ -55,24 +55,30 @@ void TFormula_example_grad_1(Double_t* x, Double_t* p, Double_t* _d_p); //CHECK-NEXT: } //CHECK: Double_t TFormula_example_darg1_0(Double_t *x, Double_t *p) { +//CHECK-NEXT: clad::array _d_x = {}; +//CHECK-NEXT: _d_x.extend(1); //CHECK-NEXT: {{double|Double_t}} _t0 = (p[0] + p[1] + p[2]); //CHECK-NEXT: clad::ValueAndPushforward _t1 = clad::custom_derivatives{{(::std)?}}::TMath::Exp_pushforward(-p[0], -1.); //CHECK-NEXT: clad::ValueAndPushforward _t2 = clad::custom_derivatives{{(::std)?}}::TMath::Abs_pushforward(p[1], 0); -//CHECK-NEXT: return 0 * _t0 + x[0] * (1. + 0 + 0) + _t1.pushforward + _t2.pushforward; +//CHECK-NEXT: return _d_x.ptr()[0] * _t0 + x[0] * (1. + 0 + 0) + _t1.pushforward + _t2.pushforward; //CHECK-NEXT: } //CHECK: Double_t TFormula_example_darg1_1(Double_t *x, Double_t *p) { +//CHECK-NEXT: clad::array _d_x = {}; +//CHECK-NEXT: _d_x.extend(1); //CHECK-NEXT: {{double|Double_t}} _t0 = (p[0] + p[1] + p[2]); //CHECK-NEXT: clad::ValueAndPushforward _t1 = clad::custom_derivatives{{(::std)?}}::TMath::Exp_pushforward(-p[0], -0); //CHECK-NEXT: clad::ValueAndPushforward _t2 = clad::custom_derivatives{{(::std)?}}::TMath::Abs_pushforward(p[1], 1.); -//CHECK-NEXT: return 0 * _t0 + x[0] * (0 + 1. + 0) + _t1.pushforward + _t2.pushforward; +//CHECK-NEXT: return _d_x.ptr()[0] * _t0 + x[0] * (0 + 1. + 0) + _t1.pushforward + _t2.pushforward; //CHECK-NEXT: } //CHECK: Double_t TFormula_example_darg1_2(Double_t *x, Double_t *p) { +//CHECK-NEXT: clad::array _d_x = {}; +//CHECK-NEXT: _d_x.extend(1); //CHECK-NEXT: {{double|Double_t}} _t0 = (p[0] + p[1] + p[2]); //CHECK-NEXT: clad::ValueAndPushforward _t1 = clad::custom_derivatives{{(::std)?}}::TMath::Exp_pushforward(-p[0], -0); //CHECK-NEXT: clad::ValueAndPushforward _t2 = clad::custom_derivatives{{(::std)?}}::TMath::Abs_pushforward(p[1], 0); -//CHECK-NEXT: return 0 * _t0 + x[0] * (0 + 0 + 1.) + _t1.pushforward + _t2.pushforward; +//CHECK-NEXT: return _d_x.ptr()[0] * _t0 + x[0] * (0 + 0 + 1.) + _t1.pushforward + _t2.pushforward; //CHECK-NEXT: } Double_t TFormula_hess1(Double_t *x, Double_t *p) {