diff --git a/src/index_notation/transformations.cpp b/src/index_notation/transformations.cpp index b09258410..fa62df2e4 100644 --- a/src/index_notation/transformations.cpp +++ b/src/index_notation/transformations.cpp @@ -654,7 +654,9 @@ static IndexStmt optimizeSpMM(IndexStmt stmt) { } // It's an SpMM statement so return an optimized SpMM statement - TensorVar w("w", Type(Float64, {Dimension()}), dense); + TensorVar w("w", + Type(Float64, {A.getType().getShape().getDimension(1)}), + dense); return forall(i, where(forall(j, A(i,j) = w(j)), diff --git a/src/tensor.cpp b/src/tensor.cpp index 0a9fa9f5c..2090b7a66 100644 --- a/src/tensor.cpp +++ b/src/tensor.cpp @@ -461,6 +461,7 @@ void TensorBase::compile(bool assembleWhileCompute) { IndexStmt stmt = makeConcrete(assignment); string reason; stmt = reorderLoopsTopologically(stmt); + stmt = insertTemporaries(stmt); taco_uassert(stmt != IndexStmt()) << reason; stmt = parallelizeOuterLoop(stmt); content->assembleFunc = lower(stmt, "assemble", true, false); diff --git a/tools/taco.cpp b/tools/taco.cpp index 397d39e6a..51a8bee60 100644 --- a/tools/taco.cpp +++ b/tools/taco.cpp @@ -817,6 +817,7 @@ int main(int argc, char* argv[]) { string reason; stmt = reorderLoopsTopologically(stmt); + stmt = insertTemporaries(stmt); taco_uassert(stmt != IndexStmt()) << reason; stmt = parallelizeOuterLoop(stmt); compute = lower(stmt, "compute", false, true);