From adea0abfb50534258ab5e92942deeb593a19bd19 Mon Sep 17 00:00:00 2001 From: Ryan Hill Date: Thu, 14 Mar 2024 18:23:32 -0700 Subject: [PATCH] Default batch size to 1 + sanity check value Improve common JSON error messge wording --- src/generators.cpp | 4 ++++ src/generators.h | 2 +- src/json.cpp | 4 ++-- src/python/python.cpp | 5 +++-- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/generators.cpp b/src/generators.cpp index db27628da..640b6559e 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -71,6 +71,10 @@ Generator::Generator(const Model& model, const GeneratorParams& params) : model_ throw std::runtime_error("search max_length is 0"); if (params.search.max_length > model.config_->model.context_length) throw std::runtime_error("max_length cannot be greater than model context_length"); + if (params.batch_size < 1) + throw std::runtime_error("batch_size must be 1 or greater"); + if (params.vocab_size < 1) + throw std::runtime_error("vocab_size must be 1 or greater"); search_ = CreateSearch(params); state_ = model.CreateState(search_->GetSequenceLengths(), params); diff --git a/src/generators.h b/src/generators.h index af98aea44..433fda103 100644 --- a/src/generators.h +++ b/src/generators.h @@ -56,7 +56,7 @@ struct GeneratorParams { int vocab_size{}; int context_length{}; - int batch_size{}; + int batch_size{1}; int sequence_length{}; int BatchBeamSize() const { return search.num_beams * batch_size; } diff --git a/src/json.cpp b/src/json.cpp index 8f3c4c88f..842660dd1 100644 --- a/src/json.cpp +++ b/src/json.cpp @@ -49,7 +49,7 @@ JSON::JSON(Element& element, std::string_view document) : begin_{document.data() int line = 1; const auto* last_cr = begin_; for (const auto* p = begin_; p < current_; p++) { - if (*p == '\r') { + if (*p == '\n') { line++; last_cr = p; } @@ -108,7 +108,7 @@ void JSON::Parse_Object(Element& element) { while (true) { if (!Skip('\"')) { - throw std::runtime_error("Expecting \""); + throw std::runtime_error("Expecting \" to start next object name, possibly due to an extra trailing ',' before this"); } auto name = Parse_String(); diff --git a/src/python/python.cpp b/src/python/python.cpp index 6bf85d89f..c9fd423a7 100644 --- a/src/python/python.cpp +++ b/src/python/python.cpp @@ -3,6 +3,7 @@ #include #include #include "../generators.h" +#include "../json.h" #include "../search.h" #include "../models/model.h" @@ -92,8 +93,8 @@ struct PyGeneratorParams : GeneratorParams { } else if (pybind11::isinstance(entry.second)) { SetSearchNumber(search, name, entry.second.cast()); } else - throw std::runtime_error("Unknown search option type, can be float/bool/int"); - } catch (const std::exception& e) { + throw std::runtime_error("Unknown search option type, can be float/bool/int:" + name); + } catch (JSON::unknown_value_error& e) { throw std::runtime_error("Unknown search option:" + name); } }