diff --git a/ODLA/platforms/tensorrt/odla_tensorrt.cc b/ODLA/platforms/tensorrt/odla_tensorrt.cc index 89394f35e..e090ef296 100644 --- a/ODLA/platforms/tensorrt/odla_tensorrt.cc +++ b/ODLA/platforms/tensorrt/odla_tensorrt.cc @@ -1614,6 +1614,22 @@ odla_value odla_Concat(odla_values inputs, odla_int32 axis, return CreateValue(concat, output_type, id); } +#if NV_TENSORRT_MAJOR > 8 || (NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR >= 2) +odla_value odla_Einsum(odla_values inputs, const odla_char* equation, + odla_value_shape output_dims, const odla_value_id id) { + int num = inputs.size; + std::vector input_tensors(num); + for (int i = 0; i < num; ++i) { + input_tensors[i] = inputs.values[i]->tensor; + } + auto ret = g_comp->network->addEinsum(input_tensors.data(), num, equation); + odla_value_type output_type{ + .element_type = inputs.values[0]->type.element_type, + .shape = output_dims}; + return CreateValue(ret, output_type, id); +} +#endif + odla_value odla_MaxPool(odla_value input, odla_memory_layout input_layout, const odla_uint32* window_dims, const odla_uint32* strides, diff --git a/lib/transforms/type_legalizer.cc b/lib/transforms/type_legalizer.cc index 49cc09895..36cf94b8d 100644 --- a/lib/transforms/type_legalizer.cc +++ b/lib/transforms/type_legalizer.cc @@ -1097,6 +1097,103 @@ static void RunOnInstruction(SetDiff1DInst* inst) { ; } +static void RunOnInstruction(EinsumInst* inst) { + int n = inst->GetNumOfOperands(); + for (int i = 0; i < n; ++i) { + if (!inst->GetOperand(i).GetType().IsValid()) { + return; + } + } + HLCHECK(n >= 1); + std::vector terms; + const std::string equ = inst->GetEquation(); + std::unordered_map char2dim; + bool has_output = false; + for (int i = 0, e = equ.size(), new_term = 1; i < e; ++i) { + auto c = equ[i]; + if (new_term == 1) { + terms.push_back(""); + new_term = 0; + } + if (c == ' ') { + continue; + } + if (c == '.') { + HLCHECK(i + 2 < e && equ[i + 1] == '.' && equ[i + 2] == '.'); + terms.back().push_back('.'); + i += 2; + } + if (c == ',') { + new_term = 1; + continue; + } + if (c == '-') { + HLCHECK(i + 1 < e && equ[i + 1] == '>'); + HLCHECK(!has_output); + i += 1; + has_output = true; + new_term = 1; + continue; + } + if (std::isalpha(c) != 0) { + std::string& term = terms.back(); + term.push_back(c); + } + } + + int num_terms = terms.size(); + auto elem_ty = inst->GetOperand(0).GetType().GetDataType(); + if (!has_output) { + HLCHECK(num_terms == n); + inst->GetResultsTypes()[0] = Type{elem_ty, {}}; + return; + } + + HLCHECK(num_terms == n + 1); + // Setup character to dimension mapping for inputs. + std::vector ellipsis_dims; + for (int i = 0; i < n; ++i) { + const auto& ty = inst->GetOperand(i).GetType(); + unsigned rank = ty.GetNumOfDims(); + const auto& term = terms[i]; + HLCHECK(term.size() <= rank); + for (unsigned j = 0, s = term.size(), dim_idx = 0; j < s; ++j, ++dim_idx) { + char c = terms[i][j]; + if (c == '.') { + bool init = ellipsis_dims.empty(); + for (unsigned k = 0; k < rank - term.size(); ++k) { + auto v = ty.GetNumOfElementsInDim(dim_idx++); + if (init) { + ellipsis_dims.push_back(v); + } else { + HLCHECK(ellipsis_dims[k] == v); + } + } + continue; + } + int64_t d = ty.GetNumOfElementsInDim(dim_idx); + if (char2dim.count(c) == 0) { + char2dim[c] = d; + } else { + HLCHECK(char2dim[c] == d); + } + } + } + const auto& out_term = terms.back(); + std::vector out_shape; + out_shape.reserve(out_term.size()); + for (auto c : out_term) { + if (c == '.') { + out_shape.insert(out_shape.end(), ellipsis_dims.begin(), + ellipsis_dims.end()); + } else { + HLCHECK(char2dim.count(c) > 0); + out_shape.push_back(char2dim[c]); + } + } + inst->GetResultsTypes()[0] = Type{elem_ty, out_shape}; +} + static void RunOnInstruction(ExpandDimsInst* inst) { const auto& idx_op = inst->GetOperand(1); const auto& idx_type = idx_op.GetType();