From bc0d188b418b6a168521af307dc6347065de3a63 Mon Sep 17 00:00:00 2001 From: Mark Glines Date: Tue, 19 Jan 2021 10:01:19 -0500 Subject: [PATCH] Add an error message for invalid input tensor names. --- src/storage/storage.cpp | 3 +++ tools/taco.cpp | 37 +++++++++++++++++++++++++++++++++++-- 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/src/storage/storage.cpp b/src/storage/storage.cpp index 80af9f0a0..aeb216054 100644 --- a/src/storage/storage.cpp +++ b/src/storage/storage.cpp @@ -32,6 +32,9 @@ struct TensorStorage::Content { int order = (int)dimensions.size(); taco_iassert(order <= INT_MAX && componentType.getNumBits() <= INT_MAX); + taco_uassert(order == format.getOrder()) << + "The number of format mode types (" << format.getOrder() << ") " << + "must match the tensor order (" << dimensions.size() << ")."; vector dimensionsInt32(order); vector modeOrdering(order); vector modeTypes(order); diff --git a/tools/taco.cpp b/tools/taco.cpp index fcc654e08..c86a5fbb6 100644 --- a/tools/taco.cpp +++ b/tools/taco.cpp @@ -839,8 +839,16 @@ int main(int argc, char* argv[]) { printCompute = true; } - // Load tensors + // pre-parse expression, to determine existence and order of loaded tensors map loadedTensors; + TensorBase temp_tensor; + parser::Parser temp_parser(exprStr, formats, dataTypes, tensorsDimensions, loadedTensors, 42); + try { + temp_parser.parse(); + temp_tensor = temp_parser.getResultTensor(); + } catch (parser::ParseError& e) { + return reportError(e.getMessage(), 6); + } // Load tensors for (auto& tensorNames : inputFilenames) { @@ -851,7 +859,32 @@ int main(int argc, char* argv[]) { return reportError("Loaded tensors can only be type double", 7); } - Format format = util::contains(formats, name) ? formats.at(name) : Dense; + // make sure the tensor exists in the expression (and stash its order) + int found_tensor_order; + bool found = false; + for (auto a : getArgumentAccesses(temp_tensor.getAssignment().concretize())) { + if (a.getTensorVar().getName() == name) { + found_tensor_order = a.getIndexVars().size(); + found = true; + break; + } + } + if(found == false) { + return reportError("Cannot load '" + filename + "': no tensor '" + name + "' found in expression", 8); + } + + Format format; + if(util::contains(formats, name)) { + // format of this tensor is specified on the command line, use it + format = formats.at(name); + } else { + // create a dense default format of the correct order + std::vector modes; + for(int i = 0; i < found_tensor_order; i++) { + modes.push_back(Dense); + } + format = Format({ModeFormatPack(modes)}); + } TensorBase tensor; TOOL_BENCHMARK_TIMER(tensor = read(filename,format,false), name+" file read:", timevalue);