Skip to content

Commit

Permalink
Einsum: Add shape inferencing and odla trt implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Weiming Zhao authored and weimingzha0 committed Nov 29, 2021
1 parent 4e6dc0d commit edf5770
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 0 deletions.
16 changes: 16 additions & 0 deletions ODLA/platforms/tensorrt/odla_tensorrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1614,6 +1614,22 @@ odla_value odla_Concat(odla_values inputs, odla_int32 axis,
return CreateValue(concat, output_type, id);
}

#if NV_TENSORRT_MAJOR > 8 || (NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR >= 2)
odla_value odla_Einsum(odla_values inputs, const odla_char* equation,
odla_value_shape output_dims, const odla_value_id id) {
int num = inputs.size;
std::vector<nvinfer1::ITensor*> input_tensors(num);
for (int i = 0; i < num; ++i) {
input_tensors[i] = inputs.values[i]->tensor;
}
auto ret = g_comp->network->addEinsum(input_tensors.data(), num, equation);
odla_value_type output_type{
.element_type = inputs.values[0]->type.element_type,
.shape = output_dims};
return CreateValue(ret, output_type, id);
}
#endif

odla_value odla_MaxPool(odla_value input, odla_memory_layout input_layout,
const odla_uint32* window_dims,
const odla_uint32* strides,
Expand Down
97 changes: 97 additions & 0 deletions lib/transforms/type_legalizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1097,6 +1097,103 @@ static void RunOnInstruction(SetDiff1DInst* inst) {
;
}

static void RunOnInstruction(EinsumInst* inst) {
int n = inst->GetNumOfOperands();
for (int i = 0; i < n; ++i) {
if (!inst->GetOperand(i).GetType().IsValid()) {
return;
}
}
HLCHECK(n >= 1);
std::vector<std::string> terms;
const std::string equ = inst->GetEquation();
std::unordered_map<char, int64_t> char2dim;
bool has_output = false;
for (int i = 0, e = equ.size(), new_term = 1; i < e; ++i) {
auto c = equ[i];
if (new_term == 1) {
terms.push_back("");
new_term = 0;
}
if (c == ' ') {
continue;
}
if (c == '.') {
HLCHECK(i + 2 < e && equ[i + 1] == '.' && equ[i + 2] == '.');
terms.back().push_back('.');
i += 2;
}
if (c == ',') {
new_term = 1;
continue;
}
if (c == '-') {
HLCHECK(i + 1 < e && equ[i + 1] == '>');
HLCHECK(!has_output);
i += 1;
has_output = true;
new_term = 1;
continue;
}
if (std::isalpha(c) != 0) {
std::string& term = terms.back();
term.push_back(c);
}
}

int num_terms = terms.size();
auto elem_ty = inst->GetOperand(0).GetType().GetDataType();
if (!has_output) {
HLCHECK(num_terms == n);
inst->GetResultsTypes()[0] = Type{elem_ty, {}};
return;
}

HLCHECK(num_terms == n + 1);
// Setup character to dimension mapping for inputs.
std::vector<int64_t> ellipsis_dims;
for (int i = 0; i < n; ++i) {
const auto& ty = inst->GetOperand(i).GetType();
unsigned rank = ty.GetNumOfDims();
const auto& term = terms[i];
HLCHECK(term.size() <= rank);
for (unsigned j = 0, s = term.size(), dim_idx = 0; j < s; ++j, ++dim_idx) {
char c = terms[i][j];
if (c == '.') {
bool init = ellipsis_dims.empty();
for (unsigned k = 0; k < rank - term.size(); ++k) {
auto v = ty.GetNumOfElementsInDim(dim_idx++);
if (init) {
ellipsis_dims.push_back(v);
} else {
HLCHECK(ellipsis_dims[k] == v);
}
}
continue;
}
int64_t d = ty.GetNumOfElementsInDim(dim_idx);
if (char2dim.count(c) == 0) {
char2dim[c] = d;
} else {
HLCHECK(char2dim[c] == d);
}
}
}
const auto& out_term = terms.back();
std::vector<int64_t> out_shape;
out_shape.reserve(out_term.size());
for (auto c : out_term) {
if (c == '.') {
out_shape.insert(out_shape.end(), ellipsis_dims.begin(),
ellipsis_dims.end());
} else {
HLCHECK(char2dim.count(c) > 0);
out_shape.push_back(char2dim[c]);
}
}
inst->GetResultsTypes()[0] = Type{elem_ty, out_shape};
}

static void RunOnInstruction(ExpandDimsInst* inst) {
const auto& idx_op = inst->GetOperand(1);
const auto& idx_type = idx_op.GetType();
Expand Down

0 comments on commit edf5770

Please sign in to comment.