Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Index tree rewrite #57

Open
wants to merge 32 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
e297f62
Begin working on IndexTree transformations
AK2000 Nov 1, 2023
25e5089
Fixing some of the problems introduced on merge
AK2000 Nov 1, 2023
250cc5f
Resolved including of device mapping attribute
AK2000 Nov 1, 2023
8ac4072
Fixed type inclusion, parsing and printing
AK2000 Nov 1, 2023
b6c8893
V1 - Lower TA to new IndexTree ops, but removed everything else
AK2000 Nov 6, 2023
e3e5a10
Fixes to TA to change how file is included
AK2000 Nov 6, 2023
1e88644
Creating new block for index tree
AK2000 Nov 8, 2023
c6289de
Implement domain inference pass, fix to index ordering
AK2000 Nov 10, 2023
bd20a01
[WIP] Fragile version of index tree to SCF lowering
AK2000 Nov 30, 2023
df08dd2
Fix carrying tensors inside loop, refactor domain concretization
AK2000 Dec 12, 2023
d25c6ca
Adding TA to index tree patterns for elementwise operations
AK2000 Dec 13, 2023
0edc1b6
[WIP] Trying to implement intersection op lowering
AK2000 Dec 19, 2023
fda15d0
[WIP] Got domain intersection working, but only with dense output
AK2000 Jan 3, 2024
d00bb75
[WIP] Minor fix to ordering of reduce args
AK2000 Jan 3, 2024
17d9f87
[WIP] Beginning support for sparse output tensors with new index tree
AK2000 Jan 10, 2024
45f39a6
[WIP] Inlined itree op, got hacky version of removing set op working
AK2000 Jan 11, 2024
0136bae
[WIP] Included lowering to LLVM, lowering print op does not work
AK2000 Jan 12, 2024
266a411
[WIP] Almost got print op lowering working
AK2000 Jan 15, 2024
47a4aef
[WIP] Fixed bufferization
AK2000 Jan 18, 2024
23bbbf6
[WIP] Generate symbolic pass for sparse tensor declarations
AK2000 Jan 24, 2024
018323e
[WIP] Lots of changes for first try at symbolic domain pass and works…
AK2000 Feb 13, 2024
c8386a2
[WIP] Broke everythink trying to redo tensor conversion infrastructure
AK2000 Feb 21, 2024
adfbdf0
Changing alot to create new sparse tensor types, and appropriate lowe…
AK2000 May 15, 2024
fc39562
Fixing some problems with tests, ad pure ops
AK2000 Jun 5, 2024
19dbedc
Fixed inconsistencies in test suite
AK2000 Jun 5, 2024
5db6c87
Fixing more of the test cases
AK2000 Jun 13, 2024
4f7fc17
Fixed dense transpose and print elapsed time
AK2000 Jun 17, 2024
872efc7
Fixing errors in typing
AK2000 Jun 19, 2024
cb998c0
Fixing errors in typing and set op
AK2000 Jun 19, 2024
86176ef
Adding back ttgt pass
AK2000 Jun 20, 2024
b57e40e
Fixed delete before use errors
AK2000 Oct 14, 2024
d52d26f
Another bug found with asan
AK2000 Oct 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ add_custom_target(comet-headers)
set_target_properties(comet-headers PROPERTIES FOLDER "Misc")
add_custom_target(comet-doc)

set(CMAKE_INCLUDE_CURRENT_DIR ON)

# Add MLIR, LLVM and BLIS headers to the include path
include_directories(${LLVM_INCLUDE_DIRS})
include_directories(${MLIR_INCLUDE_DIRS})
Expand Down Expand Up @@ -182,3 +184,9 @@ if (STANDALONE_INSTALL)
message(STATUS "Setting an $ORIGIN-based RPATH on all executables")
set_rpath_all_targets(${CMAKE_CURRENT_SOURCE_DIR})
endif()

option(DEBUG_MODE "Create a installation with debug information" off)
if (DEBUG_MODE)
message(STATUS "Building comet in debug mode")
add_compile_options(-DCOMET_DEBUG_MODE)
endif()
2 changes: 1 addition & 1 deletion frontends/comet_dsl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ set(LIBS
COMETUtils
COMETTensorAlgebraDialect
COMETIndexTreeDialect
COMETIndexTreeToSCF
# COMETIndexTreeToSCF
)

