Skip to content

Commit

Permalink
Merge pull request #369 from rohany/refactor-break
Browse files Browse the repository at this point in the history
include,src: introduce a true break statement, rename current to continue
  • Loading branch information
stephenchouca authored Jan 20, 2021
2 parents dafe2ba + f35573d commit 0bc58c6
Show file tree
Hide file tree
Showing 11 changed files with 47 additions and 15 deletions.
14 changes: 11 additions & 3 deletions include/taco/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ enum class IRNodeType {
BlankLine,
Print,
GetProperty,
Break,
Sort
Continue,
Sort,
Break
};

enum class TensorProperty {
Expand Down Expand Up @@ -719,7 +720,14 @@ struct BlankLine : public StmtNode<BlankLine> {
static const IRNodeType _type_info = IRNodeType::BlankLine;
};

/** Breaks current loop */
/** Continues past current iteration of current loop */
struct Continue : public StmtNode<Continue> {
static Stmt make();

static const IRNodeType _type_info = IRNodeType::Continue;
};

/** Breaks out of the current loop */
struct Break : public StmtNode<Break> {
static Stmt make();

Expand Down
3 changes: 2 additions & 1 deletion include/taco/ir/ir_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,11 @@ class IRPrinter : public IRVisitorStrict {
virtual void visit(const Free*);
virtual void visit(const Comment*);
virtual void visit(const BlankLine*);
virtual void visit(const Break*);
virtual void visit(const Continue*);
virtual void visit(const Print*);
virtual void visit(const GetProperty*);
virtual void visit(const Sort*);
virtual void visit(const Break*);

std::ostream &stream;
int indent;
Expand Down
3 changes: 2 additions & 1 deletion include/taco/ir/ir_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,11 @@ class IRRewriter : public IRVisitorStrict {
virtual void visit(const Free* op);
virtual void visit(const Comment* op);
virtual void visit(const BlankLine* op);
virtual void visit(const Break* op);
virtual void visit(const Continue* op);
virtual void visit(const Print* op);
virtual void visit(const GetProperty* op);
virtual void visit(const Sort *op);
virtual void visit(const Break *op);
};

}}
Expand Down
9 changes: 6 additions & 3 deletions include/taco/ir/ir_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,11 @@ struct Allocate;
struct Free;
struct Comment;
struct BlankLine;
struct Break;
struct Continue;
struct Print;
struct GetProperty;
struct Sort;
struct Break;

/// Extend this class to visit every node in the IR.
class IRVisitorStrict {
Expand Down Expand Up @@ -96,10 +97,11 @@ class IRVisitorStrict {
virtual void visit(const Free*) = 0;
virtual void visit(const Comment*) = 0;
virtual void visit(const BlankLine*) = 0;
virtual void visit(const Break*) = 0;
virtual void visit(const Continue*) = 0;
virtual void visit(const Print*) = 0;
virtual void visit(const GetProperty*) = 0;
virtual void visit(const Sort*) = 0;
virtual void visit(const Break*) = 0;
};


Expand Down Expand Up @@ -150,10 +152,11 @@ class IRVisitor : public IRVisitorStrict {
virtual void visit(const Free* op);
virtual void visit(const Comment* op);
virtual void visit(const BlankLine* op);
virtual void visit(const Break* op);
virtual void visit(const Continue* op);
virtual void visit(const Print* op);
virtual void visit(const GetProperty* op);
virtual void visit(const Sort* op);
virtual void visit(const Break* op);
};

}}
Expand Down
2 changes: 1 addition & 1 deletion src/codegen/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1142,7 +1142,7 @@ void CodeGen_CUDA::visit(const Sqrt* op) {
stream << ")";
}

