Skip to content

Commit

Permalink
src,test: fix some bugs around ivar recovery and pos splits
Browse files Browse the repository at this point in the history
This commit fixes several bugs around the recovery of index variables
when splits/divides of the position spaces are involved. In particular,
the old code would emit position loops like
```
for (int jposo = A2_pos[i] / chunkSize; jposo < (A2_pos[i+1] + (chunkSize - 1)) / chunkSize; jposo++) {
  for (int jposi = 0; jposi < chunkSize; jposi++) {
    int jposA = jposo * chunkSize + jposi;
    ...
  }
}
```
This does not correctly iterate over the position space. A correct code is:
```
for (int jposo = 0; jposo < ((A2_pos[i+1] - A2_pos[i]) + (chunkSize - 1)) / chunkSize; jposo++) {
  for (int jposi = 0; jposi < chunkSize; jposi++) {
    int jposA = jposo * chunkSize + jposi + A2_pos[i];
    ...
  }
}
```
  • Loading branch information
rohany committed Jan 2, 2022
1 parent 0ede002 commit 94961f7
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 12 deletions.
36 changes: 25 additions & 11 deletions src/index_notation/provenance_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,16 @@ std::vector<ir::Expr> SplitRelNode::deriveIterBounds(taco::IndexVar indexVar,
std::vector<ir::Expr> parentBound = parentIterBounds.at(getParentVar());
Datatype splitFactorType = parentBound[0].type();
if (indexVar == getOuterVar()) {
ir::Expr minBound = ir::Div::make(parentBound[0], ir::Literal::make(getSplitFactor(), splitFactorType));
ir::Expr maxBound = ir::Div::make(ir::Add::make(parentBound[1], ir::Literal::make(getSplitFactor()-1, splitFactorType)), ir::Literal::make(getSplitFactor(), splitFactorType));
// The outer variable must always range from 0 to the extent of the bounds (chunked up).
// This is a noop for the common case where all of our loops start at 0. However, it is
// important when doing a pos split, where the resulting bounds may not start at 0,
// and instead start at a position like T_pos[i].lo.
ir::Expr minBound = 0;
auto upper = ir::Sub::make(parentBound[1], parentBound[0]);
ir::Expr maxBound = ir::Div::make(ir::Add::make(upper, ir::Literal::make(getSplitFactor() - 1, splitFactorType)),
ir::Literal::make(getSplitFactor(), splitFactorType));
return {minBound, maxBound};
}
else if (indexVar == getInnerVar()) {
} else if (indexVar == getInnerVar()) {
ir::Expr minBound = 0;
ir::Expr maxBound = ir::Literal::make(getSplitFactor(), splitFactorType);
return {minBound, maxBound};
Expand All @@ -231,7 +236,12 @@ ir::Expr SplitRelNode::recoverVariable(taco::IndexVar indexVar,
taco_iassert(indexVar == getParentVar());
taco_iassert(variableNames.count(getParentVar()) && variableNames.count(getOuterVar()) && variableNames.count(getInnerVar()));
Datatype splitFactorType = variableNames[getParentVar()].type();
return ir::Add::make(ir::Mul::make(variableNames[getOuterVar()], ir::Literal::make(getSplitFactor(), splitFactorType)), variableNames[getInnerVar()]);
// Include the lower bound of the variable being recovered. Normally, this is 0, but
// in the case of a position split it is not.
return ir::Add::make(
ir::Add::make(ir::Mul::make(variableNames[getOuterVar()], ir::Literal::make(getSplitFactor(), splitFactorType)), variableNames[getInnerVar()]),
0 //parentIterBounds[indexVar][0]
);
}

ir::Stmt SplitRelNode::recoverChild(taco::IndexVar indexVar,
Expand Down Expand Up @@ -364,11 +374,12 @@ std::vector<ir::Expr> DivideRelNode::deriveIterBounds(taco::IndexVar indexVar,
ir::Expr minBound = 0;
ir::Expr maxBound = divFactor;
return {minBound, maxBound};
}
else if (indexVar == getInnerVar()) {
// The inner loop ranges over a chunk of size parentBound / divFactor.
ir::Expr minBound = ir::Div::make(parentBound[0], divFactor);
ir::Expr maxBound = ir::Div::make(ir::Add::make(parentBound[1], ir::Literal::make(getDivFactor()-1, divFactorType)), divFactor);
} else if (indexVar == getInnerVar()) {
// The inner loop ranges over a chunk of size parentBound / divFactor. Similarly
// to split, the loop must range from 0 to parentBound.
ir::Expr minBound = 0;
auto upper = ir::Sub::make(parentBound[1], parentBound[0]);
ir::Expr maxBound = ir::Div::make(ir::Add::make(upper, ir::Literal::make(getDivFactor()-1, divFactorType)), divFactor);
return {minBound, maxBound};
}
taco_ierror;
Expand All @@ -390,7 +401,10 @@ ir::Expr DivideRelNode::recoverVariable(taco::IndexVar indexVar,
// The bounds for the dimension are adjusted so that dimensions that aren't
// divisible by divFactor have the last piece included.
auto bounds = ir::Div::make(ir::Add::make(dimSize, divFactorMinusOne), divFactor);
return ir::Add::make(ir::Mul::make(variableNames[getOuterVar()], bounds), variableNames[getInnerVar()]);
// We multiply this all together and then add on the base of the parentBounds
// to shift up into the range of the parent. This is normally 0, but for cases
// like position loops it is not.
return ir::Add::make(ir::Add::make(ir::Mul::make(variableNames[getOuterVar()], bounds), variableNames[getInnerVar()]), parentIterBounds[indexVar][0]);
}

ir::Stmt DivideRelNode::recoverChild(taco::IndexVar indexVar,
Expand Down
6 changes: 5 additions & 1 deletion src/lower/lowerer_impl_imperative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -695,11 +695,15 @@ Stmt LowererImplImperative::lowerForall(Forall forall)
// Find the iteration bounds of the inner variable -- that is the size
// that the outer loop was broken into.
auto bounds = this->provGraph.deriveIterBounds(inner, definedIndexVarsOrdered, underivedBounds, indexVarToExprMap, iterators);
auto parentBounds = this->provGraph.deriveIterBounds(varToRecover, this->definedIndexVarsOrdered, this->underivedBounds, this->indexVarToExprMap, this->iterators);
// Use the difference between the bounds to find the size of the loop.
auto dimLen = ir::Sub::make(bounds[1], bounds[0]);
// For a variable f divided into into f1 and f2, the guard ensures that
// for iteration f, f should be within f1 * dimLen and (f1 + 1) * dimLen.
auto guard = ir::Gte::make(this->indexVarToExprMap[varToRecover], ir::Mul::make(ir::Add::make(this->indexVarToExprMap[outer], 1), dimLen));
// Additionally, similarly to the recovery of variables in the DivideRelNode,
// we also need to include the lower bound of the original variable here.
auto upper = ir::Add::make(ir::Mul::make(ir::Add::make(this->indexVarToExprMap[outer], 1), dimLen), parentBounds[0]);
auto guard = ir::Gte::make(this->indexVarToExprMap[varToRecover], ir::simplify(upper));
recoverySteps.push_back(IfThenElse::make(guard, ir::Continue::make()));
}
}
Expand Down
48 changes: 48 additions & 0 deletions test/tests-scheduling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -990,3 +990,51 @@ TEST(scheduling, divide) {
return stmt.fuse(i, j, f).pos(f, fpos, A(i, j)).divide(fpos, f0, f1, 4).split(f1, i1, i2, 16).split(i2, i3, i4, 8);
});
}

