Skip to content

Commit

Permalink
[MLIR] Handle multiple return activities (#1851)
Browse files Browse the repository at this point in the history
* Handle multiple return activities

* now fully functioning
  • Loading branch information
wsmoses authored Apr 28, 2024
1 parent fb3e5d6 commit 69d8a1c
Show file tree
Hide file tree
Showing 27 changed files with 197 additions and 52 deletions.
2 changes: 1 addition & 1 deletion enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1070,7 +1070,7 @@ void enzyme::runDataFlowActivityAnalysis(
llvm::zip(callee.getArguments(), argumentActivity)) {
// enzyme_dup, dupnoneed are initialized within the dense forward/backward
// analyses, enzyme_const is the default.
if (activity == enzyme::Activity::enzyme_out) {
if (activity == enzyme::Activity::enzyme_active) {
auto *argLattice = solver.getOrCreateState<ForwardValueActivity>(arg);
(void)argLattice->join(ValueActivity::getActiveVal());
}
Expand Down
8 changes: 5 additions & 3 deletions enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
def Activity : I32EnumAttr<"Activity",
"Possible activity states for variables",
[
I32EnumAttrCase<"enzyme_out", 0>,
I32EnumAttrCase<"enzyme_active", 0>,
I32EnumAttrCase<"enzyme_dup", 1>,
I32EnumAttrCase<"enzyme_const",2>,
I32EnumAttrCase<"enzyme_dupnoneed", 3>,
I32EnumAttrCase<"enzyme_activenoneed", 4>,
I32EnumAttrCase<"enzyme_constnoneed", 5>,
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::enzyme";
Expand All @@ -52,7 +54,7 @@ def PlaceholderOp : Enzyme_Op<"placeholder",
def ForwardDiffOp : Enzyme_Op<"fwddiff",
[DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = "Perform forward mode AD on a funcop";
let arguments = (ins FlatSymbolRefAttr:$fn, Variadic<AnyType>:$inputs, ActivityArrayAttr:$activity);
let arguments = (ins FlatSymbolRefAttr:$fn, Variadic<AnyType>:$inputs, ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity);
let results = (outs Variadic<AnyType>:$outputs);

let assemblyFormat = [{
Expand All @@ -63,7 +65,7 @@ def ForwardDiffOp : Enzyme_Op<"fwddiff",
def AutoDiffOp : Enzyme_Op<"autodiff",
[DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = "Perform reverse mode AD on a funcop";
let arguments = (ins FlatSymbolRefAttr:$fn, Variadic<AnyType>:$inputs, ActivityArrayAttr:$activity);
let arguments = (ins FlatSymbolRefAttr:$fn, Variadic<AnyType>:$inputs, ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity);
let results = (outs Variadic<AnyType>:$outputs);

let assemblyFormat = [{
Expand Down
6 changes: 6 additions & 0 deletions enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,12 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff(

auto buildFuncReturnOp = [&](OpBuilder &builder, Block *oBB) {
SmallVector<mlir::Value> retargs;
for (auto [arg, returnPrimal] :
llvm::zip(oBB->getArguments(), returnPrimals)) {
if (returnPrimal) {
retargs.push_back(gutils->getNewFromOriginal(arg));
}
}
for (auto [arg, cv] : llvm::zip(oBB->getArguments(), constants)) {
if (cv == DIFFE_TYPE::OUT_DIFF) {
retargs.push_back(gutils->diffe(arg, builder));
Expand Down
142 changes: 121 additions & 21 deletions enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,30 @@ struct DifferentiatePass : public DifferentiatePassBase<DifferentiatePass> {

auto mop = activityAttr[truei];
auto iattr = cast<mlir::enzyme::ActivityAttr>(mop);
DIFFE_TYPE ty = (DIFFE_TYPE)(iattr.getValue());
DIFFE_TYPE ty;

switch (iattr.getValue()) {
case mlir::enzyme::Activity::enzyme_active:
ty = DIFFE_TYPE::OUT_DIFF;
break;
case mlir::enzyme::Activity::enzyme_dup:
ty = DIFFE_TYPE::DUP_ARG;
break;
case mlir::enzyme::Activity::enzyme_const:
ty = DIFFE_TYPE::CONSTANT;
break;
case mlir::enzyme::Activity::enzyme_dupnoneed:
ty = DIFFE_TYPE::DUP_NONEED;
break;
case mlir::enzyme::Activity::enzyme_activenoneed:
ty = DIFFE_TYPE::OUT_DIFF;
assert(0 && "unsupported arg activenoneed");
break;
case mlir::enzyme::Activity::enzyme_constnoneed:
ty = DIFFE_TYPE::CONSTANT;
assert(0 && "unsupported arg constnoneed");
break;
}

constants.push_back(ty);
args.push_back(res);
Expand All @@ -78,7 +101,40 @@ struct DifferentiatePass : public DifferentiatePassBase<DifferentiatePass> {
auto fn = cast<FunctionOpInterface>(symbolOp);

auto mode = DerivativeMode::ForwardMode;
std::vector<DIFFE_TYPE> retType = mode_from_fn(fn, mode);
std::vector<DIFFE_TYPE> retType;

std::vector<bool> returnPrimals;
for (auto act : CI.getRetActivity()) {
auto iattr = cast<mlir::enzyme::ActivityAttr>(act);
auto val = iattr.getValue();
DIFFE_TYPE ty;
bool primalNeeded = true;
switch (val) {
case mlir::enzyme::Activity::enzyme_active:
ty = DIFFE_TYPE::OUT_DIFF;
break;
case mlir::enzyme::Activity::enzyme_dup:
ty = DIFFE_TYPE::DUP_ARG;
break;
case mlir::enzyme::Activity::enzyme_const:
ty = DIFFE_TYPE::CONSTANT;
break;
case mlir::enzyme::Activity::enzyme_dupnoneed:
ty = DIFFE_TYPE::DUP_NONEED;
primalNeeded = false;
break;
case mlir::enzyme::Activity::enzyme_activenoneed:
ty = DIFFE_TYPE::OUT_DIFF;
primalNeeded = false;
break;
case mlir::enzyme::Activity::enzyme_constnoneed:
ty = DIFFE_TYPE::CONSTANT;
primalNeeded = false;
break;
}
retType.push_back(ty);
returnPrimals.push_back(primalNeeded);
}

MTypeAnalysis TA;
auto type_args = TA.getAnalyzedTypeInfo(fn);
Expand All @@ -91,12 +147,6 @@ struct DifferentiatePass : public DifferentiatePassBase<DifferentiatePass> {
volatile_args.push_back(!(mode == DerivativeMode::ReverseModeCombined));
}

std::vector<bool> returnPrimals;
for (auto act : retType) {
(void)act;
returnPrimals.push_back(false);
}

FunctionOpInterface newFunc = Logic.CreateForwardDiff(
fn, retType, constants, TA, returnPrimals, mode, freeMemory, width,
/*addedType*/ nullptr, type_args, volatile_args,
Expand All @@ -115,7 +165,7 @@ struct DifferentiatePass : public DifferentiatePassBase<DifferentiatePass> {
template <typename T>
LogicalResult HandleAutoDiffReverse(SymbolTableCollection &symbolTable,
T CI) {
std::vector<DIFFE_TYPE> constants;
std::vector<DIFFE_TYPE> arg_activities;
SmallVector<mlir::Value, 2> args;

size_t call_idx = 0;
Expand All @@ -125,9 +175,31 @@ struct DifferentiatePass : public DifferentiatePassBase<DifferentiatePass> {
++call_idx;

auto iattr = cast<mlir::enzyme::ActivityAttr>(act);
DIFFE_TYPE ty = (DIFFE_TYPE)(iattr.getValue());

constants.push_back(ty);
auto val = iattr.getValue();
DIFFE_TYPE ty;
switch (val) {
case mlir::enzyme::Activity::enzyme_active:
ty = DIFFE_TYPE::OUT_DIFF;
break;
case mlir::enzyme::Activity::enzyme_dup:
ty = DIFFE_TYPE::DUP_ARG;
break;
case mlir::enzyme::Activity::enzyme_const:
ty = DIFFE_TYPE::CONSTANT;
break;
case mlir::enzyme::Activity::enzyme_dupnoneed:
ty = DIFFE_TYPE::DUP_NONEED;
break;
case mlir::enzyme::Activity::enzyme_activenoneed:
ty = DIFFE_TYPE::OUT_DIFF;
assert(0 && "unsupported arg activenoneed");
break;
case mlir::enzyme::Activity::enzyme_constnoneed:
ty = DIFFE_TYPE::CONSTANT;
assert(0 && "unsupported arg constnoneed");
break;
}
arg_activities.push_back(ty);
args.push_back(res);
if (ty == DIFFE_TYPE::DUP_ARG || ty == DIFFE_TYPE::DUP_NONEED) {
res = CI.getInputs()[call_idx];
Expand All @@ -141,13 +213,45 @@ struct DifferentiatePass : public DifferentiatePassBase<DifferentiatePass> {
auto fn = cast<FunctionOpInterface>(symbolOp);

auto mode = DerivativeMode::ReverseModeCombined;
std::vector<DIFFE_TYPE> retType = mode_from_fn(fn, mode);
std::vector<DIFFE_TYPE> retType;
std::vector<bool> returnPrimals;
std::vector<bool> returnShadows;

// Add the return gradient
for (auto act : retType) {
if (act == DIFFE_TYPE::OUT_DIFF) {
for (auto act : CI.getRetActivity()) {
auto iattr = cast<mlir::enzyme::ActivityAttr>(act);
auto val = iattr.getValue();
DIFFE_TYPE ty;
bool primalNeeded = true;
switch (val) {
case mlir::enzyme::Activity::enzyme_active:
ty = DIFFE_TYPE::OUT_DIFF;
break;
case mlir::enzyme::Activity::enzyme_dup:
ty = DIFFE_TYPE::DUP_ARG;
break;
case mlir::enzyme::Activity::enzyme_const:
ty = DIFFE_TYPE::CONSTANT;
break;
case mlir::enzyme::Activity::enzyme_dupnoneed:
ty = DIFFE_TYPE::DUP_NONEED;
primalNeeded = false;
break;
case mlir::enzyme::Activity::enzyme_activenoneed:
ty = DIFFE_TYPE::OUT_DIFF;
primalNeeded = false;
break;
case mlir::enzyme::Activity::enzyme_constnoneed:
ty = DIFFE_TYPE::CONSTANT;
primalNeeded = false;
break;
}
retType.push_back(ty);
returnPrimals.push_back(primalNeeded);
returnShadows.push_back(false);
if (ty == DIFFE_TYPE::OUT_DIFF) {
mlir::Value res = CI.getInputs()[call_idx];
call_idx++;
++call_idx;
args.push_back(res);
}
}
Expand All @@ -158,17 +262,13 @@ struct DifferentiatePass : public DifferentiatePassBase<DifferentiatePass> {
size_t width = 1;

std::vector<bool> volatile_args;
std::vector<bool> returnPrimals;
std::vector<bool> returnShadows;
for (auto &a : fn.getFunctionBody().getArguments()) {
(void)a;
volatile_args.push_back(!(mode == DerivativeMode::ReverseModeCombined));
returnPrimals.push_back(false);
returnShadows.push_back(false);
}

FunctionOpInterface newFunc =
Logic.CreateReverseDiff(fn, retType, constants, TA, returnPrimals,
Logic.CreateReverseDiff(fn, retType, arg_activities, TA, returnPrimals,
returnShadows, mode, freeMemory, width,
/*addedType*/ nullptr, type_args, volatile_args,
/*augmented*/ nullptr);
Expand Down
6 changes: 3 additions & 3 deletions enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ enzyme::Activity getDefaultActivity(Type argType) {
return enzyme::Activity::enzyme_const;

if (isa<FloatType, ComplexType>(argType))
return enzyme::Activity::enzyme_out;
return enzyme::Activity::enzyme_active;

if (auto T = dyn_cast<TensorType>(argType))
return getDefaultActivity(T.getElementType());
Expand Down Expand Up @@ -104,7 +104,7 @@ struct PrintActivityAnalysisPass
argActivities[paramIdx] =
llvm::TypeSwitch<Type, enzyme::Activity>(paramType)
.Case<FloatType, ComplexType>(
[](auto type) { return enzyme::Activity::enzyme_out; })
[](auto type) { return enzyme::Activity::enzyme_active; })
.Case<LLVM::LLVMPointerType, MemRefType>([&](auto type) {
// Skip the shadow
argIdx++;
Expand All @@ -121,7 +121,7 @@ struct PrintActivityAnalysisPass
resultActivities[resIdx] =
llvm::TypeSwitch<Type, enzyme::Activity>(resType)
.Case<FloatType, ComplexType>(
[](auto type) { return enzyme::Activity::enzyme_out; })
[](auto type) { return enzyme::Activity::enzyme_active; })
.Default(
[](Type type) { return enzyme::Activity::enzyme_const; });
}
Expand Down
2 changes: 1 addition & 1 deletion enzyme/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ new_local_repository(

load("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure")

llvm_configure(name = "llvm-project", targets = ["X86", "NVPTX"])
llvm_configure(name = "llvm-project", targets = ["AArch64", "NVPTX"])

load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe")
Expand Down
6 changes: 3 additions & 3 deletions enzyme/test/MLIR/ForwardMode/affine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ module {
return %r : f64
}
func.func @dloop(%x : f64, %dx : f64) -> f64 {
%r = enzyme.fwddiff @loop(%x, %dx) { activity=[#enzyme<activity enzyme_dup>] } : (f64, f64) -> (f64)
%r = enzyme.fwddiff @loop(%x, %dx) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>] } : (f64, f64) -> (f64)
return %r : f64
}
// CHECK: @fwddiffeloop
Expand Down Expand Up @@ -38,7 +38,7 @@ module {
return %res : f64
}
func.func @dif_then_else(%x : f64, %dx : f64, %c : i1) -> f64 {
%r = enzyme.fwddiff @if_then_else(%x, %dx, %c) { activity=[#enzyme<activity enzyme_dup>, #enzyme<activity enzyme_const>] } : (f64, f64, i1) -> (f64)
%r = enzyme.fwddiff @if_then_else(%x, %dx, %c) { activity=[#enzyme<activity enzyme_dup>, #enzyme<activity enzyme_const>], ret_activity=[#enzyme<activity enzyme_dupnoneed>] } : (f64, f64, i1) -> (f64)
return %r : f64
}
// CHECK: @fwddiffeif_then_else
Expand Down Expand Up @@ -74,7 +74,7 @@ module {
return %res : f64
}
func.func @dif_then(%x : f64, %dx : f64, %c : i1) -> f64 {
%r = enzyme.fwddiff @if_then(%x, %dx, %c) { activity=[#enzyme<activity enzyme_dup>, #enzyme<activity enzyme_const>] } : (f64, f64, i1) -> (f64)
%r = enzyme.fwddiff @if_then(%x, %dx, %c) { activity=[#enzyme<activity enzyme_dup>, #enzyme<activity enzyme_const>], ret_activity=[#enzyme<activity enzyme_dupnoneed>] } : (f64, f64, i1) -> (f64)
return %r : f64
}
// CHECK: @fwddiffeif_then
Expand Down
2 changes: 1 addition & 1 deletion enzyme/test/MLIR/ForwardMode/branch-self-recursive.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ module {
return %sum : f64
}
func.func @dsq(%x : f64, %dx : f64, %y : f64, %dy : f64) -> f64 {
%r = enzyme.fwddiff @infinite(%x, %dx, %y, %dy) { activity=[#enzyme<activity enzyme_dup>, #enzyme<activity enzyme_dup>] } : (f64, f64, f64, f64) -> (f64)
%r = enzyme.fwddiff @infinite(%x, %dx, %y, %dy) { activity=[#enzyme<activity enzyme_dup>, #enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>] } : (f64, f64, f64, f64) -> (f64)
return %r : f64
}
}
Expand Down
2 changes: 1 addition & 1 deletion enzyme/test/MLIR/ForwardMode/branch.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ module {
return %r : f64
}
func.func @dsq(%x : f64, %dx : f64, %y : f64, %dy : f64) -> f64 {
%r = enzyme.fwddiff @square(%x, %dx, %y, %dy) { activity=[#enzyme<activity enzyme_dup>, #enzyme<activity enzyme_dup>] } : (f64, f64, f64, f64) -> (f64)
%r = enzyme.fwddiff @square(%x, %dx, %y, %dy) { activity=[#enzyme<activity enzyme_dup>, #enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>] } : (f64, f64, f64, f64) -> (f64)
return %r : f64
}
}
Expand Down
2 changes: 1 addition & 1 deletion enzyme/test/MLIR/ForwardMode/executeop.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ module {
return %res : f64
}
func.func @dsq(%x : f64, %dx : f64, %c : i32) -> f64 {
%r = enzyme.fwddiff @square(%x, %dx, %c) { activity=[#enzyme<activity enzyme_dup>, #enzyme<activity enzyme_const>] } : (f64, f64, i32) -> (f64)
%r = enzyme.fwddiff @square(%x, %dx, %c) { activity=[#enzyme<activity enzyme_dup>, #enzyme<activity enzyme_const>], ret_activity=[#enzyme<activity enzyme_dupnoneed>] } : (f64, f64, i32) -> (f64)
return %r : f64
}
}
Expand Down
2 changes: 1 addition & 1 deletion enzyme/test/MLIR/ForwardMode/for.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ module {
return %r : f64
}
func.func @dsq(%x : f64, %dx : f64) -> f64 {
%r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme<activity enzyme_dup>] } : (f64, f64) -> (f64)
%r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>] } : (f64, f64) -> (f64)
return %r : f64
}
}
Expand Down
2 changes: 1 addition & 1 deletion enzyme/test/MLIR/ForwardMode/for2.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ module {
return %r : f64
}
func.func @dsq(%x : f64, %dx : f64) -> f64 {
%r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme<activity enzyme_dup>] } : (f64, f64) -> (f64)
%r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>] } : (f64, f64) -> (f64)
return %r : f64
}
}
Expand Down
2 changes: 1 addition & 1 deletion enzyme/test/MLIR/ForwardMode/if1.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ module {
return %res : f64
}
func.func @dsq(%x : f64, %dx : f64, %c : i1) -> f64 {
%r = enzyme.fwddiff @square(%x, %dx, %c) { activity=[#enzyme<activity enzyme_dup>, #enzyme<activity enzyme_const>] } : (f64, f64, i1) -> (f64)
%r = enzyme.fwddiff @square(%x, %dx, %c) { activity=[#enzyme<activity enzyme_dup>, #enzyme<activity enzyme_const>], ret_activity=[#enzyme<activity enzyme_dupnoneed>] } : (f64, f64, i1) -> (f64)
return %r : f64
}
}
Expand Down
2 changes: 1 addition & 1 deletion enzyme/test/MLIR/ForwardMode/inactive.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ module {
return %x : f64
}
func.func @diff(%x : f64, %dx : f64) -> f64 {
%r = enzyme.fwddiff @inactive(%x, %dx) { activity=[#enzyme<activity enzyme_dup>] } : (f64, f64) -> (f64)
%r = enzyme.fwddiff @inactive(%x, %dx) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>] } : (f64, f64) -> (f64)
return %r : f64
}
}
Expand Down
2 changes: 1 addition & 1 deletion enzyme/test/MLIR/ForwardMode/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ module {
return %r : f64
}
func.func @diff(%x : f64, %dx : f64) -> f64 {
%r = enzyme.fwddiff @unsupported(%x, %dx) { activity=[#enzyme<activity enzyme_dup>] } : (f64, f64) -> (f64)
%r = enzyme.fwddiff @unsupported(%x, %dx) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>] } : (f64, f64) -> (f64)
return %r : f64
}
}
2 changes: 1 addition & 1 deletion enzyme/test/MLIR/ForwardMode/llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ module {
return %r : f64
}
func.func @dsq(%x : f64, %dx : f64) -> f64 {
%r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme<activity enzyme_dup>] } : (f64, f64) -> (f64)
%r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>] } : (f64, f64) -> (f64)
return %r : f64
}
}
Expand Down
Loading

0 comments on commit 69d8a1c

Please sign in to comment.