Skip to content

Commit

Permalink
Remove type-inference dependency while creating qdq pattarns (#2460)
Browse files Browse the repository at this point in the history
[ParentPR](#2459)

Previously while creating the QDQ pattern we use `create` API without
using the result type and hence reply on the type inference to derive
the result type. That works for most of the element-wise operations,
however, for dot_general and convolution the result type can might be
infeasible to infer in the presense of input quantize types.

The PR fixes that.

Note to the reviewers: To may just focus on the very last commit of the
chain. The rest is coming from parent PR.

[ChildPR](#2461)
  • Loading branch information
sdasgup3 authored Jul 26, 2024
1 parent 4286b80 commit 9655fab
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 6 deletions.
6 changes: 2 additions & 4 deletions stablehlo/transforms/StablehloLegalizeQuantToMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,17 +144,15 @@ void getQuantizationStorageInfo(OpBuilder &builder, Location loc,
static_cast<float>(quantType.getStorageTypeMax())));
}

Type getQuantStorageType(QuantType type) { return type.getStorageType(); }

// Extracts storage type of a UQ type. Return original type if it is no UQ type.
// Extracts storage type of a UQ type, preserving its shape.
Type getQuantStorageType(Type type) {
if (auto shaped = dyn_cast<ShapedType>(type)) {
return shaped.clone(getQuantStorageType(shaped.getElementType()));
}

auto quantizedType = getQuantType(type);
if (succeeded(quantizedType)) {
return getQuantStorageType(*quantizedType);
return quantizedType->getStorageType();
}
return type;
}
Expand Down
30 changes: 28 additions & 2 deletions stablehlo/transforms/StablehloLegalizeQuantizedOpToQDQ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,28 @@ bool isAnyQuantizedTypes(TypeRange types) {
});
}

// Gets the QuantizedType associated with the given type, or returns failure if
// not quantized.
FailureOr<quant::QuantizedType> getQuantType(Type type) {
if (auto quantType =
dyn_cast<quant::QuantizedType>(getElementTypeOrSelf(type)))
return quantType;
return failure();
}

// Extracts expressed type of a uniform quantized type, preserving its shape.
Type getQuantExpressedType(Type type) {
if (auto shaped = dyn_cast<ShapedType>(type)) {
return shaped.clone(getQuantExpressedType(shaped.getElementType()));
}

auto quantizedType = getQuantType(type);
if (succeeded(quantizedType)) {
return quantizedType->getExpressedType();
}
return type;
}

template <typename StablehloOpType>
struct QuantizedStablehloOpConversion
: public OpRewritePattern<StablehloOpType> {
Expand All @@ -59,10 +81,13 @@ struct QuantizedStablehloOpConversion
}

auto origOp = op.getOperation();
SmallVector<Type> newResultTypes =
llvm::map_to_vector(origOp->getResultTypes(),
[](Type t) { return getQuantExpressedType(t); });
auto origAttrs = origOp->getAttrs();
auto newOp = rewriter
.create<StablehloOpType>(op.getLoc(), dequantizedOperands,
origAttrs)
.create<StablehloOpType>(op.getLoc(), newResultTypes,
dequantizedOperands, origAttrs)
.getOperation();

SmallVector<Value> quantizedResults;
Expand All @@ -77,6 +102,7 @@ struct QuantizedStablehloOpConversion
quantizedResults.push_back(newResult);
}
}

rewriter.replaceOp(op, quantizedResults);
return success();
}
Expand Down

0 comments on commit 9655fab

Please sign in to comment.