Skip to content

Commit

Permalink
Remove array_ref from differentiation with Enzyme
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi authored and vgvassilev committed Mar 21, 2024
1 parent 4eea1b7 commit 761dbfc
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 164 deletions.
53 changes: 16 additions & 37 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// Account for the this pointer.
if (isa<CXXMethodDecl>(m_Function) && !utils::IsStaticMethod(m_Function))
++numOfDerivativeParams;
// All output parameters will be of type `clad::array_ref<void>`. These
// All output parameters will be of type `void*`. These
// parameters will be casted to correct type before the call to the actual
// derived function.
// We require each output parameter to be of same type in the overloaded
Expand Down Expand Up @@ -635,29 +635,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
const auto* originalFnType =
dyn_cast<FunctionProtoType>(m_Function->getType());

// Extract Pointer from Clad Array Ref
llvm::SmallVector<VarDecl*, 8> cladRefParams;
for (unsigned i = 0; i < numParams; i++) {
QualType paramType = origParams[i]->getOriginalType();
if (paramType->isRealType()) {
cladRefParams.push_back(nullptr);
continue;
}

paramType = m_Context.getPointerType(
QualType(paramType->getPointeeOrArrayElementType(), 0));
auto* arrayRefNameExpr = BuildDeclRef(paramsRef[numParams + i]);
auto* getPointerExpr = BuildCallExprToMemFn(arrayRefNameExpr, "ptr", {});
auto* arrayRefToArrayStmt = BuildVarDecl(
paramType, "d_" + paramsRef[i]->getNameAsString(), getPointerExpr);
addToCurrentBlock(BuildDeclStmt(arrayRefToArrayStmt), direction::forward);
cladRefParams.push_back(arrayRefToArrayStmt);
}
// Prepare Arguments and Parameters to enzyme_autodiff
llvm::SmallVector<Expr*, 16> enzymeArgs;
llvm::SmallVector<ParmVarDecl*, 16> enzymeParams;
llvm::SmallVector<ParmVarDecl*, 16> enzymeRealParams;
llvm::SmallVector<ParmVarDecl*, 16> enzymeRealParamsRef;
llvm::SmallVector<ParmVarDecl*, 16> enzymeRealParamsDerived;

// First add the function itself as a parameter/argument
enzymeArgs.push_back(BuildDeclRef(const_cast<FunctionDecl*>(m_Function)));
Expand All @@ -673,23 +655,19 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
enzymeParams.push_back(m_Sema.BuildParmVarDeclForTypedef(
fdDeclContext, noLoc, paramsRef[i]->getType()));

// If the original parameter is not of array/pointer type, then we don't
// have to extract its pointer from clad array_ref and add it to the
// enzyme parameters, so we can skip the rest of the code
if (!cladRefParams[i]) {
// If original parameter is of a differentiable real type(but not
// array/pointer), then add it to the list of params whose gradient must
// be extracted later from the EnzymeGradient structure
if (paramsRef[i]->getOriginalType()->isRealFloatingType()) {
enzymeRealParams.push_back(paramsRef[i]);
enzymeRealParamsRef.push_back(paramsRef[numParams + i]);
}
continue;
QualType paramType = origParams[i]->getOriginalType();
// If original parameter is of a differentiable real type(but not
// array/pointer), then add it to the list of params whose gradient must
// be extracted later from the EnzymeGradient structure
if (paramType->isRealFloatingType()) {
enzymeRealParams.push_back(paramsRef[i]);
enzymeRealParamsDerived.push_back(paramsRef[numParams + i]);
} else if (utils::isArrayOrPointerType(paramType)) {
// Add the corresponding array/pointer variable
enzymeArgs.push_back(BuildDeclRef(paramsRef[numParams + i]));
enzymeParams.push_back(m_Sema.BuildParmVarDeclForTypedef(
fdDeclContext, noLoc, paramsRef[numParams + i]->getType()));
}
// Then add the corresponding clad array ref pointer variable
enzymeArgs.push_back(BuildDeclRef(cladRefParams[i]));
enzymeParams.push_back(m_Sema.BuildParmVarDeclForTypedef(
fdDeclContext, noLoc, cladRefParams[i]->getType()));
}

