diff --git a/include/taco/ir/ir.h b/include/taco/ir/ir.h index bbb36c12b..0fd9a45b4 100644 --- a/include/taco/ir/ir.h +++ b/include/taco/ir/ir.h @@ -65,8 +65,9 @@ enum class IRNodeType { BlankLine, Print, GetProperty, - Break, - Sort + Continue, + Sort, + Break }; enum class TensorProperty { @@ -719,7 +720,14 @@ struct BlankLine : public StmtNode { static const IRNodeType _type_info = IRNodeType::BlankLine; }; -/** Breaks current loop */ +/** Continues past current iteration of current loop */ +struct Continue : public StmtNode { + static Stmt make(); + + static const IRNodeType _type_info = IRNodeType::Continue; +}; + +/** Breaks out of the current loop */ struct Break : public StmtNode { static Stmt make(); diff --git a/include/taco/ir/ir_printer.h b/include/taco/ir/ir_printer.h index a02995fce..4e50764e9 100644 --- a/include/taco/ir/ir_printer.h +++ b/include/taco/ir/ir_printer.h @@ -65,10 +65,11 @@ class IRPrinter : public IRVisitorStrict { virtual void visit(const Free*); virtual void visit(const Comment*); virtual void visit(const BlankLine*); - virtual void visit(const Break*); + virtual void visit(const Continue*); virtual void visit(const Print*); virtual void visit(const GetProperty*); virtual void visit(const Sort*); + virtual void visit(const Break*); std::ostream &stream; int indent; diff --git a/include/taco/ir/ir_rewriter.h b/include/taco/ir/ir_rewriter.h index 81ad43705..6b58accc2 100644 --- a/include/taco/ir/ir_rewriter.h +++ b/include/taco/ir/ir_rewriter.h @@ -65,10 +65,11 @@ class IRRewriter : public IRVisitorStrict { virtual void visit(const Free* op); virtual void visit(const Comment* op); virtual void visit(const BlankLine* op); - virtual void visit(const Break* op); + virtual void visit(const Continue* op); virtual void visit(const Print* op); virtual void visit(const GetProperty* op); virtual void visit(const Sort *op); + virtual void visit(const Break *op); }; }} diff --git a/include/taco/ir/ir_visitor.h b/include/taco/ir/ir_visitor.h index 810e4f758..0e5f844d6 100644 --- a/include/taco/ir/ir_visitor.h +++ b/include/taco/ir/ir_visitor.h @@ -45,10 +45,11 @@ struct Allocate; struct Free; struct Comment; struct BlankLine; -struct Break; +struct Continue; struct Print; struct GetProperty; struct Sort; +struct Break; /// Extend this class to visit every node in the IR. class IRVisitorStrict { @@ -96,10 +97,11 @@ class IRVisitorStrict { virtual void visit(const Free*) = 0; virtual void visit(const Comment*) = 0; virtual void visit(const BlankLine*) = 0; - virtual void visit(const Break*) = 0; + virtual void visit(const Continue*) = 0; virtual void visit(const Print*) = 0; virtual void visit(const GetProperty*) = 0; virtual void visit(const Sort*) = 0; + virtual void visit(const Break*) = 0; }; @@ -150,10 +152,11 @@ class IRVisitor : public IRVisitorStrict { virtual void visit(const Free* op); virtual void visit(const Comment* op); virtual void visit(const BlankLine* op); - virtual void visit(const Break* op); + virtual void visit(const Continue* op); virtual void visit(const Print* op); virtual void visit(const GetProperty* op); virtual void visit(const Sort* op); + virtual void visit(const Break* op); }; }} diff --git a/src/codegen/codegen_cuda.cpp b/src/codegen/codegen_cuda.cpp index 44e19c36d..d0c69ffd8 100644 --- a/src/codegen/codegen_cuda.cpp +++ b/src/codegen/codegen_cuda.cpp @@ -1142,7 +1142,7 @@ void CodeGen_CUDA::visit(const Sqrt* op) { stream << ")"; } -void CodeGen_CUDA::visit(const Break*) { +void CodeGen_CUDA::visit(const Continue*) { doIndent(); if(!isHostFunction && deviceFunctionLoopDepth == 0) { // can't break out of kernel diff --git a/src/codegen/codegen_cuda.h b/src/codegen/codegen_cuda.h index 2bc8e000d..4614f6eda 100644 --- a/src/codegen/codegen_cuda.h +++ b/src/codegen/codegen_cuda.h @@ -46,7 +46,7 @@ class CodeGen_CUDA : public CodeGen { void visit(const Call*); void visit(const Store*); void visit(const Assign*); - void visit(const Break*); + void visit(const Continue*); void visit(const Free* op); std::string printDeviceFuncName(const std::vector> currentParameters, int index); void printDeviceFuncCall(const std::vector> currentParameters, Expr blockSize, int index, Expr gridSize); diff --git a/src/ir/ir.cpp b/src/ir/ir.cpp index a714dddfc..6c5dd8fcb 100644 --- a/src/ir/ir.cpp +++ b/src/ir/ir.cpp @@ -786,6 +786,11 @@ Stmt BlankLine::make() { return new BlankLine; } +// Continue +Stmt Continue::make() { + return new Continue; +} + // Break Stmt Break::make() { return new Break; @@ -954,14 +959,16 @@ template<> void StmtNode::accept(IRVisitorStrict *v) const { v->visit((const Comment*)this); } template<> void StmtNode::accept(IRVisitorStrict *v) const { v->visit((const BlankLine*)this); } -template<> void StmtNode::accept(IRVisitorStrict *v) - const { v->visit((const Break*)this); } +template<> void StmtNode::accept(IRVisitorStrict *v) + const { v->visit((const Continue*)this); } template<> void StmtNode::accept(IRVisitorStrict *v) const { v->visit((const Print*)this); } template<> void ExprNode::accept(IRVisitorStrict *v) const { v->visit((const GetProperty*)this); } template<> void StmtNode::accept(IRVisitorStrict *v) const { v->visit((const Sort*)this); } +template<> void StmtNode::accept(IRVisitorStrict *v) + const { v->visit((const Break*)this); } // printing methods std::ostream& operator<<(std::ostream& os, const Stmt& stmt) { diff --git a/src/ir/ir_printer.cpp b/src/ir/ir_printer.cpp index 879716793..044c8f4fa 100644 --- a/src/ir/ir_printer.cpp +++ b/src/ir/ir_printer.cpp @@ -558,9 +558,14 @@ void IRPrinter::visit(const BlankLine*) { stream << endl; } +void IRPrinter::visit(const Continue*) { + doIndent(); + stream << "continue;" << endl; +} + void IRPrinter::visit(const Break*) { doIndent(); - stream << "continue;" << endl; // TODO: add continue statement + stream << "break;" << endl; } void IRPrinter::visit(const Print* op) { diff --git a/src/ir/ir_rewriter.cpp b/src/ir/ir_rewriter.cpp index fd1423a00..eed6f2bab 100644 --- a/src/ir/ir_rewriter.cpp +++ b/src/ir/ir_rewriter.cpp @@ -447,6 +447,10 @@ void IRRewriter::visit(const BlankLine* op) { stmt = op; } +void IRRewriter::visit(const Continue* op) { + stmt = op; +} + void IRRewriter::visit(const Break* op) { stmt = op; } diff --git a/src/ir/ir_visitor.cpp b/src/ir/ir_visitor.cpp index 19fbfbfdf..96eaee348 100644 --- a/src/ir/ir_visitor.cpp +++ b/src/ir/ir_visitor.cpp @@ -228,6 +228,9 @@ void IRVisitor::visit(const Comment*) { void IRVisitor::visit(const BlankLine*) { } +void IRVisitor::visit(const Continue*) { +} + void IRVisitor::visit(const Break*) { } diff --git a/src/lower/lowerer_impl.cpp b/src/lower/lowerer_impl.cpp index 17a4dab3b..6925340ec 100644 --- a/src/lower/lowerer_impl.cpp +++ b/src/lower/lowerer_impl.cpp @@ -410,7 +410,7 @@ Stmt LowererImpl::lowerForall(Forall forall) if (isa(ir::simplify(iterBounds[0])) && ir::simplify(iterBounds[0]).as()->equalsScalar(0)) { guardCondition = maxGuard; } - ir::Stmt guard = ir::IfThenElse::make(guardCondition, ir::Break::make()); + ir::Stmt guard = ir::IfThenElse::make(guardCondition, ir::Continue::make()); recoverySteps.push_back(guard); } @@ -438,7 +438,7 @@ Stmt LowererImpl::lowerForall(Forall forall) } if (!hasDirectDivBound) { Stmt guard = IfThenElse::make(Gte::make(indexVarToExprMap[varToRecover], underivedBounds[varToRecover][1]), - Break::make()); + Continue::make()); recoverySteps.push_back(guard); } }