Skip to content

Commit

Permalink
Default batch size to 1 + sanity check value
Browse files Browse the repository at this point in the history
Improve common JSON error messge wording
  • Loading branch information
RyanUnderhill committed Mar 15, 2024
1 parent 7eb8c3c commit adea0ab
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 5 deletions.
4 changes: 4 additions & 0 deletions src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ Generator::Generator(const Model& model, const GeneratorParams& params) : model_
throw std::runtime_error("search max_length is 0");
if (params.search.max_length > model.config_->model.context_length)
throw std::runtime_error("max_length cannot be greater than model context_length");
if (params.batch_size < 1)
throw std::runtime_error("batch_size must be 1 or greater");
if (params.vocab_size < 1)
throw std::runtime_error("vocab_size must be 1 or greater");

search_ = CreateSearch(params);
state_ = model.CreateState(search_->GetSequenceLengths(), params);
Expand Down
2 changes: 1 addition & 1 deletion src/generators.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ struct GeneratorParams {
int vocab_size{};
int context_length{};

int batch_size{};
int batch_size{1};
int sequence_length{};
int BatchBeamSize() const { return search.num_beams * batch_size; }

Expand Down
4 changes: 2 additions & 2 deletions src/json.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ JSON::JSON(Element& element, std::string_view document) : begin_{document.data()
int line = 1;
const auto* last_cr = begin_;
for (const auto* p = begin_; p < current_; p++) {
if (*p == '\r') {
if (*p == '\n') {
line++;
last_cr = p;
}
Expand Down Expand Up @@ -108,7 +108,7 @@ void JSON::Parse_Object(Element& element) {

while (true) {
if (!Skip('\"')) {
throw std::runtime_error("Expecting \"");
throw std::runtime_error("Expecting \" to start next object name, possibly due to an extra trailing ',' before this");
}

auto name = Parse_String();
Expand Down
5 changes: 3 additions & 2 deletions src/python/python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <pybind11/numpy.h>
#include <iostream>
#include "../generators.h"
#include "../json.h"
#include "../search.h"
#include "../models/model.h"

Expand Down Expand Up @@ -92,8 +93,8 @@ struct PyGeneratorParams : GeneratorParams {
} else if (pybind11::isinstance<pybind11::int_>(entry.second)) {
SetSearchNumber(search, name, entry.second.cast<int>());
} else
throw std::runtime_error("Unknown search option type, can be float/bool/int");
} catch (const std::exception& e) {
throw std::runtime_error("Unknown search option type, can be float/bool/int:" + name);
} catch (JSON::unknown_value_error& e) {
throw std::runtime_error("Unknown search option:" + name);
}
}
Expand Down

0 comments on commit adea0ab

Please sign in to comment.