void CodeGen_CUDA::visit(const Break*) {
void CodeGen_CUDA::visit(const Continue*) {
doIndent();
if(!isHostFunction && deviceFunctionLoopDepth == 0) {
// can't break out of kernel
Expand Down
2 changes: 1 addition & 1 deletion src/codegen/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class CodeGen_CUDA : public CodeGen {
void visit(const Call*);
void visit(const Store*);
void visit(const Assign*);
void visit(const Break*);
void visit(const Continue*);
void visit(const Free* op);
std::string printDeviceFuncName(const std::vector<std::pair<std::string, Expr>> currentParameters, int index);
void printDeviceFuncCall(const std::vector<std::pair<std::string, Expr>> currentParameters, Expr blockSize, int index, Expr gridSize);
Expand Down
11 changes: 9 additions & 2 deletions src/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,11 @@ Stmt BlankLine::make() {
return new BlankLine;
}

// Continue
Stmt Continue::make() {
return new Continue;
}

// Break
Stmt Break::make() {
return new Break;
Expand Down Expand Up @@ -954,14 +959,16 @@ template<> void StmtNode<Comment>::accept(IRVisitorStrict *v)
const { v->visit((const Comment*)this); }
template<> void StmtNode<BlankLine>::accept(IRVisitorStrict *v)
const { v->visit((const BlankLine*)this); }
template<> void StmtNode<Break>::accept(IRVisitorStrict *v)
const { v->visit((const Break*)this); }
template<> void StmtNode<Continue>::accept(IRVisitorStrict *v)
const { v->visit((const Continue*)this); }
template<> void StmtNode<Print>::accept(IRVisitorStrict *v)
const { v->visit((const Print*)this); }
template<> void ExprNode<GetProperty>::accept(IRVisitorStrict *v)
const { v->visit((const GetProperty*)this); }
template<> void StmtNode<Sort>::accept(IRVisitorStrict *v)
const { v->visit((const Sort*)this); }
template<> void StmtNode<Break>::accept(IRVisitorStrict *v)
const { v->visit((const Break*)this); }

// printing methods
std::ostream& operator<<(std::ostream& os, const Stmt& stmt) {
Expand Down
7 changes: 6 additions & 1 deletion src/ir/ir_printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -558,9 +558,14 @@ void IRPrinter::visit(const BlankLine*) {
stream << endl;
}

void IRPrinter::visit(const Continue*) {
doIndent();
stream << "continue;" << endl;
}

void IRPrinter::visit(const Break*) {
doIndent();
stream << "continue;" << endl; // TODO: add continue statement
stream << "break;" << endl;
}

void IRPrinter::visit(const Print* op) {
Expand Down
4 changes: 4 additions & 0 deletions src/ir/ir_rewriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,10 @@ void IRRewriter::visit(const BlankLine* op) {
stmt = op;
}

void IRRewriter::visit(const Continue* op) {
stmt = op;
}

void IRRewriter::visit(const Break* op) {
stmt = op;
}
Expand Down
3 changes: 3 additions & 0 deletions src/ir/ir_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,9 @@ void IRVisitor::visit(const Comment*) {
void IRVisitor::visit(const BlankLine*) {
}

void IRVisitor::visit(const Continue*) {
}

void IRVisitor::visit(const Break*) {
}

Expand Down
4 changes: 2 additions & 2 deletions src/lower/lowerer_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ Stmt LowererImpl::lowerForall(Forall forall)
if (isa<ir::Literal>(ir::simplify(iterBounds[0])) && ir::simplify(iterBounds[0]).as<ir::Literal>()->equalsScalar(0)) {
guardCondition = maxGuard;
}
ir::Stmt guard = ir::IfThenElse::make(guardCondition, ir::Break::make());
ir::Stmt guard = ir::IfThenElse::make(guardCondition, ir::Continue::make());
recoverySteps.push_back(guard);
}

Expand Down Expand Up @@ -438,7 +438,7 @@ Stmt LowererImpl::lowerForall(Forall forall)
}
if (!hasDirectDivBound) {
Stmt guard = IfThenElse::make(Gte::make(indexVarToExprMap[varToRecover], underivedBounds[varToRecover][1]),
Break::make());
Continue::make());
recoverySteps.push_back(guard);
}
}
Expand Down

0 comments on commit 0bc58c6

Please sign in to comment.