diff --git a/include/taco/lower/lowerer_impl.h b/include/taco/lower/lowerer_impl.h index baa0e2ead..7f53bc669 100644 --- a/include/taco/lower/lowerer_impl.h +++ b/include/taco/lower/lowerer_impl.h @@ -266,13 +266,13 @@ class LowererImpl : public util::Uncopyable { /** * Replace scalar tensor pointers with stack scalar for lowering. */ - ir::Stmt declareScalarVariable(TensorVar var, bool zero); + ir::Stmt defineScalarVariable(TensorVar var, bool zero); /** * Creates code to declare temporaries. */ - ir::Stmt declareTemporaries(std::vector temporaries, - std::map scalars); + ir::Stmt defineTemporaries(std::vector temporaries, + std::map scalars); ir::Stmt initResultArrays(IndexVar var, std::vector writes, std::vector reads, diff --git a/src/lower/lowerer_impl.cpp b/src/lower/lowerer_impl.cpp index a9f8e8867..a2a49f065 100644 --- a/src/lower/lowerer_impl.cpp +++ b/src/lower/lowerer_impl.cpp @@ -138,7 +138,7 @@ Stmt LowererImpl::lower(IndexStmt stmt, string name, bool assemble, vector headerStmts; vector footerStmts; - // Declare and initialize dimension variables + // Define and initialize dimension variables vector indexVars = getIndexVars(stmt); for (auto& indexVar : indexVars) { Expr dimension; @@ -168,14 +168,14 @@ Stmt LowererImpl::lower(IndexStmt stmt, string name, bool assemble, dimensions.insert({indexVar, dimension}); } - // Declare and initialize scalar results and arguments + // Define and initialize scalar results and arguments if (generateComputeCode()) { for (auto& result : results) { if (isScalar(result.getType())) { taco_iassert(!util::contains(scalars, result)); taco_iassert(util::contains(tensorVars, result)); scalars.insert({result, tensorVars.at(result)}); - headerStmts.push_back(declareScalarVariable(result, true)); + headerStmts.push_back(defineScalarVariable(result, true)); } } for (auto& argument : arguments) { @@ -183,7 +183,7 @@ Stmt LowererImpl::lower(IndexStmt stmt, string name, bool assemble, taco_iassert(!util::contains(scalars, argument)); taco_iassert(util::contains(tensorVars, argument)); scalars.insert({argument, tensorVars.at(argument)}); - headerStmts.push_back(declareScalarVariable(argument, false)); + headerStmts.push_back(defineScalarVariable(argument, false)); } } } @@ -204,7 +204,7 @@ Stmt LowererImpl::lower(IndexStmt stmt, string name, bool assemble, reducedAccesses); // Declare and initialize non-scalar temporaries - Stmt declTemporaries = declareTemporaries(temporaries, scalars); + Stmt tempDefinitions = defineTemporaries(temporaries, scalars); // Lower the index statement to compute and/or assemble Stmt body = lower(stmt); @@ -231,7 +231,7 @@ Stmt LowererImpl::lower(IndexStmt stmt, string name, bool assemble, Stmt footer = footerStmts.empty() ? Stmt() : Block::make(footerStmts); return Function::make(name, resultsIR, argumentsIR, Block::blanks(header, - declTemporaries, + tempDefinitions, initializeResults, body, finalizeResults, @@ -661,7 +661,7 @@ Stmt LowererImpl::lowerWhere(Where where) { TensorVar temporary = where.getTemporary(); Stmt declareTemporary; if (isScalar(temporary.getType())) { - declareTemporary = declareScalarVariable(temporary, true); + declareTemporary = defineScalarVariable(temporary, true); } else { taco_not_supported_yet; @@ -1062,21 +1062,29 @@ ir::Stmt LowererImpl::finalizeResultArrays(std::vector writes) { } -Stmt LowererImpl::declareTemporaries(vector temporaries, - map scalars) { +Stmt LowererImpl::defineTemporaries(vector temporaries, + map scalars) { vector result; if (generateComputeCode()) { for (auto& temporary : temporaries) { - if (!isScalar(temporary.getType())) { - taco_not_supported_yet; + if (isScalar(temporary.getType())) { + // We will define scalar temporaries locally where they are initialized + continue; } - // We will declare scalar temporaries locally when they are initialized + + taco_not_supported_yet; + Expr temporaryPtr = Var::make(temporary.getName(), + temporary.getType().getDataType(), + true, true); + + Stmt definition = VarDecl::make(temporaryPtr, 0); + result.push_back(definition); } } return result.empty() ? Stmt() : Block::make(result); } -Stmt LowererImpl::declareScalarVariable(TensorVar var, bool zero) { +Stmt LowererImpl::defineScalarVariable(TensorVar var, bool zero) { Datatype type = var.getType().getDataType(); Expr varValueIR = Var::make(var.getName() + "_val", type, false, false); Expr init = (zero) ? ir::Literal::zero(type) diff --git a/test/tests-lower.cpp b/test/tests-lower.cpp index c5b34697d..9017bab48 100644 --- a/test/tests-lower.cpp +++ b/test/tests-lower.cpp @@ -726,6 +726,26 @@ TEST_STMT(where_matrix_vector_mul, } ) +TEST_STMT(DISABLED_where_spmm, + forall(i, + where(forall(j, + A(i,j) = w(j)), + forall(k, + forall(j, + w(j) += B(i,k) * C(k,j))))), + Values( + Formats({{A,Format({dense,dense})}, + {B,Format({dense,dense})}, {C,Format({dense,dense})}}), + Formats({{A,Format({dense,sparse})}, + {B,Format({dense,sparse})}, {C,Format({dense,sparse})}}) + ), + { + TestCase({{B, { {{0,1}, 2.0}, {{2,0}, 3.0}, {{2,2}, 4.0}} }, + {C, { {{0,0},10.0}, {{0,1}, 20.0}, {{2,1},30.0}} }}, + {{A, { {{2,0},30.0}, {{2,1},180.0} }}}) + } +) + // Test sequence statements @@ -782,10 +802,14 @@ TEST_STMT(matrix_transposed_input, A(i,j) = B(i,j) + C(j,i) )), Values( - Formats({{A,Format({ dense, dense})}, {B,Format({ dense, dense})}, {C,Format({dense,dense})}}), - Formats({{A,Format({ dense,sparse})}, {B,Format({ dense,sparse})}, {C,Format({dense,dense})}}), - Formats({{A,Format({sparse, dense})}, {B,Format({sparse, dense})}, {C,Format({dense,dense})}}), - Formats({{A,Format({sparse,sparse})}, {B,Format({sparse,sparse})}, {C,Format({dense,dense})}}) + Formats({{A,Format({ dense, dense})}, {B,Format({ dense, dense})}, + {C,Format({dense,dense})}}), + Formats({{A,Format({ dense,sparse})}, {B,Format({ dense,sparse})}, + {C,Format({dense,dense})}}), + Formats({{A,Format({sparse, dense})}, {B,Format({sparse, dense})}, + {C,Format({dense,dense})}}), + Formats({{A,Format({sparse,sparse})}, {B,Format({sparse,sparse})}, + {C,Format({dense,dense})}}) ), { TestCase({{B, {{{0,0}, 42.0}, {{0,2}, 2.0}, {{1,3}, 3.0}, {{3,2}, 4.0}}},