Skip to content

Commit

Permalink
Adds SpMM test
Browse files Browse the repository at this point in the history
  • Loading branch information
fredrikbk committed May 30, 2019
1 parent b7aecb6 commit 4f4ac27
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 20 deletions.
6 changes: 3 additions & 3 deletions include/taco/lower/lowerer_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,13 +266,13 @@ class LowererImpl : public util::Uncopyable {
/**
* Replace scalar tensor pointers with stack scalar for lowering.
*/
ir::Stmt declareScalarVariable(TensorVar var, bool zero);
ir::Stmt defineScalarVariable(TensorVar var, bool zero);

/**
* Creates code to declare temporaries.
*/
ir::Stmt declareTemporaries(std::vector<TensorVar> temporaries,
std::map<TensorVar,ir::Expr> scalars);
ir::Stmt defineTemporaries(std::vector<TensorVar> temporaries,
std::map<TensorVar,ir::Expr> scalars);

ir::Stmt initResultArrays(IndexVar var, std::vector<Access> writes,
std::vector<Access> reads,
Expand Down
34 changes: 21 additions & 13 deletions src/lower/lowerer_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ Stmt LowererImpl::lower(IndexStmt stmt, string name, bool assemble,
vector<Stmt> headerStmts;
vector<Stmt> footerStmts;

// Declare and initialize dimension variables
// Define and initialize dimension variables
vector<IndexVar> indexVars = getIndexVars(stmt);
for (auto& indexVar : indexVars) {
Expr dimension;
Expand Down Expand Up @@ -168,22 +168,22 @@ Stmt LowererImpl::lower(IndexStmt stmt, string name, bool assemble,
dimensions.insert({indexVar, dimension});
}

// Declare and initialize scalar results and arguments
// Define and initialize scalar results and arguments
if (generateComputeCode()) {
for (auto& result : results) {
if (isScalar(result.getType())) {
taco_iassert(!util::contains(scalars, result));
taco_iassert(util::contains(tensorVars, result));
scalars.insert({result, tensorVars.at(result)});
headerStmts.push_back(declareScalarVariable(result, true));
headerStmts.push_back(defineScalarVariable(result, true));
}
}
for (auto& argument : arguments) {
if (isScalar(argument.getType())) {
taco_iassert(!util::contains(scalars, argument));
taco_iassert(util::contains(tensorVars, argument));
scalars.insert({argument, tensorVars.at(argument)});
headerStmts.push_back(declareScalarVariable(argument, false));
headerStmts.push_back(defineScalarVariable(argument, false));
}
}
}
Expand All @@ -204,7 +204,7 @@ Stmt LowererImpl::lower(IndexStmt stmt, string name, bool assemble,
reducedAccesses);

// Declare and initialize non-scalar temporaries
Stmt declTemporaries = declareTemporaries(temporaries, scalars);
Stmt tempDefinitions = defineTemporaries(temporaries, scalars);

// Lower the index statement to compute and/or assemble
Stmt body = lower(stmt);
Expand All @@ -231,7 +231,7 @@ Stmt LowererImpl::lower(IndexStmt stmt, string name, bool assemble,
Stmt footer = footerStmts.empty() ? Stmt() : Block::make(footerStmts);
return Function::make(name, resultsIR, argumentsIR,
Block::blanks(header,
declTemporaries,
tempDefinitions,
initializeResults,
body,
finalizeResults,
Expand Down Expand Up @@ -661,7 +661,7 @@ Stmt LowererImpl::lowerWhere(Where where) {
TensorVar temporary = where.getTemporary();
Stmt declareTemporary;
if (isScalar(temporary.getType())) {
declareTemporary = declareScalarVariable(temporary, true);
declareTemporary = defineScalarVariable(temporary, true);
}
else {
taco_not_supported_yet;
Expand Down Expand Up @@ -1062,21 +1062,29 @@ ir::Stmt LowererImpl::finalizeResultArrays(std::vector<Access> writes) {
}


Stmt LowererImpl::declareTemporaries(vector<TensorVar> temporaries,
map<TensorVar, Expr> scalars) {
Stmt LowererImpl::defineTemporaries(vector<TensorVar> temporaries,
map<TensorVar, Expr> scalars) {
vector<Stmt> result;
if (generateComputeCode()) {
for (auto& temporary : temporaries) {
if (!isScalar(temporary.getType())) {
taco_not_supported_yet;
if (isScalar(temporary.getType())) {
// We will define scalar temporaries locally where they are initialized
continue;
}
// We will declare scalar temporaries locally when they are initialized

taco_not_supported_yet;
Expr temporaryPtr = Var::make(temporary.getName(),
temporary.getType().getDataType(),
true, true);

Stmt definition = VarDecl::make(temporaryPtr, 0);
result.push_back(definition);
}
}
return result.empty() ? Stmt() : Block::make(result);
}

Stmt LowererImpl::declareScalarVariable(TensorVar var, bool zero) {
Stmt LowererImpl::defineScalarVariable(TensorVar var, bool zero) {
Datatype type = var.getType().getDataType();
Expr varValueIR = Var::make(var.getName() + "_val", type, false, false);
Expr init = (zero) ? ir::Literal::zero(type)
Expand Down
32 changes: 28 additions & 4 deletions test/tests-lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,26 @@ TEST_STMT(where_matrix_vector_mul,
}
)

TEST_STMT(DISABLED_where_spmm,
forall(i,
where(forall(j,
A(i,j) = w(j)),
forall(k,
forall(j,
w(j) += B(i,k) * C(k,j))))),
Values(
Formats({{A,Format({dense,dense})},
{B,Format({dense,dense})}, {C,Format({dense,dense})}}),
Formats({{A,Format({dense,sparse})},
{B,Format({dense,sparse})}, {C,Format({dense,sparse})}})
),
{
TestCase({{B, { {{0,1}, 2.0}, {{2,0}, 3.0}, {{2,2}, 4.0}} },
{C, { {{0,0},10.0}, {{0,1}, 20.0}, {{2,1},30.0}} }},
{{A, { {{2,0},30.0}, {{2,1},180.0} }}})
}
)


// Test sequence statements

Expand Down Expand Up @@ -782,10 +802,14 @@ TEST_STMT(matrix_transposed_input,
A(i,j) = B(i,j) + C(j,i)
)),
Values(
Formats({{A,Format({ dense, dense})}, {B,Format({ dense, dense})}, {C,Format({dense,dense})}}),
Formats({{A,Format({ dense,sparse})}, {B,Format({ dense,sparse})}, {C,Format({dense,dense})}}),
Formats({{A,Format({sparse, dense})}, {B,Format({sparse, dense})}, {C,Format({dense,dense})}}),
Formats({{A,Format({sparse,sparse})}, {B,Format({sparse,sparse})}, {C,Format({dense,dense})}})
Formats({{A,Format({ dense, dense})}, {B,Format({ dense, dense})},
{C,Format({dense,dense})}}),
Formats({{A,Format({ dense,sparse})}, {B,Format({ dense,sparse})},
{C,Format({dense,dense})}}),
Formats({{A,Format({sparse, dense})}, {B,Format({sparse, dense})},
{C,Format({dense,dense})}}),
Formats({{A,Format({sparse,sparse})}, {B,Format({sparse,sparse})},
{C,Format({dense,dense})}})
),
{
TestCase({{B, {{{0,0}, 42.0}, {{0,2}, 2.0}, {{1,3}, 3.0}, {{3,2}, 4.0}}},
Expand Down

0 comments on commit 4f4ac27

Please sign in to comment.