diff --git a/python/heterocl/compute_api.py b/python/heterocl/compute_api.py index 620766fd5..5401b8bd5 100644 --- a/python/heterocl/compute_api.py +++ b/python/heterocl/compute_api.py @@ -165,6 +165,11 @@ def compute_body(name, index, _, _ = get_index(shape, indices, 0) stage.emit(_make.Store(buffer_var, _make.Cast(dtype, ret), index)) stmt = make_for(indices, stage.pop_stmt(), 0, name) + elif isinstance(ret, str): + indices = lambda_ivs + index, _, _ = get_index(shape, indices, 0) + stage.emit(_make.Store(buffer_var, _make.CastStr(dtype, ret), index)) + stmt = make_for(indices, stage.pop_stmt(), 0, name) elif isinstance(ret, Tensor): # reduction ret_ivs = [_IterVar((0, ret.shape[i]), ret.name+"_i" + str(i), 0) for i in range(0, len(ret.shape))] diff --git a/tests/test_scalar.py b/tests/test_scalar.py new file mode 100644 index 000000000..1c276590f --- /dev/null +++ b/tests/test_scalar.py @@ -0,0 +1,167 @@ +import heterocl as hcl +import numpy as np +hcl.init() + +def test_int7(A): + def doit(x): + x = 0xFA_FF00_FFFF + v = hcl.scalar("0x4A", "v", dtype=hcl.Int(7)) + return v.v + return hcl.compute(A.shape, lambda i: doit(i), "doit", dtype=hcl.Int(7)) + +def test_uint7(A): + def doit(x): + x = 0xFA_FF00_FFFF + v = hcl.scalar("0x4A", "v", dtype=hcl.UInt(7)) + return v.v + return hcl.compute(A.shape, lambda i: doit(i), "doit", dtype=hcl.UInt(7)) + +def test_int15(A): + def doit(x): + x = 0xFA_FF00_FFFF + v = hcl.scalar("0x40FF", "v", dtype=hcl.Int(15)) + return v.v + return hcl.compute(A.shape, lambda i: doit(i), "doit", dtype=hcl.Int(15)) + +def test_uint15(A): + def doit(x): + x = 0xFA_FF00_FFFF + v = hcl.scalar("0x40FF", "v", dtype=hcl.UInt(15)) + return v.v + return hcl.compute(A.shape, lambda i: doit(i), "doit", dtype=hcl.UInt(15)) + +def test_int31(A): + def doit(x): + x = 0xFA_FF00_FFFF + v = hcl.scalar("0x4F0000FF", "v", dtype=hcl.Int(31)) + return v.v + return hcl.compute(A.shape, lambda i: doit(i), "doit", dtype=hcl.Int(31)) + +def test_uint31(A): + def doit(x): + x = 0xFA_FF00_FFFF + v = hcl.scalar("0x4F0000FF", "v", dtype=hcl.UInt(31)) + return v.v + return hcl.compute(A.shape, lambda i: doit(i), "doit", dtype=hcl.UInt(31)) + +def test_int62(A): + def doit(x): + x = 0xFA_FF00_FFFF + v = hcl.scalar("0x2A000000FF0000FF", "v", dtype=hcl.Int(62)) + return v.v + return hcl.compute(A.shape, lambda i: doit(i), "doit", dtype=hcl.Int(62)) + +def test_uint62(A): + def doit(x): + x = 0xFA_FF00_FFFF + v = hcl.scalar("0x2A000000FF0000FF", "v", dtype=hcl.UInt(62)) + return v.v + return hcl.compute(A.shape, lambda i: doit(i), "doit", dtype=hcl.UInt(62)) + +A_int7 = hcl.placeholder((1,), "A_int7", dtype=hcl.Int(7)) +s_int7 = hcl.create_schedule([A_int7], test_int7) +m_int7 = hcl.build (s_int7) + +A_uint7 = hcl.placeholder((1,), "A_uint7", dtype=hcl.UInt(7)) +s_uint7 = hcl.create_schedule([A_uint7], test_uint7) +m_uint7 = hcl.build (s_uint7) + +A_int15 = hcl.placeholder((1,), "A_int15", dtype=hcl.Int(15)) +s_int15 = hcl.create_schedule([A_int15], test_int15) +m_int15 = hcl.build (s_int15) + +A_uint15 = hcl.placeholder((1,), "A_uint15", dtype=hcl.UInt(15)) +s_uint15 = hcl.create_schedule([A_uint15], test_uint15) +m_uint15 = hcl.build (s_uint15) + +A_int31 = hcl.placeholder((1,), "A_int31", dtype=hcl.Int(31)) +s_int31 = hcl.create_schedule([A_int31], test_int31) +m_int31 = hcl.build (s_int31) + +A_uint31 = hcl.placeholder((1,), "A_uint31", dtype=hcl.UInt(31)) +s_uint31 = hcl.create_schedule([A_uint31], test_uint31) +m_uint31 = hcl.build (s_uint31) + +A_int62 = hcl.placeholder((1,), "A_int62", dtype=hcl.Int(62)) +s_int62 = hcl.create_schedule([A_int62], test_int62) +m_int62 = hcl.build (s_int62) + +A_uint62 = hcl.placeholder((1,), "A_uint62", dtype=hcl.UInt(62)) +s_uint62 = hcl.create_schedule([A_uint62], test_uint62) +m_uint62 = hcl.build (s_uint62) + +A_int7 = hcl.asarray([0xA0A0], dtype=A_int7.dtype) +R_int7 = hcl.asarray([99], dtype=hcl.Int(7)) +m_int7(A_int7, R_int7) + +A_uint7 = hcl.asarray([0xA0A0], dtype=A_uint7.dtype) +R_uint7 = hcl.asarray([99], dtype=hcl.UInt(7)) +m_uint7(A_uint7, R_uint7) + +A_int15 = hcl.asarray([0xA0A0], dtype=A_int15.dtype) +R_int15 = hcl.asarray([99], dtype=hcl.Int(15)) +m_int15(A_int15, R_int15) + +A_uint15 = hcl.asarray([0xA0A0], dtype=A_uint15.dtype) +R_uint15 = hcl.asarray([99], dtype=hcl.UInt(15)) +m_uint15(A_uint15, R_uint15) + +A_int31 = hcl.asarray([0xA0A0], dtype=A_int31.dtype) +R_int31 = hcl.asarray([99], dtype=hcl.Int(31)) +m_int31(A_int31, R_int31) + +A_uint31 = hcl.asarray([0xA0A0], dtype=A_uint31.dtype) +R_uint31 = hcl.asarray([99], dtype=hcl.UInt(31)) +m_uint31(A_uint31, R_uint31) + +A_int62 = hcl.asarray([0xA0A0], dtype=A_int62.dtype) +R_int62 = hcl.asarray([99], dtype=hcl.Int(62)) +m_int62(A_int62, R_int62) + +A_uint62 = hcl.asarray([0xA0A0], dtype=A_uint62.dtype) +R_uint62 = hcl.asarray([99], dtype=hcl.UInt(62)) +m_uint62(A_uint62, R_uint62) + +print(f"R_int7 = {[bin(i) for i in R_int7.asnumpy()]}") +print(f"R_uint7 = {[hex(i) for i in R_uint7.asnumpy()]}") +print(f"R_int15 = {[hex(i) for i in R_int15.asnumpy()]}") +print(f"R_uint15 = {[hex(i) for i in R_uint15.asnumpy()]}") +print(f"R_int31 = {[hex(i) for i in R_int31.asnumpy()]}") +print(f"R_uint31 = {[hex(i) for i in R_uint31.asnumpy()]}") +print(f"R_int62 = {[hex(i) for i in R_int62.asnumpy()]}") +print(f"R_uint62 = {[hex(i) for i in R_uint62.asnumpy()]}") + +def test_int127_lower(A): + def doit(x): + x = 0xFA_FF00_FFFF + v = hcl.scalar("0x8ABCFFFFFFBAFFFA000A", "v", dtype=hcl.UInt(127)) + b = hcl.scalar(v >> 64, "b", dtype=hcl.UInt(63)) + c = hcl.scalar(v & 0xFFFFFFFFFFFFFFFF, "c", dtype=hcl.UInt(64)) + return c.v + return hcl.compute(A.shape, lambda i: doit(i), "doit", dtype=hcl.UInt(64)) + +def test_int127_upper(A): + def doit(x): + x = 0xFA_FF00_FFFF + v = hcl.scalar("0x8ABCFFFFFFBAFFFA000A", "v", dtype=hcl.UInt(127)) + b = hcl.scalar(v >> 64, "b", dtype=hcl.UInt(63)) + c = hcl.scalar(v & 0x7FFFFFFFFFFFFFFF, "c", dtype=hcl.UInt(64)) + return b.v + return hcl.compute(A.shape, lambda i: doit(i), "doit", dtype=hcl.UInt(63)) + +A = hcl.placeholder((1,), "A", dtype=hcl.UInt(63)) +s_lower = hcl.create_schedule([A], test_int127_lower) +s_upper = hcl.create_schedule([A], test_int127_upper) +m_lower = hcl.build(s_lower) +m_upper = hcl.build(s_upper) + +hcl_A = hcl.asarray([0], hcl.UInt(63)) +hcl_R_lower = hcl.asarray([0], hcl.UInt(64)) +hcl_R_upper = hcl.asarray([0], hcl.UInt(63)) + +m_lower(hcl_A, hcl_R_lower) +m_upper(hcl_A, hcl_R_upper) + +print(f"hcl_R_lower = {[hex(i) for i in hcl_R_lower.asnumpy()]}") +print(f"hcl_R_upper = {[hex(i) for i in hcl_R_upper.asnumpy()]}") + diff --git a/tvm/HalideIR/src/ir/Expr.h b/tvm/HalideIR/src/ir/Expr.h index bc5485046..60a991924 100644 --- a/tvm/HalideIR/src/ir/Expr.h +++ b/tvm/HalideIR/src/ir/Expr.h @@ -38,6 +38,7 @@ enum class IRNodeType : int { FloatImm, StringImm, Cast, + CastStr, Variable, Add, Sub, diff --git a/tvm/HalideIR/src/ir/IR.cpp b/tvm/HalideIR/src/ir/IR.cpp index 53f8ff89a..bf30618a2 100644 --- a/tvm/HalideIR/src/ir/IR.cpp +++ b/tvm/HalideIR/src/ir/IR.cpp @@ -101,6 +101,13 @@ Expr Cast::make(Type t, Expr v) { return Expr(node); } +Expr CastStr::make(Type t, const std::string &val) { + std::shared_ptr node = std::make_shared(); + node->type = t; + node->value = val; + return Expr(node); +} + Expr And::make(Expr a, Expr b) { internal_assert(a.defined()) << "And of undefined\n"; internal_assert(b.defined()) << "And of undefined\n"; @@ -1050,6 +1057,10 @@ void ExprNode::accept(IRVisitor *v, const Expr &e) const { v->visit((const Cast *)this, e); } template <> +void ExprNode::accept(IRVisitor *v, const Expr &e) const { + v->visit((const CastStr *)this, e); +} +template <> void ExprNode::accept(IRVisitor *v, const Expr &e) const { v->visit((const Variable *)this, e); } diff --git a/tvm/HalideIR/src/ir/IR.h b/tvm/HalideIR/src/ir/IR.h index 3e4004743..29cb41de0 100644 --- a/tvm/HalideIR/src/ir/IR.h +++ b/tvm/HalideIR/src/ir/IR.h @@ -103,6 +103,19 @@ struct Cast : public ExprNode { static constexpr const char *_type_key = "Cast"; }; +/** Cast a node from string to other datatype. */ +struct CastStr : public ExprNode { + std::string value; + EXPORT static Expr make(Type t, const std::string &val); + void VisitAttrs(IR::AttrVisitor *v) final { + v->Visit("dtype", &type); + v->Visit("value", &value); + } + static const IRNodeType _type_info = IRNodeType::CastStr; + static constexpr const char *_type_key = "CastStr"; +}; + + /** base class of all Binary arithematic ops */ template struct BinaryOpNode : public ExprNode { diff --git a/tvm/HalideIR/src/ir/IREquality.cpp b/tvm/HalideIR/src/ir/IREquality.cpp index 52e7c4b9c..e57a9b839 100644 --- a/tvm/HalideIR/src/ir/IREquality.cpp +++ b/tvm/HalideIR/src/ir/IREquality.cpp @@ -60,6 +60,7 @@ class IRComparer : public IRVisitor { void visit(const FloatImm *, const Expr &); void visit(const StringImm *, const Expr &); void visit(const Cast *, const Expr &); + void visit(const CastStr *, const Expr &); void visit(const Variable *, const Expr &); void visit(const Add *, const Expr &); void visit(const Sub *, const Expr &); @@ -293,6 +294,10 @@ void IRComparer::visit(const StringImm *op, const Expr &e) { compare_names(node->value, op->value); } +void IRComparer::visit(const CastStr *op, const Expr &e) { + const CastStr *node = expr_.as(); + compare_names(node->value, op->value); +} void IRComparer::visit(const Cast *op, const Expr &e) { compare_expr(expr_.as()->value, op->value); } diff --git a/tvm/HalideIR/src/ir/IRMutator.cpp b/tvm/HalideIR/src/ir/IRMutator.cpp index ad2fcad01..3410710c1 100644 --- a/tvm/HalideIR/src/ir/IRMutator.cpp +++ b/tvm/HalideIR/src/ir/IRMutator.cpp @@ -43,6 +43,15 @@ void IRMutator::visit(const Cast *op, const Expr &e) { } } +void IRMutator::visit(const CastStr *op, const Expr &e) { + std::string value = op->value; + if (value == op->value) { + expr = e; + } else { + expr = CastStr::make(op->type, value); + } +} + // use macro to access private function. #define MUTATE_BINARY_OP(op, e, T) \ Expr a = mutate(op->a); \ diff --git a/tvm/HalideIR/src/ir/IRMutator.h b/tvm/HalideIR/src/ir/IRMutator.h index 6b5649b2c..421f85100 100644 --- a/tvm/HalideIR/src/ir/IRMutator.h +++ b/tvm/HalideIR/src/ir/IRMutator.h @@ -49,6 +49,7 @@ class IRMutator : public IRVisitor { EXPORT virtual void visit(const FloatImm *, const Expr &); EXPORT virtual void visit(const StringImm *, const Expr &); EXPORT virtual void visit(const Cast *, const Expr &); + EXPORT virtual void visit(const CastStr *, const Expr &); EXPORT virtual void visit(const Variable *, const Expr &); EXPORT virtual void visit(const Add *, const Expr &); EXPORT virtual void visit(const Sub *, const Expr &); diff --git a/tvm/HalideIR/src/ir/IRPrinter.cpp b/tvm/HalideIR/src/ir/IRPrinter.cpp index ac20c0320..f1f8f33d8 100644 --- a/tvm/HalideIR/src/ir/IRPrinter.cpp +++ b/tvm/HalideIR/src/ir/IRPrinter.cpp @@ -196,6 +196,41 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->print(op->value); p->stream << ')'; }) + .set_dispatch([](const CastStr *op, IRPrinter *p) { + p->stream << op->type << '('; + auto &stream = p->stream; + stream << '"'; + for (size_t i = 0; i < op->value.size(); i++) { + unsigned char c = op->value[i]; + if (c >= ' ' && c <= '~' && c != '\\' && c != '"') { + stream << c; + } else { + stream << '\\'; + switch (c) { + case '"': + stream << '"'; + break; + case '\\': + stream << '\\'; + break; + case '\t': + stream << 't'; + break; + case '\r': + stream << 'r'; + break; + case '\n': + stream << 'n'; + break; + default: + string hex_digits = "0123456789ABCDEF"; + stream << 'x' << hex_digits[c >> 4] << hex_digits[c & 0xf]; + } + } + } + stream << '"'; + p->stream << ')'; + }) .set_dispatch([](const Variable *op, IRPrinter *p) { // omit the type // stream << op->name << "." << op->type; diff --git a/tvm/HalideIR/src/ir/IRVisitor.cpp b/tvm/HalideIR/src/ir/IRVisitor.cpp index 1a3f2d5c3..bb6e19c8e 100644 --- a/tvm/HalideIR/src/ir/IRVisitor.cpp +++ b/tvm/HalideIR/src/ir/IRVisitor.cpp @@ -17,6 +17,7 @@ void IRVisitor::visit(const FloatImm *, const Expr &) {} void IRVisitor::visit(const StringImm *, const Expr &) {} void IRVisitor::visit(const Cast *op, const Expr &) { op->value.accept(this); } +void IRVisitor::visit(const CastStr *op, const Expr &) { } void IRVisitor::visit(const Variable *, const Expr &) {} @@ -356,6 +357,8 @@ void IRGraphVisitor::visit(const StringImm *, const Expr &) {} void IRGraphVisitor::visit(const Cast *op, const Expr &) { include(op->value); } +void IRGraphVisitor::visit(const CastStr *op, const Expr &) { } + void IRGraphVisitor::visit(const Variable *op, const Expr &) {} void IRGraphVisitor::visit(const Add *op, const Expr &) { diff --git a/tvm/HalideIR/src/ir/IRVisitor.h b/tvm/HalideIR/src/ir/IRVisitor.h index 0cdda2a67..4f13d6c9a 100644 --- a/tvm/HalideIR/src/ir/IRVisitor.h +++ b/tvm/HalideIR/src/ir/IRVisitor.h @@ -30,6 +30,7 @@ class IRVisitor { EXPORT virtual void visit(const FloatImm *, const Expr &); EXPORT virtual void visit(const StringImm *, const Expr &); EXPORT virtual void visit(const Cast *, const Expr &); + EXPORT virtual void visit(const CastStr *, const Expr &); EXPORT virtual void visit(const Variable *, const Expr &); EXPORT virtual void visit(const Add *, const Expr &); EXPORT virtual void visit(const Sub *, const Expr &); @@ -116,6 +117,7 @@ class IRGraphVisitor : public IRVisitor { EXPORT virtual void visit(const FloatImm *, const Expr &); EXPORT virtual void visit(const StringImm *, const Expr &); EXPORT virtual void visit(const Cast *, const Expr &); + EXPORT virtual void visit(const CastStr *, const Expr &); EXPORT virtual void visit(const Variable *, const Expr &); EXPORT virtual void visit(const Add *, const Expr &); EXPORT virtual void visit(const Sub *, const Expr &); diff --git a/tvm/include/tvm/ir.h b/tvm/include/tvm/ir.h index 03fdb7d0a..c0988e15f 100644 --- a/tvm/include/tvm/ir.h +++ b/tvm/include/tvm/ir.h @@ -483,6 +483,7 @@ using Halide::Internal::Break; using Halide::Internal::Broadcast; using Halide::Internal::Call; using Halide::Internal::Cast; +using Halide::Internal::CastStr; using Halide::Internal::Div; using Halide::Internal::EQ; using Halide::Internal::Evaluate; diff --git a/tvm/include/tvm/ir_functor_ext.h b/tvm/include/tvm/ir_functor_ext.h index d18c0fa28..eb8879830 100644 --- a/tvm/include/tvm/ir_functor_ext.h +++ b/tvm/include/tvm/ir_functor_ext.h @@ -132,6 +132,7 @@ class ExprFunctor { virtual R VisitExpr_(const Or* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const Reduce* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const Cast* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const CastStr* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const Not* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const Select* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const Ramp* op, Args... args) EXPR_FUNCTOR_DEFAULT; @@ -179,6 +180,7 @@ class ExprFunctor { IR_EXPR_FUNCTOR_DISPATCH(Or); IR_EXPR_FUNCTOR_DISPATCH(Reduce); IR_EXPR_FUNCTOR_DISPATCH(Cast); + IR_EXPR_FUNCTOR_DISPATCH(CastStr); IR_EXPR_FUNCTOR_DISPATCH(Not); IR_EXPR_FUNCTOR_DISPATCH(Select); IR_EXPR_FUNCTOR_DISPATCH(Ramp); diff --git a/tvm/include/tvm/ir_mutator.h b/tvm/include/tvm/ir_mutator.h index f7c08d60d..5d34fe2b4 100644 --- a/tvm/include/tvm/ir_mutator.h +++ b/tvm/include/tvm/ir_mutator.h @@ -104,6 +104,7 @@ class TVM_DLL IRMutator { virtual Expr Mutate_(const Or* op, const Expr& e); virtual Expr Mutate_(const Reduce* op, const Expr& e); virtual Expr Mutate_(const Cast* op, const Expr& e); + virtual Expr Mutate_(const CastStr* op, const Expr& e); virtual Expr Mutate_(const Not* op, const Expr& e); virtual Expr Mutate_(const Select* op, const Expr& e); virtual Expr Mutate_(const Ramp* op, const Expr& e); diff --git a/tvm/include/tvm/ir_visitor.h b/tvm/include/tvm/ir_visitor.h index 8e8525bee..d7c6e4867 100644 --- a/tvm/include/tvm/ir_visitor.h +++ b/tvm/include/tvm/ir_visitor.h @@ -109,6 +109,7 @@ class TVM_DLL IRVisitor { virtual void Visit_(const Or* op); virtual void Visit_(const Reduce* op); virtual void Visit_(const Cast* op); + virtual void Visit_(const CastStr* op); virtual void Visit_(const Not* op); virtual void Visit_(const Select* op); virtual void Visit_(const Ramp* op); diff --git a/tvm/src/api/api_ir.cc b/tvm/src/api/api_ir.cc index 9404f3d0c..5162f51bd 100644 --- a/tvm/src/api/api_ir.cc +++ b/tvm/src/api/api_ir.cc @@ -205,6 +205,7 @@ REGISTER_MAKE_BINARY_OP(Or); REGISTER_MAKE1(Not); REGISTER_MAKE3(Ramp); REGISTER_MAKE2(Cast); +REGISTER_MAKE2(CastStr); REGISTER_MAKE2(Broadcast); REGISTER_MAKE3(Let); REGISTER_MAKE3(LetStmt); diff --git a/tvm/src/codegen/build_module.cc b/tvm/src/codegen/build_module.cc index c6c88b04e..6b2c4658a 100644 --- a/tvm/src/codegen/build_module.cc +++ b/tvm/src/codegen/build_module.cc @@ -302,7 +302,6 @@ runtime::Module build(const Array& funcs, const Target& target, func = ir::CombineContextCall(func); fhost.Set(i, func); } - auto mhost = codegen::Build(fhost, target_host_val.str()); if (fdevice.size() > 0) { diff --git a/tvm/src/codegen/llvm/codegen_llvm.cc b/tvm/src/codegen/llvm/codegen_llvm.cc index 08ed0e041..df15e2484 100644 --- a/tvm/src/codegen/llvm/codegen_llvm.cc +++ b/tvm/src/codegen/llvm/codegen_llvm.cc @@ -573,6 +573,25 @@ llvm::Value* CodeGenLLVM::CreateCast(Type from, Type to, llvm::Value* value) { } } +llvm::Value* CodeGenLLVM::CreateCastStr(Type to, const std::string& str) { + llvm::Type* target = LLVMType(to); + llvm::StringRef radix_str = llvm::StringRef(str).substr(0, 2); + llvm::StringRef value_str = llvm::StringRef(str).substr(2); + + if (to.is_int()) { + unsigned numBits = to.bits(); + llvm::APInt apint = llvm::APInt(numBits, value_str, 16); + llvm::ConstantInt* cont = builder_->getInt(apint); + llvm::Value* ret = builder_->CreateTruncOrBitCast(cont, target); + return cont; + } else if (to.is_uint()) { + unsigned numBits = to.bits(); + llvm::APInt apint = llvm::APInt(numBits, value_str, 16); + llvm::ConstantInt* cont = builder_->getInt(apint); + return cont; + } +} + llvm::Value* CodeGenLLVM::GetConstString(const std::string& str) { auto it = str_map_.find(str); if (it != str_map_.end()) return it->second; @@ -789,6 +808,11 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Cast* op) { CreateCast(op->value.type(), op->type, MakeValue(op->value)); return val; } +llvm::Value* CodeGenLLVM::VisitExpr_(const CastStr* op) { + llvm::Value* val = + CreateCastStr(op->type, op->value); + return val; +} llvm::Value* CodeGenLLVM::VisitExpr_(const IntImm* op) { llvm::Value* val = llvm::ConstantInt::getSigned(LLVMType(op->type), op->value); diff --git a/tvm/src/codegen/llvm/codegen_llvm.h b/tvm/src/codegen/llvm/codegen_llvm.h index 733df2177..2ad9eedfd 100644 --- a/tvm/src/codegen/llvm/codegen_llvm.h +++ b/tvm/src/codegen/llvm/codegen_llvm.h @@ -80,6 +80,7 @@ class CodeGenLLVM : public ExprFunctor, // override codegen llvm::Value* VisitExpr_(const Variable* op) override; llvm::Value* VisitExpr_(const Cast* op) override; + llvm::Value* VisitExpr_(const CastStr* op) override; llvm::Value* VisitExpr_(const IntImm* op) override; llvm::Value* VisitExpr_(const UIntImm* op) override; llvm::Value* VisitExpr_(const FloatImm* op) override; @@ -192,6 +193,8 @@ class CodeGenLLVM : public ExprFunctor, const std::vector& args); // cast operatpr llvm::Value* CreateCast(Type from, Type to, llvm::Value* value); + // caststr operator + llvm::Value* CreateCastStr(Type to, const std::string& str); // comparison op llvm::Value* GetVarValue(const Variable* v) const; llvm::Value* CreateLT(Type t, llvm::Value* a, llvm::Value* b); diff --git a/tvm/src/codegen/llvm/llvm_common.h b/tvm/src/codegen/llvm/llvm_common.h index 8328ab26f..a0dfe807f 100644 --- a/tvm/src/codegen/llvm/llvm_common.h +++ b/tvm/src/codegen/llvm/llvm_common.h @@ -34,6 +34,7 @@ #include #include +#include #include #include #include diff --git a/tvm/src/codegen/llvm/llvm_module.cc b/tvm/src/codegen/llvm/llvm_module.cc index d1e6c624b..0a2616d1e 100644 --- a/tvm/src/codegen/llvm/llvm_module.cc +++ b/tvm/src/codegen/llvm/llvm_module.cc @@ -150,7 +150,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { llvm::MDString::get(*ctx_, target)); target_ = target; mptr_ = module_.get(); - // this->SaveToFile("test.ll", "ll"); + this->SaveToFile("test.ll", "ll"); } void LoadIR(const std::string& file_name) { @@ -159,7 +159,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { llvm::SMDiagnostic err; module_ = llvm::parseIRFile(file_name, err, *ctx_); if (module_.get() == nullptr) { - std::string msg = err.getMessage(); + std::string msg = err.getMessage().str(); LOG(FATAL) << "Fail to load ir file " << file_name << "\n" << "line " << err.getLineNo() << ":" << msg; } @@ -168,7 +168,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { if (mtarget != nullptr) { llvm::MDString* pstr = llvm::dyn_cast(mtarget); CHECK(pstr != nullptr); - target_ = pstr->getString(); + target_ = pstr->getString().str(); } else { std::ostringstream os; os << "llvm -target " << module_->getTargetTriple(); @@ -278,6 +278,7 @@ TVM_REGISTER_API("codegen.llvm_target_enabled") InitializeLLVM(); *rv = (GetLLVMTargetMachine(args[0], true) != nullptr); }); + } // namespace codegen } // namespace TVM #endif // TVM_LLVM_VERSION diff --git a/tvm/src/lang/ir.cc b/tvm/src/lang/ir.cc index 3f8a9cbdb..61020226a 100644 --- a/tvm/src/lang/ir.cc +++ b/tvm/src/lang/ir.cc @@ -102,6 +102,7 @@ TVM_REGISTER_NODE_TYPE(IntImm); TVM_REGISTER_NODE_TYPE(UIntImm); TVM_REGISTER_NODE_TYPE(StringImm); TVM_REGISTER_NODE_TYPE(Cast); +TVM_REGISTER_NODE_TYPE(CastStr); TVM_REGISTER_NODE_TYPE(Variable); TVM_REGISTER_NODE_TYPE(Add); TVM_REGISTER_NODE_TYPE(Sub); diff --git a/tvm/src/pass/ir_mutator.cc b/tvm/src/pass/ir_mutator.cc index e16491a23..24648b218 100644 --- a/tvm/src/pass/ir_mutator.cc +++ b/tvm/src/pass/ir_mutator.cc @@ -541,6 +541,15 @@ Expr IRMutator::Mutate_(const Cast* op, const Expr& e) { } } +Expr IRMutator::Mutate_(const CastStr* op, const Expr& e) { + std::string value = op->value; + if (value == op->value) { + return e; + } else { + return CastStr::make(op->type, value); + } +} + Expr IRMutator::Mutate_(const Not* op, const Expr& e) { Expr a = this->Mutate(op->a); if (a.same_as(op->a)) { @@ -686,6 +695,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) .DISPATCH_TO_MUTATE_EXPR(Or) .DISPATCH_TO_MUTATE_EXPR(Reduce) .DISPATCH_TO_MUTATE_EXPR(Cast) + .DISPATCH_TO_MUTATE_EXPR(CastStr) .DISPATCH_TO_MUTATE_EXPR(Not) .DISPATCH_TO_MUTATE_EXPR(Select) .DISPATCH_TO_MUTATE_EXPR(Ramp) diff --git a/tvm/src/pass/ir_visitor.cc b/tvm/src/pass/ir_visitor.cc index 58e8967ef..55ac140c9 100644 --- a/tvm/src/pass/ir_visitor.cc +++ b/tvm/src/pass/ir_visitor.cc @@ -291,6 +291,7 @@ DEFINE_OP_NO_VISIT_(IntImm) DEFINE_OP_NO_VISIT_(UIntImm) DEFINE_OP_NO_VISIT_(FloatImm) DEFINE_OP_NO_VISIT_(StringImm) +DEFINE_OP_NO_VISIT_(CastStr) #define DISPATCH_TO_VISIT(OP) \ set_dispatch([](const OP *op, IRVisitor *v) { v->Visit_(op); }) @@ -324,6 +325,7 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) .DISPATCH_TO_VISIT(Or) .DISPATCH_TO_VISIT(Reduce) .DISPATCH_TO_VISIT(Cast) + .DISPATCH_TO_VISIT(CastStr) .DISPATCH_TO_VISIT(Not) .DISPATCH_TO_VISIT(Select) .DISPATCH_TO_VISIT(Ramp)