Skip to content

Commit

Permalink
[Mosaic TPU] Add general tpu.vector_store and support masked store.
Browse files Browse the repository at this point in the history
This cl introduces a general store op called tpu.vector_stores which aims to unify vector::store, tpu::strided_load, vector::masked_store. The tpu.vector_stores should also provide general interface for lowering for both TensorCore and SparseCore.

This cl also adds the support for (dynamic) masked store.

PiperOrigin-RevId: 681952969
  • Loading branch information
bythew3i authored and Google-ML-Automation committed Nov 19, 2024
1 parent 12a43f1 commit 58571de
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 22 deletions.
16 changes: 16 additions & 0 deletions jaxlib/mosaic/dialect/tpu/tpu.td
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,22 @@ def TPU_LoadOp : TPU_Op<"load"> {
}];
}

// TODO(jevinjiang): migrate tpu.strided_store to general vector store op.
def TPU_VectorStoreOp :TPU_Op<"vector_store", [AttrSizedOperandSegments]> {
let arguments = (ins
AnyVector:$valueToStore,
AnyMemRef:$base,
Variadic<Index>:$indices,
DenseI32ArrayAttr:$strides,
Optional<AnyVector>:$mask // Elementwise mask.
);
let results = (outs);
let assemblyFormat = [{
$base `[` $indices `]` `,` $valueToStore (`masked` $mask^)? attr-dict `:` type($base) `,` type($valueToStore) `,` type($mask)
}];
let hasVerifier = 1;
}

def TPU_StridedLoadOp : TPU_Op<"strided_load"> {
let arguments = (ins
AnyMemRef:$base,
Expand Down
27 changes: 26 additions & 1 deletion jaxlib/mosaic/dialect/tpu/tpu_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,31 @@ LogicalResult StridedStoreOp::verify() {
getValueToStore().getType());
}

LogicalResult VectorStoreOp::verify() {
if (!getStrides().empty()) {
return emitError("Not implemented: general vector store with strides.");
}
VectorType value_ty = getValueToStore().getType();
MemRefType ref_ty = getBase().getType();

if (value_ty.getElementType() != ref_ty.getElementType()) {
return emitOpError(
"Expected base and valueToStore element type should match");
}
if (llvm::size(getIndices()) != ref_ty.getRank()) {
return emitOpError("Expected ") << ref_ty.getRank() << " indices";
}
if (getMask()) {
if (value_ty.getElementTypeBitWidth() != 32) {
return emitError(
"Not implemented: masked store with non-32-bit element type");
}
if (value_ty.getShape() != getMask().getType().getShape())
return emitOpError("Expected valueToStore shape to match mask shape");
}
return success();
}

LogicalResult ReinterpretCastOp::verify() {
auto source_type = getMemRefType(getInput());
auto target_type = getType();
Expand Down Expand Up @@ -468,7 +493,7 @@ LogicalResult verifyRotateOp(Op op) {
}
if (op.getStride().has_value() != op.getStrideDimension().has_value()) {
op.emitOpError(
"Expected either none or both stride and stride dimension are "
"Expected either none or both stride and stride dimension are "
"present");
return failure();
}
Expand Down
71 changes: 52 additions & 19 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4194,18 +4194,15 @@ LogicalResult vector_shape_cast_rule(RewriteContext &ctx, Operation &op,
shape_cast_op->erase();
return success();
}
LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
TPU_ASSERT_EQ_OP(layouts_out.size(), 0);
MLIRContext *const mlir_ctx = op.getContext();
TPU_ASSERT_OP(layouts_in.front().has_value());
TPU_ASSERT_OP(llvm::none_of(layouts_in.drop_front(),
[&](const Layout &l) { return l.has_value(); }));

template <typename Op>
LogicalResult vector_store_impl(RewriteContext &ctx, Op store_op,
const VectorLayout &to_store_layout,
TypedValue<VectorType> store_mask = nullptr) {
Operation &op = *(store_op.getOperation());
MLIRContext *const mlir_ctx = store_op.getContext();
ImplicitLocOpBuilder builder(op.getLoc(), &op);
vector::StoreOp store_op = cast<vector::StoreOp>(op);
const VectorType ty = store_op.getValueToStore().getType();
const VectorLayout &to_store_layout = *layouts_in.front();
const auto memref_ty = getMemRefType(store_op.getBase());
if (!ty.getRank()) {
return op.emitOpError("Not implemented: scalar stores to vmem");
Expand Down Expand Up @@ -4302,10 +4299,9 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op,
} else {
// Convert dynamic store to dynamic slice + static store. This saves us a
// bunch of scalar core work.
auto slice_result =
sliceRef(builder, store_op.getBase(),
store_op.getVectorType().getShape(), store_op.getIndices(),
ArrayRef<int64_t>(memref_tiling).take_back(tiled_dims));
auto slice_result = sliceRef(
builder, store_op.getBase(), ty.getShape(), store_op.getIndices(),
ArrayRef<int64_t>(memref_tiling).take_back(tiled_dims));
if (failed(slice_result)) {
return failure();
}
Expand All @@ -4326,6 +4322,13 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op,
xla::Array<Value> tiles,
disassemble(builder, to_store_layout, store_op.getValueToStore(),
ctx.target_shape));
std::optional<xla::Array<Value>> tile_masks;
if (store_mask) {
FAILUREOR_ASSIGN_OR_RETURN(
tile_masks,
disassemble(builder, to_store_layout, store_mask, ctx.target_shape));
TPU_ASSERT_EQ_OP(tile_masks->dimensions(), tiles.dimensions());
}
const int64_t ndims = ty.getRank();
const auto base_s =
is_1d ? IdxConst(0, builder, op.getLoc()) : tile_base_idxs.front();
Expand All @@ -4347,6 +4350,7 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op,
const absl::Status status =
tiles.EachStatus([&](const absl::Span<const int64_t> idx,
const Value tile) -> absl::Status {
const auto tile_mask = store_mask ? (*tile_masks)(idx) : nullptr;
const std::unique_ptr<VRegDataBounds> bounds =
to_store_layout.tileDataBounds(mlir_ctx, stored_shape,
toArrayRef(idx), ctx.target_shape);
Expand Down Expand Up @@ -4406,19 +4410,19 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op,
updated = builder.create<arith::SelectOp>(mask, tile, data);
}
builder.create<tpu::StoreOp>(
updated, base_addr, indices, sublane_mask,
/*mask=*/nullptr,
updated, base_addr, indices, sublane_mask, tile_mask,
/*sublane_stride=*/builder.getI32IntegerAttr(sublane_stride));
} else {
builder.create<tpu::StoreOp>(
tile, base_addr, indices, sublane_mask,
/*mask=*/mask,
tile_mask
? builder.create<arith::AndIOp>(mask, tile_mask).getResult()
: mask,
/*sublane_stride=*/builder.getI32IntegerAttr(sublane_stride));
}
} else {
builder.create<tpu::StoreOp>(
tile, base_addr, indices, sublane_mask,
/*mask=*/nullptr,
tile, base_addr, indices, sublane_mask, tile_mask,
/*sublane_stride=*/builder.getI32IntegerAttr(sublane_stride));
}
return absl::OkStatus();
Expand All @@ -4428,7 +4432,35 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op,
}
store_op->erase();
return success();
}

LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
auto store_op = cast<vector::StoreOp>(op);
TPU_ASSERT_EQ_OP(layouts_out.size(), 0);
TPU_ASSERT_OP(layouts_in.front().has_value());
TPU_ASSERT_OP(llvm::none_of(layouts_in.drop_front(),
[&](const Layout &l) { return l.has_value(); }));
return vector_store_impl(ctx, store_op, *layouts_in.front());
}

LogicalResult tpu_vector_store_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
auto store_op = cast<tpu::VectorStoreOp>(op);
TPU_ASSERT_EQ_OP(layouts_out.size(), 0);
TPU_ASSERT_OP(layouts_in.front().has_value());
auto other_layouts_in = layouts_in.drop_front();
if (store_op.getMask()) {
TPU_ASSERT_EQ_OP(layouts_in.front(), layouts_in.back());
other_layouts_in = other_layouts_in.drop_back();
}
TPU_ASSERT_OP(llvm::none_of(other_layouts_in,
[&](const Layout &l) { return l.has_value(); }));
return vector_store_impl(ctx, store_op, *layouts_in.front(),
store_op.getMask());
}

LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
Expand Down Expand Up @@ -4642,6 +4674,7 @@ const llvm::StringMap<rule_type> &rules() {
{tpu::StoreOp::getOperationName(), tpu_store_rule},
{tpu::StridedLoadOp::getOperationName(), tpu_strided_load_rule},
{tpu::StridedStoreOp::getOperationName(), tpu_strided_store_rule},
{tpu::VectorStoreOp::getOperationName(), tpu_vector_store_rule},
{tpu::MatmulOp::getOperationName(), tpu_matmul_rule},
{tpu::RegionOp::getOperationName(), tpu_region_rule},
{tpu::BitcastOp::getOperationName(), tpu_bitcast_rule},
Expand Down
15 changes: 13 additions & 2 deletions jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,14 @@ class VectorLayoutInferer {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<tpu::VectorStoreOp>(any_op)) {
if (inferStore<tpu::VectorStoreOp>(op,
/*has_mask=*/op.getMask() != nullptr)
.failed()) {
return failure();
}
} else if (auto op = dyn_cast<vector::StoreOp>(any_op)) {
if (infer(op).failed()) {
if (inferStore<vector::StoreOp>(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<vector::TransposeOp>(any_op)) {
Expand Down Expand Up @@ -1540,7 +1546,8 @@ class VectorLayoutInferer {
return failure();
}

LogicalResult infer(vector::StoreOp op) {
template <typename Op>
LogicalResult inferStore(Op op, bool has_mask = false) {
auto ref_ty = getMemRefType(op.getBase());
auto store_ty = op.getValueToStore().getType();
TPU_CHECK_OP(ref_ty.getRank() == store_ty.getRank(),
Expand Down Expand Up @@ -1648,6 +1655,10 @@ class VectorLayoutInferer {
}
SmallVector<Layout, 5> in_layout{store_layout};
in_layout.insert(in_layout.end(), op.getIndices().size() + 1, kNoLayout);
if (has_mask) {
// Mask layout should be the same as the layout of value to store.
in_layout.push_back(store_layout);
}
setInLayout(op, in_layout);
return success();
}
Expand Down

0 comments on commit 58571de

Please sign in to comment.