Skip to content

Commit

Permalink
test option and fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
j2kun committed Dec 2, 2023
1 parent cb1c388 commit b1429a2
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 94 deletions.
71 changes: 4 additions & 67 deletions lib/Dialect/Secret/Transforms/DistributeGeneric.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,11 +225,6 @@ struct SplitGeneric : public OpRewritePattern<GenericOp> {
// to the corresponding secret operands (via the block argument number).
rewriter.startRootUpdate(genericOp);

LLVM_DEBUG({
llvm::dbgs() << "\n\ngeneric op before updating loop operands\n\n";
genericOp.dump();
});

// Set the loop op's operands that came from the secret generic block
// to be the the corresponding operand of the generic op.
for (OpOperand &operand : opToDistribute.getOpOperands()) {
Expand All @@ -240,35 +235,14 @@ struct SplitGeneric : public OpRewritePattern<GenericOp> {
}
}

LLVM_DEBUG({
llvm::dbgs() << "\n\ngeneric op after updating loop operands:\n\n";
genericOp.dump();
});

// Set the op's region iter arg types, which need to match the possibly
// new type of the operands modified above
for (auto [arg, operand] :
llvm::zip(loop.getRegionIterArgs(), loop.getInits())) {
arg.setType(operand.getType());
}

LLVM_DEBUG({
llvm::dbgs() << "\n\ngeneric op after updating region iter args\n\n";
genericOp.dump();
});

// There is a slight type conflict here: the loop's iter arg is
// secret<index>, but its block argument is just index. Since the
// CollapseSecretlessGeneric pattern will resolve this type conflict
// later, we leave it as-is here.

opToDistribute.moveBefore(genericOp);

LLVM_DEBUG({
llvm::dbgs() << "\n\nparent after moving loop out of generic body:\n\n";
genericOp->getParentOp()->dump();
});

// Now the loop is before the secret generic, but the generic still
// yields the loop's result (the loop should yield the generic's result)
// and the generic's body still needs to be moved inside the loop.
Expand All @@ -284,11 +258,6 @@ struct SplitGeneric : public OpRewritePattern<GenericOp> {
// Move the generic op to be the first op of the loop body.
genericOp->moveBefore(&loopBodyBlocks.front().getOperations().front());

LLVM_DEBUG({
llvm::dbgs() << "\n\nloop after moving generic into the loop body:\n\n";
opToDistribute.dump();
});

// Update the yielded values by the terminators of the two ops' blocks.
auto yieldedValues = loop.getYieldedValues();
genericOp.getBody(0)->getTerminator()->setOperands(yieldedValues);
Expand All @@ -300,11 +269,6 @@ struct SplitGeneric : public OpRewritePattern<GenericOp> {
terminator->setOperands(genericOp.getResults());
}

LLVM_DEBUG({
llvm::dbgs() << "\n\nloop after updating yielded values:\n\n";
opToDistribute.dump();
});

// Update the return type of the loop op to match its terminator.
auto resultRange = loop.getLoopResults();
if (resultRange.has_value()) {
Expand All @@ -314,22 +278,11 @@ struct SplitGeneric : public OpRewritePattern<GenericOp> {
}
}

LLVM_DEBUG({
llvm::dbgs() << "\n\nloop after updating return types:\n\n";
opToDistribute.dump();
});

// Move the old loop body ops into the secret.generic
for (auto *op : loopBodyOps) {
op->moveBefore(genericOp.getBody(0)->getTerminator());
}

LLVM_DEBUG({
llvm::dbgs() << "\n\nloop after moving old loop body ops into the "
"secret.generic:\n\n";
opToDistribute.dump();
});

// One of the secret.generic's inputs may still refer to the loop's
// iter_args initializer, when now it should refer to the iter_arg itself.
for (OpOperand &operand : genericOp->getOpOperands()) {
Expand All @@ -339,12 +292,6 @@ struct SplitGeneric : public OpRewritePattern<GenericOp> {
}
}

LLVM_DEBUG({
llvm::dbgs()
<< "\n\nloop after updating secret.generic to use iter_arg:\n\n";
opToDistribute.dump();
});

// The ops within the secret generic may still refer to the loop
// iter_args, which are not part of of the secret.generic's block. To be
// a bit more general, walk the entire generic body, and for any operand
Expand All @@ -371,12 +318,6 @@ struct SplitGeneric : public OpRewritePattern<GenericOp> {
}
});

LLVM_DEBUG({
llvm::dbgs() << "\n\nloop after updating op args to use plaintext "
"analogues:\n\n";
opToDistribute.dump();
});

