Skip to content

Commit

Permalink
Merge pull request google#408 from j2kun:forward-stores-to-loads-2
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 603132812
  • Loading branch information
copybara-github committed Jan 31, 2024
2 parents 1ee2d5f + 7ad513d commit bf50f2a
Show file tree
Hide file tree
Showing 3 changed files with 246 additions and 2 deletions.
4 changes: 2 additions & 2 deletions lib/Transforms/ForwardStoreToLoad/ForwardStoreToLoad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ bool ForwardSingleStoreToLoad::isForwardableOp(Operation *potentialStore,
<< "loadOp and store op do not have matching indices\n");
return false;
}
// get this node to the load node and check if any in between
// isForwardableOp

// Naively scan through the operations between the two ops and check if
// anything prevents forwarding.
for (auto currentNode = storeOp->getNextNode();
currentNode != loadOp.getOperation();
currentNode = currentNode->getNextNode()) {
Expand Down
119 changes: 119 additions & 0 deletions tests/forward_store_to_load/forward_add_one.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// RUN: heir-opt --forward-store-to-load %s | FileCheck %s

module {
// CHECK-LABEL: @add_one
func.func @add_one(%arg0: !tfhe_rust.server_key, %arg1: memref<8x!tfhe_rust.eui3>) -> memref<8x!tfhe_rust.eui3> {
%c1_i8 = arith.constant 1 : i8
%c2_i8 = arith.constant 2 : i8
%true = arith.constant true
%false = arith.constant false
%c7 = arith.constant 7 : index
%c6 = arith.constant 6 : index
%c5 = arith.constant 5 : index
%c4 = arith.constant 4 : index
%c3 = arith.constant 3 : index
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
// This alloc is not needed, so we should test that the stores are all forwarded to their loads.
// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<8xi1>
// CHECK-NOT: memref.load %[[ALLOC]]
%alloc = memref.alloc() : memref<8xi1>
memref.store %true, %alloc[%c0] : memref<8xi1>
memref.store %false, %alloc[%c1] : memref<8xi1>
memref.store %false, %alloc[%c2] : memref<8xi1>
memref.store %false, %alloc[%c3] : memref<8xi1>
memref.store %false, %alloc[%c4] : memref<8xi1>
memref.store %false, %alloc[%c5] : memref<8xi1>
memref.store %false, %alloc[%c6] : memref<8xi1>
memref.store %false, %alloc[%c7] : memref<8xi1>
%0 = memref.load %alloc[%c0] : memref<8xi1>
%1 = memref.load %arg1[%c0] : memref<8x!tfhe_rust.eui3>
%2 = tfhe_rust.create_trivial %arg0, %false : (!tfhe_rust.server_key, i1) -> !tfhe_rust.eui3
%3 = tfhe_rust.create_trivial %arg0, %0 : (!tfhe_rust.server_key, i1) -> !tfhe_rust.eui3
%4 = tfhe_rust.generate_lookup_table %arg0 {truthTable = 8 : ui8} : (!tfhe_rust.server_key) -> !tfhe_rust.lookup_table
%5 = tfhe_rust.scalar_left_shift %arg0, %2, %c2_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3
%6 = tfhe_rust.scalar_left_shift %arg0, %3, %c1_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3
%7 = tfhe_rust.add %arg0, %5, %6 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3
%8 = tfhe_rust.add %arg0, %7, %1 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3
%9 = tfhe_rust.apply_lookup_table %arg0, %8, %4 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3
%10 = memref.load %alloc[%c1] : memref<8xi1>
%11 = memref.load %arg1[%c1] : memref<8x!tfhe_rust.eui3>
%12 = tfhe_rust.create_trivial %arg0, %10 : (!tfhe_rust.server_key, i1) -> !tfhe_rust.eui3
%13 = tfhe_rust.generate_lookup_table %arg0 {truthTable = 150 : ui8} : (!tfhe_rust.server_key) -> !tfhe_rust.lookup_table
%14 = tfhe_rust.scalar_left_shift %arg0, %12, %c2_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3
%15 = tfhe_rust.scalar_left_shift %arg0, %11, %c1_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3
%16 = tfhe_rust.add %arg0, %14, %15 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3
%17 = tfhe_rust.add %arg0, %16, %9 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3
%18 = tfhe_rust.apply_lookup_table %arg0, %17, %13 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3
%19 = tfhe_rust.generate_lookup_table %arg0 {truthTable = 23 : ui8} : (!tfhe_rust.server_key) -> !tfhe_rust.lookup_table
%20 = tfhe_rust.apply_lookup_table %arg0, %17, %19 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3
%21 = memref.load %alloc[%c2] : memref<8xi1>
%22 = memref.load %arg1[%c2] : memref<8x!tfhe_rust.eui3>
%23 = tfhe_rust.create_trivial %arg0, %21 : (!tfhe_rust.server_key, i1) -> !tfhe_rust.eui3
%24 = tfhe_rust.generate_lookup_table %arg0 {truthTable = 43 : ui8} : (!tfhe_rust.server_key) -> !tfhe_rust.lookup_table
%25 = tfhe_rust.scalar_left_shift %arg0, %23, %c2_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3
%26 = tfhe_rust.scalar_left_shift %arg0, %22, %c1_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3
%27 = tfhe_rust.add %arg0, %25, %26 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3
%28 = tfhe_rust.add %arg0, %27, %20 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3
%29 = tfhe_rust.apply_lookup_table %arg0, %28, %24 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3
%30 = memref.load %alloc[%c3] : memref<8xi1>
%31 = memref.load %arg1[%c3] : memref<8x!tfhe_rust.eui3>
%32 = tfhe_rust.create_trivial %arg0, %30 : (!tfhe_rust.server_key, i1) -> !tfhe_rust.eui3
%33 = tfhe_rust.scalar_left_shift %arg0, %32, %c2_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3
%34 = tfhe_rust.scalar_left_shift %arg0, %31, %c1_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3
%35 = tfhe_rust.add %arg0, %33, %34 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3
%36 = tfhe_rust.add %arg0, %35, %29 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3
%37 = tfhe_rust.apply_lookup_table %arg0, %36, %24 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3
%38 = memref.load %alloc[%c4] : memref<8xi1>
%39 = memref.load %arg1[%c4] : memref<8x!tfhe_rust.eui3>
%40 = tfhe_rust.create_trivial %arg0, %38 : (!tfhe_rust.server_key, i1) -> !tfhe_rust.eui3
%41 = tfhe_rust.scalar_left_shift %arg0, %40, %c2_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3
%42 = tfhe_rust.scalar_left_shift %arg0, %39, %c1_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3
%43 = tfhe_rust.add %arg0, %41, %42 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3
%44 = tfhe_rust.add %arg0, %43, %37 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3
%45 = tfhe_rust.apply_lookup_table %arg0, %44, %24 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3
%46 = memref.load %alloc[%c5] : memref<8xi1>
%47 = memref.load %arg1[%c5] : memref<8x!tfhe_rust.eui3>
%48 = tfhe_rust.create_trivial %arg0, %46 : (!tfhe_rust.server_key, i1) -> !tfhe_rust.eui3
%49 = tfhe_rust.scalar_left_shift %arg0, %48, %c2_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3
%50 = tfhe_rust.scalar_left_shift %arg0, %47, %c1_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3
%51 = tfhe_rust.add %arg0, %49, %50 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3
%52 = tfhe_rust.add %arg0, %51, %45 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3
%53 = tfhe_rust.apply_lookup_table %arg0, %52, %24 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3
%54 = memref.load %alloc[%c6] : memref<8xi1>
%55 = memref.load %arg1[%c6] : memref<8x!tfhe_rust.eui3>
%56 = tfhe_rust.create_trivial %arg0, %54 : (!tfhe_rust.server_key, i1) -> !tfhe_rust.eui3
%57 = tfhe_rust.generate_lookup_table %arg0 {truthTable = 105 : ui8} : (!tfhe_rust.server_key) -> !tfhe_rust.lookup_table
%58 = tfhe_rust.scalar_left_shift %arg0, %56, %c2_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3
%59 = tfhe_rust.scalar_left_shift %arg0, %55, %c1_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3
%60 = tfhe_rust.add %arg0, %58, %59 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3
%61 = tfhe_rust.add %arg0, %60, %53 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3
%62 = tfhe_rust.apply_lookup_table %arg0, %61, %57 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3
%63 = tfhe_rust.apply_lookup_table %arg0, %61, %24 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3
%64 = memref.load %alloc[%c7] : memref<8xi1>
%65 = memref.load %arg1[%c7] : memref<8x!tfhe_rust.eui3>
%66 = tfhe_rust.create_trivial %arg0, %64 : (!tfhe_rust.server_key, i1) -> !tfhe_rust.eui3
%67 = tfhe_rust.scalar_left_shift %arg0, %66, %c2_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3
%68 = tfhe_rust.scalar_left_shift %arg0, %65, %c1_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3
%69 = tfhe_rust.add %arg0, %67, %68 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3
%70 = tfhe_rust.add %arg0, %69, %63 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3
%71 = tfhe_rust.apply_lookup_table %arg0, %70, %57 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3
%72 = tfhe_rust.generate_lookup_table %arg0 {truthTable = 6 : ui8} : (!tfhe_rust.server_key) -> !tfhe_rust.lookup_table
%73 = tfhe_rust.apply_lookup_table %arg0, %8, %72 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3
%74 = tfhe_rust.apply_lookup_table %arg0, %28, %57 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3
%75 = tfhe_rust.apply_lookup_table %arg0, %36, %57 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3
%76 = tfhe_rust.apply_lookup_table %arg0, %44, %57 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3
%77 = tfhe_rust.apply_lookup_table %arg0, %52, %57 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3
%alloc_0 = memref.alloc() : memref<8x!tfhe_rust.eui3>
memref.store %73, %alloc_0[%c0] : memref<8x!tfhe_rust.eui3>
memref.store %18, %alloc_0[%c1] : memref<8x!tfhe_rust.eui3>
memref.store %74, %alloc_0[%c2] : memref<8x!tfhe_rust.eui3>
memref.store %75, %alloc_0[%c3] : memref<8x!tfhe_rust.eui3>
memref.store %76, %alloc_0[%c4] : memref<8x!tfhe_rust.eui3>
memref.store %77, %alloc_0[%c5] : memref<8x!tfhe_rust.eui3>
memref.store %62, %alloc_0[%c6] : memref<8x!tfhe_rust.eui3>
memref.store %71, %alloc_0[%c7] : memref<8x!tfhe_rust.eui3>
return %alloc_0 : memref<8x!tfhe_rust.eui3>
}
}
125 changes: 125 additions & 0 deletions tests/forward_store_to_load/forward_store_to_load.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,128 @@ func.func @single_store(%arg0: memref<10xi8>) -> i8 {
%1 = memref.load %arg0[%c0] : memref<10xi8>
return %1: i8
}

// CHECK-LABEL: func.func @block_arg
// CHECK-SAME: %[[ARG1:.*]]: i8
func.func @block_arg(%arg0: i8) -> i8 {
%c0 = arith.constant 0 : index
%0 = memref.alloc() : memref<1xi8>
memref.store %arg0, %0[%c0] : memref<1xi8>
// CHECK-NOT: memref.load
%1 = memref.load %0[%c0] : memref<1xi8>
// CHECK: return %[[ARG1]] : i8
return %1: i8
}


// CHECK-LABEL: func.func @inside_region
// CHECK-SAME: (%[[MEMREF0:.*]]: memref<10xi8>, %[[MEMREF1:.*]]: memref<10xi8>)
func.func @inside_region(%memref0: memref<10xi8>, %memref1: memref<10xi8>) -> memref<10xi8> {
// CHECK: %[[OUT:.*]] = memref.alloc() : memref<10xi8>
%out = memref.alloc() : memref<10xi8>
// CHECK: affine.for %[[I:.*]] = 0 to 10 {
affine.for %i = 0 to 10 {
// CHECK: %[[V0:.*]] = memref.load %[[MEMREF0]][%[[I]]]
%0 = memref.load %memref0[%i] : memref<10xi8>
// CHECK: memref.store %[[V0]], %[[MEMREF1]][%[[I]]]
// CHECK-NOT: memref.load
memref.store %0, %memref1[%i] : memref<10xi8>
// CHECK-NEXT: memref.store %[[V0]], %[[OUT]][%[[I]]]
%1 = memref.load %memref1[%i] : memref<10xi8>
memref.store %1, %out[%i] : memref<10xi8>
}
return %out: memref<10xi8>
}

// Two possibilities to forward to, always use the latest
// CHECK-LABEL: func.func @forward_latest
// CHECK-SAME: (%[[MEMREF0:.*]]: memref<10xi8>)
func.func @forward_latest(%memref0: memref<10xi8>) -> i8 {
// CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
// CHECK-NEXT: %[[VAL0:.*]] = arith.constant 7 : i8
%val0 = arith.constant 7 : i8
// CHECK-NEXT: %[[VAL1:.*]] = arith.constant 8 : i8
%val1 = arith.constant 8 : i8

// CHECK-NEXT: memref.store %[[VAL0]], %[[MEMREF0]][%[[C0]]]
memref.store %val0, %memref0[%c0] : memref<10xi8>
// CHECK-NEXT: memref.store %[[VAL1]], %[[MEMREF0]][%[[C0]]]
memref.store %val1, %memref0[%c0] : memref<10xi8>
// CHECK-NOT: memref.load
// CHECK-NEXT: return %[[VAL1]] : i8
%1 = memref.load %memref0[%c0] : memref<10xi8>
return %1: i8
}


// CHECK-LABEL: func.func @skip_different_blocks
// CHECK-SAME: (%[[MEMREF0:.*]]: memref<10xi8>)
func.func @skip_different_blocks(%memref0: memref<10xi8>) -> i8 {
// CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
// CHECK-NEXT: %[[VAL0:.*]] = arith.constant 7 : i8
%val0 = arith.constant 7 : i8
%true = arith.constant true

scf.if %true {
// CHECK: memref.store %[[VAL0]], %[[MEMREF0]][%[[C0]]]
memref.store %val0, %memref0[%c0] : memref<10xi8>
}

// The store is in a different block, so we don't forward it
// CHECK: %[[V1:.*]] = memref.load %[[MEMREF0]][%[[C0]]]
%1 = memref.load %memref0[%c0] : memref<10xi8>
// CHECK-NEXT: return %[[V1]] : i8
return %1: i8
}


// CHECK-LABEL: func.func @skip_intermediate_region_holding_op
// CHECK-SAME: (%[[MEMREF0:.*]]: memref<10xi8>)
func.func @skip_intermediate_region_holding_op(%memref0: memref<10xi8>) -> i8 {
// CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
// CHECK-NEXT: %[[VAL0:.*]] = arith.constant 7 : i8
%val0 = arith.constant 7 : i8
%true = arith.constant true
// CHECK: memref.alloc
%0 = memref.alloc() : memref<1xi8>

// CHECK: memref.store %[[VAL0]], %[[MEMREF0]][%[[C0]]]
memref.store %val0, %memref0[%c0] : memref<10xi8>

// CHECK-NEXT: scf.if
// CHECK-NEXT: %[[V1:.*]] = memref.load %[[MEMREF0]][%[[C0]]]
// CHECK-NEXT: memref.store %[[V1]]
scf.if %true {
%1 = memref.load %memref0[%c0] : memref<10xi8>
memref.store %1, %0[%c0] : memref<1xi8>
}

// The store has a region-holding op between it and this load, and we don't
// check if the memref is impacted inside that region. Assume it is and don't
// forward.
// CHECK: %[[V2:.*]] = memref.load %[[MEMREF0]][%[[C0]]]
%2 = memref.load %memref0[%c0] : memref<10xi8>
// CHECK-NEXT: return %[[V2]] : i8
return %2: i8
}


// CHECK-LABEL: func.func @wrong_indices
// CHECK-SAME: (%[[MEMREF0:.*]]: memref<10xi8>)
func.func @wrong_indices(%memref0: memref<10xi8>) -> i8 {
// CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%val0 = arith.constant 7 : i8
memref.store %val0, %memref0[%c0] : memref<10xi8>

// Loading at a different index, so do not forward
// CHECK: %[[V2:.*]] = memref.load %[[MEMREF0]][%[[C1]]]
%2 = memref.load %memref0[%c1] : memref<10xi8>
// CHECK-NEXT: return %[[V2]] : i8
return %2: i8
}

0 comments on commit bf50f2a

Please sign in to comment.