diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index aa5b69cf5..aafa32e47 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -25,6 +25,9 @@ namespace clad { class ExternalRMVSource; class MultiplexExternalRMVSource; + using VectorOutputString = + std::vector>; + /// A visitor for processing the function code in reverse mode. /// Used to compute derivatives by clad::gradient. class ReverseModeVisitor @@ -38,6 +41,10 @@ namespace clad { // several private/protected members of the visitor classes. friend class ErrorEstimationHandler; llvm::SmallVector m_IndependentVars; + llvm::SmallVector m_IndependentVarsSize; + std::unordered_map m_ExprVariables; + VectorOutputString m_VectorOutputString; + /// In addition to a sequence of forward-accumulated Stmts (m_Blocks), in /// the reverse mode we also accumulate Stmts for the reverse pass which /// will be executed on return. @@ -62,6 +69,7 @@ namespace clad { std::vector m_LoopBlock; unsigned outputArrayCursor = 0; unsigned numParams = 0; + unsigned numActualParams = 0; bool isVectorValued = false; bool use_enzyme = false; // FIXME: Should we make this an object instead of a pointer? diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index d4b0de93e..15acca423 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -373,28 +373,54 @@ namespace clad { // Creates the ArraySubscriptExprs for the independent variables size_t idx = 0; + auto size_type = m_Context.getSizeType(); + unsigned size_type_bits = m_Context.getIntWidth(size_type); for (auto arg : args) { // FIXME: fix when adding array inputs, now we are just skipping all // array/pointer inputs (not treating them as independent variables). - if (utils::isArrayOrPointerType(arg->getType())) { + if (utils::isArrayOrPointerType(arg->getType())) { //is a array or pointer type parameter if (arg->getName() == "p") m_Variables[arg] = m_Result; + + ParmVarDecl* parg=dyn_cast(const_cast(arg)); + QualType qt = parg->getOriginalType(); + assert(qt->isConstantArrayType() && "Only Constant type arrays are allowed to be parameters of function whose jacobian is to be found. Non Constant types and Pointer type are not supported"); + ConstantArrayType* t = dyn_cast(const_cast(qt.getTypePtr())); + int sizeOfArray = (int)(t->getSize().roundToDouble(false)); + m_IndependentVars.push_back(arg); + m_IndependentVarsSize.push_back(sizeOfArray); + + for(int j=0;jgetNameAsString()+"["+std::to_string(j)+"]"; + // Create the idx literal. + auto i = + IntegerLiteral::Create(m_Context, llvm::APInt(size_type_bits, idx), + size_type, noLoc); + // Create the jacobianMatrix[idx] expression. + auto result_at_i = + m_Sema.CreateBuiltinArraySubscriptExpr(m_Result, noLoc, i, noLoc) + .get(); + m_ExprVariables[name]=result_at_i; + idx+=1; + numActualParams++; + } + }else{ //is normal variable parameter + // Create the idx literal. + auto i = + IntegerLiteral::Create(m_Context, llvm::APInt(size_type_bits, idx), + size_type, noLoc); + // Create the jacobianMatrix[idx] expression. + auto result_at_i = + m_Sema.CreateBuiltinArraySubscriptExpr(m_Result, noLoc, i, noLoc) + .get(); + m_Variables[arg] = result_at_i; + m_ExprVariables[arg->getNameAsString()]=result_at_i; idx += 1; - continue; + numActualParams++; + m_IndependentVars.push_back(arg); + m_IndependentVarsSize.push_back(1); } - auto size_type = m_Context.getSizeType(); - unsigned size_type_bits = m_Context.getIntWidth(size_type); - // Create the idx literal. - auto i = - IntegerLiteral::Create(m_Context, llvm::APInt(size_type_bits, idx), - size_type, noLoc); - // Create the jacobianMatrix[idx] expression. - auto result_at_i = - m_Sema.CreateBuiltinArraySubscriptExpr(m_Result, noLoc, i, noLoc) - .get(); - m_Variables[arg] = result_at_i; - idx += 1; - m_IndependentVars.push_back(arg); + } } @@ -1067,7 +1093,44 @@ namespace clad { auto ASI = SplitArraySubscript(ASE); const Expr* Base = ASI.first; const auto& Indices = ASI.second; - StmtDiff BaseDiff = Visit(Base); + StmtDiff BaseDiff; + + //Check is we are visiting an Independent variable expressed as an array subscript expression in Jacobian Mode + + if(isVectorValued){ + if(auto dyn = dyn_cast(dyn_cast(Base)->getDecl())){ + //Check if this an independent variable + for(auto i:m_IndependentVars){ + if(dyn->getNameAsString()==i->getNameAsString()){ + llvm::APSInt intIdx; + auto isIdxValid = + clad_compat::Expr_EvaluateAsInt(ASE->getIdx(), intIdx, m_Context); + + // FIXME: We assume that inside the index is just an Integer Expression + // and not a Complex expression + assert(isIdxValid && "Only Integer Literals allowed as array indices of Independent Variables"); + + int index = intIdx.getExtValue(); + std::string indVarName = dyn->getNameAsString()+"["+std::to_string(index)+"]"; + auto it = m_VectorOutputString[outputArrayCursor].find(indVarName); + + // Create the (jacobianMatrix[idx] += dfdx) statement. + if (dfdx()) { + auto add_assign = BuildOp(BO_AddAssign, it->second, dfdx()); + // Add it to the body statements. + addToCurrentBlock(add_assign, direction::reverse); + } + break; + } + } + BaseDiff = StmtDiff(dyn_cast(Clone(Base))); + }else{ + BaseDiff = Visit(Base); + } + }else{ + BaseDiff = Visit(Base); + } + llvm::SmallVector clonedIndices(Indices.size()); llvm::SmallVector reverseIndices(Indices.size()); llvm::SmallVector forwSweepDerivativeIndices(Indices.size()); @@ -1918,21 +1981,44 @@ namespace clad { std::unordered_map temp_m_Variables; - for (unsigned i = 0; i < numParams; i++) { - auto size_type = m_Context.getSizeType(); - unsigned size_type_bits = m_Context.getIntWidth(size_type); - llvm::APInt idxValue(size_type_bits, - i + (outputArrayCursor * numParams)); - auto idx = IntegerLiteral::Create(m_Context, idxValue, - size_type, noLoc); - // Create the jacobianMatrix[idx] expression. - auto result_at_i = m_Sema - .CreateBuiltinArraySubscriptExpr( - m_Result, noLoc, idx, noLoc) - .get(); - temp_m_Variables[m_IndependentVars[i]] = result_at_i; + std::unordered_map temp_m_VariablesStr; + auto size_type = m_Context.getSizeType(); + unsigned size_type_bits = m_Context.getIntWidth(size_type); + for (unsigned i = 0, j=0; i < numParams; i++) { + auto arg = m_IndependentVars[i]; + ParmVarDecl* parg=dyn_cast(const_cast(arg)); + if(parg->getOriginalType()->isConstantArrayType()){ + int arrSize = m_IndependentVarsSize[i]; + for(int k=0;kgetNameAsString()+"["+std::to_string(k)+"]"; + temp_m_VariablesStr[sName]=result_at_i; + } + }else{ + llvm::APInt idxValue(size_type_bits, + j + (outputArrayCursor * numActualParams)); + auto idx = IntegerLiteral::Create(m_Context, idxValue, + size_type, noLoc); + // Create the jacobianMatrix[idx] expression. + auto result_at_i = m_Sema + .CreateBuiltinArraySubscriptExpr( + m_Result, noLoc, idx, noLoc) + .get(); + temp_m_Variables[m_IndependentVars[i]] = result_at_i; + temp_m_VariablesStr[arg->getNameAsString()]=result_at_i; + j++; + } + } m_VectorOutput.push_back(temp_m_Variables); + m_VectorOutputString.push_back(temp_m_VariablesStr); } auto dfdf = ConstantFolder::synthesizeLiteral(m_Context.IntTy, diff --git a/test/Jacobian/Jacobian.C b/test/Jacobian/Jacobian.C index a267562a3..4da0dff4e 100644 --- a/test/Jacobian/Jacobian.C +++ b/test/Jacobian/Jacobian.C @@ -304,6 +304,118 @@ void f_1_jac_0(double a, double b, double c, double output[], double *jacobianMa // CHECK-NEXT: } // CHECK-NEXT:} +void f_5(float a[3], float output[]){ + output[0]=a[0]*a[1]; + output[1]=a[1]*a[2]; + output[2]=a[0]*a[2]; +} +void f_5_jac(float a[3], float output[], float *jacobianMatrix); +// CHECK: void f_5_jac(float a[3], float output[], float *jacobianMatrix) { +// CHECK-NEXT: float _t0; +// CHECK-NEXT: float _t1; +// CHECK-NEXT: float _t2; +// CHECK-NEXT: float _t3; +// CHECK-NEXT: float _t4; +// CHECK-NEXT: float _t5; +// CHECK-NEXT: _t1 = a[0]; +// CHECK-NEXT: _t0 = a[1]; +// CHECK-NEXT: output[0] = a[0] * a[1]; +// CHECK-NEXT: _t3 = a[1]; +// CHECK-NEXT: _t2 = a[2]; +// CHECK-NEXT: output[1] = a[1] * a[2]; +// CHECK-NEXT: _t5 = a[0]; +// CHECK-NEXT: _t4 = a[2]; +// CHECK-NEXT: output[2] = a[0] * a[2]; +// CHECK-NEXT: { +// CHECK-NEXT: float _r4 = 1 * _t4; +// CHECK-NEXT: jacobianMatrix[6UL] += _r4; +// CHECK-NEXT: float _r5 = _t5 * 1; +// CHECK-NEXT: jacobianMatrix[8UL] += _r5; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: float _r2 = 1 * _t2; +// CHECK-NEXT: jacobianMatrix[4UL] += _r2; +// CHECK-NEXT: float _r3 = _t3 * 1; +// CHECK-NEXT: jacobianMatrix[5UL] += _r3; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: float _r0 = 1 * _t0; +// CHECK-NEXT: jacobianMatrix[0UL] += _r0; +// CHECK-NEXT: float _r1 = _t1 * 1; +// CHECK-NEXT: jacobianMatrix[1UL] += _r1; +// CHECK-NEXT: } +// CHECK-NEXT: } + +void f_6(float a[1], float b, float output[]) { + output[0] = a[0] * a[0] * a[0]; + output[1] = a[0] * a[0] * a[0] + b * b * b; + output[2] = 2 * (a[0] + b); +} +void f_6_jac(float a[1], float b, float output[], float *jacobianMatrix); +// CHECK: void f_6_jac(float a[1], float b, float output[], float *jacobianMatrix) { +// CHECK-NEXT: float _t0; +// CHECK-NEXT: float _t1; +// CHECK-NEXT: float _t2; +// CHECK-NEXT: float _t3; +// CHECK-NEXT: float _t4; +// CHECK-NEXT: float _t5; +// CHECK-NEXT: float _t6; +// CHECK-NEXT: float _t7; +// CHECK-NEXT: float _t8; +// CHECK-NEXT: float _t9; +// CHECK-NEXT: float _t10; +// CHECK-NEXT: float _t11; +// CHECK-NEXT: float _t12; +// CHECK-NEXT: _t2 = a[0]; +// CHECK-NEXT: _t1 = a[0]; +// CHECK-NEXT: _t3 = _t2 * _t1; +// CHECK-NEXT: _t0 = a[0]; +// CHECK-NEXT: output[0] = a[0] * a[0] * a[0]; +// CHECK-NEXT: _t6 = a[0]; +// CHECK-NEXT: _t5 = a[0]; +// CHECK-NEXT: _t7 = _t6 * _t5; +// CHECK-NEXT: _t4 = a[0]; +// CHECK-NEXT: _t10 = b; +// CHECK-NEXT: _t9 = b; +// CHECK-NEXT: _t11 = _t10 * _t9; +// CHECK-NEXT: _t8 = b; +// CHECK-NEXT: output[1] = a[0] * a[0] * a[0] + b * b * b; +// CHECK-NEXT: _t12 = (a[0] + b); +// CHECK-NEXT: output[2] = 2 * (a[0] + b); +// CHECK-NEXT: { +// CHECK-NEXT: float _r12 = 1 * _t12; +// CHECK-NEXT: float _r13 = 2 * 1; +// CHECK-NEXT: jacobianMatrix[4UL] += _r13; +// CHECK-NEXT: jacobianMatrix[5UL] += _r13; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: float _r4 = 1 * _t4; +// CHECK-NEXT: float _r5 = _r4 * _t5; +// CHECK-NEXT: jacobianMatrix[2UL] += _r5; +// CHECK-NEXT: float _r6 = _t6 * _r4; +// CHECK-NEXT: jacobianMatrix[2UL] += _r6; +// CHECK-NEXT: float _r7 = _t7 * 1; +// CHECK-NEXT: jacobianMatrix[2UL] += _r7; +// CHECK-NEXT: float _r8 = 1 * _t8; +// CHECK-NEXT: float _r9 = _r8 * _t9; +// CHECK-NEXT: jacobianMatrix[3UL] += _r9; +// CHECK-NEXT: float _r10 = _t10 * _r8; +// CHECK-NEXT: jacobianMatrix[3UL] += _r10; +// CHECK-NEXT: float _r11 = _t11 * 1; +// CHECK-NEXT: jacobianMatrix[3UL] += _r11; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: float _r0 = 1 * _t0; +// CHECK-NEXT: float _r1 = _r0 * _t1; +// CHECK-NEXT: jacobianMatrix[0UL] += _r1; +// CHECK-NEXT: float _r2 = _t2 * _r0; +// CHECK-NEXT: jacobianMatrix[0UL] += _r2; +// CHECK-NEXT: float _r3 = _t3 * 1; +// CHECK-NEXT: jacobianMatrix[0UL] += _r3; +// CHECK-NEXT: } +// CHECK-NEXT: } + + #define TEST(F, x, y, z) { \ result[0] = 0; result[1] = 0; result[2] = 0;\ result[3] = 0; result[4] = 0; result[5] = 0;\ @@ -335,4 +447,21 @@ int main() { TEST(f_3, 1, 2, 3); // CHECK-EXEC: Result is = {22.69, 0.00, 0.00, 0.00, -17.48, 0.00, 0.00, 0.00, -41.58} TEST(f_4, 1, 2, 3); // CHECK-EXEC: Result is = {84.00, 42.00, 0.00, 0.00, 126.00, 84.00, 126.00, 0.00, 42.00} TEST_F_1_SINGLE_PARAM(1, 2, 3); // CHECK-EXEC: Result is = {3.00, 3.00, -2.00} + + + auto d_f_5 = clad::jacobian(f_5); + float a5[3]={3,4,5}; + float op5[3]={0}; + float jc5[9]={0}; + d_f_5.execute(a5,op5,jc5); + printf("Result is = {%.2f, %.2f, %.2f, %.2f, %.2f, %.2f, %.2f, %.2f, %.2f}\n", jc5[0],jc5[1],jc5[2],jc5[3],jc5[4],jc5[5],jc5[6],jc5[7],jc5[8]); + //CHECK-EXEC: Result is = {4.00, 3.00, 0.00, 0.00, 5.00, 4.00, 5.00, 0.00, 3.00} + + auto d_f_6 = clad::jacobian(f_6); + float a6[1]={3};float b6=5; + float op6[3]={0}; + float jc6[6]={0}; + d_f_6.execute(a6,b6,op6,jc6); + printf("Result is = {%.2f, %.2f, %.2f, %.2f, %.2f, %.2f}\n", jc6[0],jc6[1],jc6[2],jc6[3],jc6[4],jc6[5]); + //CHECK-EXEC: Result is = {27.00, 0.00, 27.00, 75.00, 2.00, 2.00} }