Skip to content

Commit

Permalink
load weights in parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro committed Nov 9, 2024
1 parent 9a1eae5 commit c71c6b3
Show file tree
Hide file tree
Showing 8 changed files with 200 additions and 18 deletions.
3 changes: 3 additions & 0 deletions include/flexflow/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,9 @@ enum TaskIDs {
RM_PREPARE_NEXT_BATCH_BEAM_TASK_ID,
RM_PREPARE_NEXT_BATCH_VERIFY_TASK_ID,
RM_BACKGROUND_SERVING_TASK_ID,
LOAD_FLOAT_WEIGHT_TASK_ID,
LOAD_HALF_WEIGHT_TASK_ID,
LOAD_QUANT_WEIGHT_TASK_ID,
// Custom tasks
CUSTOM_GPU_TASK_ID_FIRST,
CUSTOM_GPU_TASK_ID_1,
Expand Down
28 changes: 28 additions & 0 deletions include/flexflow/utils/file_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,26 @@ class FileDataLoader {
void load_single_weight_tensor(FFModel *ff, Layer *l, int weight_idx);

void load_quantization_weight(FFModel *ff, Layer *l, int weight_idx);
#ifdef DEADCODE
void load_weights(FFModel *ff);
#endif

static void
load_float_weight_task(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
static void
load_half_weight_task(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
static void
load_quant_weight_task(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
void load_weights_parallel(FFModel *ff, Context ctx, Runtime *runtime);

void load_positions(FFModel *ff,
Tensor pt,
Expand All @@ -54,3 +73,12 @@ class FileDataLoader {
std::string weights_folder;
bool use_full_precision;
};

struct WeightLoadTaskArgs {
FFModel *ff;
FileDataLoader *loader;
Layer *layer;
int weight_idx;
WeightLoadTaskArgs(FFModel *_ff, FileDataLoader *_loader, Layer *_l, int _idx)
: ff(_ff), loader(_loader), layer(_l), weight_idx(_idx) {}
};
22 changes: 13 additions & 9 deletions inference/python/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ def get_configs():
# Define sample configs
ff_init_configs = {
# required parameters
"num_gpus": 1,
"memory_per_gpu": 30000,
"zero_copy_memory_per_node": 60000,
"num_gpus": 8,
"memory_per_gpu": 34000,
"zero_copy_memory_per_node": 200000,
# optional parameters
"num_cpus": 4,
"legion_utility_processors": 4,
"num_cpus": 16,
"legion_utility_processors": 16,
"data_parallelism_degree": 1,
"tensor_parallelism_degree": 1,
"tensor_parallelism_degree": 8,
"pipeline_parallelism_degree": 1,
"offload": False,
"offload_reserve_space_size": 8 * 1024, # 8GB
Expand All @@ -43,7 +43,7 @@ def get_configs():
}
llm_configs = {
# required parameters
"llm_model": "meta-llama/Meta-Llama-3-8B-Instruct",
"llm_model": "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF",
# optional parameters
"cache_path": os.environ.get("FF_CACHE_PATH", ""),
"refresh_cache": False,
Expand Down Expand Up @@ -85,11 +85,15 @@ def main():

llm.start_server()

nemotron_system = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Please ensure that your responses are positive in nature."
llama_generic_system = "You are a helpful an honest programming assistant."


messages=[
{"role": "system", "content": "You are a helpful an honest programming assistant."},
{"role": "system", "content": nemotron_system},
{"role": "user", "content": "Is Rust better than Python?"},
]
llm.generate(messages, max_new_tokens=256)
llm.generate(messages, max_new_tokens=1024)

llm.stop_server()

Expand Down
5 changes: 4 additions & 1 deletion src/c/flexflow_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2929,7 +2929,10 @@ void flexflow_file_data_loader_load_weights(flexflow_file_data_loader_t handle_,
flexflow_model_t model_handle_) {
FileDataLoader *handle = FFCObjectWrapper::unwrap(handle_);
FFModel *model = FFCObjectWrapper::unwrap(model_handle_);
handle->load_weights(model);
// handle->load_weights(model);
Context ctx = model->config.lg_ctx;
Runtime *runtime = model->config.lg_hlr;
handle->load_weights_parallel(model, ctx, runtime);
}

// // -----------------------------------------------------------------------
Expand Down
6 changes: 6 additions & 0 deletions src/mapper/mapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,12 @@ void FFMapper::select_task_options(const MapperContext ctx,
output.initial_proc = all_cpus[0];
return;
}
if ((task.task_id == LOAD_FLOAT_WEIGHT_TASK_ID) ||
(task.task_id == LOAD_HALF_WEIGHT_TASK_ID) ||
(task.task_id == LOAD_QUANT_WEIGHT_TASK_ID)) {
output.initial_proc = all_cpus[0];
return;
}
if (task.task_id == TOP_LEVEL_TASK_ID) {
output.initial_proc = all_cpus[0];
// control replicate top level task
Expand Down
91 changes: 91 additions & 0 deletions src/runtime/file_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "flexflow/utils/file_loader.h"
#include "flexflow/ffconst_utils.h"
#include "flexflow/inference.h"
#include "flexflow/model.h"

#include <vector>
using namespace std;
Expand Down Expand Up @@ -851,6 +852,7 @@ void FileDataLoader::load_single_weight_tensor(FFModel *ff,
delete data;
}

#ifdef DEADCODE
void FileDataLoader::load_weights(FFModel *ff) {
for (Layer *l : ff->layers) {
if (l->numWeights < 1 || l->name == NULL || strlen(l->name) < 1) {
Expand Down Expand Up @@ -883,3 +885,92 @@ void FileDataLoader::load_weights(FFModel *ff) {
}
}
}
#endif

void FileDataLoader::load_float_weight_task(
Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime) {
WeightLoadTaskArgs const *args = (WeightLoadTaskArgs const *)task->args;
args->loader->load_single_weight_tensor<float>(
args->ff, args->layer, args->weight_idx);
}

void FileDataLoader::load_half_weight_task(
Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime) {
WeightLoadTaskArgs const *args = (WeightLoadTaskArgs const *)task->args;
args->loader->load_single_weight_tensor<half>(
args->ff, args->layer, args->weight_idx);
}

void FileDataLoader::load_quant_weight_task(
Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime) {
WeightLoadTaskArgs const *args = (WeightLoadTaskArgs const *)task->args;
args->loader->load_quantization_weight(
args->ff, args->layer, args->weight_idx);
}

void FileDataLoader::load_weights_parallel(FFModel *ff,
Context ctx,
Runtime *runtime) {
std::vector<Future> futures;

for (Layer *l : ff->layers) {
if (l->numWeights < 1 || l->name == NULL || strlen(l->name) < 1) {
continue;
}

for (int i = 0; i < l->numWeights; i++) {
Tensor weight = l->weights[i];
if (weight == NULL) {
continue;
}

if (l->op_type == OP_LORA) {
continue;
}

// Create task arguments
WeightLoadTaskArgs args(ff, this, l, i);

switch (weight->data_type) {
case DT_HALF: {
TaskLauncher launcher(
LOAD_HALF_WEIGHT_TASK_ID,
TaskArgument(&args, sizeof(WeightLoadTaskArgs)));
futures.push_back(runtime->execute_task(ctx, launcher));
break;
}
case DT_FLOAT: {
TaskLauncher launcher(
LOAD_FLOAT_WEIGHT_TASK_ID,
TaskArgument(&args, sizeof(WeightLoadTaskArgs)));
futures.push_back(runtime->execute_task(ctx, launcher));
break;
}
case DT_INT4:
case DT_INT8: {
TaskLauncher launcher(
LOAD_QUANT_WEIGHT_TASK_ID,
TaskArgument(&args, sizeof(WeightLoadTaskArgs)));
futures.push_back(runtime->execute_task(ctx, launcher));
break;
}
default:
assert(false && "Unsupported data type");
}
}
}

// Wait for all tasks to complete
for (Future &f : futures) {
f.get_void_result();
}
}
57 changes: 52 additions & 5 deletions src/runtime/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3420,11 +3420,13 @@ bool FFModel::need_to_add_combine(int layer_idx) const {

bool FFModel::need_to_add_allreduce(int layer_idx) const {
auto const &l = layers[layer_idx];
if (config.computationMode == COMP_MODE_INFERENCE && config.tensor_parallelism_degree > 1 &&
((l->op_type == OP_LINEAR && std::string(l->name).find("attn.o_proj") != std::string::npos) ||
is_mlp_block(layer_idx) ||
(l->op_type == OP_LINEAR && std::string(l->name).find("mlp.down_proj") != std::string::npos)
)) {
if (config.computationMode == COMP_MODE_INFERENCE &&
config.tensor_parallelism_degree > 1 &&
((l->op_type == OP_LINEAR &&
std::string(l->name).find("attn.o_proj") != std::string::npos) ||
is_mlp_block(layer_idx) ||
(l->op_type == OP_LINEAR &&
std::string(l->name).find("mlp.down_proj") != std::string::npos))) {
return true;
}
return false;
Expand Down Expand Up @@ -4798,6 +4800,51 @@ void register_flexflow_internal_tasks(Runtime *runtime,
registrar);
}
}
{
TaskVariantRegistrar registrar(LOAD_FLOAT_WEIGHT_TASK_ID,
"load_float_weight_task");
registrar.add_constraint(ProcessorConstraint(Processor::LOC_PROC));
if (pre_register) {
Runtime::preregister_task_variant<FileDataLoader::load_float_weight_task>(
registrar, "load_float_weight_task");
} else {
if (enable_control_replication) {
registrar.global_registration = false;
}
runtime->register_task_variant<FileDataLoader::load_float_weight_task>(
registrar);
}
}
{
TaskVariantRegistrar registrar(LOAD_HALF_WEIGHT_TASK_ID,
"load_half_weight_task");
registrar.add_constraint(ProcessorConstraint(Processor::LOC_PROC));
if (pre_register) {
Runtime::preregister_task_variant<FileDataLoader::load_half_weight_task>(
registrar, "load_half_weight_task");
} else {
if (enable_control_replication) {
registrar.global_registration = false;
}
runtime->register_task_variant<FileDataLoader::load_half_weight_task>(
registrar);
}
}
{
TaskVariantRegistrar registrar(LOAD_QUANT_WEIGHT_TASK_ID,
"load_quant_weight_task");
registrar.add_constraint(ProcessorConstraint(Processor::LOC_PROC));
if (pre_register) {
Runtime::preregister_task_variant<FileDataLoader::load_quant_weight_task>(
registrar, "load_quant_weight_task");
} else {
if (enable_control_replication) {
registrar.global_registration = false;
}
runtime->register_task_variant<FileDataLoader::load_quant_weight_task>(
registrar);
}
}
#endif
// ElementUnary task
{
Expand Down
6 changes: 3 additions & 3 deletions src/runtime/request_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3025,7 +3025,7 @@ void RequestManager::serve_incr_decoding(FFModel *llm) {
assert(im->model_weights_loaders.find(llm) !=
im->model_weights_loaders.end());
// Load model weights
im->model_weights_loaders[llm]->load_weights(llm);
im->model_weights_loaders[llm]->load_weights_parallel(llm, ctx, runtime);
// init operators
im->init_operators_inference(llm);
// Legion futures for inc_decoding and spec_infer
Expand Down Expand Up @@ -3087,7 +3087,7 @@ void RequestManager::serve_spec_infer(FFModel *llm) {
assert(im->model_weights_loaders.find(llm) !=
im->model_weights_loaders.end());
// Load model weights
im->model_weights_loaders[llm]->load_weights(llm);
im->model_weights_loaders[llm]->load_weights_parallel(llm, ctx, runtime);
// init operators
im->init_operators_inference(llm);
}
Expand All @@ -3098,7 +3098,7 @@ void RequestManager::serve_spec_infer(FFModel *llm) {
assert(im->model_weights_loaders.find(llm) !=
im->model_weights_loaders.end());
// Load model weights
im->model_weights_loaders[ssm]->load_weights(ssm);
im->model_weights_loaders[ssm]->load_weights_parallel(ssm, ctx, runtime);
// init operators
im->init_operators_inference(ssm);
}
Expand Down

0 comments on commit c71c6b3

Please sign in to comment.