Skip to content

Commit

Permalink
use select mask for scalar masked load
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbaden committed Aug 8, 2024
1 parent 2bc3b06 commit dce39b4
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 63 deletions.
83 changes: 21 additions & 62 deletions test/Conversion/intel/tritongpu_to_gen.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,16 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK-NEXT: [[CST_0:%.*]] = llvm.mlir.constant(0 : index) : i32
// CHECK-NEXT: [[IE1:%.*]] = llvm.insertelement [[ARG2_0]], [[VEC]][[[CST_0]] : i32] : vector<1xf32>
// CHECK-NEXT: [[BCAST0:%.*]] = llvm.bitcast {{.*}} : vector<1xf32> to i32
// CHECK: llvm.cond_br [[ARG1_0]], ^bb1, ^bb2([[BCAST0]] : i32)
// CHECK-NEXT: ^bb1:
// CHECK-NEXT: [[BCAST1:%.*]] = llvm.bitcast [[ARG0_0]] : !llvm.ptr<1> to !llvm.ptr<1>
// CHECK-NEXT: [[CST_1:%.*]] = llvm.mlir.constant(0 : index) : i32
// CHECK-NEXT: [[BCAST1:%.*]] = llvm.bitcast [[ARG0_0]] : !llvm.ptr<1> to !llvm.ptr<1>
// CHECK-NEXT: [[LOAD1:%.*]] = llvm.load [[BCAST1]] {alignment = 4 : i64} : !llvm.ptr<1> -> i32
// CHECK-NEXT: llvm.br ^bb2([[LOAD1]] : i32)
// CHECK-NEXT: ^bb2([[V1:%.*]]: i32):
// CHECK-NEXT: [[V1:%.*]] = llvm.select {{.*}}, [[LOAD1]], [[BCAST0]] : i1, i32
// CHECK-NEXT: [[BCAST_V1:%.*]] = llvm.bitcast [[V1]] : i32 to vector<1xf32>
// CHECK: [[EE1:%.*]] = llvm.extractelement [[BCAST_V1]][{{.*}} : i32] : vector<1xf32>
// CHECK: [[BCAST2:%.*]] = llvm.bitcast {{.*}} : vector<1xf32> to i32
// CHECK: llvm.cond_br [[ARG1_1]], ^bb3, ^bb4([[BCAST2]] : i32)
// CHECK-NEXT: ^bb3:
// CHECK-NEXT: [[BCAST3:%.*]] = llvm.bitcast [[ARG0_1]] : !llvm.ptr<1> to !llvm.ptr<1>
// CHECK: [[BCAST3:%.*]] = llvm.bitcast [[ARG0_1]] : !llvm.ptr<1> to !llvm.ptr<1>
// CHECK-NEXT: [[LOAD2:%.*]] = llvm.load [[BCAST3]] {alignment = 4 : i64} : !llvm.ptr<1> -> i32
// CHECK-NEXT: llvm.br ^bb4([[LOAD2]] : i32)
// CHECK-NEXT: ^bb4([[V2:%.*]]: i32):
// CHECK-NEXT: [[V2:%.*]] = llvm.select {{.*}}, [[LOAD2]], [[BCAST2]] : i1, i32
// CHECK-NEXT: [[BCAST_V2:%.*]] = llvm.bitcast [[V2]] : i32 to vector<1xf32>
// CHECK: [[EE2:%.*]] = llvm.extractelement [[BCAST_V2]][{{.*}} : i32] : vector<1xf32>
// CHECK-NEXT: [[RES1:%.*]] = llvm.mlir.undef : !llvm.struct<(f32, f32)>
Expand All @@ -65,21 +60,16 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK-NEXT: [[CST_0:%.*]] = llvm.mlir.constant(0 : index) : i32
// CHECK-NEXT: [[IE1:%.*]] = llvm.insertelement [[ARG2_0]], [[VEC]][[[CST_0]] : i32] : vector<1xf32>
// CHECK-NEXT: [[BCAST0:%.*]] = llvm.bitcast {{.*}} : vector<1xf32> to i32
// CHECK: llvm.cond_br [[ARG1_0]], ^bb1, ^bb2([[BCAST0]] : i32)
// CHECK-NEXT: ^bb1:
// CHECK-NEXT: [[CST_1:%.*]] = llvm.mlir.constant(0 : index) : i32
// CHECK-NEXT: [[BCAST1:%.*]] = llvm.bitcast [[ARG0_0]] : !llvm.ptr<1> to !llvm.ptr<1>
// CHECK-NEXT: [[LOAD1:%.*]] = llvm.load [[BCAST1]] {alignment = 4 : i64} : !llvm.ptr<1> -> i32
// CHECK-NEXT: llvm.br ^bb2([[LOAD1]] : i32)
// CHECK-NEXT: ^bb2([[V1:%.*]]: i32):
// CHECK-NEXT: [[V1:%.*]] = llvm.select {{.*}}, [[LOAD1]], [[BCAST0]] : i1, i32
// CHECK-NEXT: [[BCAST_V1:%.*]] = llvm.bitcast [[V1]] : i32 to vector<1xf32>
// CHECK: [[EE1:%.*]] = llvm.extractelement [[BCAST_V1]][{{.*}} : i32] : vector<1xf32>
// CHECK: [[BCAST2:%.*]] = llvm.bitcast {{.*}} : vector<1xf32> to i32
// CHECK: llvm.cond_br [[ARG1_1]], ^bb3, ^bb4([[BCAST2]] : i32)
// CHECK-NEXT: ^bb3:
// CHECK-NEXT: [[BCAST3:%.*]] = llvm.bitcast [[ARG0_1]] : !llvm.ptr<1> to !llvm.ptr<1>
// CHECK-NEXT: [[LOAD2:%.*]] = llvm.load [[BCAST3]] {alignment = 4 : i64} : !llvm.ptr<1> -> i32
// CHECK-NEXT: llvm.br ^bb4([[LOAD2]] : i32)
// CHECK-NEXT: ^bb4([[V2:%.*]]: i32):
// CHECK: [[BCAST3:%.*]] = llvm.bitcast [[ARG0_1]] : !llvm.ptr<1> to !llvm.ptr<1>
// CHECK: [[LOAD2:%.*]] = llvm.load [[BCAST3]] {alignment = 4 : i64} : !llvm.ptr<1> -> i32
// CHECK-NEXT: [[V2:%.*]] = llvm.select {{.*}}, [[LOAD2]], [[BCAST2]] : i1, i32
// CHECK-NEXT: [[BCAST_V2:%.*]] = llvm.bitcast [[V2]] : i32 to vector<1xf32>
// CHECK: [[EE2:%.*]] = llvm.extractelement [[BCAST_V2]][{{.*}} : i32] : vector<1xf32>
// CHECK-NEXT: [[RES1:%.*]] = llvm.mlir.undef : !llvm.struct<(f32, f32)>
Expand Down Expand Up @@ -117,82 +107,51 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
// CHECK-NEXT: [[CST_0:%.*]] = llvm.mlir.constant(0 : index) : i32
// CHECK-NEXT: [[IE1:%.*]] = llvm.insertelement [[ARG2_0]], [[VEC]][[[CST_0]] : i32] : vector<1xf16>
// CHECK-NEXT: [[BCAST0:%.*]] = llvm.bitcast [[IE1]] : vector<1xf16> to i16
// CHECK: llvm.cond_br [[ARG1_0]], ^bb1, ^bb2([[BCAST0]] : i16)
// CHECK-NEXT: ^bb1:
// CHECK-NEXT: [[BCAST1:%.*]] = llvm.bitcast [[ARG0_0]] : !llvm.ptr<1> to !llvm.ptr<1>
// CHECK: [[BCAST1:%.*]] = llvm.bitcast [[ARG0_0]] : !llvm.ptr<1> to !llvm.ptr<1>
// CHECK-NEXT: [[LOAD1:%.*]] = llvm.load [[BCAST1]] {alignment = 2 : i64} : !llvm.ptr<1> -> i16
// CHECK-NEXT: llvm.br ^bb2([[LOAD1]] : i16)
// CHECK-NEXT: ^bb2([[V1:%.*]]: i16):
// CHECK-NEXT: [[V1:%.*]] = llvm.select {{.*}}, [[LOAD1]], [[BCAST0]] : i1, i16
// CHECK-NEXT: [[BCAST_V1:%.*]] = llvm.bitcast [[V1]] : i16 to vector<1xf16>
// CHECK: [[EE1:%.*]] = llvm.extractelement [[BCAST_V1]][{{.*}} : i32] : vector<1xf16>
// CHECK: [[BCAST2:%.*]] = llvm.bitcast {{.*}} : vector<1xf16> to i16
// CHECK: llvm.cond_br [[ARG1_1]], ^bb3, ^bb4([[BCAST2]] : i16)

