Skip to content

Commit

Permalink
ChatCompletion + Multi-EOS support (#1535)
Browse files Browse the repository at this point in the history
* init

* support templates

* support for multiple eos token ids

* fix

* fix

* fix conda env for ci
  • Loading branch information
goliaro authored Nov 4, 2024
1 parent 89f10f4 commit d09ba0c
Show file tree
Hide file tree
Showing 14 changed files with 327 additions and 157 deletions.
6 changes: 3 additions & 3 deletions conda/flexflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ dependencies:
- qualname>=0.1.0
- keras_preprocessing>=1.1.2
- numpy>=1.16.0
- torch>=1.13.1 --index-url https://download.pytorch.org/whl/cpu
- torchaudio>=0.13.1 --index-url https://download.pytorch.org/whl/cpu
- torchvision>=0.14.1 --index-url https://download.pytorch.org/whl/cpu
- torch>=1.13.1
- torchaudio>=0.13.1
- torchvision>=0.14.1
- regex
- onnx
- transformers>=4.31.0
Expand Down
7 changes: 6 additions & 1 deletion include/flexflow/flexflow_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,7 @@ void flexflow_model_generate(flexflow_model_t handle_,
char **output_texts,
int *max_lengths,
int *max_new_tokens_,
bool *add_special_tokens_,
flexflow_peft_model_id_t *peft_model_ids,
char const **dataset_filepaths,
int *training_steps,
Expand Down Expand Up @@ -1019,14 +1020,18 @@ void flexflow_request_manager_set_max_spec_tree_token_num(
void flexflow_request_manager_set_max_sequence_length(
flexflow_request_manager_t handle_, int max_seq_length);

int flexflow_request_manager_get_max_sequence_length(
flexflow_request_manager_t handle_);

void flexflow_request_manager_set_enable_peft_finetuning(
flexflow_request_manager_t handle_, bool enable_peft_finetuning_);

void flexflow_request_manager_register_tokenizer(
flexflow_request_manager_t handle_,
enum ModelType model_type,
int bos_token_id,
int eos_token_id,
int num_eos_token_ids,
int *eos_token_ids,
char const *tokenizer_filepath);

void flexflow_request_manager_register_output_filepath(
Expand Down
6 changes: 4 additions & 2 deletions include/flexflow/request_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ struct Request {
PEFTModelID peft_model_id = PEFTModelID::NO_ID;
int max_length = -1;
int max_new_tokens = -1;
bool add_special_tokens = true;
int initial_len;
int ssm_cache_size = 0;
int llm_cache_size = 0;
Expand Down Expand Up @@ -146,7 +147,7 @@ class RequestManager {
int register_ssm_model(FFModel *model);
void register_tokenizer(ModelType model_type,
int bos_token_id,
int eos_token_id,
std::vector<int> eos_token_ids,
std::string const &path);
void register_output_filepath(std::string const &);
void initBitMask(BatchConfig::BitMask &bitmask, int initLength);
Expand Down Expand Up @@ -178,6 +179,7 @@ class RequestManager {
bool is_request_completed(RequestGuid const &guid);
void trigger_request_completion_future(RequestGuid const &guid);
// Methods for preparing next batches
bool is_eos_token(int token_id);
bool check_inf_req_completion(BatchConfig const &old_bc, int i);
void check_batch(BatchConfig const &old_bc, BatchConfig const &new_bc);
BatchConfig prepare_next_batch(BatchConfig const &bc,
Expand Down Expand Up @@ -301,7 +303,7 @@ class RequestManager {
bool verbose;
ModelType model_type;
int bos_token_id;
int eos_token_id;
std::vector<int> eos_token_ids;
bool old_llama_tokenizer = false;
std::string output_filepath;
std::queue<Request> pending_infr_request_queue;
Expand Down
19 changes: 15 additions & 4 deletions inference/incr_decoding/incr_decoding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,20 @@ void FlexFlow::top_level_task(Task const *task,
int bos_token_id = model_config.find("bos_token_id") == model_config.end()
? -1
: (int)model_config.at("bos_token_id");
int eos_token_id = model_config.find("eos_token_id") == model_config.end()
? -1
: (int)model_config.at("eos_token_id");
// parse eos token id, which can be either a single integer or an array of
// integers. Convert to std::vector<int>
std::vector<int> eos_token_ids;
if (model_config.find("eos_token_id") != model_config.end()) {
if (model_config["eos_token_id"].is_array()) {
for (auto &eos_token_id : model_config["eos_token_id"]) {
eos_token_ids.push_back(eos_token_id);
}
} else {
eos_token_ids.push_back(model_config["eos_token_id"]);
}
} else {
eos_token_ids.push_back(-1);
}

assert(model_type != ModelType::UNKNOWN &&
"Invalid LLM model type passed (or no type was passed).");
Expand All @@ -212,7 +223,7 @@ void FlexFlow::top_level_task(Task const *task,
rm->set_max_tokens_per_batch(max_tokens_per_batch);
rm->set_max_sequence_length(max_sequence_length);
rm->register_tokenizer(
model_type, bos_token_id, eos_token_id, tokenizer_filepath);
model_type, bos_token_id, eos_token_ids, tokenizer_filepath);
rm->register_output_filepath(file_paths.output_file_path);

FFModel model(ffconfig, ffconfig.cpu_offload);
Expand Down
19 changes: 15 additions & 4 deletions inference/peft/peft.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,20 @@ void FlexFlow::top_level_task(Task const *task,
int bos_token_id = model_config.find("bos_token_id") == model_config.end()
? -1
: (int)model_config.at("bos_token_id");
int eos_token_id = model_config.find("eos_token_id") == model_config.end()
? -1
: (int)model_config.at("eos_token_id");
// parse eos token id, which can be either a single integer or an array of
// integers. Convert to std::vector<int>
std::vector<int> eos_token_ids;
if (model_config.find("eos_token_id") != model_config.end()) {
if (model_config["eos_token_id"].is_array()) {
for (auto &eos_token_id : model_config["eos_token_id"]) {
eos_token_ids.push_back(eos_token_id);
}
} else {
eos_token_ids.push_back(model_config["eos_token_id"]);
}
} else {
eos_token_ids.push_back(-1);
}

assert(model_type != ModelType::UNKNOWN &&
"Invalid LLM model type passed (or no type was passed).");
Expand Down Expand Up @@ -267,7 +278,7 @@ void FlexFlow::top_level_task(Task const *task,
rm->set_max_tokens_per_batch(max_tokens_per_batch);
rm->set_max_sequence_length(max_sequence_length);
rm->register_tokenizer(
model_type, bos_token_id, eos_token_id, tokenizer_filepath);
model_type, bos_token_id, eos_token_ids, tokenizer_filepath);
rm->register_output_filepath(file_paths.output_file_path);
rm->set_enable_peft_finetuning(enable_peft_finetuning);

Expand Down
19 changes: 15 additions & 4 deletions inference/peft/peft_bwd_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,20 @@ void FlexFlow::top_level_task(Task const *task,
int bos_token_id = model_config.find("bos_token_id") == model_config.end()
? -1
: (int)model_config.at("bos_token_id");
int eos_token_id = model_config.find("eos_token_id") == model_config.end()
? -1
: (int)model_config.at("eos_token_id");
// parse eos token id, which can be either a single integer or an array of
// integers. Convert to std::vector<int>
std::vector<int> eos_token_ids;
if (model_config.find("eos_token_id") != model_config.end()) {
if (model_config["eos_token_id"].is_array()) {
for (auto &eos_token_id : model_config["eos_token_id"]) {
eos_token_ids.push_back(eos_token_id);
}
} else {
eos_token_ids.push_back(model_config["eos_token_id"]);
}
} else {
eos_token_ids.push_back(-1);
}

assert(model_type != ModelType::UNKNOWN &&
"Invalid LLM model type passed (or no type was passed).");
Expand All @@ -251,7 +262,7 @@ void FlexFlow::top_level_task(Task const *task,
rm->set_max_tokens_per_batch(max_tokens_per_batch);
rm->set_max_sequence_length(max_sequence_length);
rm->register_tokenizer(
model_type, bos_token_id, eos_token_id, tokenizer_filepath);
model_type, bos_token_id, eos_token_ids, tokenizer_filepath);
rm->register_output_filepath(file_paths.output_file_path);
rm->set_enable_peft_finetuning(enable_peft_finetuning);

Expand Down
19 changes: 15 additions & 4 deletions inference/peft/peft_fwd_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,20 @@ void FlexFlow::top_level_task(Task const *task,
int bos_token_id = model_config.find("bos_token_id") == model_config.end()
? -1
: (int)model_config.at("bos_token_id");
int eos_token_id = model_config.find("eos_token_id") == model_config.end()
? -1
: (int)model_config.at("eos_token_id");
// parse eos token id, which can be either a single integer or an array of
// integers. Convert to std::vector<int>
std::vector<int> eos_token_ids;
if (model_config.find("eos_token_id") != model_config.end()) {
if (model_config["eos_token_id"].is_array()) {
for (auto &eos_token_id : model_config["eos_token_id"]) {
eos_token_ids.push_back(eos_token_id);
}
} else {
eos_token_ids.push_back(model_config["eos_token_id"]);
}
} else {
eos_token_ids.push_back(-1);
}

assert(model_type != ModelType::UNKNOWN &&
"Invalid LLM model type passed (or no type was passed).");
Expand All @@ -251,7 +262,7 @@ void FlexFlow::top_level_task(Task const *task,
rm->set_max_tokens_per_batch(max_tokens_per_batch);
rm->set_max_sequence_length(max_sequence_length);
rm->register_tokenizer(
model_type, bos_token_id, eos_token_id, tokenizer_filepath);
model_type, bos_token_id, eos_token_ids, tokenizer_filepath);
rm->register_output_filepath(file_paths.output_file_path);
rm->set_enable_peft_finetuning(enable_peft_finetuning);

Expand Down
19 changes: 15 additions & 4 deletions inference/peft/req_rate_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,20 @@ void FlexFlow::top_level_task(Task const *task,
int bos_token_id = model_config.find("bos_token_id") == model_config.end()
? -1
: (int)model_config.at("bos_token_id");
int eos_token_id = model_config.find("eos_token_id") == model_config.end()
? -1
: (int)model_config.at("eos_token_id");
// parse eos token id, which can be either a single integer or an array of
// integers. Convert to std::vector<int>
std::vector<int> eos_token_ids;
if (model_config.find("eos_token_id") != model_config.end()) {
if (model_config["eos_token_id"].is_array()) {
for (auto &eos_token_id : model_config["eos_token_id"]) {
eos_token_ids.push_back(eos_token_id);
}
} else {
eos_token_ids.push_back(model_config["eos_token_id"]);
}
} else {
eos_token_ids.push_back(-1);
}

assert(model_type != ModelType::UNKNOWN &&
"Invalid LLM model type passed (or no type was passed).");
Expand All @@ -313,7 +324,7 @@ void FlexFlow::top_level_task(Task const *task,
rm->set_max_tokens_per_batch(max_tokens_per_batch);
rm->set_max_sequence_length(max_sequence_length);
rm->register_tokenizer(
model_type, bos_token_id, eos_token_id, tokenizer_filepath);
model_type, bos_token_id, eos_token_ids, tokenizer_filepath);
rm->register_output_filepath(file_paths.output_file_path);
rm->set_enable_peft_finetuning(enable_peft_finetuning);

Expand Down
100 changes: 100 additions & 0 deletions inference/python/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import flexflow.serve as ff
import argparse, json, os
from types import SimpleNamespace


def get_configs():
# Define sample configs
ff_init_configs = {
# required parameters
"num_gpus": 1,
"memory_per_gpu": 30000,
"zero_copy_memory_per_node": 60000,
# optional parameters
"num_cpus": 4,
"legion_utility_processors": 4,
"data_parallelism_degree": 1,
"tensor_parallelism_degree": 1,
"pipeline_parallelism_degree": 1,
"offload": False,
"offload_reserve_space_size": 8 * 1024, # 8GB
"use_4bit_quantization": False,
"use_8bit_quantization": False,
"enable_peft": False,
"peft_activation_reserve_space_size": 1024, # 1GB
"peft_weight_reserve_space_size": 1024, # 1GB
"profiling": False,
"benchmarking": False,
"inference_debugging": False,
"fusion": True,
}
llm_configs = {
# required parameters
"llm_model": "meta-llama/Meta-Llama-3-8B-Instruct",
# optional parameters
"cache_path": os.environ.get("FF_CACHE_PATH", ""),
"refresh_cache": False,
"full_precision": False,
}
# Merge dictionaries
ff_init_configs.update(llm_configs)
return ff_init_configs


def main():
configs_dict = get_configs()
configs = SimpleNamespace(**configs_dict)

# Initialize the FlexFlow runtime. ff.init() takes a dictionary or the path to a JSON file with the configs
ff.init(configs_dict)

# Create the FlexFlow LLM
ff_data_type = (
ff.DataType.DT_FLOAT if configs.full_precision else ff.DataType.DT_HALF
)
llm = ff.LLM(
configs.llm_model,
data_type=ff_data_type,
cache_path=configs.cache_path,
refresh_cache=configs.refresh_cache,
)

# Compile the LLM for inference and load the weights into memory
generation_config = ff.GenerationConfig(
do_sample=False, temperature=0.9, topp=0.8, topk=1
)
llm.compile(
generation_config,
max_requests_per_batch=1,
max_seq_length=2048,
max_tokens_per_batch=256,
)

llm.start_server()

messages=[
{"role": "system", "content": "You are a helpful an honest programming assistant."},
{"role": "user", "content": "Is Rust better than Python?"},
]
llm.generate(messages, max_new_tokens=256)

llm.stop_server()


if __name__ == "__main__":
print("flexflow inference example (incremental decoding)")
main()
Loading

0 comments on commit d09ba0c

Please sign in to comment.