Skip to content

Commit

Permalink
Use ConfinedAttr to verify attr range is positve or nonnegative (#2416)
Browse files Browse the repository at this point in the history
I recently came across [confining
attributes](https://mlir.llvm.org/docs/DefiningDialects/Operations/#confining-attributes),
and it's a nice feature that we should leverage to maximize ODS
strengths.

I'm still looking into whether we can use them for array attributes. If
possible, we may be able to remove tons of manual verification.
  • Loading branch information
ghpvnist authored Jun 26, 2024
1 parent 6b69e21 commit 007c059
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 138 deletions.
9 changes: 0 additions & 9 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1837,15 +1837,6 @@ void ReduceWindowOp::build(
llvm::report_fatal_error("Failed to infer result type(s).");
}

//===----------------------------------------------------------------------===//
// ReducePrecisionOp
//===----------------------------------------------------------------------===//

LogicalResult ReducePrecisionOp::verify() {
return hlo::verifyReducePrecisionOp(getLoc(), getExponentBits(),
getMantissaBits());
}

//===----------------------------------------------------------------------===//
// ReduceOp
//===----------------------------------------------------------------------===//
Expand Down
43 changes: 22 additions & 21 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ def StableHLO_IotaOp : StableHLO_Op<"iota", [Pure]> {
%output = stablehlo.iota dim = 0 : tensor<4x5xi32>
```
}];
let arguments = (ins I64Attr:$iota_dimension);
let arguments = (ins
ConfinedAttr<I64Attr, [IntNonNegative]>:$iota_dimension /*iota_c1*/
);

let results = (outs HLO_StaticShapeIntFpComplexOrQuantizedTensor:$output);

Expand Down Expand Up @@ -140,7 +142,7 @@ def StableHLO_DynamicIotaOp: StableHLO_ShapedInterfaceOp<"dynamic_iota", [Condit

let arguments = (ins
HLO_StaticDimensionTensor:$output_shape /*dynamic_iota_i1*/,
I64Attr:$iota_dimension /*dynamic_iota_i2*/
ConfinedAttr<I64Attr, [IntNonNegative]>:$iota_dimension /*dynamic_iota_c1, dynamic_iota_i2*/
);
let results = (outs HLO_Tensor:$result);
let hasVerifier = 1;
Expand Down Expand Up @@ -1345,7 +1347,7 @@ def StableHLO_AllGatherOp : StableHLO_Op<"all_gather",

let arguments = (ins
HLO_Tensor:$operand, /*all_gather_i1*/
I64Attr:$all_gather_dim, /*all_gather_i2*/
ConfinedAttr<I64Attr, [IntNonNegative]>:$all_gather_dim, /*all_gather_c1, all_gather_i2*/
I64ElementsAttr:$replica_groups, /*all_gather_i3*/
OptionalAttr<StableHLO_ChannelHandle>:$channel_handle, /*all_gather_i4*/
UnitAttr:$use_global_device_ids /*all_gather_i5*/
Expand Down Expand Up @@ -1423,7 +1425,7 @@ def StableHLO_ReduceScatterOp : StableHLO_Op<"reduce_scatter", [ConditionallySpe

let arguments = (ins
HLO_Tensor:$operand, /*reduce_scatter_i1*/
I64Attr:$scatter_dimension, /*reduce_scatter_i2*/
ConfinedAttr<I64Attr, [IntNonNegative]>:$scatter_dimension, /*reduce_scatter_c2, reduce_scatter_i2*/
I64ElementsAttr:$replica_groups, /*reduce_scatter_i3*/
OptionalAttr<StableHLO_ChannelHandle>:$channel_handle, /*reduce_scatter_i4*/
UnitAttr:$use_global_device_ids /*reduce_scatter_i5*/
Expand Down Expand Up @@ -1465,9 +1467,9 @@ def StableHLO_AllToAllOp : StableHLO_Op<"all_to_all",

let arguments = (ins
HLO_Tensor:$operand, /*all_to_all_i1*/
I64Attr:$split_dimension, /*all_to_all_i2*/
I64Attr:$concat_dimension, /*all_to_all_i3*/
I64Attr:$split_count, /*all_to_all_i4*/
ConfinedAttr<I64Attr, [IntNonNegative]>:$split_dimension, /*all_to_all_c1, all_to_all_i2*/
ConfinedAttr<I64Attr, [IntNonNegative]>:$concat_dimension, /*all_to_all_c3, all_to_all_i3*/
ConfinedAttr<I64Attr, [IntPositive]>:$split_count, /*all_to_all_c4, all_to_all_i4*/
I64ElementsAttr:$replica_groups, /*all_to_all_i5*/
OptionalAttr<StableHLO_ChannelHandle>:$channel_handle /*all_to_all_i6*/
);
Expand Down Expand Up @@ -1558,7 +1560,7 @@ def StableHLO_GetTupleElementOp: StableHLO_Op<"get_tuple_element", [Pure,
}];
let arguments = (ins
HLO_Tuple:$operand, /*get_tuple_element_i1*/
I32Attr:$index /*get_tuple_element_i2*/
ConfinedAttr<I32Attr, [IntNonNegative]>:$index /*get_tuple_element_c1, get_tuple_element_i2*/
);

let results = (outs HLO_TensorOrPerAxisQuantizedTensorOrTokenOrTuple);
Expand Down Expand Up @@ -1778,7 +1780,7 @@ def StableHLO_BatchNormGradOp : StableHLO_Op<"batch_norm_grad",
1DTensorOf<[HLO_Float, HLO_QuantizedInt]>:$variance, /*batch_norm_grad_i4*/
RankedTensorOf<[HLO_Float, HLO_QuantizedInt]>:$grad_output, /*batch_norm_grad_i5*/
F32Attr:$epsilon, /*batch_norm_grad_i6*/
I64Attr:$feature_index /*batch_norm_grad_i7*/
ConfinedAttr<I64Attr, [IntNonNegative]>:$feature_index /*batch_norm_grad_c1, batch_norm_grad_i7*/
);

let results = (outs
Expand Down Expand Up @@ -1815,7 +1817,7 @@ def StableHLO_BatchNormInferenceOp : StableHLO_Op<"batch_norm_inference",
1DTensorOf<[HLO_Float, HLO_QuantizedInt]>:$mean /*batch_norm_inference_i4*/,
1DTensorOf<[HLO_Float, HLO_QuantizedInt]>:$variance /*batch_norm_inference_i5*/,
F32Attr:$epsilon /*batch_norm_inference_i6*/,
I64Attr:$feature_index /*batch_norm_inference_i7*/
ConfinedAttr<I64Attr, [IntNonNegative]>:$feature_index /*batch_norm_inference_c1, batch_norm_inference_i7*/
);

let results = (outs RankedTensorOf<[HLO_Float, HLO_QuantizedInt]>:$result);
Expand Down Expand Up @@ -1849,7 +1851,7 @@ def StableHLO_BatchNormTrainingOp : StableHLO_Op<"batch_norm_training",
1DTensorOf<[HLO_Float, HLO_QuantizedInt]>:$scale /*batch_norm_training_i2*/,
1DTensorOf<[HLO_Float, HLO_QuantizedInt]>:$offset /*batch_norm_training_i3*/,
F32Attr:$epsilon /*batch_norm_training_i4*/,
I64Attr:$feature_index /*batch_norm_training_i5*/
ConfinedAttr<I64Attr, [IntNonNegative]>:$feature_index /*batch_norm_training_c1, batch_norm_training_i5*/
);

let results = (outs
Expand Down Expand Up @@ -2098,7 +2100,7 @@ def StableHLO_ConcatenateOp : StableHLO_ShapedInterfaceOp<"concatenate",

let arguments = (ins
Variadic<HLO_Tensor>:$inputs /*concatenate_i1*/,
I64Attr:$dimension /*concatenate_i2*/
ConfinedAttr<I64Attr, [IntNonNegative]>:$dimension /*concatenate_c4, concatenate_i2*/
);

let results = (outs HLO_Tensor);
Expand Down Expand Up @@ -2274,8 +2276,8 @@ def StableHLO_ConvolutionOp : StableHLO_Op<"convolution",
// Default value: false for each of the spatial dimension.
OptionalAttr<GenericDenseBoolArrayAttr>:$window_reversal, /*convolution_i7*/
StableHLO_ConvDimensionNumbers:$dimension_numbers, /*convolution_i8...convolution_i16*/
I64Attr:$feature_group_count, /*convolution_i17*/
I64Attr:$batch_group_count, /*convolution_i18*/
ConfinedAttr<I64Attr, [IntPositive]>:$feature_group_count, /*convolution_c21, convolution_i17*/
ConfinedAttr<I64Attr, [IntPositive]>:$batch_group_count, /*convolution_c22, convolution_i18*/
StableHLO_PrecisionConfigAttr:$precision_config /*convolution_i19*/
);

Expand Down Expand Up @@ -2609,7 +2611,7 @@ def StableHLO_GetDimensionSizeOp: StableHLO_Op<"get_dimension_size",
}];
let arguments = (ins
HLO_TensorOrPerAxisQuantizedTensor:$operand, /*get_dimension_size_i1*/
I64Attr:$dimension /*get_dimension_size_i2*/
ConfinedAttr<I64Attr, [IntNonNegative]>:$dimension /*get_dimension_size_c1, get_dimension_size_i2*/
);
// TODO(hinsu): Allow 64-bit result types once XLA HLO dialect based on the
// XLA semantics is available. This limitation is because of the current XLA
Expand Down Expand Up @@ -2868,7 +2870,7 @@ def StableHLO_SetDimensionSizeOp: StableHLO_Op<"set_dimension_size",
let arguments = (ins
HLO_Tensor:$operand,
I32RankedTensor:$size,
I64Attr:$dimension
ConfinedAttr<I64Attr, [IntNonNegative]>:$dimension
);
let results = (outs HLO_Tensor);

Expand Down Expand Up @@ -3376,10 +3378,9 @@ def StableHLO_ReducePrecisionOp : StableHLO_Op<"reduce_precision",
}];
let arguments = (ins
HLO_FpOrQuantizedIntTensor:$operand, /*reduce_precision_i1*/
I32Attr:$exponent_bits, /*reduce_precision_i2*/
I32Attr:$mantissa_bits /*reduce_precision_i3*/
ConfinedAttr<I32Attr, [IntPositive]>:$exponent_bits, /*reduce_precision_c2, reduce_precision_i2*/
ConfinedAttr<I32Attr, [IntNonNegative]>:$mantissa_bits /*reduce_precision_c3, reduce_precision_i3*/
);
let hasVerifier = 1;
let results = (outs HLO_FpOrQuantizedIntTensor:$output);

let assemblyFormat = [{
Expand Down Expand Up @@ -3535,8 +3536,8 @@ def StableHLO_DynamicConvOp : StableHLO_Op<"dynamic_conv",
OptionalAttr<GenericDenseI64ArrayAttr>:$rhs_dilation, /*dynamic_conv_i6*/
OptionalAttr<GenericDenseBoolArrayAttr>:$window_reversal, /*dynamic_conv_i7*/
StableHLO_ConvDimensionNumbers:$dimension_numbers, /*dynamic_conv_i8...dynamic_conv_i16*/
I64Attr:$feature_group_count, /*dynamic_conv_i17*/
I64Attr:$batch_group_count, /*dynamic_conv_i18*/
ConfinedAttr<I64Attr, [IntPositive]>:$feature_group_count, /*dynamic_conv_c21, dynamic_conv_i17*/
ConfinedAttr<I64Attr, [IntPositive]>:$batch_group_count, /*dynamic_conv_c22, dynamic_conv_i18*/
StableHLO_PrecisionConfigAttr:$precision_config /*dynamic_conv_i19*/
);
let results = (outs HLO_Tensor);
Expand Down
67 changes: 5 additions & 62 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -682,11 +682,6 @@ LogicalResult verifyBatchNorm(std::optional<Location> location,
"multi-dimensional operands; got featureIndex ",
featureIndex, ", and rank ", multiDimType.getRank(), ".");

// batch_norm_grad_c1, batch_norm_inference_c1, batch_norm_training_c1
if (featureIndex < 0)
return emitOptionalError(location, "expects featureIndex to be a ",
"non-negative number, got ", featureIndex, ".");

const int64_t featureCount = multiDimType.getDimSize(featureIndex);
const int64_t singleDimSize =
cast<RankedTensorType>(singleDimOperands[0].getType()).getDimSize(0);
Expand Down Expand Up @@ -1258,18 +1253,6 @@ LogicalResult verifyConvolutionAttributes(
location)))
return failure();

// convolution_c21, dynamic_conv_c21
if (featureGroupCount <= 0)
return emitOptionalError(
location, "expects feature_group_count to be a positive number, got ",
featureGroupCount, ".");

// convolution_c22, dynamic_conv_c22
if (batchGroupCount <= 0)
return emitOptionalError(
location, "expects batch_group_count to be a positive number, got ",
batchGroupCount, ".");

// convolution_c23, dynamic_conv_c23
if (batchGroupCount > 1 && featureGroupCount > 1)
return emitOptionalError(
Expand Down Expand Up @@ -1839,26 +1822,12 @@ LogicalResult inferAllToAllOp(
int64_t concatDimension, int64_t splitCount,
DenseIntElementsAttr replicaGroups,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
// all_to_all_c4
if (splitCount <= 0)
return emitOptionalError(location, "AllToAll split_count must be > 0");

// all_to_all_c5, all_to_all_c7, all_to_all_i5
if (failed(verifyReplicaGroups(location, replicaGroups,
/*allGroupsMustHaveSameSize=*/true,
/*useGlobalDeviceIds=*/false, splitCount)))
return failure();

// all_to_all_c1
if (splitDimension < 0)
return emitOptionalError(location,
"AllToAll split_dimension cannot be negative");

// all_to_all_c3
if (concatDimension < 0)
return emitOptionalError(location,
"AllToAll concat_dimension cannot be negative");

Type operandType = operand.getType();
auto operandRankedType = cast<RankedTensorType>(operandType);

Expand Down Expand Up @@ -2051,10 +2020,6 @@ LogicalResult inferComplexOp(std::optional<Location> location, Value lhs,
LogicalResult inferConcatenateOp(std::optional<Location> location,
TypeRange inputTypes, int64_t dimension,
SmallVectorImpl<Type>& inferredReturnTypes) {
// concatenate_c4
if (dimension < 0)
return emitOptionalError(location, "dimension ", dimension, " is negative");

auto witnessType = cast<RankedTensorType>(inputTypes[0]);
int64_t rank = witnessType.getRank();

Expand Down Expand Up @@ -2876,7 +2841,7 @@ LogicalResult inferGetTupleElementOp(
auto operandType = dyn_cast<TupleType>(operand.getType());
if (!operandType) return failure();
// get_tuple_element_c1
if (index < 0 || index >= static_cast<int64_t>(operandType.size()))
if (index >= static_cast<int64_t>(operandType.size()))
return emitOptionalError(location, "index ", index,
" is out of bounds of operand with size ",
operandType.size());
Expand Down Expand Up @@ -3566,10 +3531,6 @@ LogicalResult verifyAllGatherOp(std::optional<Location> location, Value operand,
auto operandType = cast<RankedTensorType>(operand.getType());
auto resultType = cast<RankedTensorType>(result.getType());

// all_gather_c1
if (allGatherDim < 0)
return emitOptionalError(location, "all_gather_dim cannot be negative");

// all_gather_c1
if (allGatherDim >= operandType.getRank())
return emitOptionalError(location,
Expand Down Expand Up @@ -4222,10 +4183,9 @@ LogicalResult verifyDynamicIotaOp(std::optional<Location> location,
auto resultType = cast<ShapedType>(result.getType());

// dynamic_iota_c1
if (iotaDimension >= resultType.getRank() || iotaDimension < 0)
if (iotaDimension >= resultType.getRank())
return emitOptionalError(
location,
"iota dimension cannot go beyond the output rank or be negative.");
location, "iota dimension cannot go beyond the output rank.");
// dynamic_iota_c2
if (failed(verifyShapeOperandIsCompatibleWithResultType(location, outputShape,
resultType)))
Expand Down Expand Up @@ -4410,10 +4370,9 @@ LogicalResult verifyIotaOp(std::optional<Location> location,
if (shape.getRank() == 0)
return emitOptionalError(location, "does not support scalars.");

if (iotaDimension >= shape.getRank() || iotaDimension < 0)
if (iotaDimension >= shape.getRank())
return emitOptionalError(
location,
"iota dimension cannot go beyond the output rank or be negative.");
location, "iota dimension cannot go beyond the output rank.");
return success();
}

Expand Down Expand Up @@ -4508,18 +4467,6 @@ LogicalResult verifyReduceOp(std::optional<Location> location,
return success();
}

LogicalResult verifyReducePrecisionOp(std::optional<Location> location,
int32_t exponentBits,
int32_t mantissaBits) {
// reduce_precision_c2
if (exponentBits < 1)
return emitOptionalError(location, "exponent_bits must be at least 1.");
// reduce_precision_c3
if (mantissaBits < 0)
return emitOptionalError(location, "mantissa_bits must be at least 0.");
return success();
}

LogicalResult verifyReduceScatterOp(std::optional<Location> location,
Value operand, int64_t scatterDimension,
DenseIntElementsAttr replicaGroups,
Expand All @@ -4544,10 +4491,6 @@ LogicalResult verifyReduceScatterOp(std::optional<Location> location,
return emitOptionalError(location,
"operand and result should have same rank");

// reduce_scatter_c2
if (scatterDimension < 0)
return emitOptionalError(location, "expects scatter_dimension >= 0");

// reduce_scatter_c2
if (scatterDimension >= operandType.getRank())
return emitOptionalError(
Expand Down
4 changes: 0 additions & 4 deletions stablehlo/dialect/TypeInference.h
Original file line number Diff line number Diff line change
Expand Up @@ -515,10 +515,6 @@ LogicalResult verifyReduceOpInputsAndInferShape(
ArrayRef<int64_t> dimensions, SmallVector<int64_t>& newDimensions,
Attribute& encoding);

LogicalResult verifyReducePrecisionOp(std::optional<Location> location,
int32_t exponentBits,
int32_t mantissaBits);

LogicalResult verifyReduceScatterOp(std::optional<Location> location,
Value operand, int64_t scatterDimension,
DenseIntElementsAttr replicaGroups,
Expand Down
Loading

0 comments on commit 007c059

Please sign in to comment.