diff --git a/stablehlo/dialect/StablehloOps.cpp b/stablehlo/dialect/StablehloOps.cpp index 66b4bc00437..762096bc7a1 100644 --- a/stablehlo/dialect/StablehloOps.cpp +++ b/stablehlo/dialect/StablehloOps.cpp @@ -1837,15 +1837,6 @@ void ReduceWindowOp::build( llvm::report_fatal_error("Failed to infer result type(s)."); } -//===----------------------------------------------------------------------===// -// ReducePrecisionOp -//===----------------------------------------------------------------------===// - -LogicalResult ReducePrecisionOp::verify() { - return hlo::verifyReducePrecisionOp(getLoc(), getExponentBits(), - getMantissaBits()); -} - //===----------------------------------------------------------------------===// // ReduceOp //===----------------------------------------------------------------------===// diff --git a/stablehlo/dialect/StablehloOps.td b/stablehlo/dialect/StablehloOps.td index b5d734a13d4..5a845174de4 100644 --- a/stablehlo/dialect/StablehloOps.td +++ b/stablehlo/dialect/StablehloOps.td @@ -112,7 +112,9 @@ def StableHLO_IotaOp : StableHLO_Op<"iota", [Pure]> { %output = stablehlo.iota dim = 0 : tensor<4x5xi32> ``` }]; - let arguments = (ins I64Attr:$iota_dimension); + let arguments = (ins + ConfinedAttr:$iota_dimension /*iota_c1*/ + ); let results = (outs HLO_StaticShapeIntFpComplexOrQuantizedTensor:$output); @@ -140,7 +142,7 @@ def StableHLO_DynamicIotaOp: StableHLO_ShapedInterfaceOp<"dynamic_iota", [Condit let arguments = (ins HLO_StaticDimensionTensor:$output_shape /*dynamic_iota_i1*/, - I64Attr:$iota_dimension /*dynamic_iota_i2*/ + ConfinedAttr:$iota_dimension /*dynamic_iota_c1, dynamic_iota_i2*/ ); let results = (outs HLO_Tensor:$result); let hasVerifier = 1; @@ -1345,7 +1347,7 @@ def StableHLO_AllGatherOp : StableHLO_Op<"all_gather", let arguments = (ins HLO_Tensor:$operand, /*all_gather_i1*/ - I64Attr:$all_gather_dim, /*all_gather_i2*/ + ConfinedAttr:$all_gather_dim, /*all_gather_c1, all_gather_i2*/ I64ElementsAttr:$replica_groups, /*all_gather_i3*/ OptionalAttr:$channel_handle, /*all_gather_i4*/ UnitAttr:$use_global_device_ids /*all_gather_i5*/ @@ -1423,7 +1425,7 @@ def StableHLO_ReduceScatterOp : StableHLO_Op<"reduce_scatter", [ConditionallySpe let arguments = (ins HLO_Tensor:$operand, /*reduce_scatter_i1*/ - I64Attr:$scatter_dimension, /*reduce_scatter_i2*/ + ConfinedAttr:$scatter_dimension, /*reduce_scatter_c2, reduce_scatter_i2*/ I64ElementsAttr:$replica_groups, /*reduce_scatter_i3*/ OptionalAttr:$channel_handle, /*reduce_scatter_i4*/ UnitAttr:$use_global_device_ids /*reduce_scatter_i5*/ @@ -1465,9 +1467,9 @@ def StableHLO_AllToAllOp : StableHLO_Op<"all_to_all", let arguments = (ins HLO_Tensor:$operand, /*all_to_all_i1*/ - I64Attr:$split_dimension, /*all_to_all_i2*/ - I64Attr:$concat_dimension, /*all_to_all_i3*/ - I64Attr:$split_count, /*all_to_all_i4*/ + ConfinedAttr:$split_dimension, /*all_to_all_c1, all_to_all_i2*/ + ConfinedAttr:$concat_dimension, /*all_to_all_c3, all_to_all_i3*/ + ConfinedAttr:$split_count, /*all_to_all_c4, all_to_all_i4*/ I64ElementsAttr:$replica_groups, /*all_to_all_i5*/ OptionalAttr:$channel_handle /*all_to_all_i6*/ ); @@ -1558,7 +1560,7 @@ def StableHLO_GetTupleElementOp: StableHLO_Op<"get_tuple_element", [Pure, }]; let arguments = (ins HLO_Tuple:$operand, /*get_tuple_element_i1*/ - I32Attr:$index /*get_tuple_element_i2*/ + ConfinedAttr:$index /*get_tuple_element_c1, get_tuple_element_i2*/ ); let results = (outs HLO_TensorOrPerAxisQuantizedTensorOrTokenOrTuple); @@ -1778,7 +1780,7 @@ def StableHLO_BatchNormGradOp : StableHLO_Op<"batch_norm_grad", 1DTensorOf<[HLO_Float, HLO_QuantizedInt]>:$variance, /*batch_norm_grad_i4*/ RankedTensorOf<[HLO_Float, HLO_QuantizedInt]>:$grad_output, /*batch_norm_grad_i5*/ F32Attr:$epsilon, /*batch_norm_grad_i6*/ - I64Attr:$feature_index /*batch_norm_grad_i7*/ + ConfinedAttr:$feature_index /*batch_norm_grad_c1, batch_norm_grad_i7*/ ); let results = (outs @@ -1815,7 +1817,7 @@ def StableHLO_BatchNormInferenceOp : StableHLO_Op<"batch_norm_inference", 1DTensorOf<[HLO_Float, HLO_QuantizedInt]>:$mean /*batch_norm_inference_i4*/, 1DTensorOf<[HLO_Float, HLO_QuantizedInt]>:$variance /*batch_norm_inference_i5*/, F32Attr:$epsilon /*batch_norm_inference_i6*/, - I64Attr:$feature_index /*batch_norm_inference_i7*/ + ConfinedAttr:$feature_index /*batch_norm_inference_c1, batch_norm_inference_i7*/ ); let results = (outs RankedTensorOf<[HLO_Float, HLO_QuantizedInt]>:$result); @@ -1849,7 +1851,7 @@ def StableHLO_BatchNormTrainingOp : StableHLO_Op<"batch_norm_training", 1DTensorOf<[HLO_Float, HLO_QuantizedInt]>:$scale /*batch_norm_training_i2*/, 1DTensorOf<[HLO_Float, HLO_QuantizedInt]>:$offset /*batch_norm_training_i3*/, F32Attr:$epsilon /*batch_norm_training_i4*/, - I64Attr:$feature_index /*batch_norm_training_i5*/ + ConfinedAttr:$feature_index /*batch_norm_training_c1, batch_norm_training_i5*/ ); let results = (outs @@ -2098,7 +2100,7 @@ def StableHLO_ConcatenateOp : StableHLO_ShapedInterfaceOp<"concatenate", let arguments = (ins Variadic:$inputs /*concatenate_i1*/, - I64Attr:$dimension /*concatenate_i2*/ + ConfinedAttr:$dimension /*concatenate_c4, concatenate_i2*/ ); let results = (outs HLO_Tensor); @@ -2274,8 +2276,8 @@ def StableHLO_ConvolutionOp : StableHLO_Op<"convolution", // Default value: false for each of the spatial dimension. OptionalAttr:$window_reversal, /*convolution_i7*/ StableHLO_ConvDimensionNumbers:$dimension_numbers, /*convolution_i8...convolution_i16*/ - I64Attr:$feature_group_count, /*convolution_i17*/ - I64Attr:$batch_group_count, /*convolution_i18*/ + ConfinedAttr:$feature_group_count, /*convolution_c21, convolution_i17*/ + ConfinedAttr:$batch_group_count, /*convolution_c22, convolution_i18*/ StableHLO_PrecisionConfigAttr:$precision_config /*convolution_i19*/ ); @@ -2609,7 +2611,7 @@ def StableHLO_GetDimensionSizeOp: StableHLO_Op<"get_dimension_size", }]; let arguments = (ins HLO_TensorOrPerAxisQuantizedTensor:$operand, /*get_dimension_size_i1*/ - I64Attr:$dimension /*get_dimension_size_i2*/ + ConfinedAttr:$dimension /*get_dimension_size_c1, get_dimension_size_i2*/ ); // TODO(hinsu): Allow 64-bit result types once XLA HLO dialect based on the // XLA semantics is available. This limitation is because of the current XLA @@ -2868,7 +2870,7 @@ def StableHLO_SetDimensionSizeOp: StableHLO_Op<"set_dimension_size", let arguments = (ins HLO_Tensor:$operand, I32RankedTensor:$size, - I64Attr:$dimension + ConfinedAttr:$dimension ); let results = (outs HLO_Tensor); @@ -3376,10 +3378,9 @@ def StableHLO_ReducePrecisionOp : StableHLO_Op<"reduce_precision", }]; let arguments = (ins HLO_FpOrQuantizedIntTensor:$operand, /*reduce_precision_i1*/ - I32Attr:$exponent_bits, /*reduce_precision_i2*/ - I32Attr:$mantissa_bits /*reduce_precision_i3*/ + ConfinedAttr:$exponent_bits, /*reduce_precision_c2, reduce_precision_i2*/ + ConfinedAttr:$mantissa_bits /*reduce_precision_c3, reduce_precision_i3*/ ); - let hasVerifier = 1; let results = (outs HLO_FpOrQuantizedIntTensor:$output); let assemblyFormat = [{ @@ -3535,8 +3536,8 @@ def StableHLO_DynamicConvOp : StableHLO_Op<"dynamic_conv", OptionalAttr:$rhs_dilation, /*dynamic_conv_i6*/ OptionalAttr:$window_reversal, /*dynamic_conv_i7*/ StableHLO_ConvDimensionNumbers:$dimension_numbers, /*dynamic_conv_i8...dynamic_conv_i16*/ - I64Attr:$feature_group_count, /*dynamic_conv_i17*/ - I64Attr:$batch_group_count, /*dynamic_conv_i18*/ + ConfinedAttr:$feature_group_count, /*dynamic_conv_c21, dynamic_conv_i17*/ + ConfinedAttr:$batch_group_count, /*dynamic_conv_c22, dynamic_conv_i18*/ StableHLO_PrecisionConfigAttr:$precision_config /*dynamic_conv_i19*/ ); let results = (outs HLO_Tensor); diff --git a/stablehlo/dialect/TypeInference.cpp b/stablehlo/dialect/TypeInference.cpp index edf7a3f1397..9946479e6fe 100644 --- a/stablehlo/dialect/TypeInference.cpp +++ b/stablehlo/dialect/TypeInference.cpp @@ -682,11 +682,6 @@ LogicalResult verifyBatchNorm(std::optional location, "multi-dimensional operands; got featureIndex ", featureIndex, ", and rank ", multiDimType.getRank(), "."); - // batch_norm_grad_c1, batch_norm_inference_c1, batch_norm_training_c1 - if (featureIndex < 0) - return emitOptionalError(location, "expects featureIndex to be a ", - "non-negative number, got ", featureIndex, "."); - const int64_t featureCount = multiDimType.getDimSize(featureIndex); const int64_t singleDimSize = cast(singleDimOperands[0].getType()).getDimSize(0); @@ -1258,18 +1253,6 @@ LogicalResult verifyConvolutionAttributes( location))) return failure(); - // convolution_c21, dynamic_conv_c21 - if (featureGroupCount <= 0) - return emitOptionalError( - location, "expects feature_group_count to be a positive number, got ", - featureGroupCount, "."); - - // convolution_c22, dynamic_conv_c22 - if (batchGroupCount <= 0) - return emitOptionalError( - location, "expects batch_group_count to be a positive number, got ", - batchGroupCount, "."); - // convolution_c23, dynamic_conv_c23 if (batchGroupCount > 1 && featureGroupCount > 1) return emitOptionalError( @@ -1839,26 +1822,12 @@ LogicalResult inferAllToAllOp( int64_t concatDimension, int64_t splitCount, DenseIntElementsAttr replicaGroups, SmallVectorImpl& inferredReturnShapes) { - // all_to_all_c4 - if (splitCount <= 0) - return emitOptionalError(location, "AllToAll split_count must be > 0"); - // all_to_all_c5, all_to_all_c7, all_to_all_i5 if (failed(verifyReplicaGroups(location, replicaGroups, /*allGroupsMustHaveSameSize=*/true, /*useGlobalDeviceIds=*/false, splitCount))) return failure(); - // all_to_all_c1 - if (splitDimension < 0) - return emitOptionalError(location, - "AllToAll split_dimension cannot be negative"); - - // all_to_all_c3 - if (concatDimension < 0) - return emitOptionalError(location, - "AllToAll concat_dimension cannot be negative"); - Type operandType = operand.getType(); auto operandRankedType = cast(operandType); @@ -2051,10 +2020,6 @@ LogicalResult inferComplexOp(std::optional location, Value lhs, LogicalResult inferConcatenateOp(std::optional location, TypeRange inputTypes, int64_t dimension, SmallVectorImpl& inferredReturnTypes) { - // concatenate_c4 - if (dimension < 0) - return emitOptionalError(location, "dimension ", dimension, " is negative"); - auto witnessType = cast(inputTypes[0]); int64_t rank = witnessType.getRank(); @@ -2876,7 +2841,7 @@ LogicalResult inferGetTupleElementOp( auto operandType = dyn_cast(operand.getType()); if (!operandType) return failure(); // get_tuple_element_c1 - if (index < 0 || index >= static_cast(operandType.size())) + if (index >= static_cast(operandType.size())) return emitOptionalError(location, "index ", index, " is out of bounds of operand with size ", operandType.size()); @@ -3566,10 +3531,6 @@ LogicalResult verifyAllGatherOp(std::optional location, Value operand, auto operandType = cast(operand.getType()); auto resultType = cast(result.getType()); - // all_gather_c1 - if (allGatherDim < 0) - return emitOptionalError(location, "all_gather_dim cannot be negative"); - // all_gather_c1 if (allGatherDim >= operandType.getRank()) return emitOptionalError(location, @@ -4222,10 +4183,9 @@ LogicalResult verifyDynamicIotaOp(std::optional location, auto resultType = cast(result.getType()); // dynamic_iota_c1 - if (iotaDimension >= resultType.getRank() || iotaDimension < 0) + if (iotaDimension >= resultType.getRank()) return emitOptionalError( - location, - "iota dimension cannot go beyond the output rank or be negative."); + location, "iota dimension cannot go beyond the output rank."); // dynamic_iota_c2 if (failed(verifyShapeOperandIsCompatibleWithResultType(location, outputShape, resultType))) @@ -4410,10 +4370,9 @@ LogicalResult verifyIotaOp(std::optional location, if (shape.getRank() == 0) return emitOptionalError(location, "does not support scalars."); - if (iotaDimension >= shape.getRank() || iotaDimension < 0) + if (iotaDimension >= shape.getRank()) return emitOptionalError( - location, - "iota dimension cannot go beyond the output rank or be negative."); + location, "iota dimension cannot go beyond the output rank."); return success(); } @@ -4508,18 +4467,6 @@ LogicalResult verifyReduceOp(std::optional location, return success(); } -LogicalResult verifyReducePrecisionOp(std::optional location, - int32_t exponentBits, - int32_t mantissaBits) { - // reduce_precision_c2 - if (exponentBits < 1) - return emitOptionalError(location, "exponent_bits must be at least 1."); - // reduce_precision_c3 - if (mantissaBits < 0) - return emitOptionalError(location, "mantissa_bits must be at least 0."); - return success(); -} - LogicalResult verifyReduceScatterOp(std::optional location, Value operand, int64_t scatterDimension, DenseIntElementsAttr replicaGroups, @@ -4544,10 +4491,6 @@ LogicalResult verifyReduceScatterOp(std::optional location, return emitOptionalError(location, "operand and result should have same rank"); - // reduce_scatter_c2 - if (scatterDimension < 0) - return emitOptionalError(location, "expects scatter_dimension >= 0"); - // reduce_scatter_c2 if (scatterDimension >= operandType.getRank()) return emitOptionalError( diff --git a/stablehlo/dialect/TypeInference.h b/stablehlo/dialect/TypeInference.h index fbf601aff3e..0a2b8334987 100644 --- a/stablehlo/dialect/TypeInference.h +++ b/stablehlo/dialect/TypeInference.h @@ -515,10 +515,6 @@ LogicalResult verifyReduceOpInputsAndInferShape( ArrayRef dimensions, SmallVector& newDimensions, Attribute& encoding); -LogicalResult verifyReducePrecisionOp(std::optional location, - int32_t exponentBits, - int32_t mantissaBits); - LogicalResult verifyReduceScatterOp(std::optional location, Value operand, int64_t scatterDimension, DenseIntElementsAttr replicaGroups, diff --git a/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/tests/ops_stablehlo.mlir index d8f6f563aa4..8b526ed35de 100644 --- a/stablehlo/tests/ops_stablehlo.mlir +++ b/stablehlo/tests/ops_stablehlo.mlir @@ -349,7 +349,7 @@ func.func @reduce_scatter_with_promotable_quantized_types( // ----- func.func @reduce_scatter_c2(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { - // expected-error@+1 {{expects scatter_dimension >= 0}} + // expected-error@+1 {{op attribute 'scatter_dimension' failed to satisfy constraint: 64-bit signless integer attribute whose value is non-negative}} %0 = "stablehlo.reduce_scatter"(%data) ({ ^bb0(%arg2: tensor, %arg3: tensor): %1 = stablehlo.add %arg2, %arg3 : tensor @@ -702,8 +702,7 @@ func.func @all_to_all_dynamic_concat_dim(%data: tensor) -> tensor) -> tensor<16x4xf32> { - // expected-error@+2 {{failed to infer returned types}} - // expected-error@+1 {{AllToAll split_dimension cannot be negative}} + // expected-error@+1 {{op attribute 'split_dimension' failed to satisfy constraint: 64-bit signless integer attribute whose value is non-negative}} %0 = "stablehlo.all_to_all"(%data) { split_dimension = -1 : i64, concat_dimension = 0 : i64, @@ -744,8 +743,7 @@ func.func @all_to_all_c2(%data: tensor<4x16xf32>) -> tensor<16x4xf32> { // ----- func.func @all_to_all_c3(%data: tensor<4x16xf32>) -> tensor<16x4xf32> { - // expected-error@+2 {{failed to infer returned types}} - // expected-error@+1 {{AllToAll concat_dimension cannot be negative}} + // expected-error@+1 {{op attribute 'concat_dimension' failed to satisfy constraint: 64-bit signless integer attribute whose value is non-negative}} %0 = "stablehlo.all_to_all"(%data) { split_dimension = 1 : i64, concat_dimension = -1 : i64, @@ -772,8 +770,7 @@ func.func @all_to_all_c3(%data: tensor<4x16xf32>) -> tensor<16x4xf32> { // ----- func.func @all_to_all_c4(%data: tensor<4x16xf32>) -> tensor<16x4xf32> { - // expected-error@+2 {{failed to infer returned types}} - // expected-error@+1 {{AllToAll split_count must be > 0}} + // expected-error@+1 {{op attribute 'split_count' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive}} %0 = "stablehlo.all_to_all"(%data) { split_dimension = 1 : i64, concat_dimension = 0 : i64, @@ -854,7 +851,7 @@ func.func @allgather_gather_along_zero_dimension(%arg0: tensor<128x0xf32>) -> te // ----- func.func @all_gather_c1(%arg0: tensor<8x2xf32>) -> tensor<8x8xf32> { - // expected-error@+1 {{all_gather_dim cannot be negative}} + // expected-error@+1 {{op attribute 'all_gather_dim' failed to satisfy constraint: 64-bit signless integer attribute whose value is non-negative}} %0 = "stablehlo.all_gather"(%arg0) { all_gather_dim = -1 : i64, channel_handle = #stablehlo.channel_handle, @@ -1558,23 +1555,13 @@ func.func @concatenate_c3() -> tensor<2xi32> { // ----- func.func @concatenate_c4(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<3xi32> { - // expected-error@+2 {{failed to infer returned types}} - // expected-error@+1 {{dimension -1 is negative}} + // expected-error@+1 {{op attribute 'dimension' failed to satisfy constraint: 64-bit signless integer attribute whose value is non-negative}} %0 = "stablehlo.concatenate"(%arg0, %arg1) { dimension = -1 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<3xi32> func.return %0 : tensor<3xi32> } // ----- -func.func @concatenate_c4(%arg0: tensor, %arg1: tensor) -> tensor { - // expected-error@+2 {{failed to infer returned types}} - // expected-error@+1 {{dimension -1 is negative}} - %0 = "stablehlo.concatenate"(%arg0, %arg1) { dimension = -1 : i64 } : (tensor, tensor) -> tensor - func.return %0 : tensor -} - -// ----- - func.func @concatenate_c4(%arg0: tensor, %arg1: tensor) -> tensor<2xi32> { // expected-error@+2 {{failed to infer returned types}} // expected-error@+1 {{rank-0 values cannot be concatenated}} @@ -1838,7 +1825,7 @@ func.func @iota_scalar() -> tensor { // ----- func.func @iota_invalid_iota_dimension() -> tensor<4xi32> { - // expected-error@+1 {{iota dimension cannot go beyond the output rank or be negative}} + // expected-error@+1 {{iota dimension cannot go beyond the output rank}} %0 = "stablehlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<4xi32> func.return %0 : tensor<4xi32> } @@ -2733,8 +2720,7 @@ func.func @get_tuple_element(%arg0: tuple, tensor>) -> tensor, tensor>) -> tensor { - // expected-error@+2 {{failed to infer returned types}} - // expected-error@+1 {{index -1 is out of bounds of operand with size 2}} + // expected-error@+1 {{op attribute 'index' failed to satisfy constraint: 32-bit signless integer attribute whose value is non-negative}} %0 = "stablehlo.get_tuple_element"(%arg0) {index = -1 : i32} : (tuple, tensor>) -> tensor func.return %0 : tensor } @@ -3457,7 +3443,7 @@ func.func @reduce_precision(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> { // ----- func.func @reduce_precision_c2(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> { - // expected-error @+1 {{exponent_bits must be at least 1.}} + // expected-error @+1 {{op attribute 'exponent_bits' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}} %0 = "stablehlo.reduce_precision"(%arg) {exponent_bits=0 : i32, mantissa_bits=3 : i32} : (tensor<2x4xf32>) -> tensor<2x4xf32> func.return %0 : tensor<2x4xf32> } @@ -3465,7 +3451,7 @@ func.func @reduce_precision_c2(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> { // ----- func.func @reduce_precision_c3(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> { - // expected-error @+1 {{mantissa_bits must be at least 0.}} + // expected-error @+1 {{op attribute 'mantissa_bits' failed to satisfy constraint: 32-bit signless integer attribute whose value is non-negative}} %0 = "stablehlo.reduce_precision"(%arg) {exponent_bits=1 : i32, mantissa_bits=-1 : i32} : (tensor<2x4xf32>) -> tensor<2x4xf32> func.return %0 : tensor<2x4xf32> } @@ -4680,8 +4666,7 @@ func.func @get_dimension_size(%I: tensor<1x128x512xf32>) -> tensor { // ----- func.func @get_dimension_size_c1(%I: tensor<1x128x512xf32>) -> tensor { - // expected-error@+2 {{failed to infer returned types}} - // expected-error@+1 {{requires non-negative dimension attribute; found (-1)}} + // expected-error@+1 {{op attribute 'dimension' failed to satisfy constraint: 64-bit signless integer attribute whose value is non-negative}} %size = "stablehlo.get_dimension_size"(%I) {dimension = -1 : i64} : (tensor<1x128x512xf32>) -> tensor func.return %size : tensor } @@ -4710,8 +4695,7 @@ func.func @set_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<1x128x512xf32 func.func @set_dimension_size_negative_dimension(%I: tensor<1x128x512xf32>) -> tensor<1x128x512xf32> { %dim = stablehlo.constant dense<512> : tensor - // expected-error@+2 {{failed to infer returned types}} - // expected-error@+1 {{requires non-negative dimension attribute; found (-1)}} + // expected-error@+1 {{op attribute 'dimension' failed to satisfy constraint: 64-bit signless integer attribute whose value is non-negative}} %result = "stablehlo.set_dimension_size"(%I, %dim) {dimension =-1 : i64} : (tensor<1x128x512xf32>, tensor) -> tensor<1x128x512xf32> func.return %result : tensor<1x128x512xf32> } @@ -5133,8 +5117,7 @@ func.func @batch_norm_training_dynamic(%input: tensor, %scale: tens // ----- func.func @batch_norm_training_c1(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %offset: tensor<2xf32>) -> tensor<2x2x2x2xf32> { - // expected-error@+2 {{failed to infer returned types}} - // expected-error@+1 {{expects featureIndex to be a non-negative number, got -1.}} + // expected-error@+1 {{op attribute 'feature_index' failed to satisfy constraint: 64-bit signless integer attribute whose value is non-negative}} %0:3 = "stablehlo.batch_norm_training" (%input, %scale, %offset) { epsilon = 0.001 : f32, feature_index = -1 : i64 @@ -5196,8 +5179,7 @@ func.func @batch_norm_inference_dynamic(%input: tensor<4x?xf32>, %scale: tensor< // ----- func.func @batch_norm_inference_c1(%input: tensor<4x256xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>, %mean: tensor<256xf32>, %variance: tensor<256xf32>) -> (tensor<4x256xf32>) { - // expected-error@+2 {{failed to infer returned types}} - // expected-error@+1 {{expects featureIndex to be a non-negative number, got -1.}} + // expected-error@+1 {{op attribute 'feature_index' failed to satisfy constraint: 64-bit signless integer attribute whose value is non-negative}} %0 = "stablehlo.batch_norm_inference" (%input, %scale, %offset, %mean, %variance) { epsilon = 1.001000e-05 : f32, feature_index = -1 : i64 @@ -5249,8 +5231,7 @@ func.func @batch_norm_grad_dynamic(%input: tensor, %scale: tensor, %scale: tensor<2xf32>, %mean: tensor<2xf32>, %variance: tensor<2xf32>, %grad_output: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> { - // expected-error@+2 {{failed to infer returned types}} - // expected-error@+1 {{expects featureIndex to be a non-negative number, got -1.}} + // expected-error@+1 {{op attribute 'feature_index' failed to satisfy constraint: 64-bit signless integer attribute whose value is non-negative}} %0:3 = "stablehlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) {epsilon = 0.001 : f32, feature_index = -1 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2x2x2x2xf32>) -> (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>) func.return %0#0 : tensor<2x2x2x2xf32> } @@ -6006,8 +5987,7 @@ func.func @is_finite_mismatch_return_shape(%arg0: tensor<3xf32>) -> tensor<4xi1> // ----- func.func @negative_dimension_attr(%arg0: tensor>, %arg1: tensor) -> tensor { - // expected-error@+2 {{failed to infer returned types}} - // expected-error@+1 {{requires non-negative dimension attribute; found (-1)}} + // expected-error@+1 {{op attribute 'dimension' failed to satisfy constraint: 64-bit signless integer attribute whose value is non-negative}} %result = "stablehlo.set_dimension_size"(%arg0, %arg1) {dimension = -1 : i64} : (tensor>, tensor) -> tensor func.return %result : tensor } @@ -6080,7 +6060,7 @@ func.func @dynamic_iota_dynamic() -> tensor { // ----- func.func @dynamic_iota_invalid_iota_dimension_negative() -> tensor { - // expected-error@+2 {{iota dimension cannot go beyond the output rank or be negative}} + // expected-error@+2 {{op attribute 'iota_dimension' failed to satisfy constraint: 64-bit signless integer attribute whose value is non-negative}} %0 = stablehlo.constant dense<[4]> : tensor<1xi64> %1 = stablehlo.dynamic_iota %0, dim = -1 : (tensor<1xi64>) -> tensor func.return %1 : tensor @@ -6090,7 +6070,7 @@ func.func @dynamic_iota_invalid_iota_dimension_negative() -> tensor { func.func @dynamic_iota_invalid_iota_dimension_too_big() -> tensor { %0 = stablehlo.constant dense<[4]> : tensor<1xi64> - // expected-error@+1 {{iota dimension cannot go beyond the output rank or be negative}} + // expected-error@+1 {{iota dimension cannot go beyond the output rank}} %1 = stablehlo.dynamic_iota %0, dim = 2 : (tensor<1xi64>) -> tensor func.return %1 : tensor } diff --git a/stablehlo/tests/verify_convolution.mlir b/stablehlo/tests/verify_convolution.mlir index bdfc6e69066..c50382caf4a 100644 --- a/stablehlo/tests/verify_convolution.mlir +++ b/stablehlo/tests/verify_convolution.mlir @@ -737,7 +737,7 @@ func.func @convolution_c20(%arg0 : tensor<100x26x26x32xf32>, func.func @convolution_c21(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { - // expected-error@+1 {{expects feature_group_count to be a positive number, got 0.}} + // expected-error@+5 {{op attribute 'feature_group_count' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive}} %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], @@ -755,7 +755,7 @@ func.func @convolution_c21(%arg0: tensor<1x8x8x207xf32>, func.func @convolution_c22(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { - // expected-error@+1 {{expects batch_group_count to be a positive number, got 0.}} + // expected-error@+5 {{op attribute 'batch_group_count' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive}} %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], diff --git a/stablehlo/tests/verify_dynamic_conv.mlir b/stablehlo/tests/verify_dynamic_conv.mlir index 8ae0660f51c..7e0a156804b 100644 --- a/stablehlo/tests/verify_dynamic_conv.mlir +++ b/stablehlo/tests/verify_dynamic_conv.mlir @@ -575,7 +575,7 @@ func.func @dynamic_conv_c20(%arg0 : tensor<100x26x26x32xf32>, func.func @dynamic_conv_c21(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { - // expected-error@+2 {{expects feature_group_count to be a positive number, got 0.}} + // expected-error@+2 {{op attribute 'feature_group_count' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive}} %padding = stablehlo.constant dense<0> : tensor<2x2xi64> %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %padding) { dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, @@ -590,7 +590,7 @@ func.func @dynamic_conv_c21(%arg0: tensor<1x8x8x207xf32>, func.func @dynamic_conv_c22(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { - // expected-error@+2 {{expects batch_group_count to be a positive number, got 0.}} + // expected-error@+2 {{op attribute 'batch_group_count' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive}} %padding = stablehlo.constant dense<0> : tensor<2x2xi64> %result = "stablehlo.dynamic_conv"(%arg0, %arg1, %padding) { dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,