// CHECK-NEXT: ^bb3:
// CHECK-NEXT: [[BCAST3:%.*]] = llvm.bitcast [[ARG0_1]] : !llvm.ptr<1> to !llvm.ptr<1>
// CHECK: [[BCAST3:%.*]] = llvm.bitcast [[ARG0_1]] : !llvm.ptr<1> to !llvm.ptr<1>
// CHECK-NEXT: [[LOAD2:%.*]] = llvm.load [[BCAST3]] {alignment = 2 : i64} : !llvm.ptr<1> -> i16
// CHECK-NEXT: llvm.br ^bb4([[LOAD2]] : i16)
// CHECK-NEXT: ^bb4([[V2:%.*]]: i16):
// CHECK-NEXT: [[V2:%.*]] = llvm.select {{.*}}, [[LOAD2]], [[BCAST2]] : i1, i16
// CHECK-NEXT: [[BCAST_V2:%.*]] = llvm.bitcast [[V2]] : i16 to vector<1xf16>
// CHECK: [[EE2:%.*]] = llvm.extractelement [[BCAST_V2]][{{.*}} : i32] : vector<1xf16>
// CHECK: [[BCAST4:%.*]] = llvm.bitcast {{.*}} : vector<1xf16> to i16
// CHECK: llvm.cond_br [[ARG1_2]], ^bb5, ^bb6([[BCAST4]] : i16)