llvm::SmallVector<QualType, 16> enzymeParamsType;
Expand Down Expand Up @@ -732,7 +710,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
addToCurrentBlock(BuildDeclStmt(gradDeclStmt), direction::forward);

for (unsigned i = 0; i < enzymeRealParams.size(); i++) {
auto* LHSExpr = BuildOp(UO_Deref, BuildDeclRef(enzymeRealParamsRef[i]));
auto* LHSExpr =
BuildOp(UO_Deref, BuildDeclRef(enzymeRealParamsDerived[i]));

auto* ME = utils::BuildMemberExpr(m_Sema, getCurrentScope(),
BuildDeclRef(gradDeclStmt), "d_arr");
Expand Down
2 changes: 1 addition & 1 deletion test/CUDA/GradientCuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ __device__ __host__ double gauss(double* x, double* p, double sigma, int dim) {
}


// CHECK: void gauss_grad_1(double *x, double *p, double sigma, int dim, clad::array_ref<double> _d_p) __attribute__((device)) __attribute__((host)) {
// CHECK: void gauss_grad_1(double *x, double *p, double sigma, int dim, double *_d_p) __attribute__((device)) __attribute__((host)) {
//CHECK-NEXT: double _d_sigma = 0;
//CHECK-NEXT: int _d_dim = 0;
//CHECK-NEXT: double _d_t = 0;
Expand Down
12 changes: 6 additions & 6 deletions test/Enzyme/DifferentCladEnzymeDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,19 @@ double foo(double x, double y){
return x*y;
}

// CHECK: void foo_grad(double x, double y, clad::array_ref<double> _d_x, clad::array_ref<double> _d_y) {
// CHECK: void foo_grad(double x, double y, double *_d_x, double *_d_y) {
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: * _d_x += 1 * y;
// CHECK-NEXT: * _d_y += x * 1;
// CHECK-NEXT: *_d_x += 1 * y;
// CHECK-NEXT: *_d_y += x * 1;
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void foo_grad_enzyme(double x, double y, clad::array_ref<double> _d_x, clad::array_ref<double> _d_y) {
// CHECK: void foo_grad_enzyme(double x, double y, double *_d_x, double *_d_y) {
// CHECK-NEXT: clad::EnzymeGradient<2> grad = __enzyme_autodiff_foo(foo, x, y);
// CHECK-NEXT: * _d_x = grad.d_arr[0U];
// CHECK-NEXT: * _d_y = grad.d_arr[1U];
// CHECK-NEXT: *_d_x = grad.d_arr[0U];
// CHECK-NEXT: *_d_y = grad.d_arr[1U];
// CHECK-NEXT: }

int main(){
Expand Down
35 changes: 15 additions & 20 deletions test/Enzyme/FunctionPrototypesReverseMode.C
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,19 @@

double f1(double* arr) { return arr[0] * arr[1]; }

// CHECK: void f1_grad_enzyme(double *arr, clad::array_ref<double> _d_arr) {
// CHECK-NEXT: double *d_arr = _d_arr.ptr();
// CHECK-NEXT: __enzyme_autodiff_f1(f1, arr, d_arr);
// CHECK: void f1_grad_enzyme(double *arr, double *_d_arr) {
// CHECK-NEXT: __enzyme_autodiff_f1(f1, arr, _d_arr);
// CHECK-NEXT:}

double f2(double x, double y, double z){
return x * y * z;
}

// CHECK: void f2_grad_enzyme(double x, double y, double z, clad::array_ref<double> _d_x, clad::array_ref<double> _d_y, clad::array_ref<double> _d_z) {
// CHECK: void f2_grad_enzyme(double x, double y, double z, double *_d_x, double *_d_y, double *_d_z) {
// CHECK-NEXT: clad::EnzymeGradient<3> grad = __enzyme_autodiff_f2(f2, x, y, z);
// CHECK-NEXT: * _d_x = grad.d_arr[0U];
// CHECK-NEXT: * _d_y = grad.d_arr[1U];
// CHECK-NEXT: * _d_z = grad.d_arr[2U];
// CHECK-NEXT: *_d_x = grad.d_arr[0U];
// CHECK-NEXT: *_d_y = grad.d_arr[1U];
// CHECK-NEXT: *_d_z = grad.d_arr[2U];
// CHECK-NEXT:}

double f3(double* arr, int n){
Expand All @@ -31,12 +30,11 @@ double f3(double* arr, int n){
return sum;
}

// CHECK: void f3_grad_enzyme(double *arr, int n, clad::array_ref<double> _d_arr, clad::array_ref<int> _d_n) {
// CHECK-NEXT: double *d_arr = _d_arr.ptr();
// CHECK-NEXT: __enzyme_autodiff_f3(f3, arr, d_arr, n);
// CHECK: void f3_grad_enzyme(double *arr, int n, double *_d_arr, int *_d_n) {
// CHECK-NEXT: __enzyme_autodiff_f3(f3, arr, _d_arr, n);
// CHECK-NEXT: }

double f4(double* arr1, int n, double*arr2, int m){
double f4(double* arr1, int n, double* arr2, int m){
double sum=0;
for(int i=0;i<n;i++){
sum+=arr1[i]*arr1[i];
Expand All @@ -47,10 +45,8 @@ double f4(double* arr1, int n, double*arr2, int m){
return sum;
}

// CHECK: void f4_grad_enzyme(double *arr1, int n, double *arr2, int m, clad::array_ref<double> _d_arr1, clad::array_ref<int> _d_n, clad::array_ref<double> _d_arr2, clad::array_ref<int> _d_m) {
// CHECK-NEXT: double *d_arr1 = _d_arr1.ptr();
// CHECK-NEXT: double *d_arr2 = _d_arr2.ptr();
// CHECK-NEXT: __enzyme_autodiff_f4(f4, arr1, d_arr1, n, arr2, d_arr2, m);
// CHECK: void f4_grad_enzyme(double *arr1, int n, double *arr2, int m, double *_d_arr1, int *_d_n, double *_d_arr2, int *_d_m) {
// CHECK-NEXT: __enzyme_autodiff_f4(f4, arr1, _d_arr1, n, arr2, _d_arr2, m);
// CHECK-NEXT: }

double f5(double arr[], double x,int n,double y){
Expand All @@ -61,11 +57,10 @@ double f5(double arr[], double x,int n,double y){
return res;
}

// CHECK: void f5_grad_enzyme(double arr[], double x, int n, double y, clad::array_ref<double> _d_arr, clad::array_ref<double> _d_x, clad::array_ref<int> _d_n, clad::array_ref<double> _d_y) {
// CHECK-NEXT: double *d_arr = _d_arr.ptr();
// CHECK-NEXT: clad::EnzymeGradient<2> grad = __enzyme_autodiff_f5(f5, arr, d_arr, x, n, y);
// CHECK-NEXT: * _d_x = grad.d_arr[0U];
// CHECK-NEXT: * _d_y = grad.d_arr[1U];
// CHECK: void f5_grad_enzyme(double arr[], double x, int n, double y, double *_d_arr, double *_d_x, int *_d_n, double *_d_y) {
// CHECK-NEXT: clad::EnzymeGradient<2> grad = __enzyme_autodiff_f5(f5, arr, _d_arr, x, n, y);
// CHECK-NEXT: *_d_x = grad.d_arr[0U];
// CHECK-NEXT: *_d_y = grad.d_arr[1U];
// CHECK-NEXT: }


Expand Down
Loading

0 comments on commit 761dbfc

Please sign in to comment.