Skip to content

Commit

Permalink
fixed some failing compound expr tests due to issue in the TA codegen #…
Browse files Browse the repository at this point in the history
  • Loading branch information
rizwanashraf committed Aug 25, 2023
1 parent 8e164b5 commit 5729b83
Showing 1 changed file with 29 additions and 10 deletions.
39 changes: 29 additions & 10 deletions frontends/comet_dsl/mlir/MLIRGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1112,16 +1112,35 @@ namespace
{
comet_debug() << " is TransposeOp\n";

// infer the format
mlir::ArrayAttr opFormatsArrayAttr = dyn_cast<mlir::tensorAlgebra::TransposeOp>(e.getDefiningOp()).getFormats();
unsigned int i = opFormatsArrayAttr.size() - 1;
std::string lhs_format(opFormatsArrayAttr[i].cast<mlir::StringAttr>().getValue());
comet_debug() << __LINE__ << " lhs_format: " << lhs_format << "\n";

comet_debug() << " lhs_format: " << lhs_format << "\n";
formats.push_back(lhs_format);
// get the real transpose op output via the set op.
mlir::Value transposeOut;
mlir::Operation *firstUser = e.getDefiningOp()->getNextNode();
if (isa<TensorSetOp>(firstUser))
{
TensorSetOp setOp = cast<TensorSetOp>(firstUser);
transposeOut = setOp.getOperand(1);
}
else
{
assert(false && "Transpose has no set_op after it!");
}

tensors.push_back(dyn_cast<mlir::tensorAlgebra::TransposeOp>(e.getDefiningOp()).getOperation()->getResult(0));
// get the format of transposeOut tensor
if (isa<DenseTensorDeclOp>(transposeOut.getDefiningOp()))
{
auto denseFormat = dyn_cast<DenseTensorDeclOp>(transposeOut.getDefiningOp()).getFormat();
formats.push_back(denseFormat);
}
else if (isa<SparseTensorDeclOp>(transposeOut.getDefiningOp()))
{
auto sparseFormat = dyn_cast<SparseTensorDeclOp>(transposeOut.getDefiningOp()).getFormat();
formats.push_back(sparseFormat);
}
else
{
assert(false && "Can not determine tensor format with transpose op");
}
tensors.push_back(transposeOut);
}
else
{
Expand Down Expand Up @@ -1150,7 +1169,7 @@ namespace
comet_debug() << " formats.size(): " << formats.size() << "\n";
auto strAttr = builder.getStrArrayAttr(formats);

assert(tensors.size() == 2 && " less than 2 input tensors for ta.tc or ta.elews_mul\n");
assert(tensors.size() == 2 && " less than 2 input tensors for ta.mul or ta.elews_mul\n");

std::vector<mlir::Value> labels;
for (auto i : ret_lbls)
Expand Down

0 comments on commit 5729b83

Please sign in to comment.