// CHECK-NEXT: ^bb5:
// CHECK-NEXT: [[BCAST5:%.*]] = llvm.bitcast [[ARG0_2]] : !llvm.ptr<1> to !llvm.ptr<1>
// CHECK: [[BCAST5:%.*]] = llvm.bitcast [[ARG0_2]] : !llvm.ptr<1> to !llvm.ptr<1>
// CHECK-NEXT: [[LOAD3:%.*]] = llvm.load [[BCAST5]] {alignment = 2 : i64} : !llvm.ptr<1> -> i16
// CHECK-NEXT: llvm.br ^bb6([[LOAD3]] : i16)
// CHECK-NEXT: ^bb6([[V3:%.*]]: i16):
// CHECK-NEXT: [[V3:%.*]] = llvm.select {{.*}}, [[LOAD3]], [[BCAST4]] : i1, i16
// CHECK-NEXT: [[BCAST_V3:%.*]] = llvm.bitcast [[V3]] : i16 to vector<1xf16>
// CHECK: [[EE3:%.*]] = llvm.extractelement [[BCAST_V3]][{{.*}} : i32] : vector<1xf16>
// CHECK: [[BCAST5:%.*]] = llvm.bitcast {{.*}} : vector<1xf16> to i16
// CHECK: llvm.cond_br [[ARG1_3]], ^bb7, ^bb8([[BCAST5]] : i16)

// CHECK-NEXT: ^bb7:
// CHECK: [[BCAST6:%.*]] = llvm.bitcast [[ARG0_3]] : !llvm.ptr<1> to !llvm.ptr<1>
// CHECK-NEXT: [[LOAD4:%.*]] = llvm.load [[BCAST6]] {alignment = 2 : i64} : !llvm.ptr<1> -> i16
// CHECK-NEXT: llvm.br ^bb8([[LOAD4]] : i16)
// CHECK-NEXT: ^bb8([[V4:%.*]]: i16):
// CHECK-NEXT: [[V4:%.*]] = llvm.select {{.*}}, [[LOAD4]], [[BCAST5]] : i1, i16
// CHECK-NEXT: [[BCAST_V4:%.*]] = llvm.bitcast [[V4]] : i16 to vector<1xf16>
// CHECK: [[EE4:%.*]] = llvm.extractelement [[BCAST_V4]][{{.*}} : i32] : vector<1xf16>
// CHECK: [[BCAST7:%.*]] = llvm.bitcast {{.*}} : vector<1xf16> to i16
// CHECK: llvm.cond_br [[ARG1_4]], ^bb9, ^bb10([[BCAST7]] : i16)

// CHECK-NEXT: ^bb9:
// CHECK: [[BCAST8:%.*]] = llvm.bitcast [[ARG0_4]] : !llvm.ptr<1> to !llvm.ptr<1>
// CHECK-NEXT: [[LOAD5:%.*]] = llvm.load [[BCAST8]] {alignment = 2 : i64} : !llvm.ptr<1> -> i16
// CHECK-NEXT: llvm.br ^bb10([[LOAD5]] : i16)
// CHECK-NEXT: ^bb10([[V5:%.*]]: i16):
// CHECK-NEXT: [[V5:%.*]] = llvm.select {{.*}}, [[LOAD5]], [[BCAST7]] : i1, i16
// CHECK-NEXT: [[BCAST_V5:%.*]] = llvm.bitcast [[V5]] : i16 to vector<1xf16>
// CHECK: [[EE5:%.*]] = llvm.extractelement [[BCAST_V5]][{{.*}} : i32] : vector<1xf16>
// CHECK: [[BCAST8:%.*]] = llvm.bitcast {{.*}} : vector<1xf16> to i16
// CHECK: llvm.cond_br [[ARG1_5]], ^bb11, ^bb12([[BCAST8]] : i16)

