Skip to content

Commit

Permalink
Extend accuracy sample
Browse files Browse the repository at this point in the history
  • Loading branch information
iefode committed Oct 3, 2024
1 parent 1d7d31d commit 8204128
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ int main(int argc, char* argv[]) try {
("n,num_prompts", "A number of prompts", cxxopts::value<size_t>()->default_value("1"))
("dynamic_split_fuse", "Whether to use dynamic split-fuse or vLLM scheduling", cxxopts::value<bool>()->default_value("false"))
("m,model", "Path to model and tokenizers base directory", cxxopts::value<std::string>()->default_value("."))
("a,draft_model", "Path to assisting model base directory", cxxopts::value<std::string>()->default_value(""))
("d,device", "Target device to run the model", cxxopts::value<std::string>()->default_value("CPU"))
("use_prefix", "Whether to use a prefix or not", cxxopts::value<bool>()->default_value("false"))
("h,help", "Print usage");
Expand All @@ -42,6 +43,7 @@ int main(int argc, char* argv[]) try {
const size_t num_prompts = result["num_prompts"].as<size_t>();
const bool dynamic_split_fuse = result["dynamic_split_fuse"].as<bool>();
const std::string models_path = result["model"].as<std::string>();
const std::string draft_models_path = result["draft_model"].as<std::string>();
const std::string device = result["device"].as<std::string>();
const bool use_prefix = result["use_prefix"].as<bool>();

Expand Down Expand Up @@ -97,9 +99,14 @@ int main(int argc, char* argv[]) try {
scheduler_config.max_num_seqs = 2;
scheduler_config.enable_prefix_caching = use_prefix;

ov::AnyMap plugin_config;
if (!draft_models_path.empty()) {
plugin_config.insert({{ov::genai::draft_model.name(), draft_models_path}});
}

// It's possible to construct a Tokenizer from a different path.
// If the Tokenizer isn't specified, it's loaded from the same folder.
ov::genai::ContinuousBatchingPipeline pipe(models_path, ov::genai::Tokenizer{models_path}, scheduler_config, device);
ov::genai::ContinuousBatchingPipeline pipe(models_path, ov::genai::Tokenizer{models_path}, scheduler_config, device, plugin_config);

if (use_prefix) {
std::cout << "Running inference for prefix to compute the shared prompt's KV cache..." << std::endl;
Expand Down
4 changes: 2 additions & 2 deletions src/cpp/include/openvino/genai/llm_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ using StringInputs = std::variant<std::string, std::vector<std::string>>;
static constexpr ov::Property<SchedulerConfig> scheduler_config{"scheduler_config"};

/**
* @brief draft_model_path property serves to activate speculative decoding model in continuous batching pipeline.
* @brief draft_model property serves to activate speculative decoding model in continuous batching pipeline.
* Create SchedulerConfig and fill it with sutable values. Copy or move it to plugin_config.
* And create LLMPipeline instance with this config.
*/
static constexpr ov::Property<SchedulerConfig> draft_model_path{"draft_model_path"};
static constexpr ov::Property<SchedulerConfig> draft_model{"draft_model"};

/**
* @brief Structure to store resulting batched tokens and scores for each batch sequence.
Expand Down
27 changes: 21 additions & 6 deletions src/cpp/src/continuous_batching_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,27 @@

using namespace ov::genai;

inline std::string
extract_draft_model_from_config(ov::AnyMap& config) {
std::string draft_model;
if (config.find(ov::genai::draft_model.name()) != config.end()) {
draft_model = config.at(ov::genai::draft_model.name()).as<std::string>();
config.erase(ov::genai::draft_model.name());
}
return draft_model;
}

ContinuousBatchingPipeline::ContinuousBatchingPipeline( const std::string& models_path,
const SchedulerConfig& scheduler_config,
const std::string& device,
const ov::AnyMap& llm_plugin_config,
const ov::AnyMap& tokenizer_plugin_config) {
if (llm_plugin_config.find(ov::genai::draft_model_path.name()) == llm_plugin_config.end()) {
auto llm_plugin_config_without_draft_model = llm_plugin_config;
auto draft_model = extract_draft_model_from_config(llm_plugin_config_without_draft_model);
if (draft_model.empty()) {
m_impl = std::make_shared<ContinuousBatchingImpl>(models_path, scheduler_config, device, llm_plugin_config, tokenizer_plugin_config);
} else {
std::string draft_model_path = llm_plugin_config.at(ov::genai::draft_model_path.name()).as<std::string>();
auto llm_plugin_config_without_draft_model = llm_plugin_config;
llm_plugin_config_without_draft_model.erase(ov::genai::draft_model_path.name());
m_impl = std::make_shared<SpeculativeDecodingImpl>(models_path, draft_model_path, scheduler_config, device, llm_plugin_config_without_draft_model, tokenizer_plugin_config);
m_impl = std::make_shared<SpeculativeDecodingImpl>(models_path, draft_model, scheduler_config, device, llm_plugin_config_without_draft_model, tokenizer_plugin_config);
}
}

Expand All @@ -38,7 +47,13 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline(
const SchedulerConfig& scheduler_config,
const std::string& device,
const ov::AnyMap& plugin_config) {
m_impl = std::make_shared<ContinuousBatchingImpl>(model_path, tokenizer, scheduler_config, device, plugin_config);
auto plugin_config_without_draft_model = plugin_config;
auto draft_model = extract_draft_model_from_config(plugin_config_without_draft_model);
if (draft_model.empty()) {
m_impl = std::make_shared<ContinuousBatchingImpl>(model_path, tokenizer, scheduler_config, device, plugin_config);
} else {
m_impl = std::make_shared<SpeculativeDecodingImpl>(model_path, draft_model, tokenizer, scheduler_config, device, plugin_config_without_draft_model);
}
}

ov::genai::Tokenizer ContinuousBatchingPipeline::get_tokenizer() {
Expand Down

0 comments on commit 8204128

Please sign in to comment.