target_link_libraries(comet-opt
Expand Down
119 changes: 65 additions & 54 deletions frontends/comet_dsl/comet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/Func/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"


#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
Expand Down Expand Up @@ -245,7 +248,7 @@ int loadAndProcessMLIR(mlir::MLIRContext &context,
pm.addPass(mlir::comet::createFuncOpLoweringPass());

mlir::OpPassManager &optPM = pm.nest<mlir::func::FuncOp>();
optPM.addPass(mlir::comet::createRemoveLabeledTensorOpsPass());
// optPM.addPass(mlir::comet::createRemoveLabeledTensorOpsPass());

/// Check to see if we are dumping to TA dialect.
if (emitTA)
Expand Down Expand Up @@ -287,17 +290,15 @@ int loadAndProcessMLIR(mlir::MLIRContext &context,
/// Generate the index tree IR
optPM.addPass(mlir::comet::createLowerTensorAlgebraToIndexTreePass());

if (OptKernelFusion)
{
/// Apply partial fusion on index tree dialect for some compound expressions.
optPM.addPass(mlir::comet::createIndexTreeKernelFusionPass());
}
// Create new pass manager to optimize the index tree dialect
// mlir::OpPassManager &itOptPM = optPM.nest<IndexTreeOp>();
optPM.addPass(mlir::comet::createIndexTreeDomainInferencePass());

if (OptWorkspace)
{
/// Optimized workspace transformations, reduce iteration space for nonzero elements
optPM.addPass(mlir::comet::createIndexTreeWorkspaceTransformationsPass());
}
// if (OptKernelFusion)
// {
// /// Apply partial fusion on index tree dialect for some compound expressions.
// optPM.addPass(mlir::comet::createIndexTreeKernelFusionPass());
// }

/// Dump index tree dialect.
if (emitIT)
Expand All @@ -319,8 +320,9 @@ int loadAndProcessMLIR(mlir::MLIRContext &context,
/// sparse input tensor declaration lowering, also generate sparse_output_tensor declaration if needed
/// input and output sparse tensor declaration lowering are distant and need different information
optPM.addPass(mlir::comet::createSparseTensorDeclLoweringPass());
// optPM.addPass(mlir::comet::createSparseOutputTensorDeclLoweringPass());
optPM.addPass(mlir::comet::createDenseTensorDeclLoweringPass());
optPM.addPass(mlir::comet::createSparseTempOutputTensorDeclLoweringPass());
optPM.addPass(mlir::comet::createSparseOutputTensorDeclLoweringPass());
optPM.addPass(mlir::comet::createTensorFillLoweringPass());

/// =============================================================================
Expand All @@ -332,75 +334,83 @@ int loadAndProcessMLIR(mlir::MLIRContext &context,
optPM.addPass(mlir::comet::createLoweringTTGTPass(IsSelectBestPermTTGT, selectedPermNum, IsPrintFlops));
}

/// =============================================================================
/// Operation based optimizations
/// =============================================================================
if (OptMatmulTiling)
{
optPM.addPass(mlir::comet::createLinAlgMatmulTilingPass());
}
// /// =============================================================================
// /// Operation based optimizations
// /// =============================================================================
// if (OptMatmulTiling)
// {
// optPM.addPass(mlir::comet::createLinAlgMatmulTilingPass());
// }

if (OptCallToMatMulMicroKernel)
{
optPM.addPass(mlir::comet::createLinAlgMatmulMicroKernelPass());
}
// if (OptCallToMatMulMicroKernel)
// {
// optPM.addPass(mlir::comet::createLinAlgMatmulMicroKernelPass());
// }

/// =============================================================================
/// Lowering all the operations to loops
/// =============================================================================
if (IsLoweringtoSCF || emitLoops || emitLLVM)
{
/// Workspace transformations will create new dense tensor declarations, so we need to call createDenseTensorDeclLoweringPass
optPM.addPass(mlir::comet::createDenseTensorDeclLoweringPass()); /// lowers dense input/output tensor declaration
optPM.addPass(mlir::comet::createSparseTempOutputTensorDeclLoweringPass()); /// Temporary sparse output tensor declarations introduced by compound expressions
/// should be lowered before sparse output tensor declarations
optPM.addPass(mlir::comet::createSparseOutputTensorDeclLoweringPass()); /// lowering for sparse output tensor declarations
//(sparse_output_tensor_decl and temp_sparse_output_tensor_decl)
/// The partial Fusion pass might add new tensor.fill operations
optPM.addPass(mlir::comet::createTensorFillLoweringPass());
optPM.addPass(mlir::comet::createPCToLoopsLoweringPass());

{
/// =============================================================================
/// Lowering of other operations such as transpose, sum, etc. to SCF dialect
/// =============================================================================
/// If it is a transpose of dense tensor, the rewrites rules replaces ta.transpose with linalg.copy.
/// If it is a transpose of sparse tensor, it lowers the code to make a runtime call to specific sorting algorithm
optPM.addPass(mlir::comet::createLowerTensorAlgebraToSCFPass());

/// Finally lowering index tree to SCF dialect
optPM.addPass(mlir::comet::createLowerIndexTreeToSCFPass());
optPM.addPass(mlir::createTensorBufferizePass());
pm.addPass(mlir::func::createFuncBufferizePass()); /// Needed for func
/// Concretize the domains of all the index variables
optPM.addPass(mlir::comet::createIndexTreeDomainConcretizationPass());

if (OptDenseTransposeOp) /// Optimize Dense Transpose operation
{
/// If it is a dense transpose ops, the rewrites rules replaces ta.transpose with linalg.transpose, then
/// Create a pass to optimize LinAlg Copy Op - follow in HPTT paper
/// HPTT: A High-Performance Tensor Transposition C++ Library
/// https://arxiv.org/abs/1704.04374
optPM.addPass(mlir::comet::createOptDenseTransposePass());
if (OptWorkspace) {
/// Optimized workspace transformations, reduce iteration space for nonzero elements
optPM.addPass(mlir::comet::createIndexTreeWorkspaceTransformationsPass());
}

optPM.addPass(mlir::comet::createIndexTreeSymbolicComputePass());

/// Finally lowering index tree to SCF dialect
optPM.addPass(mlir::comet::createLowerIndexTreeToSCFPass());
optPM.addPass(mlir::comet::createConvertSymbolicDomainsPass());
optPM.addPass(mlir::comet::createSparseTensorConversionPass());
optPM.addPass(mlir::comet::createIndexTreeInliningPass());
optPM.addPass(mlir::createCanonicalizerPass());

// if (OptDenseTransposeOp) /// Optimize Dense Transpose operation
// {
// /// If it is a dense transpose ops, the rewrites rules replaces ta.transpose with linalg.transpose, then
// /// Create a pass to optimize LinAlg Copy Op - follow in HPTT paper
// /// HPTT: A High-Performance Tensor Transposition C++ Library
// /// https://arxiv.org/abs/1704.04374
// optPM.addPass(mlir::comet::createOptDenseTransposePass());
// }

/// Dump index tree dialect.
if (emitLoops)
{
if (mlir::failed(pm.run(*module)))
return 4;
return 0;
}
/// =============================================================================
}
/// =============================================================================

/// =============================================================================
/// Late lowering passes
/// =============================================================================
// /// =============================================================================
// /// Late lowering passes
// /// =============================================================================
// pm.addPass(mlir::bufferization::createEmptyTensorToAllocTensorPass());
mlir::bufferization::OneShotBufferizationOptions opts;
opts.allowUnknownOps = true;
pm.addPass(mlir::bufferization::createOneShotBufferizePass(opts));

optPM.addPass(mlir::comet::createSTCRemoveDeadOpsPass());
optPM.addPass(mlir::comet::createLateLoweringPass());
optPM.addPass(mlir::createCanonicalizerPass());
optPM.addPass(mlir::createCSEPass());
mlir::OpPassManager &late_lowering_pm = pm.nest<mlir::func::FuncOp>();
late_lowering_pm.addPass(mlir::comet::createSTCRemoveDeadOpsPass());
late_lowering_pm.addPass(mlir::comet::createLateLoweringPass());

pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());

/// =============================================================================
// /// =============================================================================

if (isLoweringToLLVM || emitLLVM)
{
Expand Down Expand Up @@ -481,6 +491,7 @@ int main(int argc, char **argv)
context.loadDialect<mlir::linalg::LinalgDialect>();
context.loadDialect<mlir::scf::SCFDialect>();
context.loadDialect<mlir::bufferization::BufferizationDialect>();
context.loadDialect<mlir::index::IndexDialect>();

mlir::OwningOpRef<mlir::ModuleOp> module;

Expand Down
2 changes: 1 addition & 1 deletion frontends/comet_dsl/include/Lexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
#include <string>

// *********** For debug purpose *********//
//#define COMET_DEBUG_MODE
// #define COMET_DEBUG_MODE
#include "comet/Utils/debug.h"
#undef COMET_DEBUG_MODE
// *********** For debug purpose *********//
Expand Down
64 changes: 55 additions & 9 deletions frontends/comet_dsl/mlir/MLIRGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ using llvm::Twine;
using StringSet = std::set<std::string>;

// *********** For debug purpose *********//
//#define COMET_DEBUG_MODE
// #define COMET_DEBUG_MODE
#include "comet/Utils/debug.h"
#undef COMET_DEBUG_MODE
// *********** For debug purpose *********//
Expand Down Expand Up @@ -591,23 +591,41 @@ namespace
comet_debug() << "\n";

auto lhs_tensor = lhs.getDefiningOp()->getOpResult(0).getType();
assert(lhs_tensor.isa<mlir::TensorType>());

comet_pdump(lhs.getDefiningOp());

auto lhs_labeledtensor = lhs.getDefiningOp()->getOpResult(0);

comet_vdump(lhs_labeledtensor); // ta.labeled_tensor
auto lhs_el_type = lhs_tensor.cast<mlir::TensorType>().getElementType();
mlir::Type lhs_el_type;
if(auto tensor_type = llvm::dyn_cast<mlir::TensorType>(lhs_tensor)){
lhs_el_type = tensor_type.getElementType();
}
else if(auto tensor_type = llvm::dyn_cast<SparseTensorType>(lhs_tensor)){
lhs_el_type = tensor_type.getElementType();
}
else {
assert(false && "Expected a tensor input");
}

auto rhs_tensor = rhs.getDefiningOp()->getOpResult(0).getType();

comet_pdump(rhs.getDefiningOp());
assert(rhs_tensor.isa<mlir::TensorType>());

auto rhs_labeledtensor = rhs.getDefiningOp()->getOpResult(0);

comet_vdump(rhs_labeledtensor);
auto rhs_el_type = rhs_tensor.cast<mlir::TensorType>().getElementType();
mlir::Type rhs_el_type;
if(auto tensor_type = llvm::dyn_cast<mlir::TensorType>(rhs_tensor)){
rhs_el_type = tensor_type.getElementType();
}
else if(auto tensor_type = llvm::dyn_cast<SparseTensorType>(rhs_tensor)){
rhs_el_type = tensor_type.getElementType();
}
else {
assert(false && "Expected a tensor input");
}

auto result_type = getBinOpResultType(lhs_el_type, rhs_el_type);
comet_debug() << __LINE__ << " ";
comet_vdump(result_type);
Expand Down Expand Up @@ -817,8 +835,6 @@ namespace
}

std::vector<int64_t> result_dims = getDimSizes(ret_lbls_value);
auto ret_tensor_type = mlir::RankedTensorType::get(result_dims, result_type);

auto affineMapArrayAttr = builder.getAffineMapArrayAttr(affine_maps);

SmallVector<mlir::StringRef, 8> formats;
Expand Down Expand Up @@ -1000,18 +1016,29 @@ namespace
}
comet_debug() << __LINE__ << " formats.size(): " << formats.size() << "\n";
assert(formats.size() == 2 && " less than 2 input tensors\n");
mlir::Type ret_tensor_type;
if (formats[0].compare("CSR") == 0 && formats[1].compare("CSR") == 0)
{
formats.push_back("CSR");
std::vector format_array = getFormats("CSR", result_dims.size(), builder.getContext());
ret_tensor_type = SparseTensorType::get(builder.getContext(), result_type, result_dims, format_array);
}
else if (formats[0].compare("Dense") == 0 && formats[1].compare("Dense") == 0)
{
formats.push_back("Dense");
ret_tensor_type = mlir::RankedTensorType::get(result_dims, result_type);
}
else if (out_format.length() > 0) // non-empty format string provided.
{
comet_debug() << " Output Format: " << out_format << "\n";
formats.push_back(out_format);
if(out_format.compare("Dense") == 0)
{
ret_tensor_type = mlir::RankedTensorType::get(result_dims, result_type);
} else {
std::vector format_array = getFormats(out_format, result_dims.size(), builder.getContext());
ret_tensor_type = SparseTensorType::get(builder.getContext(), result_type, result_dims, format_array);
}
}
else
{
Expand Down Expand Up @@ -1604,9 +1631,24 @@ namespace
if (isDense(formats_str, ", ") == false)
{
/// BoolAttr is false because there is explicit sparse densor declaration.
/// SparseTensorDeclOp is not for temporaries in compound expressions
/// SparseTensorDeclOp is not for temporaries in compound expression
std::vector<int32_t> format = mlir::tensorAlgebra::getFormats(tensor_format, dims_sizes.size(), builder.getContext());
mlir::Type element_type;
switch (vartype.elt_ty)
{
case VarType::TY_FLOAT:
element_type = builder.getF32Type();
break;
case VarType::TY_DOUBLE:
element_type = builder.getF64Type();
break;
case VarType::TY_INT:
element_type = builder.getIntegerType(64);
break;
}
auto sp_tensor_type = SparseTensorType::get(builder.getContext(), element_type, dims_sizes, format);
value = builder.create<SparseTensorDeclOp>(loc(tensordecl.loc()),
tensor_type, labels, tensor_format, false);
sp_tensor_type, labels, tensor_format, false);
comet_debug() << "MLIRGen SparseTensorDeclaration creation\n";
comet_vdump(value);
}
Expand Down Expand Up @@ -1864,6 +1906,10 @@ namespace
mlir::StringRef format_strref = dyn_cast<SparseTensorDeclOp>(rhs_tensor.getDefiningOp()).getFormat();
mlir::StringAttr formatAttr = builder.getStringAttr(format_strref);

std::vector<int32_t> format = mlir::tensorAlgebra::getFormats(format_strref, result_dims.size(), builder.getContext());
mlir::Type element_type = builder.getF64Type();
return_type = SparseTensorType::get(builder.getContext(), element_type, result_dims, format);

/// no lhs_LabeledTensor has been created. The output tensor of tranpose doesn't have explicit declaration,
/// BoolAttr is true to speficy SparseTensorDeclOp is for temporaries
lhs_tensor = builder.create<SparseTensorDeclOp>(loc(transpose.loc()), return_type, lhs_labels_val, formatAttr, builder.getBoolAttr(true));
Expand Down
7 changes: 2 additions & 5 deletions include/comet/Conversion/IndexTreeToSCF/IndexTreeToSCF.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,13 @@ namespace mlir
namespace comet
{
#define GEN_PASS_DECL_CONVERTINDEXTREETOSCF
#define GEN_PASS_DECL_CONVERTSYMBOLICDOMAINS
#include "comet/Conversion/Passes.h.inc"

/// Collect a set of patterns to convert IndexTree operations to SCF
/// operations within the SCF dialect.
void populateIndexTreeToSCFConversionPatterns(RewritePatternSet &patterns);

/// Lowers indexTree operations (e.g., IndexTreeComputeLHSOp, IndexTreeComputeRHSOp and IndexTreeComputeOp)
/// to equivalent scf constructs including basic blocks and arithmetic
/// primitives).
std::unique_ptr<Pass> createLowerIndexTreeToSCFPass();
std::unique_ptr<Pass> createConvertSymbolicDomainsPass();
}
} // namespace mlir

Expand Down
Loading