TEST(scheduling, posSplitAndDivide) {
int dim = 10;
Tensor<int> A("A", {dim, dim}, {Dense, Sparse});
Tensor<int> x("x", {dim}, Dense);

auto sparsity = 0.5;
srand(59393);
for (int i = 0; i < dim; i++) {
x.insert({i}, i);
for (int j = 0; j < dim; j++) {
auto rand_float = (float)rand()/(float)(RAND_MAX);
if (rand_float < sparsity) {
A.insert({i, j},((int)(rand_float * 10 / sparsity)));
}
}
}
A.pack();
x.pack();

IndexVar i("i"), j("j"), ipos("ipos"), iposo("iposo"), iposi("iposi");
auto test = [&](std::function<IndexStmt(IndexStmt)> f) {
Tensor<int> y("y", {dim}, Dense);
y(i) = A(i, j) * x(j);
auto stmt = f(y.getAssignment().concretize());
y.compile(stmt);
y.evaluate();
Tensor<int> expected("expected", {dim}, Dense);
expected(i) = A(i, j) * x(j);
expected.evaluate();
ASSERT_TRUE(equals(expected, y)) << expected << endl << y << endl;
};

// TODO (rohany): The old split code did not work prior to this commit on large
// problem instances from suitesparse, but I'm not able to reproduce the bug on
// small test cases here.
// test([&](IndexStmt stmt) {
// return stmt.pos(j, ipos, A(i, j))
// .split(ipos, iposo, iposi, 12)
// ;
// });

test([&](IndexStmt stmt) {
return stmt.pos(j, ipos, A(i, j))
.divide(ipos, iposo, iposi, 8)
;
});
}

0 comments on commit 94961f7

Please sign in to comment.