Skip to content

Commit

Permalink
FlexLLM server demo (#1510)
Browse files Browse the repository at this point in the history
* init

* update

* update

* update

* update

* add max new tokens parameter

* backup

* update

* backup

* lora configs serialize / deserialize into single file

* backup

* .

* .

* .

* .

* frontend

* bug fix

* fixes

* fix

* updates

* fix

* fix

* fix

* small fix

* fix

* fix reset input grad for non-activated loras

* fix

* update

* demo fixes & readme

* load weights in parallel

* cleanup

* cleanup

* load weights faster in inference test

* fix

* cleanup and fixes

* linting

* fix

* cleanup

* docker run update
  • Loading branch information
goliaro authored Nov 18, 2024
1 parent 1bef1a3 commit 7dcbd62
Show file tree
Hide file tree
Showing 68 changed files with 2,326 additions and 1,284 deletions.
3 changes: 2 additions & 1 deletion docker/flexflow-environment/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ LABEL org.opencontainers.image.description="FlexFlow environment container"
SHELL ["/bin/bash", "-c"]

# Install basic dependencies
RUN apt-get update && apt-get install -y --no-install-recommends wget sudo binutils git zlib1g-dev lsb-release nano gdb libhdf5-dev jq && \
RUN apt-get update && apt-get install -y --no-install-recommends wget sudo binutils git zlib1g-dev lsb-release nano gdb libhdf5-dev jq openssh-client && \
rm -rf /var/lib/apt/lists/* /etc/apt/sources.list.d/cuda.list /etc/apt/sources.list.d/nvidia-ml.list && \
apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends software-properties-common && \
apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends build-essential apt-utils \
Expand Down Expand Up @@ -125,6 +125,7 @@ RUN pip3 install transformers>=4.31.0 sentencepiece einops
RUN pip3 install tensorflow notebook
# PEFT-related
RUN pip3 install scipy bitsandbytes datasets accelerate loralib triton peft
RUN pip3 install streamlit

# Install Rust
RUN curl https://sh.rustup.rs -sSf | sh -s -- -y
Expand Down
13 changes: 12 additions & 1 deletion docker/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ hip_version=${hip_version:-"empty"}
ATTACH_GPUS=${ATTACH_GPUS:-true}
gpu_arg=""
if $ATTACH_GPUS ; then gpu_arg="--gpus all" ; fi
FORWARD_STREAMLIT_PORT=${FORWARD_STREAMLIT_PORT:-true}
port_forward_arg=""
if $FORWARD_STREAMLIT_PORT ; then
port_forward_arg+="-p 8501:8501"
fi


# Amount of shared memory to give the Docker container access to
Expand Down Expand Up @@ -120,4 +125,10 @@ if [ -f "$hf_token_path" ]; then
hf_token_volume+="-v $hf_token_path:/root/.cache/huggingface/token"
fi

eval docker run -it "$gpu_arg" "--shm-size=${SHM_SIZE}" "--cap-add=SYS_PTRACE" "${hf_token_volume}" "${image}-${FF_GPU_BACKEND}${gpu_backend_version}:latest"
ssh_key_volume=""
ssh_key_path="$HOME/.ssh/id_rsa"
if [ -f "$ssh_key_path" ]; then
# If the token exists, add the volume mount to the Docker command
ssh_key_volume+="-v $ssh_key_path:/root/.ssh/id_rsa"
fi
eval docker run -it "$gpu_arg" "--shm-size=${SHM_SIZE}" "--cap-add=SYS_PTRACE" "${ssh_key_volume}" "${hf_token_volume}" "${port_forward_arg}" "${image}-${FF_GPU_BACKEND}${gpu_backend_version}:latest"
6 changes: 5 additions & 1 deletion include/flexflow/batch_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "legion.h"
#include <cstddef>
#include <cstdlib>
#include <cstring>

// #define MAX_SEQ_LEN 1024
// #define BATCH_SIZE 2
Expand Down Expand Up @@ -74,6 +75,7 @@ class BatchConfig {
static int const MAX_NUM_REQUESTS = 65;
static int const MAX_NUM_TOKENS = 1024;
static int const MAX_SPEC_TREE_TOKEN_NUM = 64;
static int const MAX_PEFT_CONFIG_SIZE = 1024;

// Set by update

Expand All @@ -89,11 +91,12 @@ class BatchConfig {
num_tokens_in_batch = 0;
max_length = 0;
request_guid = 0;
peft_model_id = PEFTModelID::NO_ID;
prompt_phase = false;
batch_config_request_id = -1;
peft_model_id = PEFTModelID::NO_ID;
peft_bwd = false;
optimizer_tasks = {true, false, false, false};
std::memset(peft_model_config_str, 0, MAX_PEFT_CONFIG_SIZE);
}
int first_token_depth_in_request;
int first_token_offset_in_batch;
Expand All @@ -106,6 +109,7 @@ class BatchConfig {
RequestGuid request_guid;
// PEFT fields
PEFTModelID peft_model_id;
char peft_model_config_str[MAX_PEFT_CONFIG_SIZE];
bool peft_bwd;
OptimizerTasks optimizer_tasks;
};
Expand Down
4 changes: 0 additions & 4 deletions include/flexflow/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,6 @@ struct FFHandler {
// PEFT related fields
MemoryAllocator *peft_activation_allocator;
size_t peft_activation_reserve_space_size;
PEFTWeightAllocator *peft_weight_allocator;
size_t peft_weight_reserve_space_size;
// Quantization fields
DataType quantization_type;
bool allowTensorOpMathConversion;
Expand All @@ -118,7 +116,6 @@ struct FFInitInfo {
size_t workSpaceSize;
size_t offload_reserve_space_size;
size_t peft_activation_reserve_space_size;
size_t peft_weight_reserve_space_size;
DataType quantization_type;
bool allowTensorOpMathConversion;
// int myRank, allRanks;
Expand Down Expand Up @@ -179,7 +176,6 @@ class FFConfig {
// PEFT related fields
bool enable_peft;
size_t peft_activation_reserve_space_size;
size_t peft_weight_reserve_space_size;
// Control parallelizable dimensions
bool only_data_parallel;
bool enable_sample_parallel;
Expand Down
1 change: 1 addition & 0 deletions include/flexflow/fftype.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class PEFTModelID {
PEFTModelID(size_t id);
bool is_valid_id() const;
friend bool operator==(PEFTModelID const &lhs, PEFTModelID const &rhs);
friend bool operator!=(PEFTModelID const &lhs, PEFTModelID const &rhs);
friend std::ostream &operator<<(std::ostream &os,
PEFTModelID const &peft_model_id);

Expand Down
11 changes: 10 additions & 1 deletion include/flexflow/flexflow_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ int flexflow_config_get_tensor_parallelism_degree(flexflow_config_t handle_);

int flexflow_config_get_pipeline_parallelism_degree(flexflow_config_t handle_);

bool flexflow_config_get_enable_peft(flexflow_config_t handle_);

void flexflow_config_set_data_parallelism_degree(flexflow_config_t handle_,
int value);

Expand Down Expand Up @@ -622,7 +624,11 @@ flexflow_tensor_t flexflow_model_add_argmax(flexflow_model_t handle_,
bool beam_search,
char const *name);

flexflow_peft_model_id_t flexflow_model_add_lora_layer(
void flexflow_model_add_lora_layers(flexflow_model_t handle_,
int num_target_modules,
char const **target_modules_);

flexflow_peft_model_id_t flexflow_model_register_peft_adapter(
flexflow_model_t handle_, const flexflow_lora_linear_config_t peft_config_);

void flexflow_model_set_sgd_optimizer(flexflow_model_t handle,
Expand Down Expand Up @@ -1023,6 +1029,9 @@ void flexflow_request_manager_set_max_sequence_length(
int flexflow_request_manager_get_max_sequence_length(
flexflow_request_manager_t handle_);

void flexflow_request_manager_set_max_concurrent_adapters(
flexflow_request_manager_t handle_, int max_concurrent_adapters);

void flexflow_request_manager_set_enable_peft_finetuning(
flexflow_request_manager_t handle_, bool enable_peft_finetuning_);

Expand Down
11 changes: 7 additions & 4 deletions include/flexflow/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ enum TaskIDs {
RM_PREPARE_NEXT_BATCH_BEAM_TASK_ID,
RM_PREPARE_NEXT_BATCH_VERIFY_TASK_ID,
RM_BACKGROUND_SERVING_TASK_ID,
LOAD_WEIGHT_TASK_ID,
// Custom tasks
CUSTOM_GPU_TASK_ID_FIRST,
CUSTOM_GPU_TASK_ID_1,
Expand Down Expand Up @@ -835,7 +836,9 @@ class FFModel {
// ========================================
// PEFT Layers
// ========================================
PEFTModelID *add_lora_layer(LoraLinearConfig const peft_config);
// PEFTModelID *add_lora_layer(LoraLinearConfig const peft_config);
void add_lora_layers(std::vector<std::string> target_modules);
PEFTModelID *register_peft_adapter(LoraLinearConfig const &peft_config);
// ========================================
// Inference APIs
// ========================================
Expand Down Expand Up @@ -1170,9 +1173,9 @@ class FFModel {
std::vector<ParallelTensor> parameters;
// PEFT related
std::unordered_map<Layer *, Layer *> base_layer_to_peft_layer;
std::unordered_map<Layer *, std::vector<PEFTModelID>> peft_layer_to_peft_id;
std::unordered_map<PEFTModelID, LoraLinearConfig> peft_configs;
// std::vector<Op *> peft_operators;
// std::unordered_map<Layer *, std::vector<PEFTModelID>>
// peft_layer_to_peft_id; std::unordered_map<PEFTModelID, LoraLinearConfig>
// peft_configs; std::vector<Op *> peft_operators;

FFHandler handlers[MAX_NUM_WORKERS];
Legion::Future current_metrics;
Expand Down
2 changes: 1 addition & 1 deletion include/flexflow/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ class Op {
// get operator name and print it
std::string op_name_without_uid = get_op_name_without_uid(m);
std::cout << (fwd_pass ? "INF " : "BWD ") << op_name_without_uid
<< std::endl;
<< (before_kernel ? " (before kernel)" : "") << std::endl;
// build the path to save the tensor
fs::path dst_filepath;
if (fwd_pass) {
Expand Down
2 changes: 2 additions & 0 deletions include/flexflow/ops/kernels/linear_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ void inference_kernel_wrapper(LinearMeta *m,
int out_dim,
int batch_size);
void peft_bwd_kernel_wrapper(LinearMeta const *m,
BatchConfig const *bc,
void *input_grad_ptr,
void *output_grad_ptr,
void const *kernel_ptr,
Expand Down Expand Up @@ -94,6 +95,7 @@ void forward_kernel(LinearMeta const *m,
ffStream_t stream);
template <typename DT>
void peft_bwd_kernel(LinearMeta const *m,
BatchConfig const *bc,
void *input_grad_ptr,
void *output_grad_ptr,
void const *kernel_ptr,
Expand Down
38 changes: 12 additions & 26 deletions include/flexflow/ops/kernels/lora_linear_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,43 +6,27 @@
#include "flexflow/fftype.h"
#include "flexflow/op_meta.h"
#include "flexflow/ops/lora_linear.h"
#include "flexflow/utils/peft_weight_allocator.h"

namespace FlexFlow {

using Legion::Context;
using Legion::Runtime;
struct LoraLinearWeight {
// weights
void *w0_ptr, *w1_ptr;
// gradients
void *w0_grad_ptr, *w1_grad_ptr;
// v values for SGD optimizer (when using momentum)
void *w0_v_values_ptr, *w1_v_values_ptr;
int in_dim, out_dim, rank, num_shards;
};

struct LoraLinearModelState {
LoraLinearWeight weights;
LoraOptimizerConfig const *optimizer_config;
float lora_alpha;
std::string cache_folder;
// Huggingface model ID (for download and/or upload)
std::string peft_model_id;
};

class LoraLinearMeta : public OpMeta {
public:
LoraLinearMeta(FFHandler handle, LoraLinear const *li);
~LoraLinearMeta(void);
// PEFT related fields
void *low_rank_activation;
void *input_activation;
std::unordered_map<PEFTModelID, LoraLinearModelState> model_state;
size_t allocated_peft_buffer_size1 = 0, allocated_peft_buffer_size2 = 0;
PEFTMemoryManager *peft_memory_manager;
};

namespace Kernels {
namespace LoraLinear {
void init_kernel_wrapper(LoraLinearMeta *m, int seed);

bool lora_applies_to_this_layer(LoraLinearMeta *m,
LoraLinearConfig const &config);

// void init_kernel_wrapper(LoraLinearMeta *m, int seed);
void inference_kernel_wrapper(LoraLinearMeta *m,
BatchConfig const *bc,
GenericTensorAccessorR const &input,
Expand All @@ -51,12 +35,13 @@ void peft_bwd_kernel_wrapper(Context ctx,
Runtime *runtime,
LoraLinearMeta *m,
BatchConfig const *bc,
int shard_id,
GenericTensorAccessorW const &input_grad,
GenericTensorAccessorR const &output_grad);

namespace Internal {
template <typename DT>
void init_kernel(LoraLinearMeta *m, int seed, ffStream_t stream);
// template <typename DT>
// void init_kernel(LoraLinearMeta *m, int seed, ffStream_t stream);
template <typename DT>
void inference_kernel(LoraLinearMeta *m,
BatchConfig const *bc,
Expand All @@ -70,6 +55,7 @@ void peft_bwd_kernel(Context ctx,
Runtime *runtime,
LoraLinearMeta *m,
BatchConfig const *bc,
int shard_id,
DT *input_grad_ptr,
DT const *output_grad_ptr,
int in_dim,
Expand Down
19 changes: 10 additions & 9 deletions include/flexflow/ops/lora_linear.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@ class LoraLinear : public Op {
using Params = LoraLinearParams;
using Input = std::pair<ParallelTensor, ParallelTensor>;

LoraLinear(
FFModel &model,
LayerID const &layer_guid,
OperatorType type,
ParallelTensor const input,
ParallelTensor const output,
std::unordered_map<PEFTModelID, LoraLinearConfig> const &_peft_configs,
char const *name = nullptr);
LoraLinear(FFModel &model,
LayerID const &layer_guid,
ParallelTensor const input,
ParallelTensor const output,
int max_rank,
int max_concurrent_adapters,
char const *name = nullptr);
LoraLinear(FFModel &model,
LoraLinear const &other,
ParallelTensor const input,
Expand Down Expand Up @@ -91,7 +90,9 @@ class LoraLinear : public Op {
// size_t get_params_hash() const override;
LoraLinearParams get_params() const;

std::unordered_map<PEFTModelID, LoraLinearConfig> peft_configs;
// std::unordered_map<PEFTModelID, LoraLinearConfig> peft_configs;
int max_rank;
int max_concurrent_adapters;
};

}; // namespace FlexFlow
Expand Down
Loading

0 comments on commit 7dcbd62

Please sign in to comment.