// Finally, ops that came after the original secret.generic may still
// refer to a secret.generic result, when now they should refer to the
// corresponding result of the loop, if the loop has results.
Expand All @@ -391,12 +332,6 @@ struct SplitGeneric : public OpRewritePattern<GenericOp> {
}
}

LLVM_DEBUG({
llvm::dbgs()
<< "\n\nloop after updating potential downstream users\n\n";
opToDistribute.getParentOp()->dump();
});

rewriter.finalizeRootUpdate(genericOp);
return;
}
Expand Down Expand Up @@ -482,7 +417,7 @@ struct SplitGeneric : public OpRewritePattern<GenericOp> {
return failure();
}

Operation *opToDistribute;
Operation *opToDistribute = nullptr;
bool first = true;
if (opsToDistribute.empty()) {
opToDistribute = &body->front();
Expand All @@ -492,6 +427,8 @@ struct SplitGeneric : public OpRewritePattern<GenericOp> {
// affine.for)
if (std::find(opsToDistribute.begin(), opsToDistribute.end(),
op.getName().getStringRef()) != opsToDistribute.end()) {
LLVM_DEBUG(llvm::dbgs()
<< "Found op to distribute: " << op.getName() << "\n");
opToDistribute = &op;
break;
}
Expand All @@ -501,7 +438,7 @@ struct SplitGeneric : public OpRewritePattern<GenericOp> {

// Base case: if none of a generic op's member ops are in the list of ops
// to process, stop.
if (!opToDistribute) return failure();
if (opToDistribute == nullptr) return failure();

if (numOps == 2 && !opToDistribute->getRegions().empty()) {
distributeThroughRegionHoldingOp(op, *opToDistribute, rewriter);
Expand Down
29 changes: 29 additions & 0 deletions tests/secret/distribute_generic_flags.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// RUN: heir-opt --secret-distribute-generic="distribute-through=affine.for" %s | FileCheck %s

// CHECK-LABEL: test_affine_for
// CHECK-SAME: %[[value:.*]]: !secret.secret<i32>
// CHECK-SAME: %[[data:.*]]: !secret.secret<memref<10xi32>>
func.func @test_affine_for(
%value: !secret.secret<i32>,
%data: !secret.secret<memref<10xi32>>) -> !secret.secret<memref<10xi32>> {
// CHECK: affine.for
// CHECK: secret.generic
// CHECK-NEXT: bb
// CHECK-NEXT: affine.load
// CHECK-NEXT: arith.addi
// CHECK-NEXT: affine.store
// CHECK-NEXT: secret.yield
// CHECK-NOT: secret.generic
// CHECK: return %[[data]]
secret.generic
ins(%value, %data : !secret.secret<i32>, !secret.secret<memref<10xi32>>) {
^bb0(%clear_value: i32, %clear_data : memref<10xi32>):
affine.for %i = 0 to 10 {
%2 = affine.load %clear_data[%i] : memref<10xi32>
%3 = arith.addi %2, %clear_value : i32
affine.store %3, %clear_data[%i] : memref<10xi32>
}
secret.yield
} -> ()
func.return %data : !secret.secret<memref<10xi32>>
}
27 changes: 0 additions & 27 deletions tests/secret/verifier.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -27,33 +27,6 @@ func.func @test_secret_type_mismatch(%value: !secret.secret<i32>, %c1: i32) {

// -----

func.func @test_refers_to_value_outside_block(%value: !secret.secret<i32>) {
%c1 = arith.constant 1 : i32
// expected-error@+1 {{uses a value defined outside the block}}
%Z = secret.generic
ins(%value : !secret.secret<i32>) {
^bb0(%clear_value: i32):
%1 = arith.addi %clear_value, %c1 : i32
secret.yield %1 : i32
} -> (!secret.secret<i32>)
return
}

// -----

func.func @test_refers_to_block_argument_outside_block(%value: !secret.secret<i32>, %c1 : i32) {
// expected-error@+1 {{uses a block argument defined outside the block}}
%Z = secret.generic
ins(%value : !secret.secret<i32>) {
^bb0(%clear_value: i32):
%1 = arith.addi %clear_value, %c1 : i32
secret.yield %1 : i32
} -> (!secret.secret<i32>)
return
}

// -----

func.func @ensure_yield_inside_generic(%value: !secret.secret<i32>) {
// expected-error@+1 {{expects parent op 'secret.generic'}}
secret.yield %value : !secret.secret<i32>
Expand Down

0 comments on commit b1429a2

Please sign in to comment.