diff --git a/src/index_notation/transformations.cpp b/src/index_notation/transformations.cpp index f0728ef2b..5310455f6 100644 --- a/src/index_notation/transformations.cpp +++ b/src/index_notation/transformations.cpp @@ -1114,16 +1114,19 @@ static IndexStmt optimizeSpMM(IndexStmt stmt) { return stmt; } + // I think we can to linear combination of rows as long as there are no permutations in the format and the + // level formats are ordered. The i -> k -> j loops should iterate over the data structures without issue. TensorVar B = Baccess.getTensorVar(); - if (B.getFormat().getModeOrdering()[0] != 0 || + if (!B.getFormat().getModeFormats()[0].isOrdered() || + !B.getFormat().getModeFormats()[1].isOrdered() || + B.getFormat().getModeOrdering()[0] != 0 || B.getFormat().getModeOrdering()[1] != 1) { return stmt; } - // We need random access into the first mode or this tensor in order to perform a linear combination of rows - // algorithm. (I think?) TensorVar C = Caccess.getTensorVar(); - if (!C.getFormat().getModeFormats()[0].hasLocate() || + if (!C.getFormat().getModeFormats()[0].isOrdered() || + !C.getFormat().getModeFormats()[1].isOrdered() || C.getFormat().getModeOrdering()[0] != 0 || C.getFormat().getModeOrdering()[1] != 1) { return stmt;