// CHECK-NEXT: ^bb11:
// CHECK: [[BCAST9:%.*]] = llvm.bitcast [[ARG0_5]] : !llvm.ptr<1> to !llvm.ptr<1>
// CHECK-NEXT: [[LOAD6:%.*]] = llvm.load [[BCAST9]] {alignment = 2 : i64} : !llvm.ptr<1> -> i16
// CHECK-NEXT: llvm.br ^bb12([[LOAD6]] : i16)
// CHECK-NEXT: ^bb12([[V6:%.*]]: i16):
// CHECK-NEXT: [[V6:%.*]] = llvm.select {{.*}}, [[LOAD6]], [[BCAST8]] : i1, i16
// CHECK-NEXT: [[BCAST_V6:%.*]] = llvm.bitcast [[V6]] : i16 to vector<1xf16>
// CHECK: [[EE6:%.*]] = llvm.extractelement [[BCAST_V6]][{{.*}} : i32] : vector<1xf16>
// CHECK: [[BCAST10:%.*]] = llvm.bitcast {{.*}} : vector<1xf16> to i16
// CHECK: llvm.cond_br [[ARG1_6]], ^bb13, ^bb14([[BCAST10]] : i16)

// CHECK-NEXT: ^bb13:
// CHECK: [[BCAST11:%.*]] = llvm.bitcast [[ARG0_6]] : !llvm.ptr<1> to !llvm.ptr<1>
// CHECK-NEXT: [[LOAD7:%.*]] = llvm.load [[BCAST11]] {alignment = 2 : i64} : !llvm.ptr<1> -> i16
// CHECK-NEXT: llvm.br ^bb14([[LOAD7]] : i16)
// CHECK-NEXT: ^bb14([[V7:%.*]]: i16):
// CHECK-NEXT: [[V7:%.*]] = llvm.select {{.*}}, [[LOAD7]], [[BCAST10]] : i1, i16
// CHECK-NEXT: [[BCAST_V7:%.*]] = llvm.bitcast [[V7]] : i16 to vector<1xf16>
// CHECK: [[EE7:%.*]] = llvm.extractelement [[BCAST_V7]][{{.*}} : i32] : vector<1xf16>
// CHECK: [[BCAST12:%.*]] = llvm.bitcast {{.*}} : vector<1xf16> to i16
// CHECK: llvm.cond_br [[ARG1_7]], ^bb15, ^bb16([[BCAST12]] : i16)

// CHECK-NEXT: ^bb15:
// CHECK: [[BCAST13:%.*]] = llvm.bitcast [[ARG0_7]] : !llvm.ptr<1> to !llvm.ptr<1>
// CHECK-NEXT: [[LOAD8:%.*]] = llvm.load [[BCAST13]] {alignment = 2 : i64} : !llvm.ptr<1> -> i16
// CHECK-NEXT: llvm.br ^bb16([[LOAD8]] : i16)
// CHECK-NEXT: ^bb16([[V8:%.*]]: i16):
// CHECK-NEXT: [[V8:%.*]] = llvm.select {{.*}}, [[LOAD8]], [[BCAST12]] : i1, i16
// CHECK-NEXT: [[BCAST_V8:%.*]] = llvm.bitcast [[V8]] : i16 to vector<1xf16>
// CHECK: [[EE8:%.*]] = llvm.extractelement [[BCAST_V8]][{{.*}} : i32] : vector<1xf16>
// CHECK-NEXT: [[RES1:%.*]] = llvm.mlir.undef : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,10 @@ struct LoadOpConversion
if (nWords == 1) {
Value addrElem = bitcast(ptrElems[vecStart], ptr_ty(ctx, 1 /*global*/));
uint32_t alignment = nWords * width / 8;
ret = rewriter.create<LLVM::SelectOp>(loc, pred, rewriter.create<LLVM::LoadOp>(loc, retTy, addrElem, alignment), other_);
ret = rewriter.create<LLVM::SelectOp>(
loc, pred,
rewriter.create<LLVM::LoadOp>(loc, retTy, addrElem, alignment),
other_);
} else {
Block &endBlock = LLVM::intel::createPredicatedBlock(
rewriter, loc, pred, SmallVector<Value, 1>{other_}, [&]() {
Expand Down

0 comments on commit dce39b4

Please sign in to comment.