-
Notifications
You must be signed in to change notification settings - Fork 229
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'suffix_decoding' into new_tokenizer
- Loading branch information
Showing
14 changed files
with
1,001 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -296,6 +296,16 @@ if(NOT BUILD_LEGION_ONLY) | |
endif() | ||
|
||
set(FLEXFLOW_CPP_DRV_SRC ${FLEXFLOW_ROOT}/src/runtime/cpp_driver.cc) | ||
# SuffixDecoding | ||
include(FetchContent) | ||
FetchContent_Declare( | ||
suffix_decoding | ||
GIT_REPOSITORY [email protected]:Snowflake-Labs/suffix-tree-decoding.git | ||
GIT_TAG main # or a specific tag/commit hash | ||
) | ||
FetchContent_MakeAvailable(suffix_decoding) | ||
list(APPEND FLEXFLOW_INCLUDE_DIRS ${suffix_decoding_SOURCE_DIR}/src) | ||
list(APPEND FLEXFLOW_SRC ${suffix_decoding_SOURCE_DIR}/src/suffix_decoding.cc) | ||
|
||
add_library(substitution_loader SHARED | ||
${FLEXFLOW_ROOT}/src/runtime/substitution_loader.cc) | ||
|
@@ -534,6 +544,7 @@ if(NOT BUILD_LEGION_ONLY) | |
|
||
if(FF_BUILD_INFERENCE) | ||
add_subdirectory(inference/spec_infer) | ||
add_subdirectory(inference/suffix_decoding) | ||
add_subdirectory(inference/incr_decoding) | ||
add_subdirectory(inference/peft) | ||
endif() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
cmake_minimum_required(VERSION 3.10) | ||
|
||
project(FlexFlow_SpecInfer) | ||
set(project_target suffix_decoding) | ||
|
||
|
||
set(CPU_SRC | ||
${FLEXFLOW_CPP_DRV_SRC} | ||
suffix_decoding.cc | ||
utils.cc | ||
../models/llama.cc | ||
../models/opt.cc | ||
../models/falcon.cc | ||
../models/mpt.cc) | ||
|
||
if (FF_GPU_BACKEND STREQUAL "cuda" OR FF_GPU_BACKEND STREQUAL "hip_cuda") | ||
cuda_add_executable(${project_target} ${CPU_SRC}) | ||
if (FF_GPU_BACKEND STREQUAL "hip_cuda") | ||
target_compile_definitions(${project_target} PRIVATE __HIP_PLATFORM_NVIDIA__) | ||
endif() | ||
elseif(FF_GPU_BACKEND STREQUAL "hip_rocm") | ||
set_source_files_properties(${CPU_SRC} PROPERTIES LANGUAGE HIP) | ||
hip_add_executable(${project_target} ${CPU_SRC}) | ||
if (FF_HIP_ARCH STREQUAL "") | ||
message(FATAL_ERROR "FF_HIP_ARCH is empty!") | ||
endif() | ||
set_property(TARGET ${project_target} PROPERTY HIP_ARCHITECTURES "${FF_HIP_ARCH}") | ||
target_compile_definitions(${project_target} PRIVATE __HIP_PLATFORM_AMD__) | ||
else() | ||
message(FATAL_ERROR "Compilation of ${project_target} for ${FF_GPU_BACKEND} backend not yet supported") | ||
endif() | ||
|
||
target_include_directories(${project_target} PRIVATE ${FLEXFLOW_INCLUDE_DIRS} ${CMAKE_INSTALL_INCLUDEDIR}) | ||
target_include_directories(${project_target} PRIVATE ${CMAKE_SOURCE_DIR}/inference) | ||
target_link_libraries(${project_target} -Wl,--whole-archive flexflow -Wl,--no-whole-archive ${FLEXFLOW_EXT_LIBRARIES}) | ||
|
||
set(BIN_DEST "bin") | ||
install(TARGETS ${project_target} DESTINATION ${BIN_DEST}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
# Flags for directing the runtime makefile what to include | ||
DEBUG ?= 0 # Include debugging symbols | ||
MAX_DIM ?= 4 # Maximum number of dimensions | ||
OUTPUT_LEVEL ?= LEVEL_DEBUG # Compile time logging level | ||
USE_CUDA ?= 1 # Include CUDA support (requires CUDA) | ||
USE_GASNET ?= 0 # Include GASNet support (requires GASNet) | ||
USE_HDF ?= 1 # Include HDF5 support (requires HDF5) | ||
ALT_MAPPERS ?= 0 # Include alternative mappers (not recommended) | ||
|
||
# Put the binary file name here | ||
OUTFILE ?= llama_pipeline | ||
# List all the application source files here | ||
ifndef CUDA_HOME | ||
CUDA_HOME = $(patsubst %/bin/nvcc,%,$(shell which nvcc | head -1)) | ||
endif | ||
|
||
|
||
ifndef FF_HOME | ||
$(error FF_HOME variable is not defined, aborting build) | ||
endif | ||
|
||
include $(FF_HOME)/FlexFlow.mk |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "suffix_decoding/utils.h" | ||
|
||
using namespace FlexFlow; | ||
using namespace Legion; | ||
using json = nlohmann::json; | ||
|
||
Legion::Logger log_app("llama"); | ||
|
||
|
||
void process_partition(RequestManager *rm, std::string input_filename) { | ||
} | ||
|
||
void FlexFlow::top_level_task(Task const *task, | ||
std::vector<PhysicalRegion> const ®ions, | ||
Context ctx, | ||
Runtime *runtime) { | ||
FFConfig ffconfig; | ||
FilePaths file_paths; | ||
ModelMeta model_metadata; | ||
std::string partition_name; | ||
bool use_full_precision = false; | ||
bool verbose = false; | ||
int max_requests_per_batch = 16; | ||
int max_tokens_per_batch = 256; | ||
int max_sequence_length = 1024; | ||
int max_spec_tree_token_num = 23; | ||
int expansion_degree = 1; | ||
|
||
InputArgs const &command_args = HighLevelRuntime::get_input_args(); | ||
char **argv = command_args.argv; | ||
int argc = command_args.argc; | ||
parse_input_args(argv, | ||
argc, | ||
file_paths, | ||
model_metadata.model_names, | ||
partition_name, | ||
use_full_precision, | ||
verbose, | ||
max_requests_per_batch, | ||
max_tokens_per_batch, | ||
max_sequence_length, | ||
expansion_degree); | ||
|
||
get_model_meta(file_paths, model_metadata, use_full_precision); | ||
|
||
assert(ffconfig.data_parallelism_degree * ffconfig.tensor_parallelism_degree * | ||
ffconfig.pipeline_parallelism_degree == | ||
ffconfig.numNodes * ffconfig.workersPerNode); | ||
|
||
json trace = load_trace(file_paths.prompt_file_path); | ||
json training_entries = get_training_entries(trace, partition_name); | ||
json eval_entries = get_eval_entries(trace, partition_name); | ||
|
||
GenerationConfig generationConfig; | ||
InferenceManager *im = InferenceManager::get_inference_manager(); | ||
RequestManager *rm = RequestManager::get_request_manager(); | ||
init_request_manager(rm, | ||
model_metadata, | ||
file_paths, | ||
max_requests_per_batch, | ||
max_tokens_per_batch, | ||
max_spec_tree_token_num, | ||
max_sequence_length, | ||
expansion_degree); | ||
|
||
// Create LLM model | ||
FFModel tree_model(ffconfig, ffconfig.cpu_offload); | ||
init_llm(tree_model, model_metadata, generationConfig, use_full_precision); | ||
|
||
// Create SSM models | ||
int num_ssms = model_metadata.ssm_model_types.size(); | ||
std::vector<FFModel> ssm_models; | ||
FFConfig bm_config = ffconfig; | ||
bm_config.data_parallelism_degree = bm_config.tensor_parallelism_degree = | ||
bm_config.pipeline_parallelism_degree = 1; | ||
for (int ssm_id = 0; ssm_id < num_ssms; ssm_id++) { | ||
FFModel beam_model(bm_config); | ||
ssm_models.push_back(beam_model); | ||
} | ||
init_ssms(rm, ssm_models, num_ssms, model_metadata, generationConfig, use_full_precision); | ||
|
||
rm->start_background_server(&tree_model); | ||
|
||
int total_num_requests = 0; | ||
{ | ||
std::vector<Request> requests; | ||
for (auto entry: eval_entries) { | ||
std::string prompt = entry["prompt"]; | ||
int response_length = entry["response_length"]; | ||
// printf("Prompt[%d]: %s\n", total_num_requests, prompt.c_str()); | ||
// Add inference request | ||
Request inference_req; | ||
inference_req.prompt = prompt; | ||
inference_req.max_new_tokens = response_length; | ||
requests.push_back(inference_req); | ||
total_num_requests++; | ||
} | ||
tree_model.generate(requests); | ||
} | ||
|
||
// Register requests from prompt file | ||
// int total_num_requests = 0; | ||
// { | ||
// using json = nlohmann::json; | ||
// std::ifstream file_handle(file_paths.prompt_file_path); | ||
// assert(file_handle.good() && "Prompt file does not exist."); | ||
// json prompt_json = json::parse(file_handle, | ||
// /*parser_callback_t */ nullptr, | ||
// /*allow_exceptions */ true, | ||
// /*ignore_comments */ true); | ||
|
||
// std::vector<Request> requests; | ||
// for (auto &prompt : prompt_json) { | ||
// std::string text = prompt.get<std::string>(); | ||
// printf("Prompt[%d]: %s\n", total_num_requests, text.c_str()); | ||
// // Add inference request | ||
// Request inference_req; | ||
// inference_req.prompt = text; | ||
// inference_req.max_length = 128; | ||
// requests.push_back(inference_req); | ||
// total_num_requests++; | ||
// } | ||
// tree_model.generate(requests); | ||
// } | ||
|
||
// terminate the request manager by stopping the background thread | ||
rm->terminate_background_server(); | ||
|
||
// Execution fence | ||
{ | ||
Future future = runtime->issue_execution_fence(ctx); | ||
future.get_void_result(); | ||
} | ||
|
||
// float* data | ||
std::cout << "----------inference finished--------------" << std::endl; | ||
} | ||
|
||
void FlexFlow::register_custom_tasks() {} |
Oops, something went wrong.