diff --git a/docker/flexflow-environment/Dockerfile b/docker/flexflow-environment/Dockerfile index d571befdda..2af81de11f 100644 --- a/docker/flexflow-environment/Dockerfile +++ b/docker/flexflow-environment/Dockerfile @@ -7,7 +7,7 @@ LABEL org.opencontainers.image.description="FlexFlow environment container" SHELL ["/bin/bash", "-c"] # Install basic dependencies -RUN apt-get update && apt-get install -y --no-install-recommends wget sudo binutils git zlib1g-dev lsb-release nano gdb libhdf5-dev jq && \ +RUN apt-get update && apt-get install -y --no-install-recommends wget sudo binutils git zlib1g-dev lsb-release nano gdb libhdf5-dev jq openssh-client && \ rm -rf /var/lib/apt/lists/* /etc/apt/sources.list.d/cuda.list /etc/apt/sources.list.d/nvidia-ml.list && \ apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends software-properties-common && \ apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends build-essential apt-utils \ @@ -125,6 +125,7 @@ RUN pip3 install transformers>=4.31.0 sentencepiece einops RUN pip3 install tensorflow notebook # PEFT-related RUN pip3 install scipy bitsandbytes datasets accelerate loralib triton peft +RUN pip3 install streamlit # Install Rust RUN curl https://sh.rustup.rs -sSf | sh -s -- -y diff --git a/docker/run.sh b/docker/run.sh index 46c63bab6f..759da521aa 100755 --- a/docker/run.sh +++ b/docker/run.sh @@ -17,6 +17,11 @@ hip_version=${hip_version:-"empty"} ATTACH_GPUS=${ATTACH_GPUS:-true} gpu_arg="" if $ATTACH_GPUS ; then gpu_arg="--gpus all" ; fi +FORWARD_STREAMLIT_PORT=${FORWARD_STREAMLIT_PORT:-true} +port_forward_arg="" +if $FORWARD_STREAMLIT_PORT ; then + port_forward_arg+="-p 8501:8501" +fi # Amount of shared memory to give the Docker container access to @@ -120,4 +125,10 @@ if [ -f "$hf_token_path" ]; then hf_token_volume+="-v $hf_token_path:/root/.cache/huggingface/token" fi -eval docker run -it "$gpu_arg" "--shm-size=${SHM_SIZE}" "--cap-add=SYS_PTRACE" "${hf_token_volume}" "${image}-${FF_GPU_BACKEND}${gpu_backend_version}:latest" +ssh_key_volume="" +ssh_key_path="$HOME/.ssh/id_rsa" +if [ -f "$ssh_key_path" ]; then + # If the token exists, add the volume mount to the Docker command + ssh_key_volume+="-v $ssh_key_path:/root/.ssh/id_rsa" +fi +eval docker run -it "$gpu_arg" "--shm-size=${SHM_SIZE}" "--cap-add=SYS_PTRACE" "${ssh_key_volume}" "${hf_token_volume}" "${port_forward_arg}" "${image}-${FF_GPU_BACKEND}${gpu_backend_version}:latest" diff --git a/include/flexflow/batch_config.h b/include/flexflow/batch_config.h index a509af765c..bb8b4c67f6 100644 --- a/include/flexflow/batch_config.h +++ b/include/flexflow/batch_config.h @@ -20,6 +20,7 @@ #include "legion.h" #include #include +#include // #define MAX_SEQ_LEN 1024 // #define BATCH_SIZE 2 @@ -74,6 +75,7 @@ class BatchConfig { static int const MAX_NUM_REQUESTS = 65; static int const MAX_NUM_TOKENS = 1024; static int const MAX_SPEC_TREE_TOKEN_NUM = 64; + static int const MAX_PEFT_CONFIG_SIZE = 1024; // Set by update @@ -89,11 +91,12 @@ class BatchConfig { num_tokens_in_batch = 0; max_length = 0; request_guid = 0; + peft_model_id = PEFTModelID::NO_ID; prompt_phase = false; batch_config_request_id = -1; - peft_model_id = PEFTModelID::NO_ID; peft_bwd = false; optimizer_tasks = {true, false, false, false}; + std::memset(peft_model_config_str, 0, MAX_PEFT_CONFIG_SIZE); } int first_token_depth_in_request; int first_token_offset_in_batch; @@ -106,6 +109,7 @@ class BatchConfig { RequestGuid request_guid; // PEFT fields PEFTModelID peft_model_id; + char peft_model_config_str[MAX_PEFT_CONFIG_SIZE]; bool peft_bwd; OptimizerTasks optimizer_tasks; }; diff --git a/include/flexflow/config.h b/include/flexflow/config.h index dd9d657117..37afa0df27 100644 --- a/include/flexflow/config.h +++ b/include/flexflow/config.h @@ -104,8 +104,6 @@ struct FFHandler { // PEFT related fields MemoryAllocator *peft_activation_allocator; size_t peft_activation_reserve_space_size; - PEFTWeightAllocator *peft_weight_allocator; - size_t peft_weight_reserve_space_size; // Quantization fields DataType quantization_type; bool allowTensorOpMathConversion; @@ -118,7 +116,6 @@ struct FFInitInfo { size_t workSpaceSize; size_t offload_reserve_space_size; size_t peft_activation_reserve_space_size; - size_t peft_weight_reserve_space_size; DataType quantization_type; bool allowTensorOpMathConversion; // int myRank, allRanks; @@ -179,7 +176,6 @@ class FFConfig { // PEFT related fields bool enable_peft; size_t peft_activation_reserve_space_size; - size_t peft_weight_reserve_space_size; // Control parallelizable dimensions bool only_data_parallel; bool enable_sample_parallel; diff --git a/include/flexflow/fftype.h b/include/flexflow/fftype.h index 3e482b8d67..ebc811c262 100644 --- a/include/flexflow/fftype.h +++ b/include/flexflow/fftype.h @@ -27,6 +27,7 @@ class PEFTModelID { PEFTModelID(size_t id); bool is_valid_id() const; friend bool operator==(PEFTModelID const &lhs, PEFTModelID const &rhs); + friend bool operator!=(PEFTModelID const &lhs, PEFTModelID const &rhs); friend std::ostream &operator<<(std::ostream &os, PEFTModelID const &peft_model_id); diff --git a/include/flexflow/flexflow_c.h b/include/flexflow/flexflow_c.h index 6501b0658c..677f9915cd 100644 --- a/include/flexflow/flexflow_c.h +++ b/include/flexflow/flexflow_c.h @@ -91,6 +91,8 @@ int flexflow_config_get_tensor_parallelism_degree(flexflow_config_t handle_); int flexflow_config_get_pipeline_parallelism_degree(flexflow_config_t handle_); +bool flexflow_config_get_enable_peft(flexflow_config_t handle_); + void flexflow_config_set_data_parallelism_degree(flexflow_config_t handle_, int value); @@ -622,7 +624,11 @@ flexflow_tensor_t flexflow_model_add_argmax(flexflow_model_t handle_, bool beam_search, char const *name); -flexflow_peft_model_id_t flexflow_model_add_lora_layer( +void flexflow_model_add_lora_layers(flexflow_model_t handle_, + int num_target_modules, + char const **target_modules_); + +flexflow_peft_model_id_t flexflow_model_register_peft_adapter( flexflow_model_t handle_, const flexflow_lora_linear_config_t peft_config_); void flexflow_model_set_sgd_optimizer(flexflow_model_t handle, @@ -1023,6 +1029,9 @@ void flexflow_request_manager_set_max_sequence_length( int flexflow_request_manager_get_max_sequence_length( flexflow_request_manager_t handle_); +void flexflow_request_manager_set_max_concurrent_adapters( + flexflow_request_manager_t handle_, int max_concurrent_adapters); + void flexflow_request_manager_set_enable_peft_finetuning( flexflow_request_manager_t handle_, bool enable_peft_finetuning_); diff --git a/include/flexflow/model.h b/include/flexflow/model.h index 51b7950db8..e352159af0 100644 --- a/include/flexflow/model.h +++ b/include/flexflow/model.h @@ -278,6 +278,7 @@ enum TaskIDs { RM_PREPARE_NEXT_BATCH_BEAM_TASK_ID, RM_PREPARE_NEXT_BATCH_VERIFY_TASK_ID, RM_BACKGROUND_SERVING_TASK_ID, + LOAD_WEIGHT_TASK_ID, // Custom tasks CUSTOM_GPU_TASK_ID_FIRST, CUSTOM_GPU_TASK_ID_1, @@ -835,7 +836,9 @@ class FFModel { // ======================================== // PEFT Layers // ======================================== - PEFTModelID *add_lora_layer(LoraLinearConfig const peft_config); + // PEFTModelID *add_lora_layer(LoraLinearConfig const peft_config); + void add_lora_layers(std::vector target_modules); + PEFTModelID *register_peft_adapter(LoraLinearConfig const &peft_config); // ======================================== // Inference APIs // ======================================== @@ -1170,9 +1173,9 @@ class FFModel { std::vector parameters; // PEFT related std::unordered_map base_layer_to_peft_layer; - std::unordered_map> peft_layer_to_peft_id; - std::unordered_map peft_configs; - // std::vector peft_operators; + // std::unordered_map> + // peft_layer_to_peft_id; std::unordered_map + // peft_configs; std::vector peft_operators; FFHandler handlers[MAX_NUM_WORKERS]; Legion::Future current_metrics; diff --git a/include/flexflow/operator.h b/include/flexflow/operator.h index 007314797a..c108740ef3 100644 --- a/include/flexflow/operator.h +++ b/include/flexflow/operator.h @@ -280,7 +280,7 @@ class Op { // get operator name and print it std::string op_name_without_uid = get_op_name_without_uid(m); std::cout << (fwd_pass ? "INF " : "BWD ") << op_name_without_uid - << std::endl; + << (before_kernel ? " (before kernel)" : "") << std::endl; // build the path to save the tensor fs::path dst_filepath; if (fwd_pass) { diff --git a/include/flexflow/ops/kernels/linear_kernels.h b/include/flexflow/ops/kernels/linear_kernels.h index 90e50a0c9a..aaa845db23 100644 --- a/include/flexflow/ops/kernels/linear_kernels.h +++ b/include/flexflow/ops/kernels/linear_kernels.h @@ -61,6 +61,7 @@ void inference_kernel_wrapper(LinearMeta *m, int out_dim, int batch_size); void peft_bwd_kernel_wrapper(LinearMeta const *m, + BatchConfig const *bc, void *input_grad_ptr, void *output_grad_ptr, void const *kernel_ptr, @@ -94,6 +95,7 @@ void forward_kernel(LinearMeta const *m, ffStream_t stream); template void peft_bwd_kernel(LinearMeta const *m, + BatchConfig const *bc, void *input_grad_ptr, void *output_grad_ptr, void const *kernel_ptr, diff --git a/include/flexflow/ops/kernels/lora_linear_kernels.h b/include/flexflow/ops/kernels/lora_linear_kernels.h index eee9875d30..fd86dc68c0 100644 --- a/include/flexflow/ops/kernels/lora_linear_kernels.h +++ b/include/flexflow/ops/kernels/lora_linear_kernels.h @@ -6,43 +6,27 @@ #include "flexflow/fftype.h" #include "flexflow/op_meta.h" #include "flexflow/ops/lora_linear.h" +#include "flexflow/utils/peft_weight_allocator.h" namespace FlexFlow { + using Legion::Context; using Legion::Runtime; -struct LoraLinearWeight { - // weights - void *w0_ptr, *w1_ptr; - // gradients - void *w0_grad_ptr, *w1_grad_ptr; - // v values for SGD optimizer (when using momentum) - void *w0_v_values_ptr, *w1_v_values_ptr; - int in_dim, out_dim, rank, num_shards; -}; - -struct LoraLinearModelState { - LoraLinearWeight weights; - LoraOptimizerConfig const *optimizer_config; - float lora_alpha; - std::string cache_folder; - // Huggingface model ID (for download and/or upload) - std::string peft_model_id; -}; class LoraLinearMeta : public OpMeta { public: LoraLinearMeta(FFHandler handle, LoraLinear const *li); ~LoraLinearMeta(void); - // PEFT related fields - void *low_rank_activation; - void *input_activation; - std::unordered_map model_state; - size_t allocated_peft_buffer_size1 = 0, allocated_peft_buffer_size2 = 0; + PEFTMemoryManager *peft_memory_manager; }; namespace Kernels { namespace LoraLinear { -void init_kernel_wrapper(LoraLinearMeta *m, int seed); + +bool lora_applies_to_this_layer(LoraLinearMeta *m, + LoraLinearConfig const &config); + +// void init_kernel_wrapper(LoraLinearMeta *m, int seed); void inference_kernel_wrapper(LoraLinearMeta *m, BatchConfig const *bc, GenericTensorAccessorR const &input, @@ -51,12 +35,13 @@ void peft_bwd_kernel_wrapper(Context ctx, Runtime *runtime, LoraLinearMeta *m, BatchConfig const *bc, + int shard_id, GenericTensorAccessorW const &input_grad, GenericTensorAccessorR const &output_grad); namespace Internal { -template -void init_kernel(LoraLinearMeta *m, int seed, ffStream_t stream); +// template +// void init_kernel(LoraLinearMeta *m, int seed, ffStream_t stream); template void inference_kernel(LoraLinearMeta *m, BatchConfig const *bc, @@ -70,6 +55,7 @@ void peft_bwd_kernel(Context ctx, Runtime *runtime, LoraLinearMeta *m, BatchConfig const *bc, + int shard_id, DT *input_grad_ptr, DT const *output_grad_ptr, int in_dim, diff --git a/include/flexflow/ops/lora_linear.h b/include/flexflow/ops/lora_linear.h index 9e83c3f90e..cc625cafc2 100644 --- a/include/flexflow/ops/lora_linear.h +++ b/include/flexflow/ops/lora_linear.h @@ -17,14 +17,13 @@ class LoraLinear : public Op { using Params = LoraLinearParams; using Input = std::pair; - LoraLinear( - FFModel &model, - LayerID const &layer_guid, - OperatorType type, - ParallelTensor const input, - ParallelTensor const output, - std::unordered_map const &_peft_configs, - char const *name = nullptr); + LoraLinear(FFModel &model, + LayerID const &layer_guid, + ParallelTensor const input, + ParallelTensor const output, + int max_rank, + int max_concurrent_adapters, + char const *name = nullptr); LoraLinear(FFModel &model, LoraLinear const &other, ParallelTensor const input, @@ -91,7 +90,9 @@ class LoraLinear : public Op { // size_t get_params_hash() const override; LoraLinearParams get_params() const; - std::unordered_map peft_configs; + // std::unordered_map peft_configs; + int max_rank; + int max_concurrent_adapters; }; }; // namespace FlexFlow diff --git a/include/flexflow/ops/lora_linear_params.h b/include/flexflow/ops/lora_linear_params.h index 70539271f2..46b88c9690 100644 --- a/include/flexflow/ops/lora_linear_params.h +++ b/include/flexflow/ops/lora_linear_params.h @@ -17,6 +17,9 @@ namespace FlexFlow { class LoraOptimizerConfig { public: LoraOptimizerConfig(); + virtual std::string getType() const = 0; + virtual nlohmann::json toJson() const = 0; + static LoraOptimizerConfig *fromJson(nlohmann::json const &j); virtual ~LoraOptimizerConfig() {} }; @@ -29,9 +32,11 @@ class LoraSGDOptimizerConfig : public LoraOptimizerConfig { bool weight_decay_ = 0.0f); friend std::ostream &operator<<(std::ostream &os, LoraSGDOptimizerConfig const &llc); - - NLOHMANN_DEFINE_TYPE_INTRUSIVE( - LoraSGDOptimizerConfig, lr, momentum, nesterov, weight_decay) + std::string getType() const override { + return "SGD"; + } + nlohmann::json toJson() const override; + static LoraSGDOptimizerConfig *fromJson(nlohmann::json const &j); public: double lr = 0.001f; @@ -51,8 +56,11 @@ class LoraAdamOptimizerConfig : public LoraOptimizerConfig { friend std::ostream &operator<<(std::ostream &os, LoraAdamOptimizerConfig const &llc); - NLOHMANN_DEFINE_TYPE_INTRUSIVE( - LoraAdamOptimizerConfig, alpha, beta1, beta2, weight_decay, epsilon) + std::string getType() const override { + return "Adam"; + } + nlohmann::json toJson() const override; + static LoraAdamOptimizerConfig *fromJson(nlohmann::json const &j); public: // Adam @@ -63,14 +71,6 @@ class LoraAdamOptimizerConfig : public LoraOptimizerConfig { double epsilon = 1e-8; }; -// Serialization helpers -template -void serialize_to_json_file(T const &obj, fs::path const &filepath); - -// Function to deserialize JSON from file and create object -template -std::unique_ptr deserialize_from_json_file(fs::path const &filepath); - class LoraLinearConfig { public: static const LoraLinearConfig EmptyConfig; @@ -92,17 +92,14 @@ class LoraLinearConfig { friend std::ostream &operator<<(std::ostream &os, LoraLinearConfig const &llc); - NLOHMANN_DEFINE_TYPE_INTRUSIVE(LoraLinearConfig, - cache_folder, - peft_model_id, - rank, - lora_alpha, - lora_dropout, - target_modules, - trainable, - init_lora_weights, - base_model_name_or_path, - precision) + std::string serialize_to_json_string(int indent = -1) const; + void serialize_to_json_file(std::string const &filename) const; + // Deserialization method + static LoraLinearConfig + deserialize_from_json_string(std::string const &json_string); + // Deserialization method + static LoraLinearConfig + deserialize_from_json_file(std::string const &filename); std::string cache_folder; // Huggingface model ID (for download and/or upload) @@ -128,8 +125,8 @@ class LoraLinearConfig { class LoraLinearParams { public: LayerID layer_guid; - OperatorType type; - std::unordered_map peft_configs; + int max_rank; + int max_concurrent_adapters; char name[MAX_OPNAME]; bool is_valid(std::pair const @@ -147,4 +144,4 @@ struct hash { }; } // namespace std -#endif // _FLEXFLOW_LORA_LINEAR_PARAMS_H +#endif // _FLEXFLOW_LORA_LINEAR_PARAMS_H \ No newline at end of file diff --git a/include/flexflow/request_manager.h b/include/flexflow/request_manager.h index d62b610f3d..c15c0ff8b4 100644 --- a/include/flexflow/request_manager.h +++ b/include/flexflow/request_manager.h @@ -150,6 +150,13 @@ class RequestManager { std::vector eos_token_ids, std::string const &path); void register_output_filepath(std::string const &); + void set_peft_config(PEFTModelID const &peft_model_id, + LoraLinearConfig const &peft_config); + LoraLinearConfig const &get_peft_config(PEFTModelID const &peft_model_id); + void set_max_lora_rank(int max_lora_rank); + void set_max_concurrent_adapters(int max_concurrent_adapters); + int get_max_lora_rank(); + int get_max_concurrent_adapters(); void initBitMask(BatchConfig::BitMask &bitmask, int initLength); void appendPendingRequest(BatchConfig::BitMask &bitmask, int initLength); void appendBitMask(BatchConfig::BitMask &bitmask, @@ -182,6 +189,9 @@ class RequestManager { bool is_eos_token(int token_id); bool check_inf_req_completion(BatchConfig const &old_bc, int i); void check_batch(BatchConfig const &old_bc, BatchConfig const &new_bc); + void add_peft_config_to_request_info(BatchConfig &bc, + int req_idx, + LoraLinearConfig const &peft_config); BatchConfig prepare_next_batch(BatchConfig const &bc, InferenceResult const &result); BatchConfigFuture prepare_next_batch(BatchConfigFuture const &bc, @@ -291,6 +301,10 @@ class RequestManager { int max_sequence_length; Status request_manager_status; + // peft + std::unordered_map peft_configs; + int max_lora_rank = 32; + int max_concurrent_adapters = 0; // peft benchmarking bool enable_peft_finetuning = false; static bool inference_finished; diff --git a/include/flexflow/utils/file_loader.h b/include/flexflow/utils/file_loader.h index 646eb18da2..8735f23571 100644 --- a/include/flexflow/utils/file_loader.h +++ b/include/flexflow/utils/file_loader.h @@ -39,7 +39,13 @@ class FileDataLoader { void load_single_weight_tensor(FFModel *ff, Layer *l, int weight_idx); void load_quantization_weight(FFModel *ff, Layer *l, int weight_idx); - void load_weights(FFModel *ff); + + static void + load_weight_task(Legion::Task const *task, + std::vector const ®ions, + Legion::Context ctx, + Legion::Runtime *runtime); + void load_weights_parallel(FFModel *ff, Context ctx, Runtime *runtime); void load_positions(FFModel *ff, Tensor pt, @@ -54,3 +60,18 @@ class FileDataLoader { std::string weights_folder; bool use_full_precision; }; + +struct WeightLoadTaskArgs { + FFModel *ff; + FileDataLoader *loader; + Layer *layer; + int weight_idx; + DataType data_type; + WeightLoadTaskArgs(FFModel *_ff, + FileDataLoader *_loader, + Layer *_l, + int _idx, + DataType _data_type) + : ff(_ff), loader(_loader), layer(_l), weight_idx(_idx), + data_type(_data_type) {} +}; diff --git a/include/flexflow/utils/peft_weight_allocator.h b/include/flexflow/utils/peft_weight_allocator.h index dae46a8af1..21ac9bf426 100644 --- a/include/flexflow/utils/peft_weight_allocator.h +++ b/include/flexflow/utils/peft_weight_allocator.h @@ -17,76 +17,121 @@ #define _FLEXFLOW_UTILS_PEFT_WEIGHT_ALLOCATOR_H_ #include "flexflow/config.h" -#include +#include "flexflow/ffconst_utils.h" +#include "flexflow/ops/lora_linear_params.h" +// #include namespace FlexFlow { -class PEFTWeightAllocator { -public: - PEFTWeightAllocator(void *_base_ptr, size_t _total_size) - : base_ptr(_base_ptr), total_size(_total_size), sync_offset(0), - local_offset(_total_size) {} +struct LoraLinearWeight { + // weights + void *w0_ptr, *w1_ptr; + // gradients + void *w0_grad_ptr, *w1_grad_ptr; + // activations + void *input_activation; + void *low_rank_activation; + // v values for SGD optimizer (when using momentum) + void *w0_v_values_ptr, *w1_v_values_ptr; + LoraLinearWeight(void *w0 = nullptr, + void *w1 = nullptr, + void *w0_grad = nullptr, + void *w1_grad = nullptr, + void *w0_v_values = nullptr, + void *w1_v_values = nullptr, + void *low_rank_activation_ = nullptr, + void *input_activation_ = nullptr) + : w0_ptr(w0), w1_ptr(w1), w0_grad_ptr(w0_grad), w1_grad_ptr(w1_grad), + w0_v_values_ptr(w0_v_values), w1_v_values_ptr(w1_v_values), + low_rank_activation(low_rank_activation_), + input_activation(input_activation_) {} +}; - inline void *allocate_sync_weights_untyped(PEFTModelID const &peft_model_id, - size_t datalen) { - const std::lock_guard lock(peft_weight_allocator_mutex); - void *ptr = static_cast(base_ptr) + sync_offset; - off_t model_sync_weights_offset = sync_offset; - size_t model_sync_weights_size = datalen; - if (sync_weights.find(peft_model_id) != sync_weights.end()) { - // Assert that sync weights for each PEFT model is consecutive - std::pair offset_and_size = sync_weights[peft_model_id]; - assert(sync_offset == offset_and_size.first + offset_and_size.second); - model_sync_weights_offset = offset_and_size.first; - model_sync_weights_size = offset_and_size.second + datalen; - } - sync_offset += datalen; - assert(sync_offset < local_offset); - sync_weights[peft_model_id] = - std::make_pair(model_sync_weights_offset, model_sync_weights_size); - return ptr; - } +void init_peft_weight_wrapper(LoraLinearWeight const &weight, + int in_dim, + int out_dim, + int rank, + DataType dt, + int seed); - std::pair - get_sync_weights_ptr_and_size(PEFTModelID const &peft_model_id) { - const std::lock_guard lock(peft_weight_allocator_mutex); - assert(sync_weights.find(peft_model_id) != sync_weights.end()); - std::pair offset_and_size = sync_weights[peft_model_id]; - return std::make_pair(static_cast(base_ptr) + offset_and_size.first, - offset_and_size.second); +class PEFTMemoryManager { +public: + PEFTMemoryManager(Legion::Memory gpu_mem_, + int max_rank_, + int max_concurrent_adapters_, + int max_peft_tokens_, + int in_dim_, + int out_dim_, + int num_shards_, + int shard_id_, + std::string const &lora_layername_substr_, + DataType dt_) + : gpu_mem(gpu_mem_), max_concurrent_adapters(max_concurrent_adapters_), + max_rank(max_rank_), in_dim(in_dim_), out_dim(out_dim_), + num_shards(num_shards_), shard_id(shard_id_), + max_peft_tokens(max_peft_tokens_), + lora_layername_substr(lora_layername_substr_), dt(dt_), + base_ptr(nullptr), finetuning_ptr(nullptr), + finetuning_model_id(PEFTModelID::NO_ID) { + max_lora_size = + data_type_size(dt) * (max_rank * in_dim + max_rank * out_dim); + assert(max_concurrent_adapters > 0 && + "PEFT Memory Manager max_concurrent_adapters must be > 0"); + assert(max_lora_size > 0 && + "PEFT Memory Manager max_lora_size must be > 0"); + allocate_inference_memory(); + // finetuning memory is allocated upon the first finetuning request, so we + // can skip for inference-only workloads } - inline void *allocate_local_weights_untyped(PEFTModelID const &peft_model_id, - size_t datalen) { - const std::lock_guard lock(peft_weight_allocator_mutex); - local_offset -= datalen; - assert(sync_offset < local_offset); - void *ptr = static_cast(base_ptr) + local_offset; - return ptr; - } + // allocate memory for all the PEFT adapters for a given layer on a given + // shard + void allocate_inference_memory(); + // allocate memory for the PEFT adapter for a finetuning request for a given + // layer and shard + void allocate_finetuning_memory(); - template - inline DT *allocate_sync_weights(PEFTModelID const &peft_model_id, - size_t count) { - return static_cast
( - allocate_sync_weights_untyped(peft_model_id, sizeof(DT) * count)); - } + LoraLinearWeight get_peft(PEFTModelID const &model_id, + LoraLinearConfig const &lora_config); + void check_ft_model_id(PEFTModelID const &model_id); - template - inline DT *allocate_local_weights(PEFTModelID const &peft_model_id, - size_t count) { - return static_cast
( - allocate_local_weights_untyped(peft_model_id, sizeof(DT) * count)); - } +private: + // Check if the PEFT adapter for the given model is in memory. If not, sets + // the cache_miss flag to true. If this is the first finetuning request, + // allocate memory for the finetuning adapter. + void get_finetuning_slot(PEFTModelID const &model_id, bool *cache_miss); + // Returns the slot in memory where the peft model weights are/will be stored. + // If the model is not in memory (cache miss), set the cache_miss flag to + // true. + int get_inference_peft_slot(PEFTModelID const &model_id, bool *cache_miss); + void load_peft_model(LoraLinearWeight &weight, + LoraLinearConfig const &lora_config); + LoraLinearWeight get_inference_peft(PEFTModelID const &model_id, + LoraLinearConfig const &lora_config); + LoraLinearWeight get_finetuning_peft(PEFTModelID const &model_id, + LoraLinearConfig const &lora_config); -public: - void *base_ptr; - size_t total_size; - off_t sync_offset, local_offset; - std::unordered_map> sync_weights; - std::mutex peft_weight_allocator_mutex; + // Legion memory management apparatus + Legion::Memory gpu_mem; + Realm::RegionInstance peftLegionInst; + void *base_ptr, *finetuning_ptr; + // Size and shapes + int max_concurrent_adapters; + int max_rank; + int max_lora_size; + int in_dim, out_dim, num_shards, shard_id; + int max_peft_tokens; + // LRU cache apparatus + std::unordered_map lru_hashtable; + std::vector + lru_list; // head = least recently used, tail=most recently used + std::unordered_map peft2mem_slot; + // Miscellanea + std::string lora_layername_substr; + DataType dt; + PEFTModelID finetuning_model_id; }; -}; // namespace FlexFlow +} // namespace FlexFlow #endif // _FLEXFLOW_UTILS_PEFT_WEIGHT_ALLOCATOR_H_ diff --git a/inference/models/falcon.cc b/inference/models/falcon.cc index fd4da87b99..b4f961b006 100644 --- a/inference/models/falcon.cc +++ b/inference/models/falcon.cc @@ -269,6 +269,14 @@ void FALCON::create_falcon_model(FFModel &ff, output = ff.argmax(lm_head, /*beam_Search*/ false); } + // If PEFT is enabled, add LoRA layers + if (ff.config.enable_peft) { + // todo: add attention projections + std::vector target_modules = {"dense_h_to_4h", + "dense_4h_to_h"}; + ff.add_lora_layers(target_modules); + } + FileDataLoader *fileloader = new FileDataLoader("", weight_file_path, diff --git a/inference/models/llama.cc b/inference/models/llama.cc index bd5243bd4b..7b4a14b472 100644 --- a/inference/models/llama.cc +++ b/inference/models/llama.cc @@ -250,9 +250,6 @@ void LLAMA::create_llama_model(FFModel &ff, REG_MODE_NONE, 0.0f, std::string("layers." + std::to_string(i) + ".mlp.down_proj").c_str()); - // Low-Rank Adapter (LoRA) for the second linear layer - // ff.lora_linear(std::string("down_proj"), std::string("layers." + - // std::to_string(i) + ".mlp.down_proj.lora").c_str()); } // final normalization and linear Tensor final_rms_norm_output[2] = {nullptr, nullptr}; @@ -297,6 +294,14 @@ void LLAMA::create_llama_model(FFModel &ff, } } + // If PEFT is enabled, add LoRA layers + if (ff.config.enable_peft) { + // todo: add attention projections + std::vector target_modules = { + "gate_proj", "up_proj", "down_proj"}; + ff.add_lora_layers(target_modules); + } + FileDataLoader *fileloader = new FileDataLoader( "", weight_file_path, diff --git a/inference/models/mpt.cc b/inference/models/mpt.cc index d02c0f3b82..6807266ef4 100644 --- a/inference/models/mpt.cc +++ b/inference/models/mpt.cc @@ -272,6 +272,14 @@ void MPT::create_mpt_model(FFModel &ff, } else { output = ff.argmax(lm_head, /*beam_Search*/ false); } + + // If PEFT is enabled, add LoRA layers + if (ff.config.enable_peft) { + // todo: add attention projections + std::vector target_modules = {"up_proj", "down_proj"}; + ff.add_lora_layers(target_modules); + } + FileDataLoader *fileloader = new FileDataLoader("", weight_file_path, diff --git a/inference/models/opt.cc b/inference/models/opt.cc index 34a6bb0f02..cb3d5290cf 100644 --- a/inference/models/opt.cc +++ b/inference/models/opt.cc @@ -243,9 +243,6 @@ void OPT::create_opt_model(FFModel &ff, REG_MODE_NONE, 0.0f, std::string("layers." + std::to_string(i) + ".fc2").c_str()); - // Low-Rank Adapter (LoRA) for the second linear layer - // ff.lora_linear(std::string("fc2"), std::string("layers." + - // std::to_string(i) + ".fc2.lora").c_str()); } // final @@ -286,6 +283,13 @@ void OPT::create_opt_model(FFModel &ff, output = ff.argmax(softmax, /*beam_Search*/ false); } + // If PEFT is enabled, add LoRA layers + if (ff.config.enable_peft) { + // todo: add attention projections + std::vector target_modules = {"fc1", "fc2"}; + ff.add_lora_layers(target_modules); + } + FileDataLoader *fileloader = new FileDataLoader( "", weight_file_path, diff --git a/inference/models/starcoder.cc b/inference/models/starcoder.cc index 2429b1ec1b..3dd61be983 100644 --- a/inference/models/starcoder.cc +++ b/inference/models/starcoder.cc @@ -253,6 +253,13 @@ void STARCODER::create_starcoder_model( } } + // If PEFT is enabled, add LoRA layers + if (ff.config.enable_peft) { + // todo: add attention projections + std::vector target_modules = {"c_fc", "c_proj"}; + ff.add_lora_layers(target_modules); + } + InferenceManager *im = InferenceManager::get_inference_manager(); FileDataLoader *fileloader = new FileDataLoader( "", diff --git a/inference/peft/peft.cc b/inference/peft/peft.cc index 0ab0b62ee8..4f2d47055a 100644 --- a/inference/peft/peft.cc +++ b/inference/peft/peft.cc @@ -256,7 +256,7 @@ void FlexFlow::top_level_task(Task const *task, LoraOptimizerConfig *optim_config = nullptr; if (enable_peft_finetuning) { // float sgd_learning_rate = 2e-1; - float sgd_learning_rate = 1.0f; + float sgd_learning_rate = 0.001f; optim_config = new LoraSGDOptimizerConfig(sgd_learning_rate); } LoraLinearConfig peft_config_finetuning = @@ -275,6 +275,8 @@ void FlexFlow::top_level_task(Task const *task, rm->set_max_requests_per_batch( max_requests_per_batch + (int)enable_peft_finetuning); // add one slot for finetuning if needed + rm->set_max_concurrent_adapters(max_requests_per_batch + + (int)enable_peft_finetuning); rm->set_max_tokens_per_batch(max_tokens_per_batch); rm->set_max_sequence_length(max_sequence_length); rm->register_tokenizer( @@ -320,18 +322,19 @@ void FlexFlow::top_level_task(Task const *task, assert(false && "unknow model type"); } - // Add PEFT layer + // Start background server + rm->start_background_server(&model); + + // Add PEFT adapter(s) PEFTModelID *peft_model_id = nullptr, *peft_model_id_finetuning = nullptr; if (!peft_model_name.empty()) { - peft_model_id = model.add_lora_layer(peft_config); + peft_model_id = model.register_peft_adapter(peft_config); if (enable_peft_finetuning) { - peft_model_id_finetuning = model.add_lora_layer(peft_config_finetuning); + peft_model_id_finetuning = + model.register_peft_adapter(peft_config_finetuning); } } - // Start background server - rm->start_background_server(&model); - // Run workload { std::vector requests; diff --git a/inference/peft/peft_bwd_benchmark.cc b/inference/peft/peft_bwd_benchmark.cc index 85e97ec4e8..9da4fa1994 100644 --- a/inference/peft/peft_bwd_benchmark.cc +++ b/inference/peft/peft_bwd_benchmark.cc @@ -304,15 +304,15 @@ void FlexFlow::top_level_task(Task const *task, assert(false && "unknow model type"); } + // Start background server + rm->start_background_server(&model); + // Add PEFT layer PEFTModelID *peft_model_id = nullptr; if (!peft_model_name.empty()) { - peft_model_id = model.add_lora_layer(peft_config); + peft_model_id = model.register_peft_adapter(peft_config); } - // Start background server - rm->start_background_server(&model); - // Warmup stage { std::vector requests; diff --git a/inference/peft/peft_fwd_benchmark.cc b/inference/peft/peft_fwd_benchmark.cc index 87322a42dd..3274f2e535 100644 --- a/inference/peft/peft_fwd_benchmark.cc +++ b/inference/peft/peft_fwd_benchmark.cc @@ -304,15 +304,15 @@ void FlexFlow::top_level_task(Task const *task, assert(false && "unknow model type"); } + // Start background server + rm->start_background_server(&model); + // Add PEFT layer PEFTModelID *peft_model_id = nullptr; if (!peft_model_name.empty()) { - peft_model_id = model.add_lora_layer(peft_config); + peft_model_id = model.register_peft_adapter(peft_config); } - // Start background server - rm->start_background_server(&model); - // Run workload { std::vector requests; diff --git a/inference/peft/req_rate_benchmark.cc b/inference/peft/req_rate_benchmark.cc index ffa77478e1..8a94f6e68b 100644 --- a/inference/peft/req_rate_benchmark.cc +++ b/inference/peft/req_rate_benchmark.cc @@ -366,14 +366,14 @@ void FlexFlow::top_level_task(Task const *task, assert(false && "unknow model type"); } + rm->start_background_server(&model); + // Add PEFT layer PEFTModelID *peft_model_id = nullptr; if (!peft_model_name.empty()) { - peft_model_id = model.add_lora_layer(peft_config); + peft_model_id = model.register_peft_adapter(peft_config); } - rm->start_background_server(&model); - // Warmup stage { std::vector requests; diff --git a/inference/python/chat.py b/inference/python/chat.py index 13ece116a6..95132443a2 100644 --- a/inference/python/chat.py +++ b/inference/python/chat.py @@ -21,14 +21,14 @@ def get_configs(): # Define sample configs ff_init_configs = { # required parameters - "num_gpus": 1, - "memory_per_gpu": 30000, - "zero_copy_memory_per_node": 60000, + "num_gpus": 8, + "memory_per_gpu": 34000, + "zero_copy_memory_per_node": 200000, # optional parameters - "num_cpus": 4, - "legion_utility_processors": 4, + "num_cpus": 16, + "legion_utility_processors": 16, "data_parallelism_degree": 1, - "tensor_parallelism_degree": 1, + "tensor_parallelism_degree": 8, "pipeline_parallelism_degree": 1, "offload": False, "offload_reserve_space_size": 8 * 1024, # 8GB @@ -36,7 +36,6 @@ def get_configs(): "use_8bit_quantization": False, "enable_peft": False, "peft_activation_reserve_space_size": 1024, # 1GB - "peft_weight_reserve_space_size": 1024, # 1GB "profiling": False, "benchmarking": False, "inference_debugging": False, @@ -44,7 +43,7 @@ def get_configs(): } llm_configs = { # required parameters - "llm_model": "meta-llama/Meta-Llama-3-8B-Instruct", + "llm_model": "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF", # optional parameters "cache_path": os.environ.get("FF_CACHE_PATH", ""), "refresh_cache": False, @@ -86,11 +85,15 @@ def main(): llm.start_server() + nemotron_system = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Please ensure that your responses are positive in nature." + llama_generic_system = "You are a helpful an honest programming assistant." + + messages=[ - {"role": "system", "content": "You are a helpful an honest programming assistant."}, + {"role": "system", "content": nemotron_system}, {"role": "user", "content": "Is Rust better than Python?"}, ] - llm.generate(messages, max_new_tokens=256) + llm.generate(messages, max_new_tokens=1024) llm.stop_server() diff --git a/inference/python/ff_peft.py b/inference/python/ff_peft.py index 13da7aee20..0167cecebc 100644 --- a/inference/python/ff_peft.py +++ b/inference/python/ff_peft.py @@ -41,14 +41,14 @@ def get_configs(): # Define sample configs ff_init_configs = { # required parameters - "num_gpus": 2, + "num_gpus": 4, "memory_per_gpu": 14000, "zero_copy_memory_per_node": 10000, # optional parameters "num_cpus": 4, "legion_utility_processors": 4, "data_parallelism_degree": 1, - "tensor_parallelism_degree": 2, + "tensor_parallelism_degree": 4, "pipeline_parallelism_degree": 1, "offload": False, "offload_reserve_space_size": 8 * 1024, # 8GB @@ -56,7 +56,6 @@ def get_configs(): "use_8bit_quantization": False, "enable_peft": True, "peft_activation_reserve_space_size": 1024, # 1GB - "peft_weight_reserve_space_size": 1024, # 1GB "profiling": False, "inference_debugging": True, "fusion": False, @@ -103,6 +102,23 @@ def main(): refresh_cache=configs.refresh_cache, output_file=configs.output_file, ) + + # Compile the LLM for inference and load the weights into memory + generation_config = ff.GenerationConfig( + do_sample=False, temperature=0.9, topp=0.8, topk=1 + ) + enable_peft_finetuning = len(configs.finetuning_dataset) > 0 + llm.compile( + generation_config, + max_requests_per_batch=1 if not enable_peft_finetuning else 2, + max_seq_length=256, + max_tokens_per_batch=128, + max_concurrent_adapters=1 if not enable_peft_finetuning else 2, + enable_peft_finetuning=enable_peft_finetuning, + ) + + llm.start_server() + # Add inference and/or finetuning lora lora_inference_config = None lora_finetuning_config = None @@ -112,18 +128,8 @@ def main(): configs.inference_peft_model_id, base_model_name_or_path=configs.base_model, ) - llm.add_peft(lora_inference_config) + llm.register_peft_adapter(lora_inference_config) if len(configs.finetuning_dataset) > 0: - # lora_finetuning_config = ff.LoraLinearConfig( - # llm.cache_path, - # configs.finetuning_peft_model_id, - # target_modules=["down_proj"], - # rank=16, - # lora_alpha=16, - # trainable=True, - # init_lora_weights=True, - # optimizer_type=ff.OptimizerType.OPTIMIZER_TYPE_SGD, - # ) lora_finetuning_config = ff.LoraLinearConfig( llm.cache_path, configs.inference_peft_model_id, @@ -137,22 +143,7 @@ def main(): "nesterov": False, }, ) - llm.add_peft(lora_finetuning_config) - - # Compile the LLM for inference and load the weights into memory - generation_config = ff.GenerationConfig( - do_sample=False, temperature=0.9, topp=0.8, topk=1 - ) - enable_peft_finetuning = len(configs.finetuning_dataset) > 0 - llm.compile( - generation_config, - enable_peft_finetuning=enable_peft_finetuning, - max_requests_per_batch=1 if not enable_peft_finetuning else 2, - max_seq_length=256, - max_tokens_per_batch=128, - ) - - llm.start_server() + llm.register_peft_adapter(lora_finetuning_config) requests = [] # Serving diff --git a/inference/python/incr_decoding.py b/inference/python/incr_decoding.py index 232ef1699c..4bb6892a6b 100644 --- a/inference/python/incr_decoding.py +++ b/inference/python/incr_decoding.py @@ -56,7 +56,6 @@ def get_configs(): "use_8bit_quantization": False, "enable_peft": False, "peft_activation_reserve_space_size": 1024, # 1GB - "peft_weight_reserve_space_size": 1024, # 1GB "profiling": False, "benchmarking": False, "inference_debugging": False, diff --git a/inference/python/peft_demo/INSTRUCTIONS.md b/inference/python/peft_demo/INSTRUCTIONS.md index 9b2a7a53b2..0f78efdea9 100644 --- a/inference/python/peft_demo/INSTRUCTIONS.md +++ b/inference/python/peft_demo/INSTRUCTIONS.md @@ -13,7 +13,7 @@ * `export HUGGINGFACE_TOKEN="[Your token]"` * `huggingface-cli login --token "$HUGGINGFACE_TOKEN"` - * `python3 inference/utils/download_peft_model.py "goliaro/llama-2-7b-lora-full" --base_model_name "meta-llama/Llama-2-7b-hf"` + * `python3 inference/utils/download_peft_model.py "goliaro/llama-2-7b-lora-full"` * Run the demo ``` diff --git a/inference/python/peft_demo/demo.ipynb b/inference/python/peft_demo/demo.ipynb index dfb5193a1d..ea2b8417b6 100644 --- a/inference/python/peft_demo/demo.ipynb +++ b/inference/python/peft_demo/demo.ipynb @@ -91,7 +91,6 @@ " \"use_8bit_quantization\": False,\n", " \"enable_peft\": True,\n", " \"peft_activation_reserve_space_size\": 1024, # 1GB\n", - " \"peft_weight_reserve_space_size\": 1024, # 1GB\n", " \"profiling\": False,\n", " \"inference_debugging\": False,\n", " \"fusion\": False,\n", @@ -195,7 +194,7 @@ } ], "source": [ - "args = [configs.inference_peft_model_id, '--base_model_name', configs.base_model]\n", + "args = [configs.inference_peft_model_id]\n", "subprocess.run(['python', '../../utils/download_peft_model.py'] + args)" ] }, @@ -1773,7 +1772,6 @@ " \"use_8bit_quantization\": False,\n", " \"enable_peft\": True,\n", " \"peft_activation_reserve_space_size\": 1024, # 1GB\n", - " \"peft_weight_reserve_space_size\": 1024, # 1GB\n", " \"profiling\": False,\n", " \"inference_debugging\": False,\n", " \"fusion\": False,\n", @@ -1815,7 +1813,7 @@ "configs = SimpleNamespace(**configs_dict)\n", "\n", "\n", - "args = [configs.finetuning_peft_model_id+\"-dolly\", '--base_model_name', configs.base_model]\n", + "args = [configs.finetuning_peft_model_id+\"-dolly\"]\n", "subprocess.run(['python', '../../utils/download_peft_model.py'] + args)\n", "\n", "# Initialize the FlexFlow runtime. ff.init() takes a dictionary or the path to a JSON file with the configs\n", diff --git a/inference/python/peft_demo/demo.py b/inference/python/peft_demo/demo.py index 9e01b4645b..b70f3c8966 100644 --- a/inference/python/peft_demo/demo.py +++ b/inference/python/peft_demo/demo.py @@ -47,7 +47,6 @@ def create_datasets(finetune_dataset_size=2, inference_file_path='inference_data "use_8bit_quantization": False, "enable_peft": True, "peft_activation_reserve_space_size": 1024, # 1GB - "peft_weight_reserve_space_size": 1024, # 1GB "profiling": False, "inference_debugging": False, "fusion": False, @@ -99,7 +98,7 @@ def create_datasets(finetune_dataset_size=2, inference_file_path='inference_data file.write('') # Download base and peft inference models -args = [configs.inference_peft_model_id, '--base_model_name', configs.base_model] +args = [configs.inference_peft_model_id] # hf_token = input("Please enter your HuggingFace personal access token: ") # subprocess.run(['huggingface-cli', 'login', '--token', hf_token]) subprocess.run(['python', '../../utils/download_peft_model.py'] + args) @@ -207,7 +206,7 @@ def create_datasets(finetune_dataset_size=2, inference_file_path='inference_data ) llm.add_peft(lora_inference_config) -args = [configs.finetuning_peft_model_id, '--base_model_name', configs.base_model] +args = [configs.finetuning_peft_model_id] #hf_token = input("Please enter your HuggingFace personal access token: ") # subprocess.run(['huggingface-cli', 'login', '--token', hf_token]) # subprocess.run(['python', '../../utils/download_peft_model.py'] + args) diff --git a/inference/python/spec_infer.py b/inference/python/spec_infer.py index 7ae752cffc..8cf96c1eba 100644 --- a/inference/python/spec_infer.py +++ b/inference/python/spec_infer.py @@ -56,7 +56,6 @@ def get_configs(): "use_8bit_quantization": False, "enable_peft": False, "peft_activation_reserve_space_size": 1024, # 1GB - "peft_weight_reserve_space_size": 1024, # 1GB "profiling": False, "benchmarking": False, "inference_debugging": False, diff --git a/inference/python/streamlit/README.md b/inference/python/streamlit/README.md new file mode 100644 index 0000000000..86a15e2d6d --- /dev/null +++ b/inference/python/streamlit/README.md @@ -0,0 +1,18 @@ +# Streamlit demo + +## Instructions + +1. Build and install FlexFlow, or build and run `source ./set_python_envs.sh` from the build folder +2. Edit the FlexFlow/inference/python/streamlit/fastapi_incr.py to configure the model to run and the system configs (num gpus, amount of memory, etc) +3. In one terminal, launch the LLM engine with the commands below, and wait until the model's weights loading completes +``` +cd FlexFlow/inference/python/streamlit +python fastapi_incr.py +``` +4. In another terminal, launch the streamlit app: +``` +cd FlexFlow/inference/python/streamlit +streamlit run app.py +``` +5. Open the URL printed to the terminal, e.g. `http://localhost:8501` and interact with the app via browser + diff --git a/inference/python/streamlit/app.py b/inference/python/streamlit/app.py new file mode 100644 index 0000000000..9788765a3a --- /dev/null +++ b/inference/python/streamlit/app.py @@ -0,0 +1,188 @@ +import streamlit as st +import requests +import os, json +from huggingface_hub import model_info + + +# App title +st.set_page_config(page_title="🚀💻 FlexLLM Server", layout="wide") + +# FastAPI server URL +FASTAPI_URL = "http://localhost:8000/chat/completions" # Adjust the port if necessary +FINETUNE_URL = "http://localhost:8000/finetuning" + +# Initialize session state variables +if 'added_adapters' not in st.session_state: + st.session_state.added_adapters = [] + +# Store LLM generated responses +if "messages" not in st.session_state.keys(): + st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}] + +def check_model_availability(model_name): + try: + info = model_info(model_name) + return True + except Exception: + return False + +def clear_chat_history(): + st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}] + +# Function for generating LLaMA2 response +def generate_llama3_response(prompt_input): + system_prompt="You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Please ensure that your responses are positive in nature." + + # Send request to FastAPI server + response = requests.post(FASTAPI_URL, json={"max_new_tokens": 1024, "messages": [{"role": "system", "content": system_prompt}] + st.session_state.messages + [{"role": "user", "content": prompt_input}]}) + + if response.status_code == 200: + return response.json()["response"] + else: + return f"Error: {response.status_code} - {response.text}" + +# Sidebar +with st.sidebar: + st.title('🚀 FlexLLM Server') + page = st.radio("Choose a page", ["Chat", "Finetune"]) + if page == "Chat": + st.header('🦙 Llama Chatbot') + # st.success('Using local FastAPI server', icon='✅') + st.sidebar.button('Clear Chat History', on_click=clear_chat_history) + + st.subheader('Generation parameters') + max_length = st.sidebar.slider('Max generation length', min_value=64, max_value=2048, value=1024, step=8) + # selected_model = st.sidebar.selectbox('Choose a Llama2 model', ['Llama2-7B', 'Llama2-13B', 'Llama2-70B'], key='selected_model') + decoding_method = st.sidebar.selectbox('Decoding method', ['Greedy decoding (default)', 'Sampling'], key='decoding_method') + temperature = st.sidebar.slider('temperature', min_value=0.01, max_value=5.0, value=0.1, step=0.01, disabled=decoding_method == 'Greedy decoding (default)') + top_p = st.sidebar.slider('top_p', min_value=0.01, max_value=1.0, value=0.9, step=0.01, disabled=decoding_method == 'Greedy decoding (default)') + + # lora_adapter = st.sidebar.text_input('Lora adapter', placeholder='None') + st.subheader("LoRA Adapters (optional)") + # Text input for PEFT model ID + peft_id = st.text_input("Add a LoRA Adapter", placeholder="Enter the Huggingface PEFT model ID") + # Button to load the adapter + if st.button("Load Adapter"): + if peft_id: + with st.spinner("Checking PEFT availability..."): + is_available = check_model_availability(peft_id) + if is_available: + if peft_id not in st.session_state.added_adapters: + st.session_state.added_adapters.append(peft_id) + st.success(f"Successfully added PEFT: {peft_id}") + else: + st.warning(f"PEFT {peft_id} is already in the list.") + else: + st.error(f"PEFT {peft_id} is not available on Hugging Face. Please check the ID and try again.") + else: + st.warning("Please enter a PEFT Model ID.") + # Button to remove all adapters + if st.button("Remove All Adapters"): + st.session_state.added_adapters = [] + st.success("All adapters have been removed.") + # Display the list of added adapters + st.markdown("**Added Adapters:**") + if st.session_state.added_adapters: + for adapter in st.session_state.added_adapters: + st.write(f"- {adapter}") + else: + st.write("No adapters added yet.") + # st.markdown('📖 Learn how to build this app in this [blog](https://blog.streamlit.io/how-to-build-a-llama-2-chatbot/)!') + elif page == "Finetune": + st.header("🏋️‍♂️ LoRA Finetuning") + + # Hugging Face token input + # hf_token = st.text_input("Enter your Hugging Face token:", type="password") + if 'hf_token' in st.session_state.keys(): + st.success('HF token already provided!', icon='✅') + hf_token = st.session_state.hf_token + else: + hf_token = st.text_input('Enter your Hugging Face token:', type='password') + if not (hf_token.startswith('hf_') and len(hf_token)==37): + st.warning('please enter a valid token', icon='⚠️') + else: + st.success('Proceed to finetuning your model!', icon='👉') + st.session_state.hf_token = hf_token + + # PEFT model name + peft_model_name = st.text_input("Enter the PEFT model name:", help="The name of the PEFT model should start with the username associated with the provided HF token, followed by '/'ß. E.g. 'username/peft-base-uncased'") + + # Dataset selection + dataset_option = st.radio("Choose dataset source:", ["Upload JSON", "Hugging Face Dataset"]) + + if dataset_option == "Upload JSON": + uploaded_file = st.file_uploader("Upload JSON dataset", type="json") + if uploaded_file is not None: + dataset = json.load(uploaded_file) + st.success("Dataset uploaded successfully!") + else: + dataset_name = st.text_input("Enter Hugging Face dataset name:") + + # Finetuning parameters + st.subheader("Finetuning parameters") + lora_rank = st.number_input("LoRA rank", min_value=2, max_value=64, value=16, step=2) + lora_alpha = st.number_input("LoRA alpha", min_value=2, max_value=64, value=16, step=2) + target_modules = st.multiselect("Target modules", ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", "lm_head"], default=["down_proj"]) + learning_rate = st.number_input("Learning rate", min_value=1e-6, max_value=1e-3, value=1e-5, step=1e-6) + optimizer_type = st.selectbox("Optimizer type", ["SGD", "Adam", "AdamW", "Adagrad", "Adadelta", "Adamax", "RMSprop"]) + momentum = st.number_input("Momentum", min_value=0.0, max_value=1.0, value=0.0, step=0.01) + weight_decay = st.number_input("Weight decay", min_value=0.0, max_value=1.0, value=0.0, step=0.01) + nesterov = st.checkbox("Nesterov") + max_steps = st.number_input("Max steps", min_value=1000, max_value=100000, value=10000, step=1000) + + # Start finetuning button + if st.button("Start Finetuning"): + if not hf_token: + st.error("Please enter your Hugging Face token.") + elif dataset_option == "Upload JSON" and uploaded_file is None: + st.error("Please upload a JSON dataset.") + elif dataset_option == "Hugging Face Dataset" and not dataset_name: + st.error("Please enter a Hugging Face dataset name.") + else: + # Prepare the request data + request_data = { + "token": hf_token, + "dataset_source": dataset_option, + } + + if dataset_option == "Upload JSON": + request_data["dataset"] = dataset + else: + request_data["dataset_name"] = dataset_name + + # Send finetuning request to FastAPI server + with st.spinner("Finetuning in progress..."): + response = requests.post(FINETUNE_URL, json=request_data) + + if response.status_code == 200: + st.success("Finetuning completed successfully!") + else: + st.error(f"Finetuning failed. Error: {response.status_code} - {response.text}") + +if page == "Chat": + # Display or clear chat messages + for message in st.session_state.messages: + with st.chat_message(message["role"]): + st.write(message["content"]) + + # User-provided prompt + if prompt := st.chat_input(): + st.session_state.messages.append({"role": "user", "content": prompt}) + with st.chat_message("user"): + st.write(prompt) + + # Generate a new response if last message is not from assistant + if st.session_state.messages[-1]["role"] != "assistant": + with st.chat_message("assistant"): + with st.spinner("Running..."): + response = generate_llama3_response(prompt) + placeholder = st.empty() + full_response = '' + for item in response: + full_response += item + placeholder.markdown(full_response) + placeholder.markdown(full_response) + message = {"role": "assistant", "content": full_response} + st.session_state.messages.append(message) +elif page == "Finetune": + st.write("Use the sidebar to configure and start finetuning.") \ No newline at end of file diff --git a/inference/python/streamlit/fastapi_incr.py b/inference/python/streamlit/fastapi_incr.py new file mode 100644 index 0000000000..6ac7f4149a --- /dev/null +++ b/inference/python/streamlit/fastapi_incr.py @@ -0,0 +1,207 @@ +# Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +Running Instructions: +- To run this FastAPI application, make sure you have FastAPI and Uvicorn installed. +- Save this script as 'fastapi_incr.py'. +- Run the application using the command: `uvicorn fastapi_incr:app --reload --port PORT_NUMBER` +- The server will start on `http://localhost:PORT_NUMBER`. Use this base URL to make API requests. +- Go to `http://localhost:PORT_NUMBER/docs` for API documentation. +""" + + +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel, Field +import flexflow.serve as ff +import uvicorn +import json, os, argparse +from types import SimpleNamespace +from typing import Optional, List +import time + + +# Initialize FastAPI application +app = FastAPI() + +# Define the request model +class PromptRequest(BaseModel): + prompt: str + +# data models +class Message(BaseModel): + role: str + content: str + + +# class ChatCompletionRequest(BaseModel): +# model: Optional[str] = "mock-gpt-model" +# messages: List[Message] +# max_tokens: Optional[int] = 512 +# temperature: Optional[float] = 0.1 +# stream: Optional[bool] = False + +class ChatCompletionRequest(BaseModel): + max_new_tokens: Optional[int] = 1024 + messages: List[Message] + +# Global variable to store the LLM model +llm = None + + +def get_configs(): + + # Fetch configuration file path from environment variable + config_file = os.getenv("CONFIG_FILE", "") + + # Load configs from JSON file (if specified) + if config_file: + if not os.path.isfile(config_file): + raise FileNotFoundError(f"Config file {config_file} not found.") + try: + with open(config_file) as f: + return json.load(f) + except json.JSONDecodeError as e: + print("JSON format error:") + print(e) + else: + # Define sample configs + ff_init_configs = { + # required parameters + "num_gpus": 8, + "memory_per_gpu": 20000, + "zero_copy_memory_per_node": 40000, + # optional parameters + "num_cpus": 4, + "legion_utility_processors": 8, + "data_parallelism_degree": 1, + "tensor_parallelism_degree": 4, + "pipeline_parallelism_degree": 1, + "offload": False, + "offload_reserve_space_size": 8 * 1024, # 8GB + "use_4bit_quantization": False, + "use_8bit_quantization": False, + "enable_peft": False, + "peft_activation_reserve_space_size": 1024, # 1GB + "profiling": False, + "benchmarking": False, + "inference_debugging": False, + "fusion": True, + } + llm_configs = { + # required parameters + "llm_model": "meta-llama/Llama-3.1-8B-Instruct", + # optional parameters + "cache_path": os.environ.get("FF_CACHE_PATH", ""), + "refresh_cache": False, + "full_precision": False, + "prompt": "", + "output_file": "", + } + # Merge dictionaries + ff_init_configs.update(llm_configs) + return ff_init_configs + + +# Initialize model on startup +@app.on_event("startup") +async def startup_event(): + global llm + + # Initialize your LLM model configuration here + configs_dict = get_configs() + configs = SimpleNamespace(**configs_dict) + ff.init(configs_dict) + + ff_data_type = ( + ff.DataType.DT_FLOAT if configs.full_precision else ff.DataType.DT_HALF + ) + llm = ff.LLM( + configs.llm_model, + data_type=ff_data_type, + cache_path=configs.cache_path, + refresh_cache=configs.refresh_cache, + output_file=configs.output_file, + ) + + generation_config = ff.GenerationConfig( + do_sample=False, temperature=0.9, topp=0.8, topk=1 + ) + llm.compile( + generation_config, + max_requests_per_batch=16, + max_seq_length=2048, + max_tokens_per_batch=1024, + ) + llm.start_server() + +# API endpoint to generate response +@app.post("/generate/") +async def generate(prompt_request: PromptRequest): + if llm is None: + raise HTTPException(status_code=503, detail="LLM model is not initialized.") + + # Call the model to generate a response + full_output = llm.generate([prompt_request.prompt])[0].output_text.decode('utf-8') + + # Separate the prompt and response + split_output = full_output.split('\n', 1) + if len(split_output) > 1: + response_text = split_output[1] + else: + response_text = "" + + # Return the prompt and the response in JSON format + return { + "prompt": prompt_request.prompt, + "response": response_text + } + +@app.post("/chat/completions") +async def chat_completions(request: ChatCompletionRequest): + + if llm is None: + raise HTTPException(status_code=503, detail="LLM model is not initialized.") + + print("received request:", request) + result = llm.generate([message.dict() for message in request.messages], max_new_tokens=request.max_new_tokens)[0].output_text.decode('utf-8') + print("returning response:", result) + return { + "response": result + } + return { + "id": "1337", + "object": "chat.completion", + "created": time.time(), + "model": request.model, + "choices": [{"message": Message(role="assistant", content=resp_content)}], + } + +# Shutdown event to stop the model server +@app.on_event("shutdown") +async def shutdown_event(): + global llm + if llm is not None: + llm.stop_server() + +# Main function to run Uvicorn server +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", port=8000) + +# Running within the entrypoint folder: +# uvicorn fastapi_incr:app --reload --port + +# Running within the python folder: +# uvicorn entrypoint.fastapi_incr:app --reload --port 3000 diff --git a/inference/utils/download_peft_model.py b/inference/utils/download_peft_model.py index 38dd577574..2ee63b10bc 100644 --- a/inference/utils/download_peft_model.py +++ b/inference/utils/download_peft_model.py @@ -1,13 +1,11 @@ #!/usr/bin/env python import flexflow.serve as ff import argparse, os +from peft import PeftConfig def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument( - "--base_model_name", type=str, help="Name of the model to download" - ) parser.add_argument( "peft_model_ids", type=str, @@ -48,19 +46,21 @@ def main(args): else: data_types = (ff.DataType.DT_FLOAT, ff.DataType.DT_HALF) - for data_type in data_types: - llm = ff.LLM( - args.base_model_name, - data_type=data_type, - cache_path=args.cache_folder, - refresh_cache=args.refresh_cache, - ) - for peft_model_id in args.peft_model_ids: - lora_config = ff.LoraLinearConfig(llm.cache_path, peft_model_id) - llm.add_peft(lora_config) - llm.download_hf_weights_if_needed() - llm.download_hf_config() - llm.download_hf_tokenizer_if_needed() + for peft_model_id in args.peft_model_ids: + hf_config = PeftConfig.from_pretrained(peft_model_id) + for data_type in data_types: + llm = ff.LLM( + hf_config.base_model_name_or_path, + data_type=data_type, + cache_path=args.cache_folder, + refresh_cache=args.refresh_cache, + ) + # Download base model config, weights and tokenizer + llm.download_hf_config() + llm.download_hf_weights_if_needed() + llm.download_hf_tokenizer_if_needed() + # Download PEFT adapter + llm.download_peft_adapter_if_needed(peft_model_id) if __name__ == "__main__": diff --git a/python/flexflow/core/__init__.py b/python/flexflow/core/__init__.py index b8ed15eaea..52fe331bf3 100644 --- a/python/flexflow/core/__init__.py +++ b/python/flexflow/core/__init__.py @@ -91,7 +91,6 @@ "use_8bit_quantization": "--8bit-quantization", "enable_peft": "-enable-peft", "peft_activation_reserve_space_size": "-peft-activation-reserve-space-size", - "peft_weight_reserve_space_size": "-peft-weight-reserve-space-size", } diff --git a/python/flexflow/core/flexflow_cffi.py b/python/flexflow/core/flexflow_cffi.py index 59e62ea023..02eff0ca76 100644 --- a/python/flexflow/core/flexflow_cffi.py +++ b/python/flexflow/core/flexflow_cffi.py @@ -811,6 +811,10 @@ def pipeline_parallelism_degree(self, value): @property def python_data_loader_type(self): return ffc().flexflow_config_get_python_data_loader_type(self.handle) + + @property + def enable_peft(self): + return ffc().flexflow_config_get_enable_peft(self.handle) @property def cpu_offload(self): @@ -1629,6 +1633,11 @@ def set_max_sequence_length(self, max_length): def get_max_sequence_length(self): return ffc().flexflow_request_manager_get_max_sequence_length(self.handle) + + def set_max_concurrent_adapters(self, max_adapters): + return ffc().flexflow_request_manager_set_max_concurrent_adapters( + self.handle, max_adapters + ) def set_enable_peft_finetuning(self, enable_peft_finetuning): return ffc().flexflow_request_manager_set_enable_peft_finetuning( @@ -4288,8 +4297,12 @@ def argmax(self, input, beam_search, name=None): self.add_layer(OpType.ARGMAX, name) return Tensor(handle, owner_op_type=OpType.ARGMAX) - def add_lora_layer(self, peft_config): - return ffc().flexflow_model_add_lora_layer(self.handle, peft_config.handle) + def add_lora_layers(self, target_modules: List[str]): + c_target_modules = [get_c_name(module) for module in target_modules] + return ffc().flexflow_model_add_lora_layers(self.handle, len(target_modules), c_target_modules) + + def register_peft_adapter(self, peft_config): + return ffc().flexflow_model_register_peft_adapter(self.handle, peft_config.handle) def reset_metrics(self): """Reset performance metrics. @@ -4751,6 +4764,7 @@ def generate(self, requests_list: List[Request]): finetuning_losses=finetuning_losses, ) ) + return results def set_position_offset(self, offset): ffc().flexflow_model_set_position_offset(self.handle, offset) diff --git a/python/flexflow/serve/__init__.py b/python/flexflow/serve/__init__.py index fd29080a6a..55044d1838 100644 --- a/python/flexflow/serve/__init__.py +++ b/python/flexflow/serve/__init__.py @@ -55,7 +55,6 @@ def init( use_8bit_quantization: Optional[bool] = None, enable_peft: Optional[bool] = None, peft_activation_reserve_space_size: Optional[int] = None, - peft_weight_reserve_space_size: Optional[int] = None, profiling: Optional[bool] = None, benchmarking: Optional[bool] = None, inference_debugging: Optional[bool] = None, @@ -86,7 +85,6 @@ def init( - use_8bit_quantization: whether to use 8-bit quantization, defaults to False - enable_peft: whether to enable the use of PEFT, defaults to False - peft_activation_reserve_space_size: the space (in MB) to reserve on GPU for PEFT activations, default to 1 GB - - peft_weight_reserve_space_size: the space (in MB) to reserve on GPU for PEFT weights, default to 1 GB - profiling: whether to enable the FlexFlow profiling mode, defaults to False - benchmarking: whether to run benchmaking only, without loading real weights, defaults to False - inference_debugging: whether to run inference in debugging mode, saving all inputs/outputs/weights to file, defaults to False @@ -125,8 +123,6 @@ def init( :type enable_peft: Optional[bool], optional :param peft_activation_reserve_space_size: the space (in MB) to reserve on GPU for PEFT activations, default to 1 GB :type peft_activation_reserve_space_size: Optional[int], optional - :param peft_weight_reserve_space_size: the space (in MB) to reserve on GPU for PEFT weights, default to 1 GB - :type peft_weight_reserve_space_size: Optional[int], optional :param profiling: whether to enable the FlexFlow profiling mode, defaults to False :type profiling: Optional[bool], optional :param benchmarking: whether to run benchmaking only, without loading real weights, defaults to False @@ -158,7 +154,6 @@ def init( use_8bit_quantization is not None, enable_peft is not None, peft_activation_reserve_space_size is not None, - peft_weight_reserve_space_size is not None, profiling is not None, benchmarking is not None, inference_debugging is not None, @@ -187,7 +182,6 @@ def init( "use_8bit_quantization": use_8bit_quantization, "enable_peft": enable_peft, "peft_activation_reserve_space_size": peft_activation_reserve_space_size, - "peft_weight_reserve_space_size": peft_weight_reserve_space_size, "profiling": profiling, "benchmarking": benchmarking, "inference_debugging": inference_debugging, @@ -210,7 +204,6 @@ def init( "pipeline_parallelism_degree", "offload_reserve_space_size", "peft_activation_reserve_space_size", - "peft_weight_reserve_space_size", ] for param in positive_int_params: __check_positive_int(configs_dict, param) @@ -238,8 +231,6 @@ def init( configs_dict["enable_peft"] = False if configs_dict.get("peft_activation_reserve_space_size", None) is None: configs_dict["peft_activation_reserve_space_size"] = 8 * 1024**3 - if configs_dict.get("peft_weight_reserve_space_size", None) is None: - configs_dict["peft_weight_reserve_space_size"] = 1024**3 if configs_dict.get("profiling", None) is None: configs_dict["profiling"] = False if configs_dict.get("benchmarking", None) is None: diff --git a/python/flexflow/serve/models/falcon.py b/python/flexflow/serve/models/falcon.py index 0c6102406f..60aa3c27e9 100644 --- a/python/flexflow/serve/models/falcon.py +++ b/python/flexflow/serve/models/falcon.py @@ -257,6 +257,10 @@ def build_model(self, max_tokens_per_batch): # output = ffmodel.arg_top_k(lm_head, 1, False) softmax = ffmodel.softmax(lm_head, -1) output = ffmodel.argmax(softmax, False) + + if self.ffconfig.enable_peft: + # TODO: add attention projections + ffmodel.add_lora_layers(["dense_h_to_4h", "dense_4h_to_h"]) self.ffmodel = ffmodel diff --git a/python/flexflow/serve/models/llama.py b/python/flexflow/serve/models/llama.py index e149834603..ceea9e96b0 100644 --- a/python/flexflow/serve/models/llama.py +++ b/python/flexflow/serve/models/llama.py @@ -264,6 +264,10 @@ def build_model(self, max_tokens_per_batch): # output = ffmodel.arg_top_k(dense, 1, False) softmax = ffmodel.softmax(dense, -1) output = ffmodel.argmax(softmax, False) + + if self.ffconfig.enable_peft: + # TODO: add attention projections + ffmodel.add_lora_layers(["gate_proj", "up_proj", "down_proj"]) self.ffmodel = ffmodel diff --git a/python/flexflow/serve/models/mpt.py b/python/flexflow/serve/models/mpt.py index a0e70b381a..d927a1fbb3 100644 --- a/python/flexflow/serve/models/mpt.py +++ b/python/flexflow/serve/models/mpt.py @@ -258,6 +258,10 @@ def build_model(self, max_tokens_per_batch): softmax = ffmodel.softmax(lm_head, -1) output = ffmodel.argmax(softmax, False) + if self.ffconfig.enable_peft: + # TODO: add attention projections + ffmodel.add_lora_layers(["up_proj", "down_proj"]) + self.ffmodel = ffmodel # TODO: finish this diff --git a/python/flexflow/serve/models/opt.py b/python/flexflow/serve/models/opt.py index ba2e21b690..e8d6fec9af 100644 --- a/python/flexflow/serve/models/opt.py +++ b/python/flexflow/serve/models/opt.py @@ -287,6 +287,10 @@ def build_model(self, max_tokens_per_batch): softmax = ffmodel.softmax(lm_head, -1) output = ffmodel.argmax(softmax, False) + if self.ffconfig.enable_peft: + # TODO: add attention projections + ffmodel.add_lora_layers(["fc1", "fc2"]) + self.ffmodel = ffmodel def convert_hf_weight_name(name): diff --git a/python/flexflow/serve/models/starcoder.py b/python/flexflow/serve/models/starcoder.py index dc5faf175f..107614e9dd 100644 --- a/python/flexflow/serve/models/starcoder.py +++ b/python/flexflow/serve/models/starcoder.py @@ -228,6 +228,10 @@ def build_model(self, max_tokens_per_batch): softmax = ffmodel.softmax(lm_head, -1) output = ffmodel.argmax(softmax, False) + if self.ffconfig.enable_peft: + # TODO: add attention projections + ffmodel.add_lora_layers(["c_fc", "c_proj"]) + self.ffmodel = ffmodel def convert_hf_model(model, dst_folder): diff --git a/python/flexflow/serve/serve.py b/python/flexflow/serve/serve.py index e4248a2fc1..c2804b6966 100644 --- a/python/flexflow/serve/serve.py +++ b/python/flexflow/serve/serve.py @@ -31,9 +31,17 @@ from peft import PeftModel, PeftConfig, LoraConfig from huggingface_hub import HfApi import torch, shutil, hashlib, json, gc -from typing import Union, List +from typing import Union, List, Tuple +from safetensors import safe_open from huggingface_hub import snapshot_download +from enum import Enum + + +class CachedResourceType(Enum): + TOKENIZER = "tokenizer" + WEIGHTS = "weights" + class _SupportedModels: def __init__( @@ -104,14 +112,14 @@ def __init__( self.output_file = output_file self.rm = None self.pefts = {} - self.tokenizer=None + self.tokenizer = None def __del__(self): # Stop the background server before deleting the object if type(self) == LLM and self.rm is not None: self.rm.stop_server() - def add_peft(self, lora_config: LoraLinearConfig): + def register_peft_adapter(self, lora_config: LoraLinearConfig): """Add a PEFT adapter to the LLM""" if lora_config is None: raise ValueError("lora_config cannot be None") @@ -145,9 +153,12 @@ def add_peft(self, lora_config: LoraLinearConfig): f"Attempting to add PEFT with base model name {peft_config.base_model_name_or_path} to LLM {self.model_name}" ) + lora_config.ff_compile() + self.pefts[lora_config] = { "peft_config": peft_config, "peft_type": peft_config.peft_type, + "ff_peft_model_id": self.model.ffmodel.register_peft_adapter(lora_config), } def get_ff_peft_id(self, lora_config: LoraLinearConfig) -> PEFTModelID: @@ -175,34 +186,33 @@ def download_hf_config(self): os.makedirs(config_dir, exist_ok=True) print(f"Creating directory {config_dir} (if it doesn't exist)...") print(f"Saving {self.model_name} configs to file {config_path}...") - self.hf_config.to_json_file(config_path) - - # Save PEFT configs if the LLM has any registered PEFTs - for ff_peft_config, peft_dict in self.pefts.items(): - peft_config = peft_dict["peft_config"] - peft_model_id = ff_peft_config.peft_model_id - peft_config_dir = os.path.join( - os.path.expanduser(self.cache_path), "configs", peft_model_id.lower() - ) - os.makedirs(peft_config_dir, exist_ok=True) - peft_config_path = os.path.join(peft_config_dir, "config.json") - print(f"Saving {peft_model_id} configs to file {peft_config_path}...") - with open(peft_config_path, "w") as json_file: - - class SetEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, set): - return list(obj) - return super().default(obj) - - json.dump(peft_config.to_dict(), json_file, indent=2, cls=SetEncoder) - - def __get_revision_hashes(self, model_name: str, folder: str): + # self.hf_config.to_json_file(config_path) + src_folder = snapshot_download( + repo_id=self.model_name, allow_patterns="config.json" + ) + src_path = os.path.join(src_folder, "config.json") + if os.path.exists(src_path): + shutil.copy(src_path, config_path) + + def __get_revision_hashes( + self, model_name: str, folder: str + ) -> Tuple[Union[str, None], str, str]: + """Return the commit hash of the object (weight, tokenizer, etc) cached by FlexFlow and the latest commit hash of the object from HuggingFace (or other source) + + Args: + model_name (str): Name of the model cached by FlexFlow + folder (str): Folder where the cached object is stored + + Returns: + ff_revision: Commit hash of the object cached by FlexFlow + ff_revision_filepath: Path to the file containing the commit hash of the object cached by FlexFlow + latest_revision: Latest commit hash of the object from HuggingFace (or other source) + """ ff_revision = None - ff_revision_file = os.path.join(folder, "rev_sha.txt") + ff_revision_filepath = os.path.join(folder, "rev_sha.txt") - if os.path.exists(ff_revision_file): - ff_revision = "".join(open(ff_revision_file).read().split()) + if os.path.exists(ff_revision_filepath): + ff_revision = "".join(open(ff_revision_filepath).read().split()) if os.path.exists(model_name) and os.path.isdir(model_name): # Local model @@ -215,16 +225,21 @@ def __get_revision_hashes(self, model_name: str, folder: str): # Remote HuggingFace model hf_api = HfApi() latest_revision = hf_api.model_info(self.model_name).sha - return ff_revision, ff_revision_file, latest_revision + return ff_revision, latest_revision - def download_hf_weights_if_needed(self): - """Check in the folder specified by the cache_path whether the LLM's model weights are available and up to date. - If not, or if the refresh_cache parameter is set to True, download new weights. + def __get_resource_path( + self, model_name: str, resource_type: CachedResourceType + ) -> str: + """Returns the path to the folder where the model weights or tokenizer files are stored - If any PEFT adapter is registered, perform the same operation for PEFT. - """ + Args: + model_name (str): Name of the model + resource_type (CachedResourceType): Whether to get the path to the weights or the tokenizer - def get_weights_path(model_name): + Returns: + str: Path to the folder where the model weights or tokenizer files are stored + """ + if resource_type == CachedResourceType.WEIGHTS: return os.path.join( os.path.expanduser(self.cache_path), "weights", @@ -235,19 +250,49 @@ def get_weights_path(model_name): else "half-precision" ), ) + elif resource_type == CachedResourceType.TOKENIZER: + return os.path.join( + os.path.expanduser(self.cache_path), "tokenizers", model_name.lower() + ) + else: + raise ValueError(f"Invalid resource type {resource_type}") - def refresh_cache_if_needed(model_name): - weights_path = get_weights_path(model_name) - if self.refresh_cache: - print( - f"Refreshing weights in cache for model {model_name} at path {weights_path} ..." - ) - if os.path.exists(weights_path): - shutil.rmtree(weights_path) - os.makedirs(weights_path, exist_ok=True) + def __need_cache_refresh( + self, model_name: str, resource_type: CachedResourceType + ) -> bool: + """Check whether the model weights or tokenizer files are available and up to date. + If they need a refresh, create the folder for the resource, save the new commit hash to the rev_sha.txt file, delete any existing files, and return true. - def get_hf_llm(model_name): - return AutoModelForCausalLM.from_pretrained( + Args: + model_name (str): Name of the model to check + resource_type (CachedResourceType): Whether to check the weights or the tokenizer + + Returns: + bool: True if the weights or tokenizer need a refresh, False otherwise + """ + resource_path = self.__get_resource_path(model_name, resource_type) + ff_revision, latest_revision = self.__get_revision_hashes(self.model_name, resource_path) + if self.refresh_cache or not os.path.exists(resource_path) or ff_revision != latest_revision: + print( + f"Refreshing {resource_type} in cache for model {model_name} at path {resource_path} ..." + ) + if os.path.exists(resource_path): + shutil.rmtree(resource_path) + os.makedirs(resource_path, exist_ok=True) + ff_revision_file = os.path.join(resource_path, "rev_sha.txt") + with open(ff_revision_file, "w+") as f: + f.write(latest_revision) + return True + return False + + def download_hf_weights_if_needed(self) -> None: + """Check in the folder specified by the cache_path whether the LLM's model weights are available and up to date. + If not, or if the refresh_cache parameter is set to True, download new weights and convert them. + """ + + # TODO: edit this to download the weights using snapshot_download and convert them to FlexFlow format without loading them to GPU + def download_and_convert_llm_weights(model_name): + hf_model = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True, torch_dtype=( @@ -256,73 +301,26 @@ def get_hf_llm(model_name): else torch.float16 ), ) - - def download_llm_weights(): - refresh_cache_if_needed(self.model_name) - ff_revision, ff_revision_file, latest_revision = self.__get_revision_hashes( - self.model_name, self.weights_path + # Convert the model to FlexFlow format + weights_path = self.__get_resource_path( + model_name, CachedResourceType.WEIGHTS ) - if ff_revision != latest_revision: - print( - f"'{self.model_name}' local model weights need updating! Downloading/converting new weights now..." - ) - hf_model = get_hf_llm(self.model_name) - # Convert the model to FlexFlow format - self.model_class.convert_hf_model(hf_model, self.weights_path) - # Save new revision hash to file - with open(ff_revision_file, "w+") as f: - f.write(latest_revision) - print(f"Done converting the weights for model {self.model_name}") - # Deallocate hf model - del hf_model - gc.collect() - torch.cuda.empty_cache() - - def convert_peft_model(hf_peft_model, peft_type, weights_path): - for name, params in hf_peft_model.named_parameters(): - if peft_type.lower() in name: - name = name.replace("base_model.model.model.", "").replace( - ".default", "" - ) - name = self.model_class.convert_hf_weight_name(name) - params.detach().cpu().numpy().tofile(f"{weights_path}/{name}") - - def download_peft_weights(): - for ff_peft_config, peft_dict in self.pefts.items(): - if not ff_peft_config.init_lora_weights: - peft_config = peft_dict["peft_config"] - peft_type = peft_dict["peft_type"] - peft_model_id = ff_peft_config.peft_model_id - - weights_path = get_weights_path(peft_model_id) - refresh_cache_if_needed(peft_model_id) - ff_revision, ff_revision_file, latest_revision = ( - self.__get_revision_hashes(peft_model_id, weights_path) - ) - - if ff_revision != latest_revision: - print( - f"'{peft_model_id}' local model weights need updating! Downloading/converting new weights now..." - ) - hf_model = get_hf_llm(peft_model_id) - hf_peft_model = PeftModel.from_pretrained( - hf_model, peft_model_id, config=peft_config - ) - # Convert the model to FlexFlow format - convert_peft_model(hf_peft_model, peft_type, weights_path) - # Save new revision hash to file - with open(ff_revision_file, "w+") as f: - f.write(latest_revision) - print(f"Done converting the weights for model {peft_model_id}") - # Deallocate hf model - del hf_peft_model - del hf_model - gc.collect() - torch.cuda.empty_cache() - - self.weights_path = get_weights_path(self.model_name) - download_llm_weights() - download_peft_weights() + self.model_class.convert_hf_model(hf_model, weights_path) + # Save new revision hash to file + print(f"Done converting the weights for model {self.model_name}") + # Deallocate hf model + del hf_model + gc.collect() + torch.cuda.empty_cache() + + need_refresh = self.__need_cache_refresh( + self.model_name, CachedResourceType.WEIGHTS + ) + if need_refresh: + print( + f"'{self.model_name}' local model weights need updating! Downloading/converting new weights now..." + ) + download_and_convert_llm_weights(self.model_name) def download_hf_tokenizer_if_needed(self): """Check in the folder specified by the cache_path whether the LLM's tokenizer files are available and up to date. @@ -331,25 +329,10 @@ def download_hf_tokenizer_if_needed(self): print("Loading tokenizer...") # Use local cache, or download new version - self.tokenizer_path = os.path.join( - os.path.expanduser(self.cache_path), "tokenizers", self.model_name.lower() + need_refresh = self.__need_cache_refresh( + self.model_name, CachedResourceType.TOKENIZER ) - if self.refresh_cache: - print( - f"Refreshing cached tokenizer for model {self.model_name} at path {self.tokenizer_path} ..." - ) - if os.path.exists(self.tokenizer_path): - shutil.rmtree(self.tokenizer_path) - if not os.path.exists(self.tokenizer_path): - print(f"Creating directory {self.tokenizer_path} (if it doesn't exist)...") - os.makedirs(self.tokenizer_path, exist_ok=True) - - # Get local revision SHA, check if it matches latest one on huggingface - ff_revision, ff_revision_file, latest_revision = self.__get_revision_hashes( - self.model_name, self.tokenizer_path - ) - - if ff_revision != latest_revision: + if need_refresh: print( f"'{self.model_name}' tokenizer needs updating! Downloading tokenizer now..." ) @@ -367,15 +350,76 @@ def download_hf_tokenizer_if_needed(self): hf_tokenizer_path = snapshot_download( repo_id=self.model_name, allow_patterns=target_tokenizer_files ) + tokenizer_path = self.__get_resource_path( + self.model_name, CachedResourceType.TOKENIZER + ) for file in target_tokenizer_files: src_path = os.path.join(hf_tokenizer_path, file) - dst_path = os.path.join(self.tokenizer_path, file) + dst_path = os.path.join(tokenizer_path, file) if os.path.exists(src_path): shutil.copy(src_path, dst_path) print("Done updating HF tokenizer.") - # Save new revision hash to file - with open(ff_revision_file, "w+") as f: - f.write(latest_revision) + + def download_peft_adapter_if_needed(self, hf_peft_model_id: str): + """Check in the folder specified by the cache_path whether the PEFT model weights are available and up to date. + If not, or if the refresh_cache parameter is set to True, download new weights and convert them. + """ + + def download_and_convert_peft_model(hf_peft_model_id: str): + if ( + self.data_type != DataType.DT_FLOAT + and self.data_type != DataType.DT_HALF + ): + raise ValueError( + "data_type must be either DataType.DT_FLOAT or DataType.DT_HALF" + ) + + # Save peft config to file + peft_config_dir = os.path.join( + os.path.expanduser(self.cache_path), "configs", hf_peft_model_id.lower() + ) + dst_path = os.path.join(peft_config_dir, "config.json") + os.makedirs(peft_config_dir, exist_ok=True) + print(f"Saving {hf_peft_model_id} configs to file {dst_path}...") + config_path = snapshot_download( + repo_id=hf_peft_model_id, allow_patterns="adapter_config.json" + ) + src_path = os.path.join(config_path, "adapter_config.json") + if os.path.exists(src_path): + shutil.copy(src_path, dst_path) + + # Save peft weights to file + adapter_path = snapshot_download( + repo_id=hf_peft_model_id, allow_patterns="adapter_model.safetensors" + ) + weights_path = self.__get_resource_path( + hf_peft_model_id.lower(), CachedResourceType.WEIGHTS + ) + with safe_open(adapter_path, framework="pt", device="cpu") as f: + for tensor_name in f.keys(): + tensor = f.get_tensor(tensor_name) + if self.data_type == DataType.DT_HALF: + tensor = tensor.half() + else: + tensor = tensor.float() + tensor_name = tensor_name.replace( + "base_model.model.model.", "" + ).replace(".default", "") + print(tensor_name) + + tensor_name = self.model_class.convert_hf_weight_name(tensor_name) + tensor.detach().cpu().numpy().tofile( + f"{weights_path}/{tensor_name}" + ) + + need_refresh = self.__need_cache_refresh( + hf_peft_model_id, CachedResourceType.WEIGHTS + ) + if need_refresh: + print( + f"'{hf_peft_model_id}' local model weights need updating! Downloading/converting new weights now..." + ) + download_and_convert_peft_model(hf_peft_model_id) def compile( self, @@ -383,10 +427,8 @@ def compile( max_requests_per_batch: int = 1, max_seq_length: int = 256, max_tokens_per_batch: int = 64, + max_concurrent_adapters: int = 1, enable_peft_finetuning: bool = False, - model_specific_data_parallelism_degree: int = None, - model_specific_tensor_parallelism_degree: int = None, - model_specific_pipeline_parallelism_degree: int = None, ssms: list = [], ): """Compile the LLM for inference and load the weights into memory @@ -399,14 +441,10 @@ def compile( :type max_seq_length: int, optional :param max_tokens_per_batch: The maximum number of tokens (across requests) to allow per batch, defaults to 64 :type max_tokens_per_batch: int, optional + :param max_concurrent_adapters: The maximum number of concurrent LoRA adapters, defaults to 1 + :type max_concurrent_adapters: int, optional :param enable_peft_finetuning: Whether to enable support for PEFT fine-tuning, defaults to False :type enable_peft_finetuning: bool, optional - :param model_specific_data_parallelism_degree: Use this parameter if you want to give the LLM a different data parallelism degree than the one used to initialize the runtime, defaults to None - :type model_specific_data_parallelism_degree: int, optional - :param model_specific_tensor_parallelism_degree: Use this parameter if you want to give the LLM a different tensor parallelism degree than the one used to initialize the runtime, defaults to None - :type model_specific_tensor_parallelism_degree: int, optional - :param model_specific_pipeline_parallelism_degree: Use this parameter if you want to give the LLM a different pipeline parallelism degree than the one used to initialize the runtime, defaults to None - :type model_specific_pipeline_parallelism_degree: int, optional :param ssms: The SSMs to use when operating in speculative inference mode, defaults to [] :type ssms: list, optional """ @@ -418,24 +456,13 @@ def compile( mode = InferenceMode.TREE_VERIFY_MODE elif type(self) == SSM: mode = InferenceMode.BEAM_SEARCH_MODE + self.ffconfig.data_parallelism_degree = 1 + self.ffconfig.tensor_parallelism_degree = 1 + self.ffconfig.pipeline_parallelism_degree = 1 else: assert type(self) == LLM mode = InferenceMode.INC_DECODING_MODE - # Apply model-specific parallelism degrees, if needed - if model_specific_data_parallelism_degree: - self.ffconfig.data_parallelism_degree = ( - model_specific_data_parallelism_degree - ) - if model_specific_tensor_parallelism_degree: - self.ffconfig.tensor_parallelism_degree = ( - model_specific_tensor_parallelism_degree - ) - if model_specific_pipeline_parallelism_degree: - self.ffconfig.pipeline_parallelism_degree = ( - model_specific_pipeline_parallelism_degree - ) - self.max_seq_length = max_seq_length # Create request manager and set serving configuration @@ -443,6 +470,7 @@ def compile( self.rm.set_max_requests_per_batch(max_requests_per_batch) self.rm.set_max_tokens_per_batch(max_tokens_per_batch) self.rm.set_max_sequence_length(max_seq_length) + self.rm.set_max_concurrent_adapters(max_concurrent_adapters) self.rm.set_enable_peft_finetuning(enable_peft_finetuning) # Instantiate the relevant model @@ -464,12 +492,6 @@ def compile( # Download the weights from huggingface (if needed) self.download_hf_weights_if_needed() - # Add PEFT layer if registered - for ff_peft_config, peft_dict in self.pefts.items(): - ff_peft_config.ff_compile() - ff_peft_model_id = self.model.ffmodel.add_lora_layer(ff_peft_config) - peft_dict["ff_peft_model_id"] = ff_peft_model_id - # Create file data loader, load weights into tensors model_configs = self.config_class(self.hf_config) @@ -479,8 +501,11 @@ def compile( else 20 ) + weights_path = self.__get_resource_path( + self.model_name, CachedResourceType.WEIGHTS + ) self.fileloader = FileDataLoader( - self.weights_path, + weights_path, model_configs.num_attention_heads, model_configs.num_key_value_heads, model_configs.hidden_size, @@ -504,8 +529,11 @@ def compile( eos_token_id = [eos_token_id] elif type(eos_token_id) != list: raise ValueError("eos_token_id must be an integer or a list of integers") + tokenizer_path = self.__get_resource_path( + self.model_name, CachedResourceType.TOKENIZER + ) self.rm.register_tokenizer( - self.model_type, bos_token_id, eos_token_id, self.tokenizer_path + self.model_type, bos_token_id, eos_token_id, tokenizer_path ) self.rm.register_output_filepath(self.output_file) @@ -520,14 +548,14 @@ def compile( atexit.register(self.rm.stop_server) - def _generate(self, requests: List[Request]): + def _generate(self, requests: List[Request]) -> List[GenerationResult]: if len(requests) == 0: return [] for req in requests: if req.req_type == RequestType.REQ_INFERENCE: # check max_length and max_new_tokens parameters if req.max_length == -1 and req.max_new_tokens == -1: - req.max_length = self.max_seq_length -1 + req.max_length = self.max_seq_length - 1 elif req.max_length != -1 and req.max_new_tokens != -1: warnings.warn( f"Both `max_new_tokens` (={req.max_new_tokens}) and `max_length`(={req.max_length}) seem to have been set. `max_new_tokens` will take precedence." @@ -546,14 +574,14 @@ def _generate(self, requests: List[Request]): f"max_new_tokens ({req.max_new_tokens}) is not allowed for finetuning requests." ) if req.max_length == -1: - req.max_length = self.max_seq_length -1 + req.max_length = self.max_seq_length - 1 if req.max_length >= self.max_seq_length: raise ValueError( f"max_length ({req.max_length}) exceeds the maximum sequence length ({self.max_seq_length})" ) return self.model.ffmodel.generate(requests) - def __chat2prompt(self, messages: List[dict]): + def __chat2prompt(self, messages: List[dict]) -> str: """Convert a list of messages to a single prompt string :param messages: The list of messages to convert @@ -563,15 +591,31 @@ def __chat2prompt(self, messages: List[dict]): """ # ensure that each element is a dictionary, containing the "role" and "content" keys for message in messages: - if type(message) != dict or "role" not in message or "content" not in message: + if ( + type(message) != dict + or "role" not in message + or "content" not in message + ): raise ValueError( "Each element in the list must be a dictionary with the keys 'role' and 'content'" ) if self.tokenizer is None: self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) if self.tokenizer.chat_template is None: - raise ValueError(f"Model {self.model_name} does not support chat completion") - return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + raise ValueError( + f"Model {self.model_name} does not support chat completion" + ) + return self.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + def __output2chat_response( + self, requests: List[Request], outputs: List[GenerationResult] + ) -> List[GenerationResult]: + assert len(requests) == len(outputs) + for i in range(len(outputs)): + outputs[i].output_text = outputs[i].output_text[len(requests[i].prompt) :] + return outputs def generate( self, @@ -625,9 +669,12 @@ def generate( max_new_tokens=max_new_tokens, add_special_tokens=False, ) - return self._generate([request]) + outputs = self._generate([request]) + return self.__output2chat_response([request], outputs) elif type(requests_or_prompts[0]) == list: - prompts = [self.__chat2prompt(messages) for messages in requests_or_prompts] + prompts = [ + self.__chat2prompt(messages) for messages in requests_or_prompts + ] requests = [ Request( req_type=RequestType.REQ_INFERENCE, @@ -638,12 +685,15 @@ def generate( ) for prompt in prompts ] - return self._generate(requests) + outputs = self._generate(requests) + return self.__output2chat_response(requests, outputs) elif type(requests_or_prompts[0]) == Request: print(requests_or_prompts) return self._generate(requests_or_prompts) else: - assert False, "Please pass a string, list of strings, Request, or list of Requests" + assert ( + False + ), "Please pass a string, list of strings, Request, or list of Requests" def start_server(self): self.rm.start_server(self.model.ffmodel) @@ -685,11 +735,9 @@ def compile( generation_config: GenerationConfig = GenerationConfig(), max_requests_per_batch: int = 16, max_seq_length: int = 256, - max_tokens_per_batch: int = 128, + max_tokens_per_batch: int = 2048, + max_concurrent_adapters: int = 1, enable_peft_finetuning: bool = False, - model_specific_data_parallelism_degree: int = 1, - model_specific_tensor_parallelism_degree: int = 1, - model_specific_pipeline_parallelism_degree: int = 1, ssms: list = [], ): """Compile the SSM for inference and load the weights into memory @@ -699,16 +747,12 @@ def compile( :type max_requests_per_batch: int, optional :param max_seq_length: The maximum sequence length to allow per batch, defaults to 256 :type max_seq_length: int, optional - :param max_tokens_per_batch: The maximum number of tokens (across requests) to allow per batch, defaults to 128 + :param max_tokens_per_batch: The maximum number of tokens (across requests) to allow per batch, defaults to 2048 :type max_tokens_per_batch: int, optional + :param max_concurrent_adapters: The maximum number of concurrent LoRA adapters, defaults to 1 + :type max_concurrent_adapters: int, optional :param enable_peft_finetuning: Whether to enable support for PEFT fine-tuning, defaults to False :type enable_peft_finetuning: bool, optional - :param model_specific_data_parallelism_degree: Use this parameter if you want to give the SSM a different data parallelism degree than the default one, defaults to 1 - :type model_specific_data_parallelism_degree: int, optional - :param model_specific_tensor_parallelism_degree: Use this parameter if you want to give the SSM a different tensor parallelism degree than the default one, defaults to 1 - :type model_specific_tensor_parallelism_degree: int, optional - :param model_specific_pipeline_parallelism_degree: Use this parameter if you want to give the SSM a different pipeline parallelism degree than the default one, defaults to 1 - :type model_specific_pipeline_parallelism_degree: int, optional :param ssms: The SSMs to use when operating in speculative inference mode, defaults to [] :type ssms: list, optional """ @@ -717,9 +761,7 @@ def compile( max_requests_per_batch, max_seq_length, max_tokens_per_batch, + max_concurrent_adapters, enable_peft_finetuning, - model_specific_data_parallelism_degree, - model_specific_tensor_parallelism_degree, - model_specific_pipeline_parallelism_degree, ssms, ) diff --git a/src/c/flexflow_c.cc b/src/c/flexflow_c.cc index da90c586e3..e16b0e87bd 100644 --- a/src/c/flexflow_c.cc +++ b/src/c/flexflow_c.cc @@ -177,6 +177,11 @@ void flexflow_config_set_pipeline_parallelism_degree(flexflow_config_t handle_, handle->pipeline_parallelism_degree = value; } +bool flexflow_config_get_enable_peft(flexflow_config_t handle_) { + FFConfig *handle = FFCObjectWrapper::unwrap(handle_); + return handle->enable_peft; +} + int flexflow_config_get_python_data_loader_type(flexflow_config_t handle_) { FFConfig *handle = FFCObjectWrapper::unwrap(handle_); return handle->python_data_loader_type; @@ -1608,18 +1613,33 @@ flexflow_tensor_t flexflow_model_add_argmax(flexflow_model_t handle_, } #ifdef FF_BUILD_INFERENCE -flexflow_peft_model_id_t flexflow_model_add_lora_layer( +void flexflow_model_add_lora_layers(flexflow_model_t handle_, + int num_target_modules, + char const **target_modules_) { + FFModel *handle = FFCObjectWrapper::unwrap(handle_); + std::vector target_modules; + for (int i = 0; i < num_target_modules; i++) { + target_modules.push_back(target_modules_[i]); + } + DEBUG_PRINT("[Add Lora Layers] model handle: %p, num_target_modules %d", + handle, + num_target_modules); + handle->add_lora_layers(target_modules); +} + +flexflow_peft_model_id_t flexflow_model_register_peft_adapter( flexflow_model_t handle_, const flexflow_lora_linear_config_t peft_config_) { FFModel *handle = FFCObjectWrapper::unwrap(handle_); LoraLinearConfig const *peft_config = FFCObjectWrapper::unwrap(peft_config_); - PEFTModelID *peft_model_id = handle->add_lora_layer(*peft_config); + PEFTModelID *peft_model_id = handle->register_peft_adapter(*peft_config); - DEBUG_PRINT("[Add Lora Layer] model handle: %p, peft_config handle %p, " - "peft_model_id: %p", - handle, - peft_config, - peft_model_id); + DEBUG_PRINT( + "[Register PEFT Adapter] model handle: %p, peft_config handle %p, " + "peft_model_id: %p", + handle, + peft_config, + peft_model_id); return FFCObjectWrapper::wrap(peft_model_id); } #endif @@ -2765,6 +2785,14 @@ int flexflow_request_manager_get_max_sequence_length( return handle->get_max_sequence_length(); } +void flexflow_request_manager_set_max_concurrent_adapters( + flexflow_request_manager_t handle_, int max_concurrent_adapters) { + RequestManager *handle = FFCObjectWrapper::unwrap(handle_); + handle->set_max_concurrent_adapters(max_concurrent_adapters); + DEBUG_PRINT("[RequestManager] set max_concurrent_adapters %d", + max_concurrent_adapters); +} + void flexflow_request_manager_set_enable_peft_finetuning( flexflow_request_manager_t handle_, bool enable_peft_finetuning_) { RequestManager *handle = FFCObjectWrapper::unwrap(handle_); @@ -2909,7 +2937,9 @@ void flexflow_file_data_loader_load_weights(flexflow_file_data_loader_t handle_, flexflow_model_t model_handle_) { FileDataLoader *handle = FFCObjectWrapper::unwrap(handle_); FFModel *model = FFCObjectWrapper::unwrap(model_handle_); - handle->load_weights(model); + Context ctx = model->config.lg_ctx; + Runtime *runtime = model->config.lg_hlr; + handle->load_weights_parallel(model, ctx, runtime); } // // ----------------------------------------------------------------------- diff --git a/src/mapper/mapper.cc b/src/mapper/mapper.cc index d7b9a5e99d..c02f70f752 100644 --- a/src/mapper/mapper.cc +++ b/src/mapper/mapper.cc @@ -288,6 +288,10 @@ void FFMapper::select_task_options(const MapperContext ctx, output.initial_proc = all_cpus[0]; return; } + if (task.task_id == LOAD_WEIGHT_TASK_ID) { + output.initial_proc = all_cpus[0]; + return; + } if (task.task_id == TOP_LEVEL_TASK_ID) { output.initial_proc = all_cpus[0]; // control replicate top level task diff --git a/src/ops/fused.cu b/src/ops/fused.cu index 62845c0f8e..8635fd6a87 100644 --- a/src/ops/fused.cu +++ b/src/ops/fused.cu @@ -862,6 +862,7 @@ __host__ void FusedOp::peft_bwd_task(Task const *task, int num_infr_tokens = bc->num_active_infr_tokens(); int num_peft_tokens = bc->num_active_peft_tokens(); Kernels::Linear::peft_bwd_kernel_wrapper(m, + bc, my_input_grad_accessor[0].ptr, my_output_grad_accessor[0].ptr, my_weight_accessor[0].ptr, @@ -889,11 +890,13 @@ __host__ void FusedOp::peft_bwd_task(Task const *task, // Assert that the output and the second input are at the same place // since we ``inplace'' the output for LoRA assert(my_input_grad_accessor[1].ptr == my_output_grad_accessor[0].ptr); + int shard_id = task->index_point.point_data[0]; Kernels::LoraLinear::peft_bwd_kernel_wrapper( ctx, runtime, m, bc, + shard_id, my_input_grad_accessor[0], my_output_grad_accessor[0]); break; diff --git a/src/ops/kernels/linear_kernels.cu b/src/ops/kernels/linear_kernels.cu index 3832428c64..51954597d7 100644 --- a/src/ops/kernels/linear_kernels.cu +++ b/src/ops/kernels/linear_kernels.cu @@ -16,6 +16,7 @@ #include "flexflow/ffconst_utils.h" #include "flexflow/ops/kernels/decompress_kernels.h" #include "flexflow/ops/kernels/linear_kernels.h" +#include "flexflow/ops/lora_linear_params.h" #include "flexflow/utils/cuda_helper.h" namespace FlexFlow { @@ -73,6 +74,17 @@ LinearMeta::~LinearMeta(void) { } } +bool lora_applies_to_this_layer(LinearMeta const *m, + LoraLinearConfig const &config) { + for (std::string s : config.target_modules) { + std::string n(m->op_name); + if (n.find(s) != std::string::npos) { + return true; + } + } + return false; +} + namespace Kernels { namespace Linear { @@ -285,6 +297,7 @@ void inference_kernel_wrapper(LinearMeta *m, } void peft_bwd_kernel_wrapper(LinearMeta const *m, + BatchConfig const *bc, void *input_grad_ptr, void *output_grad_ptr, void const *weight_ptr, @@ -302,6 +315,7 @@ void peft_bwd_kernel_wrapper(LinearMeta const *m, } if (m->input_type[0] == DT_FLOAT) { Internal::peft_bwd_kernel(m, + bc, input_grad_ptr, output_grad_ptr, weight_ptr, @@ -312,6 +326,7 @@ void peft_bwd_kernel_wrapper(LinearMeta const *m, stream); } else if (m->input_type[0] == DT_HALF) { Internal::peft_bwd_kernel(m, + bc, input_grad_ptr, output_grad_ptr, weight_ptr, @@ -568,6 +583,7 @@ void forward_kernel(LinearMeta const *m, template void peft_bwd_kernel(LinearMeta const *m, + BatchConfig const *bc, void *input_grad_ptr, void *output_grad_ptr, void const *kernel_ptr, @@ -611,6 +627,35 @@ void peft_bwd_kernel(LinearMeta const *m, // NOTE: we use beta=1 for input_grad to accumulate gradients when needed DT alpha = 1.0f; DT beta = m->reset_input_grads[0] ? 0.0f : 1.0f; + + // ensure that we only have one finetuning request, with a single lora + int num_peft_requests = 0; + bool lora_applies = false; + for (int i = 0; i < bc->max_requests_per_batch(); i++) { + if (bc->request_completed[i] || + bc->requestsInfo[i].peft_model_id == PEFTModelID::NO_ID || + !bc->requestsInfo[i].peft_bwd) { + continue; + } + num_peft_requests++; + std::string peft_model_config_str = + std::string(bc->requestsInfo[i].peft_model_config_str); + LoraLinearConfig lora_config = + LoraLinearConfig::deserialize_from_json_string(peft_model_config_str); + if (!lora_applies_to_this_layer(m, lora_config)) { + continue; + } + lora_applies = true; + } + assert(num_peft_requests == 1 && + "Exactly one PEFT finetuning request is required"); + // if the request does not have any active lora in the current layer, reset + // beta to 0 std::cout << m->op_name << " original beta: " << (float)beta << " + // lora_applies: " << lora_applies << std::endl; + if (lora_applies) { + beta = 1.0f; + } + if (input_grad_ptr != NULL) { checkCUDA(cublasGemmEx(m->handle.blas, CUBLAS_OP_N, diff --git a/src/ops/kernels/lora_linear_kernels.cu b/src/ops/kernels/lora_linear_kernels.cu index 638cee8cae..40095484b5 100644 --- a/src/ops/kernels/lora_linear_kernels.cu +++ b/src/ops/kernels/lora_linear_kernels.cu @@ -23,29 +23,32 @@ namespace FlexFlow { LoraLinearMeta::LoraLinearMeta(FFHandler handler, LoraLinear const *li) - : OpMeta(handler, li) { - allocated_peft_buffer_size1 = 0; - allocated_peft_buffer_size2 = 0; -} + : OpMeta(handler, li) {} LoraLinearMeta::~LoraLinearMeta(void) {} -namespace Kernels { -namespace LoraLinear { - -void init_kernel_wrapper(LoraLinearMeta *m, int seed) { - cudaStream_t stream; - checkCUDA(get_legion_stream(&stream)); - - if (m->input_type[0] == DT_FLOAT) { - Internal::init_kernel(m, seed, stream); - } else if (m->input_type[0] == DT_HALF) { - Internal::init_kernel(m, seed, stream); +std::string + get_peft_dbg_folder(LoraLinearMeta const *m, int shard_id, bool is_fwd) { + std::string op_name_without_uid = LoraLinear::get_op_name_without_uid(m); + fs::path dst_filepath; + if (is_fwd) { + dst_filepath = get_dst_folder("fwd", m->decoding_step, shard_id); } else { - assert(false && "Unsupported data type"); + dst_filepath = get_dst_folder("bwd", m->bwd_step, shard_id); } + if (m->layer_guid.model_id > 0) { + assert(false && "Model ID > 0 not supported yet"); + } + std::string layername = "layers." + + std::to_string(m->layer_guid.transformer_layer_id) + + "." + op_name_without_uid; + dst_filepath /= layername; + return dst_filepath.string(); } +namespace Kernels { +namespace LoraLinear { + void inference_kernel_wrapper(LoraLinearMeta *m, BatchConfig const *bc, GenericTensorAccessorR const &input, @@ -100,6 +103,7 @@ void peft_bwd_kernel_wrapper(Context ctx, Runtime *runtime, LoraLinearMeta *m, BatchConfig const *bc, + int shard_id, GenericTensorAccessorW const &input_grad, GenericTensorAccessorR const &output_grad) { cudaStream_t stream; @@ -117,6 +121,7 @@ void peft_bwd_kernel_wrapper(Context ctx, runtime, m, bc, + shard_id, input_grad.get_float_ptr(), output_grad.get_float_ptr(), in_dim, @@ -127,6 +132,7 @@ void peft_bwd_kernel_wrapper(Context ctx, runtime, m, bc, + shard_id, input_grad.get_half_ptr(), output_grad.get_half_ptr(), in_dim, @@ -151,58 +157,19 @@ void peft_bwd_kernel_wrapper(Context ctx, } } -namespace Internal { - -template -void init_kernel(LoraLinearMeta *m, int seed, cudaStream_t stream) { - // Initialize generator - std::mt19937 gen(seed); - - // Get handle to weights by iterating over m->model_state to get each - // LoraLinearWeight object - for (auto &model_state : m->model_state) { - LoraLinearWeight weight = model_state.second.weights; - int w0_num_elements = weight.rank * weight.in_dim; - int w1_num_elements = weight.rank * weight.out_dim; - - // LoRA_A weight: [in_dim, rank] - float stdv_lora_a = 1.0f / sqrt(weight.in_dim); - std::uniform_real_distribution dis_lora_a(-stdv_lora_a, stdv_lora_a); - std::vector
lora_a_random_init(w0_num_elements); - for (auto &num : lora_a_random_init) { - float num_float = dis_lora_a(gen); - if (std::is_same::value) { - num = __float2half(num_float); - } else { - num = num_float; - } - } - checkCUDA(cudaMemcpyAsync(static_cast
(weight.w0_ptr), - lora_a_random_init.data(), - w0_num_elements * sizeof(DT), - cudaMemcpyHostToDevice, - stream)); - - // LoRA_B weight: [rank, out_dim] - float stdv_lora_b = 1.0f / sqrt(weight.rank); - std::uniform_real_distribution dis_lora_b(-stdv_lora_b, stdv_lora_b); - std::vector lora_b_random_init(w1_num_elements); - for (auto &num : lora_b_random_init) { - float num_float = dis_lora_b(gen); - if (std::is_same::value) { - num = __float2half(num_float); - } else { - num = num_float; - } +bool lora_applies_to_this_layer(LoraLinearMeta *m, + LoraLinearConfig const &config) { + for (std::string s : config.target_modules) { + std::string n(m->op_name); + if (n.find(s) != std::string::npos) { + return true; } - checkCUDA(cudaMemcpyAsync(static_cast
(weight.w1_ptr), - lora_b_random_init.data(), - w1_num_elements * sizeof(DT), - cudaMemcpyHostToDevice, - stream)); } + return false; } +namespace Internal { + template void inference_kernel(LoraLinearMeta *m, BatchConfig const *bc, @@ -213,91 +180,60 @@ void inference_kernel(LoraLinearMeta *m, ffStream_t stream) { checkCUDA(cublasSetStream(m->handle.blas, stream)); checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); - DT alpha = 1.0f, beta = 0.0f; cudaDataType_t input_type = ff_to_cuda_datatype(m->input_type[0]); cudaDataType_t output_type = ff_to_cuda_datatype(m->input_type[1]); cudaDataType_t lr_actv_type = output_type; assert(input_type == output_type); cudaDataType_t weight_type = output_type; cudaDataType_t compute_type = output_type; - // #if defined(CUDA_VERSION) && (CUDA_VERSION < 11000) - // cudaDataType_t compute_type = output_type; - // #else - // // For best performance, set the default cublas compute type to - // // CUBLAS_COMPUTE_16F for half precision and to - // // CUBLAS_COMPUTE_32F_FAST_16F for full precision - // cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; - // if (m->input_type[0] == DT_FLOAT) { - // compute_type = CUBLAS_COMPUTE_32F_FAST_16F; - // } - // #endif + int num_peft_requests = 0; for (int i = 0; i < bc->max_requests_per_batch(); i++) { - if (bc->request_completed[i]) { - continue; - } - if (bc->requestsInfo[i].peft_model_id == PEFTModelID::NO_ID) { + if (bc->request_completed[i] || + bc->requestsInfo[i].peft_model_id == PEFTModelID::NO_ID) { continue; } if (bc->requestsInfo[i].peft_bwd) { num_peft_requests++; } - } - // Assert that we have at most one request that requires peft_bwd - assert(num_peft_requests <= 1); - for (int i = 0; i < bc->max_requests_per_batch(); i++) { - if (bc->request_completed[i]) { - continue; - } - // Skip non-PEFT requests - if (bc->requestsInfo[i].peft_model_id == PEFTModelID::NO_ID) { + std::string peft_model_config_str = + std::string(bc->requestsInfo[i].peft_model_config_str); + LoraLinearConfig lora_config = + LoraLinearConfig::deserialize_from_json_string(peft_model_config_str); + if (!lora_applies_to_this_layer(m, lora_config)) { continue; } + // std::cout << "Lora layer activated!" << std::endl; + // std::cout << "Lora Config: " << peft_model_config_str << std::endl; + assert(lora_config.trainable == bc->requestsInfo[i].peft_bwd && + "Trainable flag mismatch"); int num_peft_tokens = bc->requestsInfo[i].num_tokens_in_batch; - int max_peft_tokens = bc->requestsInfo[i].max_length; + // int max_peft_tokens = bc->requestsInfo[i].max_length; int first_token_offset = bc->requestsInfo[i].first_token_offset_in_batch; - assert(m->model_state.find(bc->requestsInfo[i].peft_model_id) != - m->model_state.end()); - LoraLinearWeight weight = - m->model_state[bc->requestsInfo[i].peft_model_id].weights; - int rank = weight.rank; - void *intermediate_result_ptr = nullptr; + LoraLinearWeight weight = m->peft_memory_manager->get_peft( + bc->requestsInfo[i].peft_model_id, lora_config); + void *intermediate_result_ptr = (bc->requestsInfo[i].peft_bwd) + ? weight.low_rank_activation + : m->handle.workSpace; if (bc->requestsInfo[i].peft_bwd) { - size_t activation_size_needed1 = - data_type_size(m->input_type[0]) * max_peft_tokens * in_dim; - size_t activation_size_needed2 = - data_type_size(m->input_type[1]) * max_peft_tokens * rank; - MemoryAllocator *allocator = m->handle.peft_activation_allocator; - if (activation_size_needed1 > m->allocated_peft_buffer_size1) { - m->input_activation = - allocator->allocate_instance_untyped(activation_size_needed1); - m->allocated_peft_buffer_size1 = activation_size_needed1; - } - if (activation_size_needed2 > m->allocated_peft_buffer_size2) { - m->low_rank_activation = - allocator->allocate_instance_untyped(activation_size_needed2); - m->allocated_peft_buffer_size2 = activation_size_needed2; - } - // copy input activation - checkCUDA(cudaMemcpyAsync(m->input_activation, + checkCUDA(cudaMemcpyAsync(weight.input_activation, input_ptr + first_token_offset * in_dim, data_type_size(m->input_type[0]) * num_peft_tokens * in_dim, cudaMemcpyDeviceToDevice, stream)); - intermediate_result_ptr = m->low_rank_activation; } else { // use workspace to save intermediate result - assert(m->handle.workSpaceSize >= - data_type_size(m->input_type[1]) * num_peft_tokens * rank); - intermediate_result_ptr = m->handle.workSpace; + assert(m->handle.workSpaceSize >= data_type_size(m->input_type[1]) * + num_peft_tokens * lora_config.rank); } + DT alpha = 1.0f, beta = 0.0f; // buffer = weight_first * input // [rank, num_peft_tokens] = [in_dim, rank].T * [in_dim, num_peft_tokens] checkCUDA(cublasGemmEx(m->handle.blas, CUBLAS_OP_T, CUBLAS_OP_N, - rank, + lora_config.rank, num_peft_tokens, in_dim, &alpha, @@ -310,29 +246,27 @@ void inference_kernel(LoraLinearMeta *m, &beta, intermediate_result_ptr, lr_actv_type, - rank, + lora_config.rank, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // output = weight_second * buffer // [out_dim, num_peft_tokens] = [rank, out_dim].T * [rank, num_peft_tokens] // Note that we use alpha in both places since we do // an in-place update for LoraLinear - float lora_alpha = - m->model_state[bc->requestsInfo[i].peft_model_id].lora_alpha; - DT scaling_constant = (DT)(lora_alpha / rank); + DT scaling_constant = (DT)(lora_config.lora_alpha / lora_config.rank); checkCUDA(cublasGemmEx(m->handle.blas, CUBLAS_OP_T, CUBLAS_OP_N, out_dim, num_peft_tokens, - rank, + lora_config.rank, &scaling_constant, weight.w1_ptr, weight_type, - rank, + lora_config.rank, intermediate_result_ptr, lr_actv_type, - rank, + lora_config.rank, &alpha, output_ptr + first_token_offset * out_dim, output_type, @@ -340,6 +274,7 @@ void inference_kernel(LoraLinearMeta *m, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } + assert(num_peft_requests <= 1); } template @@ -371,6 +306,7 @@ void peft_bwd_kernel(Context ctx, Runtime *runtime, LoraLinearMeta *m, BatchConfig const *bc, + int shard_id, DT *input_grad_ptr, DT const *output_grad_ptr, int in_dim, @@ -384,39 +320,33 @@ void peft_bwd_kernel(Context ctx, cudaDataType_t weight_type = output_type; cudaDataType_t lr_actv_type = output_type; cudaDataType_t compute_type = output_type; - // #if defined(CUDA_VERSION) && (CUDA_VERSION < 11000) - // cudaDataType_t compute_type = output_type; - // #else - // // For best performance, set the default cublas compute type to - // // CUBLAS_COMPUTE_16F for half precision and to - // // CUBLAS_COMPUTE_32F_FAST_16F for full precision - // cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; - // if (m->output_type[0] == DT_FLOAT) { - // compute_type = CUBLAS_COMPUTE_32F_FAST_16F; - // } - // #endif + for (int i = 0; i < bc->max_requests_per_batch(); i++) { - if (bc->request_completed[i]) { - continue; - } - // Skip non-PEFT requests - if (bc->requestsInfo[i].peft_model_id == PEFTModelID::NO_ID) { + // Skip completed, non-PEFT and PEFT forward-only requests + if (bc->request_completed[i] || + bc->requestsInfo[i].peft_model_id == PEFTModelID::NO_ID || + !bc->requestsInfo[i].peft_bwd) { continue; } - // Skip PEFT forward-only requests - if (!bc->requestsInfo[i].peft_bwd) { + std::string peft_model_config_str = + std::string(bc->requestsInfo[i].peft_model_config_str); + LoraLinearConfig lora_config = + LoraLinearConfig::deserialize_from_json_string(peft_model_config_str); + if (!lora_applies_to_this_layer(m, lora_config)) { continue; } + // std::cout << "Lora layer activated!" << std::endl; + // std::cout << "Lora Config: " << peft_model_config_str << std::endl; + assert(lora_config.trainable == bc->requestsInfo[i].peft_bwd && + "Trainable flag mismatch"); + m->peft_memory_manager->check_ft_model_id( + bc->requestsInfo[i].peft_model_id); int num_peft_tokens = bc->requestsInfo[i].num_tokens_in_batch; + // int max_peft_tokens = bc->requestsInfo[i].max_length; // int first_token_offset = bc->requestsInfo[i].first_token_offset_in_batch; - assert(m->model_state.find(bc->requestsInfo[i].peft_model_id) != - m->model_state.end()); - LoraLinearWeight weight = - m->model_state[bc->requestsInfo[i].peft_model_id].weights; - int rank = weight.rank; - float lora_alpha = - m->model_state[bc->requestsInfo[i].peft_model_id].lora_alpha; - DT scaling_constant = (DT)(lora_alpha / rank); + LoraLinearWeight weight = m->peft_memory_manager->get_peft( + bc->requestsInfo[i].peft_model_id, lora_config); + DT scaling_constant = (DT)(lora_config.lora_alpha / lora_config.rank); // Compute LORA_B weight's gradient if (bc->requestsInfo[i].optimizer_tasks.compute_gradients) { @@ -424,23 +354,35 @@ void peft_bwd_kernel(Context ctx, DT beta = (bc->requestsInfo[i].optimizer_tasks.reset_gradients_to_zero) ? 0.0f : 1.0f; + // std::cout << "Lora B gradient computation, beta = " << (float) beta << + // std::endl; + if (m->inference_debugging) { + // save result to file for checking + std::string filename = + get_peft_dbg_folder(m, shard_id, false) + ".low_rank_activation"; + std::cout << "Save low_rank_activation (" << lora_config.rank << ", " + << num_peft_tokens << ") to " << filename << std::endl; + save_tensor(static_cast(weight.low_rank_activation), + lora_config.rank * num_peft_tokens, + filename.c_str()); + } checkCUDA(cublasGemmEx(m->handle.blas, CUBLAS_OP_N, CUBLAS_OP_T, - rank, + lora_config.rank, out_dim, num_peft_tokens, &scaling_constant, - m->low_rank_activation, + weight.low_rank_activation, lr_actv_type, - rank, + lora_config.rank, output_grad_ptr, output_type, out_dim, &beta, weight.w1_grad_ptr, weight_type, - rank, + lora_config.rank, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } @@ -452,20 +394,20 @@ void peft_bwd_kernel(Context ctx, checkCUDA(cublasGemmEx(m->handle.blas, CUBLAS_OP_N, CUBLAS_OP_N, - rank, + lora_config.rank, num_peft_tokens, out_dim, &scaling_constant, weight.w1_ptr, weight_type, - rank, + lora_config.rank, output_grad_ptr, output_type, out_dim, &beta, - m->low_rank_activation, + weight.low_rank_activation, lr_actv_type, - rank, + lora_config.rank, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } @@ -480,15 +422,15 @@ void peft_bwd_kernel(Context ctx, CUBLAS_OP_N, CUBLAS_OP_T, in_dim, - rank, + lora_config.rank, num_peft_tokens, &alpha, - m->input_activation, + weight.input_activation, input_type, in_dim, - m->low_rank_activation, + weight.low_rank_activation, lr_actv_type, - rank, + lora_config.rank, &beta, weight.w0_grad_ptr, weight_type, @@ -506,14 +448,14 @@ void peft_bwd_kernel(Context ctx, CUBLAS_OP_N, in_dim, num_peft_tokens, - rank, + lora_config.rank, &alpha, weight.w0_ptr, weight_type, in_dim, - m->low_rank_activation, + weight.low_rank_activation, lr_actv_type, - rank, + lora_config.rank, &beta, input_grad_ptr, input_type, @@ -523,17 +465,16 @@ void peft_bwd_kernel(Context ctx, } if (bc->requestsInfo[i].optimizer_tasks.update_weights) { - LoraOptimizerConfig const *optimizer_config = - m->model_state[bc->requestsInfo[i].peft_model_id].optimizer_config; - assert(optimizer_config != nullptr); - assert(typeid(*optimizer_config) != typeid(LoraOptimizerConfig)); - int w0_num_elements = rank * in_dim; - int w1_num_elements = rank * out_dim; + assert(lora_config.optimizer_config != nullptr); + int w0_num_elements = lora_config.rank * in_dim; + int w1_num_elements = lora_config.rank * out_dim; // Get optimizer config - if (typeid(*optimizer_config) == typeid(LoraSGDOptimizerConfig)) { + + if (lora_config.optimizer_config->getType() == "SGD") { LoraSGDOptimizerConfig const *sgd_config = - (LoraSGDOptimizerConfig const *)optimizer_config; + static_cast( + lora_config.optimizer_config); // LoRA_A weight is split in tensor parallelism, so no need to apply // all-reduce sgd_update<<(weight.w1_grad_ptr), static_cast
(weight.w1_v_values_ptr), static_cast
(weight.w1_ptr)); - } else if (typeid(*optimizer_config) == typeid(LoraAdamOptimizerConfig)) { + } else if (lora_config.optimizer_config->getType() == "Adam") { assert(false && "Adam optimizer type not implemented yet"); } else { assert(false && "Unsupported optimizer type"); diff --git a/src/ops/linear.cc b/src/ops/linear.cc index 09170d3c28..8c2120e283 100644 --- a/src/ops/linear.cc +++ b/src/ops/linear.cc @@ -769,6 +769,7 @@ void Linear::peft_bwd_task(Task const *task, num_peft_tokens); } peft_bwd_kernel_wrapper(m, + bc, input_grad.ptr, output_grad.ptr, weight.ptr, diff --git a/src/ops/lora_linear.cc b/src/ops/lora_linear.cc index 3749cce994..68605160a5 100644 --- a/src/ops/lora_linear.cc +++ b/src/ops/lora_linear.cc @@ -3,6 +3,7 @@ #include "flexflow/layer.h" #include "flexflow/model.h" #include "flexflow/ops/kernels/lora_linear_kernels.h" +#include "flexflow/request_manager.h" #include "flexflow/utils/hash_utils.h" #include "flexflow/utils/peft_weight_allocator.h" #include "legion/legion_utilities.h" @@ -51,18 +52,18 @@ bool check_lora_layer_match(Layer *potential_target, return false; } -PEFTModelID *FFModel::add_lora_layer(LoraLinearConfig const peft_config) { +void FFModel::add_lora_layers(std::vector target_modules) { assert(config.enable_peft && "Cannot add a LoRA layer if PEFT mode is not enabled"); - if (peft_config.target_modules.size() == 0) { - printf("PEFT config does not contain any target module\n"); - std::cout << peft_config << std::endl; - assert(false); - } - PEFTModelID *peft_model_id = new PEFTModelID(peft_model_global_guid++); - peft_configs[*peft_model_id] = peft_config; - - for (std::string target_module_name : peft_config.target_modules) { + assert(target_modules.size() > 0 && "LoRA target module name is empty"); + RequestManager *rm = RequestManager::get_request_manager(); + int max_lora_rank = rm->get_max_lora_rank(); + int max_concurrent_adapters = rm->get_max_concurrent_adapters(); + assert(max_lora_rank > 1 && max_lora_rank <= 32 && "Invalid max LoRA rank"); + assert(max_concurrent_adapters > 0 && + "Invalid number of LoRA concurrent adapters"); + + for (std::string target_module_name : target_modules) { assert(target_module_name.length() > 0 && "LoRA target module name is empty"); // find target layer @@ -72,127 +73,84 @@ PEFTModelID *FFModel::add_lora_layer(LoraLinearConfig const peft_config) { if (!match) { continue; } - - if (base_layer_to_peft_layer.find(target_module) != - base_layer_to_peft_layer.end()) { - // lora linear layer already added, no need to add again - Layer *peft_layer = base_layer_to_peft_layer[target_module]; - peft_layer_to_peft_id[peft_layer].push_back(*peft_model_id); - } else { - Tensor const input = target_module->inputs[0]; - Tensor const output = target_module->outputs[0]; - assert(input->data_type == output->data_type); - std::string name_ = target_module->name - ? std::string(target_module->name) - : std::string(""); - size_t last_underscore = name_.length() - 1; - for (int i = name_.length() - 1; i > 0; i--) { - if (!(std::isdigit(target_module->name[i]) || - target_module->name[i] == '_')) { - break; - } else if (target_module->name[i] == '_') { - last_underscore = i; - } + assert(base_layer_to_peft_layer.find(target_module) == + base_layer_to_peft_layer.end() && + "LoRA layer already added, attempting to add again"); + // Get input and output tensors from target module + Tensor const input = target_module->inputs[0]; + Tensor const output = target_module->outputs[0]; + assert(input->data_type == output->data_type); + // Compute OP_LORA layer name, based on target module name + std::string name_ = target_module->name ? std::string(target_module->name) + : std::string(""); + size_t last_underscore = name_.length() - 1; + for (int i = name_.length() - 1; i > 0; i--) { + if (!(std::isdigit(target_module->name[i]) || + target_module->name[i] == '_')) { + break; + } else if (target_module->name[i] == '_') { + last_underscore = i; } - name_.erase(last_underscore); - - name_ += ".lora"; - std::cout << "Adding layer " << name_ << std::endl; - Layer *peft_layer = new Layer(this, - OP_LORA, - output->data_type, - name_.c_str(), - 2 /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - input, - output); - // fix LoRA layer's transformer layer ID and model ID - peft_layer->layer_guid.transformer_layer_id = - target_module->layer_guid.transformer_layer_id; - peft_layer->layer_guid.model_id = target_module->layer_guid.model_id; - { - int numdims = output->num_dims; - int dims[MAX_TENSOR_DIM]; - for (int i = 0; i < numdims; i++) { - dims[i] = output->dims[i]; - } - peft_layer->outputs[0] = - create_tensor_legion_ordering(numdims, - dims, - output->data_type, - peft_layer, - 0, - true /*create_grad*/); + } + name_.erase(last_underscore); + name_ += ".lora"; + std::cout << "Adding layer " << name_ << std::endl; + // Create OP_LORA layer given input, output and name + Layer *peft_layer = new Layer(this, + OP_LORA, + output->data_type, + name_.c_str(), + 2 /*inputs*/, + 0 /*weights*/, + 1 /*outputs*/, + input, + output); + // fix LoRA layer's transformer layer ID and model ID (to be the same as + // target module) + peft_layer->layer_guid.transformer_layer_id = + target_module->layer_guid.transformer_layer_id; + peft_layer->layer_guid.model_id = target_module->layer_guid.model_id; + // set up output tensor for OP_LORA layer + { + int numdims = output->num_dims; + int dims[MAX_TENSOR_DIM]; + for (int i = 0; i < numdims; i++) { + dims[i] = output->dims[i]; } - it = layers.insert(it + 1, peft_layer); - ++it; - base_layer_to_peft_layer[target_module] = peft_layer; - peft_layer_to_peft_id[peft_layer] = std::vector(); - peft_layer_to_peft_id[peft_layer].push_back(*peft_model_id); + peft_layer->outputs[0] = + create_tensor_legion_ordering(numdims, + dims, + output->data_type, + peft_layer, + 0, + true /*create_grad*/); } + // pass max_rank and max_concurrent_adapters to OP_LORA layer + peft_layer->add_int_property("max_rank", max_lora_rank); + peft_layer->add_int_property("max_concurrent_adapters", + max_concurrent_adapters); + it = layers.insert(it + 1, peft_layer); + ++it; + base_layer_to_peft_layer[target_module] = peft_layer; } } - - // save finetuned lora model configs to file - if (peft_config.trainable) { - std::string finetuned_model_folder = join_path({ - peft_config.cache_folder, - "finetuned_models", - peft_config.peft_model_id, - }); - fs::remove_all(finetuned_model_folder); - std::string finetuned_model_config_folder = join_path({ - finetuned_model_folder, - "config", - }); - fs::create_directories(finetuned_model_config_folder); - std::string lora_linear_config_filepath = join_path({ - finetuned_model_config_folder, - "ff_config.json", - }); - serialize_to_json_file(peft_config, lora_linear_config_filepath); - std::string optimizer_config_filepath = join_path({ - finetuned_model_config_folder, - "ff_optimizer_config.json", - }); - if (typeid(*peft_config.optimizer_config) == - typeid(LoraSGDOptimizerConfig)) { - LoraSGDOptimizerConfig const *sgd_config = - static_cast( - peft_config.optimizer_config); - serialize_to_json_file(*sgd_config, optimizer_config_filepath); - } else if (typeid(*peft_config.optimizer_config) == - typeid(LoraAdamOptimizerConfig)) { - LoraAdamOptimizerConfig const *adam_config = - static_cast( - peft_config.optimizer_config); - serialize_to_json_file(*adam_config, optimizer_config_filepath); - } else { - assert(false && "Optimizer not supported"); - } - } - - return peft_model_id; } Op *LoraLinear::create_operator_from_layer( FFModel &model, Layer const *layer, std::vector const &inputs) { - std::unordered_map _peft_configs; - std::vector const &peft_ids = - model.peft_layer_to_peft_id[(Layer *)layer]; - for (int i = 0; i < peft_ids.size(); i++) { - _peft_configs.emplace( - std::make_pair(peft_ids[i], model.peft_configs[peft_ids[i]])); - } + long long value; + layer->get_int_property("max_rank", value); + int max_rank = value; + layer->get_int_property("max_concurrent_adapters", value); + int max_concurrent_adapters = value; return new LoraLinear(model, layer->layer_guid, - layer->op_type, inputs[0], inputs[1], - _peft_configs, + max_rank, + max_concurrent_adapters, layer->name); } @@ -202,10 +160,10 @@ LoraLinear::LoraLinear(FFModel &model, ParallelTensor const output) : LoraLinear(model, other.layer_guid, - other.op_type, input, output, - other.peft_configs, + other.max_rank, + other.max_concurrent_adapters, other.name) {} LoraLinear::LoraLinear(FFModel &model, @@ -214,22 +172,23 @@ LoraLinear::LoraLinear(FFModel &model, char const *name) : LoraLinear(model, params.layer_guid, - params.type, inputs.first, inputs.second, - params.peft_configs, + params.max_rank, + params.max_concurrent_adapters, params.name) {} LoraLinear::LoraLinear( FFModel &model, LayerID const &_layer_guid, - OperatorType _op_type, ParallelTensor const _input, ParallelTensor const _output, - std::unordered_map const &_peft_configs, + int _max_rank, + int _max_concurrent_adapters, + // std::unordered_map const &_peft_configs, char const *name) : Op(model, - _op_type, + OP_LORA, _output->data_type, name, 2 /*inputs*/, @@ -256,9 +215,11 @@ LoraLinear::LoraLinear( outputs[0] = model.create_parallel_tensor_legion_ordering( numdim, dims, inputs[1]->data_type, this); } - for (auto const &kv : _peft_configs) { - peft_configs.insert(kv); - } + // for (auto const &kv : _peft_configs) { + // peft_configs.insert(kv); + // } + max_rank = _max_rank; + max_concurrent_adapters = _max_concurrent_adapters; // assert(check_output_input_weight_parallel_dims(allocate_weights)); } @@ -313,56 +274,6 @@ void LoraLinear::init_inference( set_opmeta_from_futuremap_inference(ff, fm, output_tensor); } -template -void load_peft_from_file(DT *ptr, - size_t num_rows, - size_t num_columns, - int num_shards, - int shard_id, - std::string filepath) { - std::ifstream in(filepath, std::ios::in | std::ios::binary); - if (!in.good()) { - printf("Could not open file: %s\n", filepath.c_str()); - } - assert(in.good() && "incorrect weight file path"); - - // HuggingFace dims (serialized in row-major order) - // lora_A: [rank, intermediate_dim] - // lora_B: [hidden_dim, rank] - // FlexFlow dims (serialized in column-major order) - // lora_A: [intermediate_dim, rank] - // lora_B: [rank, out_dim] - // Tensor parallelism: shard lora_A along intermediate_dim, replicate lora_B - assert(num_rows % num_shards == 0); - size_t chunk_size = num_rows / num_shards; - size_t offset = (num_shards > 1) ? shard_id * chunk_size : 0; - - // Allocate memory for the weight shard - std::vector
host_array(chunk_size * num_columns); - // Read the chunk - size_t total_size_read = 0; - for (int i = 0; i < num_columns; ++i) { - in.seekg((i * num_rows + offset) * sizeof(DT)); - in.read(reinterpret_cast(host_array.data() + i * chunk_size), - chunk_size * sizeof(DT)); - total_size_read += in.gcount(); - } - // Check weight shard size - size_t expected_data_size = chunk_size * num_columns * sizeof(DT); - if (total_size_read != expected_data_size) { - printf("load weight data error: expected %lu bytes, got: %lu bytes, data " - "size: %lu\n", - expected_data_size, - total_size_read, - sizeof(DT)); - assert(false); - } - assert(host_array.size() == chunk_size * num_columns); - // Copy weight to device memory - copy_tensor_host_to_dev(ptr, host_array.data(), chunk_size * num_columns); - in.close(); -} - /* regions[0](O): output regions[1](I): kernel @@ -428,162 +339,20 @@ OpMeta *LoraLinear::init_task(Task const *task, std::string lora_layername_substr = lora_layername.substr(0, found + searchString.length()); - for (auto const &kv : lora->peft_configs) { - PEFTModelID const &model_id = kv.first; - LoraLinearConfig const &lora_config = kv.second; - - int rank = lora_config.rank; - - int w0_num_elements = rank * in_dim; - int w1_num_elements = rank * out_dim; - // values below represent total weight sizes before sharding. Lora B is not - // sharded. - int lora_A_num_rows = in_dim * num_shards; - int lora_A_num_cols = rank; - int lora_B_num_rows = rank; - int lora_B_num_cols = out_dim; - int lora_A_num_shards = num_shards; - int lora_B_num_shards = 1; - - LoraLinearWeight weight; - weight.in_dim = in_dim; - weight.out_dim = out_dim; - weight.rank = rank; - weight.num_shards = num_shards; - PEFTWeightAllocator *allocator = m->handle.peft_weight_allocator; - weight.w0_ptr = allocator->allocate_local_weights_untyped( - model_id, w0_num_elements * data_type_size(dt)); - weight.w1_ptr = allocator->allocate_local_weights_untyped( - model_id, w1_num_elements * data_type_size(dt)); - - if (!lora_config.init_lora_weights) { - // load weights from file - std::string weights_folder_filepath = join_path({ - lora_config.cache_folder, - "weights", - lora_config.peft_model_id, - dt == DT_FLOAT ? "full-precision" : "half-precision", - }); - std::string w0_filepath = join_path( - {weights_folder_filepath, lora_layername_substr + "_A.weight"}); - std::string w1_filepath = join_path( - {weights_folder_filepath, lora_layername_substr + "_B.weight"}); - if (dt == DT_FLOAT) { - std::cout << "Loading LORA weight " - << lora_layername_substr + "_A.weight" - << ", num_rows: " << lora_A_num_rows - << ", num_cols: " << lora_A_num_cols - << ", num_shards: " << lora_A_num_shards - << ", shard_id: " << shard_id << std::endl; - load_peft_from_file((float *)weight.w0_ptr, - lora_A_num_rows, - lora_A_num_cols, - lora_A_num_shards, - shard_id, - w0_filepath); - std::cout << "Loading LORA weight " - << lora_layername_substr + "_B.weight" - << ", num_rows: " << lora_B_num_rows - << ", num_cols: " << lora_B_num_cols - << ", num_shards: " << lora_B_num_shards - << ", shard_id: " << shard_id << std::endl; - load_peft_from_file((float *)weight.w1_ptr, - lora_B_num_rows, - lora_B_num_cols, - lora_B_num_shards, - shard_id, - w1_filepath); - } else if (dt == DT_HALF) { - std::cout << "Loading LORA weight " - << lora_layername_substr + "_A.weight" - << ", num_rows: " << lora_A_num_rows - << ", num_cols: " << lora_A_num_cols - << ", num_shards: " << lora_A_num_shards - << ", shard_id: " << shard_id << std::endl; - load_peft_from_file((half *)weight.w0_ptr, - lora_A_num_rows, - lora_A_num_cols, - lora_A_num_shards, - shard_id, - w0_filepath); - std::cout << "Loading LORA weight " - << lora_layername_substr + "_B.weight" - << ", num_rows: " << lora_B_num_rows - << ", num_cols: " << lora_B_num_cols - << ", num_shards: " << lora_B_num_shards - << ", shard_id: " << shard_id << std::endl; - load_peft_from_file((half *)weight.w1_ptr, - lora_B_num_rows, - lora_B_num_cols, - lora_B_num_shards, + // allocate space for lora weights + Memory gpu_mem = get_proc_mem(Machine::get_machine(), task->target_proc); + m->peft_memory_manager = + new PEFTMemoryManager(gpu_mem, + lora->max_rank, + lora->max_concurrent_adapters, + BatchConfig::max_sequence_length(), + in_dim, + out_dim, + num_shards, shard_id, - w1_filepath); - } else { - assert(false && "Data type not supported"); - } - } else { - // initialize weights - int seed = 0; - init_kernel_wrapper(m, seed); - } - - // allocate space for gradients if the LoRA layer is trainable - if (lora_config.trainable) { - // Ensure we have an optimizer - assert(lora_config.optimizer_config != nullptr && "Optimizer not set"); - assert(typeid(*lora_config.optimizer_config) != - typeid(LoraOptimizerConfig) && - "Optimizer config is not a subclass of LoraOptimizerConfig"); - if (lora->inputs[0]->dims[num_dims - 1].degree == 1) { - // Input is partitioned (no replication) - // w0_grad is local weight gradients - weight.w0_grad_ptr = allocator->allocate_local_weights_untyped( - model_id, w0_num_elements * data_type_size(dt)); - // w1_grad is sync weight gradients - weight.w1_grad_ptr = allocator->allocate_sync_weights_untyped( - model_id, w1_num_elements * data_type_size(dt)); - } else { - // Input is replicated - // w0_grad is sync weight gradients - weight.w0_grad_ptr = allocator->allocate_sync_weights_untyped( - model_id, w0_num_elements * data_type_size(dt)); - // w1_grad is local weight gradients - weight.w1_grad_ptr = allocator->allocate_local_weights_untyped( - model_id, w1_num_elements * data_type_size(dt)); - } - // allocate space for v_values if needed by optimizer - if (typeid(*lora_config.optimizer_config) == - typeid(LoraSGDOptimizerConfig)) { - LoraSGDOptimizerConfig const *sgd_config = - static_cast( - lora_config.optimizer_config); - if (sgd_config->momentum > 0.0f) { - if (lora->inputs[0]->dims[num_dims - 1].degree == 1) { - weight.w0_v_values_ptr = allocator->allocate_local_weights_untyped( - model_id, w0_num_elements * data_type_size(dt)); - weight.w1_v_values_ptr = allocator->allocate_sync_weights_untyped( - model_id, w1_num_elements * data_type_size(dt)); - } else { - weight.w0_v_values_ptr = allocator->allocate_sync_weights_untyped( - model_id, w0_num_elements * data_type_size(dt)); - weight.w1_v_values_ptr = allocator->allocate_local_weights_untyped( - model_id, w1_num_elements * data_type_size(dt)); - } - } - } else if (typeid(*lora_config.optimizer_config) == - typeid(LoraAdamOptimizerConfig)) { - assert(false && "Adam optim not yet implemented"); - } else { - assert(false && "Optimizer not supported"); - } - } - assert(m->model_state.find(model_id) == m->model_state.end()); - m->model_state[model_id].weights = weight; - m->model_state[model_id].optimizer_config = lora_config.optimizer_config; - m->model_state[model_id].lora_alpha = lora_config.lora_alpha; - m->model_state[model_id].cache_folder = lora_config.cache_folder; - m->model_state[model_id].peft_model_id = lora_config.peft_model_id; - } + lora_layername_substr, + dt); + m->peft_memory_manager->allocate_inference_memory(); return m; } @@ -655,8 +424,8 @@ void LoraLinear::inference_task(Task const *task, m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); GenericTensorAccessorW output = helperGetGenericTensorAccessorRW( m->input_type[1], regions[1], task->regions[1], FID_DATA, ctx, runtime); - // int in_dim = input.domain.hi()[0] - input.domain.lo()[0] + 1; - // int out_dim = output.domain.hi()[0] - output.domain.lo()[0] + 1; + int in_dim = input.domain.hi()[0] - input.domain.lo()[0] + 1; + int out_dim = output.domain.hi()[0] - output.domain.lo()[0] + 1; // int num_infr_tokens = bc->num_active_infr_tokens(); // int num_peft_tokens = bc->num_active_peft_tokens(); @@ -707,12 +476,20 @@ void LoraLinear::inference_task(Task const *task, assert(false); } - int rank, num_tokens; - for (auto it = m->model_state.begin(); it != m->model_state.end(); ++it) { - PEFTModelID peft_model_id = it->first; - LoraLinearWeight weight = m->model_state[peft_model_id].weights; - rank = weight.rank; - num_tokens = input.domain.get_volume() / weight.in_dim; + for (int i = 0; i < bc->max_requests_per_batch(); i++) { + if (bc->request_completed[i] || + bc->requestsInfo[i].peft_model_id == PEFTModelID::NO_ID) { + continue; + } + std::string peft_model_config_str = + std::string(bc->requestsInfo[i].peft_model_config_str); + LoraLinearConfig lora_config = + LoraLinearConfig::deserialize_from_json_string(peft_model_config_str); + if (!lora_applies_to_this_layer(m, lora_config)) { + continue; + } + LoraLinearWeight weight = m->peft_memory_manager->get_peft( + bc->requestsInfo[i].peft_model_id, lora_config); fs::path dst_filepath_weights = get_dst_folder("weights", m->decoding_step, shard_id) / layername; std::string filenameA = @@ -721,21 +498,38 @@ void LoraLinear::inference_task(Task const *task, dst_filepath_weights.string() + ".weight_B.original"; if (m->input_type[0] == DT_FLOAT) { save_tensor((float *)weight.w0_ptr, - weight.rank * weight.in_dim, + lora_config.rank * in_dim, filenameA.c_str()); save_tensor((float *)weight.w1_ptr, - weight.rank * weight.out_dim, + lora_config.rank * out_dim, filenameB.c_str()); } else if (m->input_type[0] == DT_HALF) { save_tensor((half *)weight.w0_ptr, - weight.rank * weight.in_dim, + lora_config.rank * in_dim, filenameA.c_str()); save_tensor((half *)weight.w1_ptr, - weight.rank * weight.out_dim, + lora_config.rank * out_dim, filenameB.c_str()); } else { assert(false && "Data type not supported"); } + + if (bc->requestsInfo[i].peft_bwd) { + int num_tokens = input.domain.get_volume() / in_dim; + // input activation (intermediate) + filename = dst_filepath.string() + ".low_rank_activation"; + if (output.data_type == DT_FLOAT) { + save_tensor((float *)weight.low_rank_activation, + lora_config.rank * num_tokens, + filename.c_str()); + } else if (output.data_type == DT_HALF) { + save_tensor((half *)weight.low_rank_activation, + lora_config.rank * num_tokens, + filename.c_str()); + } else { + assert(false); + } + } } filename = dst_filepath.string() + ".output_0"; @@ -749,21 +543,6 @@ void LoraLinear::inference_task(Task const *task, assert(false); } - if (bc->num_active_peft_tokens() > 0) { - // input activation (intermediate) - filename = dst_filepath.string() + ".low_rank_activation"; - if (output.data_type == DT_FLOAT) { - save_tensor((float *)m->low_rank_activation, - rank * num_tokens, - filename.c_str()); - } else if (output.data_type == DT_HALF) { - save_tensor((half *)m->low_rank_activation, - rank * num_tokens, - filename.c_str()); - } else { - assert(false); - } - } m->decoding_step++; } } @@ -819,6 +598,8 @@ void lora_inference_debugging(LoraLinearMeta *m, GenericTensorAccessorW input_grad, GenericTensorAccessorR output_grad, int shard_id) { + int in_dim = input_grad.domain.hi()[0] - input_grad.domain.lo()[0] + 1; + int out_dim = output_grad.domain.hi()[0] - output_grad.domain.lo()[0] + 1; // get layer name std::string lora_layername = std::string(m->op_name); std::string searchString = "lora"; @@ -852,10 +633,22 @@ void lora_inference_debugging(LoraLinearMeta *m, // weights, weights gradients fs::path dst_filepath_weights = get_dst_folder("weights", m->bwd_step, shard_id) / layername; - assert(m->model_state.size() >= 1 && "Model state empty!"); - for (auto it = m->model_state.begin(); it != m->model_state.end(); ++it) { - PEFTModelID peft_model_id = it->first; - LoraLinearWeight weight = m->model_state[peft_model_id].weights; + + for (int i = 0; i < bc->max_requests_per_batch(); i++) { + if (bc->request_completed[i] || + bc->requestsInfo[i].peft_model_id == PEFTModelID::NO_ID || + !bc->requestsInfo[i].peft_bwd) { + continue; + } + std::string peft_model_config_str = + std::string(bc->requestsInfo[i].peft_model_config_str); + LoraLinearConfig lora_config = + LoraLinearConfig::deserialize_from_json_string(peft_model_config_str); + if (!lora_applies_to_this_layer(m, lora_config)) { + continue; + } + LoraLinearWeight weight = m->peft_memory_manager->get_peft( + bc->requestsInfo[i].peft_model_id, lora_config); std::string filename_weight_A = dst_filepath_weights.string() + ".weight_A.finetuned"; std::string filename_weight_B = @@ -867,36 +660,36 @@ void lora_inference_debugging(LoraLinearMeta *m, if (m->input_type[0] == DT_FLOAT) { // weight A save_tensor((float *)weight.w0_ptr, - weight.rank * weight.in_dim, + lora_config.rank * in_dim, filename_weight_A.c_str()); // weight grad A save_tensor((float *)weight.w0_grad_ptr, - weight.rank * weight.in_dim, + lora_config.rank * in_dim, filename_grad_A.c_str()); // weight B save_tensor((float *)weight.w1_ptr, - weight.rank * weight.out_dim, + lora_config.rank * out_dim, filename_weight_B.c_str()); // weight grad B save_tensor((float *)weight.w1_grad_ptr, - weight.rank * weight.out_dim, + lora_config.rank * out_dim, filename_grad_B.c_str()); } else if (m->input_type[0] == DT_HALF) { // weight A save_tensor((half *)weight.w0_ptr, - weight.rank * weight.in_dim, + lora_config.rank * in_dim, filename_weight_A.c_str()); // weight grad A save_tensor((half *)weight.w0_grad_ptr, - weight.rank * weight.in_dim, + lora_config.rank * in_dim, filename_grad_A.c_str()); // weight B save_tensor((half *)weight.w1_ptr, - weight.rank * weight.out_dim, + lora_config.rank * out_dim, filename_weight_B.c_str()); // weight grad B save_tensor((half *)weight.w1_grad_ptr, - weight.rank * weight.out_dim, + lora_config.rank * out_dim, filename_grad_B.c_str()); } else { assert(false && "Data type not supported"); @@ -975,62 +768,50 @@ void save_peft_weights_if_needed(LoraLinearMeta *m, } std::string lora_layername_substr = lora_layername.substr(0, found + searchString.length()); + for (int i = 0; i < bc->max_requests_per_batch(); i++) { - if (bc->request_completed[i]) { - continue; - } - // Skip non-PEFT requests - if (bc->requestsInfo[i].peft_model_id == PEFTModelID::NO_ID) { + if (bc->request_completed[i] || + bc->requestsInfo[i].peft_model_id == PEFTModelID::NO_ID || + !bc->requestsInfo[i].peft_bwd) { continue; } - // Skip PEFT forward-only requests - if (!bc->requestsInfo[i].peft_bwd) { + std::string peft_model_config_str = + std::string(bc->requestsInfo[i].peft_model_config_str); + LoraLinearConfig lora_config = + LoraLinearConfig::deserialize_from_json_string(peft_model_config_str); + if (!lora_applies_to_this_layer(m, lora_config)) { continue; } if (bc->requestsInfo[i].optimizer_tasks.save_updated_weights) { - assert(m->model_state.find(bc->requestsInfo[i].peft_model_id) != - m->model_state.end()); std::string weight_export_folder = join_path({ - m->model_state[bc->requestsInfo[i].peft_model_id].cache_folder, + lora_config.cache_folder, "finetuned_models", - m->model_state[bc->requestsInfo[i].peft_model_id].peft_model_id, + lora_config.peft_model_id, "weights", "shard_" + std::to_string(shard_id), }); fs::create_directories(weight_export_folder); - int rank = m->model_state[bc->requestsInfo[i].peft_model_id].weights.rank; + int rank = lora_config.rank; int w0_num_elements = rank * in_dim; int w1_num_elements = rank * out_dim; std::string w0_filepath = join_path( {weight_export_folder, lora_layername_substr + "_A.weight"}); std::string w1_filepath = join_path( {weight_export_folder, lora_layername_substr + "_B.weight"}); + LoraLinearWeight weight = m->peft_memory_manager->get_peft( + bc->requestsInfo[i].peft_model_id, lora_config); if (m->input_type[0] == DT_FLOAT) { - save_peft_to_file( - (float *)m->model_state[bc->requestsInfo[i].peft_model_id] - .weights.w0_ptr, - w0_num_elements, - w0_filepath); + save_peft_to_file((float *)weight.w0_ptr, w0_num_elements, w0_filepath); if (shard_id == 0) { save_peft_to_file( - (float *)m->model_state[bc->requestsInfo[i].peft_model_id] - .weights.w1_ptr, - w1_num_elements, - w1_filepath); + (float *)weight.w1_ptr, w1_num_elements, w1_filepath); } } else if (m->input_type[0] == DT_HALF) { - save_peft_to_file( - (half *)m->model_state[bc->requestsInfo[i].peft_model_id] - .weights.w0_ptr, - w0_num_elements, - w0_filepath); + save_peft_to_file((half *)weight.w0_ptr, w0_num_elements, w0_filepath); if (shard_id == 0) { save_peft_to_file( - (half *)m->model_state[bc->requestsInfo[i].peft_model_id] - .weights.w1_ptr, - w1_num_elements, - w1_filepath); + (half *)weight.w1_ptr, w1_num_elements, w1_filepath); } } else { assert(false && "Data type not supported"); @@ -1065,7 +846,8 @@ void LoraLinear::peft_bwd_task(Task const *task, int out_dim = output_grad.domain.hi()[0] - output_grad.domain.lo()[0] + 1; // int num_infr_tokens = bc->num_active_infr_tokens(); // int num_peft_tokens = bc->num_active_peft_tokens(); - peft_bwd_kernel_wrapper(ctx, runtime, m, bc, input_grad, output_grad); + peft_bwd_kernel_wrapper( + ctx, runtime, m, bc, shard_id, input_grad, output_grad); save_peft_weights_if_needed(m, bc, in_dim, out_dim, shard_id); @@ -1098,14 +880,9 @@ bool LoraLinear::measure_operator_cost(Simulator *sim, } bool operator==(LoraLinearParams const &lhs, LoraLinearParams const &rhs) { - if (lhs.layer_guid == rhs.layer_guid && lhs.type == rhs.type && - lhs.peft_configs.size() == rhs.peft_configs.size()) { - for (auto const &kv : lhs.peft_configs) { - auto it = rhs.peft_configs.find(kv.first); - if (it == rhs.peft_configs.end() || !(it->second == kv.second)) { - return false; - } - } + if (lhs.layer_guid == rhs.layer_guid && lhs.max_rank == rhs.max_rank && + lhs.max_concurrent_adapters == rhs.max_concurrent_adapters && + strcmp(lhs.name, rhs.name) == 0) { return true; } return false; @@ -1144,48 +921,8 @@ void LoraLinear::serialize(Legion::Serializer &sez) const { sez.serialize(this->layer_guid.id); sez.serialize(this->layer_guid.transformer_layer_id); sez.serialize(this->layer_guid.model_id); - sez.serialize(this->op_type); - sez.serialize(this->peft_configs.size()); - for (auto const &kv : this->peft_configs) { - // Serialize PEFTModelID - sez.serialize(kv.first.id); - - // Serialize LoraLinearConfig and OptimizerConfig to tmp folder - // 1. Create tmp dir and serialize it - fs::path unique_temp_dir = create_unique_temp_directory(); - serialize_string(sez, unique_temp_dir.string()); - // 2. Dump LoraLinearConfig to json file in tmp dir - std::string lora_config_filename = std::string("lora_linear_config_") + - std::to_string(kv.first.id) + - std::string(".json"); - fs::path lora_config_json_filepath = unique_temp_dir / lora_config_filename; - serialize_to_json_file(kv.second, lora_config_json_filepath); - // 3. Dump optimizer to json file in tmp dir, and serialize optimizer type - std::string optimizer_filename = std::string("optimizer_config_") + - std::to_string(kv.first.id) + - std::string(".json"); - fs::path optim_config_filepath = unique_temp_dir / optimizer_filename; - assert((kv.second.trainable) == (kv.second.optimizer_config != nullptr)); - if (kv.second.trainable) { - if (typeid(*kv.second.optimizer_config) == - typeid(LoraSGDOptimizerConfig)) { - sez.serialize(OPTIMIZER_TYPE_SGD); - LoraSGDOptimizerConfig const *sgd_config = - static_cast( - kv.second.optimizer_config); - serialize_to_json_file(*sgd_config, optim_config_filepath); - } else if (typeid(*kv.second.optimizer_config) == - typeid(LoraAdamOptimizerConfig)) { - sez.serialize(OPTIMIZER_TYPE_ADAM); - LoraAdamOptimizerConfig const *adam_config = - static_cast( - kv.second.optimizer_config); - serialize_to_json_file(*adam_config, optim_config_filepath); - } else { - assert(false && "Optimizer type not yet supported"); - } - } - } + sez.serialize(this->max_rank); + sez.serialize(this->max_concurrent_adapters); sez.serialize(strlen(this->name)); sez.serialize(this->name, strlen(this->name)); } @@ -1198,8 +935,9 @@ Node LoraLinear::deserialize(FFModel &ff, int num_inputs) { assert(num_inputs == 2); size_t id, transformer_layer_id, deserialized_model_id; - OperatorType op_type; - size_t num_pefts; + int max_rank, max_concurrent_adapters; + // OperatorType op_type; + // size_t num_pefts; size_t name_len; char name[MAX_OPNAME] = {0}; @@ -1208,62 +946,16 @@ Node LoraLinear::deserialize(FFModel &ff, dez.deserialize(id); dez.deserialize(transformer_layer_id); dez.deserialize(deserialized_model_id); - dez.deserialize(op_type); - dez.deserialize(num_pefts); - for (int i = 0; i < num_pefts; i++) { - // Deserialize PEFTModelID - size_t pid; - dez.deserialize(pid); - PEFTModelID peft_model_id(pid); - // Deserialize tmp folder containing LoraLinearConfig and optimizer config - fs::path unique_temp_dir = fs::path(deserialize_string(dez)); - // 1. Deserialize LoraLinearConfig - std::string lora_config_filename = std::string("lora_linear_config_") + - std::to_string(pid) + - std::string(".json"); - fs::path lora_config_json_filepath = unique_temp_dir / lora_config_filename; - std::unique_ptr lora_linear_config = - deserialize_from_json_file(lora_config_json_filepath); - // 2. Deserialize optimizer if needed - if (lora_linear_config->trainable) { - std::string optimizer_filename = std::string("optimizer_config_") + - std::to_string(pid) + - std::string(".json"); - fs::path optim_config_filepath = unique_temp_dir / optimizer_filename; - OptimizerType type_; - dez.deserialize(type_); - if (type_ == OPTIMIZER_TYPE_SGD) { - std::unique_ptr sgd_optimizer_config = - deserialize_from_json_file( - optim_config_filepath); - lora_linear_config->optimizer_config = - dynamic_cast(sgd_optimizer_config.release()); - } else if (type_ == OPTIMIZER_TYPE_ADAM) { - std::unique_ptr adam_optimizer_config = - deserialize_from_json_file( - optim_config_filepath); - lora_linear_config->optimizer_config = - dynamic_cast( - adam_optimizer_config.release()); - } else { - printf("Optimizer type: %d\n", type_); - assert(false && "Optimizer type not yet supported"); - } - } - try { - fs::remove_all(unique_temp_dir); - } catch (fs::filesystem_error const &e) { - std::cerr << "Error removing tmp directory: " << e.what() << std::endl; - } - params.peft_configs.emplace( - std::make_pair(peft_model_id, *lora_linear_config)); - } + dez.deserialize(max_rank); + dez.deserialize(max_concurrent_adapters); dez.deserialize(name_len); dez.deserialize(name, name_len); LayerID layer_guid(id, transformer_layer_id, deserialized_model_id); params.layer_guid = layer_guid; - params.type = op_type; + // params.type = op_type; + params.max_rank = max_rank; + params.max_concurrent_adapters = max_concurrent_adapters; strcpy(params.name, name); return ff.get_or_create_node({inputs[0], inputs[1]}, params); } @@ -1278,11 +970,13 @@ Op *LoraLinear::materialize(FFModel &ff, LoraLinearParams LoraLinear::get_params() const { LoraLinearParams params; params.layer_guid = this->layer_guid; - params.type = this->op_type; + params.max_rank = this->max_rank; + params.max_concurrent_adapters = this->max_concurrent_adapters; + // params.type = this->op_type; if (strlen(this->name) < MAX_OPNAME) { strcpy(params.name, this->name); } - params.peft_configs = this->peft_configs; + // params.peft_configs = this->peft_configs; return params; } @@ -1301,17 +995,8 @@ size_t hash::operator()( hash_combine(key, params.layer_guid.id); hash_combine(key, params.layer_guid.transformer_layer_id); hash_combine(key, params.layer_guid.model_id); - for (auto const &kv : params.peft_configs) { - hash_combine(key, kv.first.id); - hash_combine(key, kv.second.rank); - hash_combine(key, kv.second.trainable); - hash_combine(key, kv.second.cache_folder); - hash_combine(key, kv.second.peft_model_id); - hash_combine(key, kv.second.lora_alpha); - hash_combine(key, kv.second.lora_dropout); - hash_combine(key, kv.second.target_modules); - hash_combine(key, kv.second.init_lora_weights); - } + hash_combine(key, params.max_rank); + hash_combine(key, params.max_concurrent_adapters); return key; } }; // namespace std diff --git a/src/ops/lora_linear_params.cc b/src/ops/lora_linear_params.cc index 6e0c60e057..69c0081ec9 100644 --- a/src/ops/lora_linear_params.cc +++ b/src/ops/lora_linear_params.cc @@ -12,6 +12,17 @@ namespace FlexFlow { // empty optimizer LoraOptimizerConfig::LoraOptimizerConfig() {} +LoraOptimizerConfig *LoraOptimizerConfig::fromJson(nlohmann::json const &j) { + std::string type = j["type"]; + if (type == "SGD") { + return LoraSGDOptimizerConfig::fromJson(j); + } + if (type == "Adam") { + return LoraAdamOptimizerConfig::fromJson(j); + } + throw std::runtime_error("Unknown optimizer type"); +} + // SGD optimizer LoraSGDOptimizerConfig::LoraSGDOptimizerConfig() : lr(0.001f), momentum(0.0f), nesterov(false), weight_decay(0.0f) {} @@ -30,6 +41,24 @@ std::ostream &operator<<(std::ostream &os, LoraSGDOptimizerConfig const &llc) { return os; } +nlohmann::json LoraSGDOptimizerConfig::toJson() const { + return {{"type", "SGD"}, + {"lr", lr}, + {"momentum", momentum}, + {"nesterov", nesterov}, + {"weight_decay", weight_decay}}; +} + +LoraSGDOptimizerConfig * + LoraSGDOptimizerConfig::fromJson(nlohmann::json const &j) { + LoraSGDOptimizerConfig *sgd = new LoraSGDOptimizerConfig(); + sgd->lr = j["lr"]; + sgd->momentum = j["momentum"]; + sgd->nesterov = j["nesterov"]; + sgd->weight_decay = j["weight_decay"]; + return sgd; +} + // Adam optimizer LoraAdamOptimizerConfig::LoraAdamOptimizerConfig() : alpha(0.001f), beta1(0.9f), beta2(0.999f), weight_decay(0.0f), @@ -50,38 +79,26 @@ std::ostream &operator<<(std::ostream &os, LoraAdamOptimizerConfig const &llc) { return os; } -// Serialization helpers -template -void serialize_to_json_file(T const &obj, fs::path const &filepath) { - json j = obj; - std::ofstream file(filepath); - file << j.dump(4); +nlohmann::json LoraAdamOptimizerConfig::toJson() const { + return {{"type", "Adam"}, + {"alpha", alpha}, + {"beta1", beta1}, + {"beta2", beta2}, + {"weight_decay", weight_decay}, + {"epsilon", epsilon}}; } -template -std::unique_ptr deserialize_from_json_file(fs::path const &filepath) { - std::ifstream file(filepath); - json j; - file >> j; - return std::make_unique(j.get()); +LoraAdamOptimizerConfig * + LoraAdamOptimizerConfig::fromJson(nlohmann::json const &j) { + LoraAdamOptimizerConfig *adam = new LoraAdamOptimizerConfig(); + adam->alpha = j["alpha"]; + adam->beta1 = j["beta1"]; + adam->beta2 = j["beta2"]; + adam->weight_decay = j["weight_decay"]; + adam->epsilon = j["epsilon"]; + return adam; } -template void - serialize_to_json_file(LoraLinearConfig const &obj, - fs::path const &filepath); -template void serialize_to_json_file( - LoraSGDOptimizerConfig const &obj, fs::path const &filepath); -template void serialize_to_json_file( - LoraAdamOptimizerConfig const &obj, fs::path const &filepath); -template std::unique_ptr - deserialize_from_json_file(fs::path const &filepath); -template std::unique_ptr - deserialize_from_json_file( - fs::path const &filepath); -template std::unique_ptr - deserialize_from_json_file( - fs::path const &filepath); - // ------------------ LoRA configs ------------------- // --------------------------------------------------- const LoraLinearConfig LoraLinearConfig::EmptyConfig = LoraLinearConfig("", ""); @@ -218,4 +235,76 @@ std::ostream &operator<<(std::ostream &os, LoraLinearConfig const &llc) { return os; } -}; // namespace FlexFlow +double ToThreeDecimalPlaces(float f) { + double d = static_cast(f); + int i; + if (d >= 0) { + i = static_cast(d * 1000 + 0.5); + } else { + i = static_cast(d * 1000 - 0.5); + } + return (i / 1000.0); +} + +std::string LoraLinearConfig::serialize_to_json_string(int indent) const { + nlohmann::json j = {{"cache_folder", cache_folder}, + {"peft_model_id", peft_model_id}, + {"rank", rank}, + {"lora_alpha", ToThreeDecimalPlaces(lora_alpha)}, + {"lora_dropout", ToThreeDecimalPlaces(lora_dropout)}, + {"target_modules", target_modules}, + {"trainable", trainable}, + {"init_lora_weights", init_lora_weights}, + {"base_model_name_or_path", base_model_name_or_path}, + {"precision", precision}, + {"optimizer_config", + optimizer_config + ? nlohmann::json(optimizer_config->toJson()) + : nlohmann::json()}}; + + return j.dump(indent); // No indentation +} + +void LoraLinearConfig::serialize_to_json_file( + std::string const &filename) const { + std::string j = serialize_to_json_string(4); + std::ofstream file(filename); + file << j; +} + +// Deserialization method +LoraLinearConfig LoraLinearConfig::deserialize_from_json_string( + std::string const &json_string) { + // std::cout << "Attempting to deserialize from JSON string: " << json_string + // << std::endl; + nlohmann::json j = nlohmann::json::parse(json_string); + LoraOptimizerConfig *optimizer_config_ = nullptr; + if (!j["optimizer_config"].is_null()) { + optimizer_config_ = LoraOptimizerConfig::fromJson(j["optimizer_config"]); + } + LoraLinearConfig config = LoraLinearConfig::EmptyConfig; + config.cache_folder = j["cache_folder"].get(); + config.peft_model_id = j["peft_model_id"].get(); + config.rank = j["rank"].get(); + config.lora_alpha = j["lora_alpha"].get(); + config.lora_dropout = j["lora_dropout"].get(); + config.target_modules = j["target_modules"].get>(); + config.trainable = j["trainable"].get(); + config.init_lora_weights = j["init_lora_weights"].get(); + config.base_model_name_or_path = + j["base_model_name_or_path"].get(); + config.precision = j["precision"].get(); + config.optimizer_config = optimizer_config_; + return config; +} + +// Deserialization method +LoraLinearConfig + LoraLinearConfig::deserialize_from_json_file(std::string const &filename) { + std::ifstream file(filename); + std::string j; + file >> j; + return deserialize_from_json_string(j); +} + +}; // namespace FlexFlow \ No newline at end of file diff --git a/src/runtime/fftype.cc b/src/runtime/fftype.cc index 8213726e8a..31937cef66 100644 --- a/src/runtime/fftype.cc +++ b/src/runtime/fftype.cc @@ -46,6 +46,10 @@ bool operator==(PEFTModelID const &lhs, PEFTModelID const &rhs) { return lhs.id == rhs.id; } +bool operator!=(PEFTModelID const &lhs, PEFTModelID const &rhs) { + return !(lhs == rhs); +} + std::ostream &operator<<(std::ostream &os, PEFTModelID const &peft_model_id) { if (peft_model_id == PEFTModelID::NO_ID) { os << "NO_ID"; diff --git a/src/runtime/file_loader.cc b/src/runtime/file_loader.cc index e73893475c..3ebe6cf095 100644 --- a/src/runtime/file_loader.cc +++ b/src/runtime/file_loader.cc @@ -16,6 +16,7 @@ #include "flexflow/utils/file_loader.h" #include "flexflow/ffconst_utils.h" #include "flexflow/inference.h" +#include "flexflow/model.h" #include using namespace std; @@ -851,35 +852,70 @@ void FileDataLoader::load_single_weight_tensor(FFModel *ff, delete data; } -void FileDataLoader::load_weights(FFModel *ff) { +void FileDataLoader::load_weight_task( + Legion::Task const *task, + std::vector const ®ions, + Legion::Context ctx, + Legion::Runtime *runtime) { + WeightLoadTaskArgs const *args = (WeightLoadTaskArgs const *)task->args; + + switch (args->data_type) { + case DT_HALF: { + args->loader->load_single_weight_tensor( + args->ff, args->layer, args->weight_idx); + break; + } + case DT_FLOAT: { + args->loader->load_single_weight_tensor( + args->ff, args->layer, args->weight_idx); + break; + } + case DT_INT4: + case DT_INT8: { + args->loader->load_quantization_weight( + args->ff, args->layer, args->weight_idx); + break; + } + default: + assert(false && "Unsupported data type"); + } +} + +void FileDataLoader::load_weights_parallel(FFModel *ff, + Context ctx, + Runtime *runtime) { + std::vector futures; + for (Layer *l : ff->layers) { if (l->numWeights < 1 || l->name == NULL || strlen(l->name) < 1) { continue; } + for (int i = 0; i < l->numWeights; i++) { Tensor weight = l->weights[i]; if (weight == NULL) { continue; } - // TODO: currently skip Lora layers + if (l->op_type == OP_LORA) { continue; } - switch (weight->data_type) { - case DT_HALF: - load_single_weight_tensor(ff, l, i); - break; - case DT_FLOAT: - load_single_weight_tensor(ff, l, i); - break; - case DT_INT4: - case DT_INT8: - // load weights in quantization - load_quantization_weight(ff, l, i); - break; - default: - assert(false && "Unsupported data type"); + + if (weight->data_type != DT_FLOAT && weight->data_type != DT_HALF && + weight->data_type != DT_INT4 && weight->data_type != DT_INT8) { + assert(false && "Unsupported data type"); } + + // Create task arguments + WeightLoadTaskArgs args(ff, this, l, i, weight->data_type); + TaskLauncher launcher(LOAD_WEIGHT_TASK_ID, + TaskArgument(&args, sizeof(WeightLoadTaskArgs))); + futures.push_back(runtime->execute_task(ctx, launcher)); } } + + // Wait for all tasks to complete + for (Future &f : futures) { + f.get_void_result(); + } } diff --git a/src/runtime/inference_manager.cc b/src/runtime/inference_manager.cc index f39ea91f28..45b6ba0db8 100644 --- a/src/runtime/inference_manager.cc +++ b/src/runtime/inference_manager.cc @@ -273,7 +273,9 @@ void InferenceManager::compile_model_and_allocate_buffer(FFModel *model) { } reset_inputs.insert(op->inputs[i]->region); } else { - reset_inputs.insert(op->inputs[i]->region); + if (op->op_type != OP_LORA) { + reset_inputs.insert(op->inputs[i]->region); + } } } } diff --git a/src/runtime/model.cc b/src/runtime/model.cc index 417cd2c056..2a95caf6cb 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -1550,8 +1550,6 @@ FFRuntime::FFRuntime(FFConfig &config) { config.cpu_offload ? config.offload_reserve_space_size : 0; info.peft_activation_reserve_space_size = config.enable_peft ? config.peft_activation_reserve_space_size : 0; - info.peft_weight_reserve_space_size = - config.enable_peft ? config.peft_weight_reserve_space_size : 0; info.quantization_type = config.quantization_type; info.allowTensorOpMathConversion = config.allow_tensor_op_math_conversion; argmap.set_point(*it, TaskArgument(&info, sizeof(FFInitInfo))); @@ -3423,62 +3421,29 @@ bool FFModel::need_to_add_combine(int layer_idx) const { bool FFModel::need_to_add_allreduce(int layer_idx) const { auto const &l = layers[layer_idx]; if (config.computationMode == COMP_MODE_INFERENCE && - config.tensor_parallelism_degree > 1 && - ( - // l->op_type == OP_INC_MULTIHEAD_SELF_ATTENTION || - // l->op_type == OP_TREE_INC_MULTIHEAD_SELF_ATTENTION || - (std::string(l->name).find("attn.o_proj") != std::string::npos) || - // mlp layer - is_mlp_block(layer_idx) || - // llama mlp layer - (l->op_type == OP_LINEAR && layer_idx >= 2 && - layers[layer_idx - 1]->op_type == OP_GELU && - layers[layer_idx - 2]->op_type == OP_LINEAR) || - // LLAMA without element-wise operator fusion - (l->op_type == OP_LINEAR && layer_idx >= 5 && - layers[layer_idx - 1]->op_type == OP_EW_MUL && - layers[layer_idx - 2]->op_type == OP_EW_MUL && - layers[layer_idx - 3]->op_type == OP_SIGMOID && - layers[layer_idx - 4]->op_type == OP_LINEAR && - layers[layer_idx - 5]->op_type == OP_LINEAR) || - // LLAMA with element-wise operator fusion - (l->op_type == OP_LINEAR && layer_idx >= 3 && - layers[layer_idx - 1]->op_type == OP_SIGMOID_SILU_MULTI && - layers[layer_idx - 2]->op_type == OP_LINEAR && - layers[layer_idx - 3]->op_type == OP_LINEAR))) { + config.tensor_parallelism_degree > 1 && l->op_type == OP_LINEAR && + (/*llama/mpt attention*/ + (std::string(l->name).find("attn.o_proj") != std::string::npos) || + /*opt/starcoder attention*/ + (std::string(l->name).find("self_attn.o_proj") != std::string::npos) || + /*falcon attention*/ + (std::string(l->name).find("self_attention.o_proj") != + std::string::npos) || + /*llama mlp*/ + (std::string(l->name).find("mlp.down_proj") != std::string::npos) || + /*opt mlp*/ + (std::string(l->name).find("fc2") != std::string::npos) || + /*falcon mlp*/ + (std::string(l->name).find("mlp.dense_4h_to_h") != std::string::npos) || + /*mpt mlp*/ + (std::string(l->name).find("ffn.down_proj") != std::string::npos) || + /*starcoder mlp*/ + (std::string(l->name).find("mlp.c_proj") != std::string::npos))) { return true; } return false; } -#ifdef DEADCODE -bool FFModel::need_to_add_parallel_identity(int layer_idx) const { - auto const &l = layers[layer_idx]; - // add parallel identity (allreduce in the backward pass) before the lm head - // we find the lm head by looking for the linear layer right after a residual - // rms norm / layer norm, and before a softmax, followed by - // argmax/argtopk/sampling - if (config.computationMode == COMP_MODE_INFERENCE && - config.tensor_parallelism_degree > 1 && - ((l->op_type == OP_RESIDUAL_RMS_NORM || - l->op_type == OP_RESIDUAL_LAYERNORM) && - // there are at least 2 layers before the norm, and at least 3 following - // the norm - layer_idx >= 2 && layer_idx < layers.size() - 3 && - // norm is followed by linear layer (lm head) - layers[layer_idx + 1]->op_type == OP_LINEAR && - // lm head is followed by softmax - layers[layer_idx + 2]->op_type == OP_SOFTMAX && - // softmax is followed by argmax/argtopk/sampling - (layers[layer_idx + 3]->op_type == OP_ARG_TOPK || - layers[layer_idx + 3]->op_type == OP_SAMPLING || - layers[layer_idx + 3]->op_type == OP_ARGMAX || - layers[layer_idx + 3]->op_type == OP_SCALAR_TRUE_DIV))) { - return true; - } - return false; -} -#endif bool FFModel::need_to_add_parallel_identity(int layer_idx) const { auto const &l = layers[layer_idx]; // add parallel identity (allreduce in the backward pass) before the lm head @@ -4400,7 +4365,6 @@ FFConfig::FFConfig() { enable_peft = DefaultConfig::enablePeft; peft_activation_reserve_space_size = DefaultConfig::peftActivationReserveSpaceSize; - peft_weight_reserve_space_size = DefaultConfig::peftWeightReserveSpaceSize; quantization_type = DT_NONE; only_data_parallel = DefaultConfig::onlyDataParallel; data_parallelism_degree = 1; @@ -4535,10 +4499,6 @@ void FFConfig::parse_args(char **argv, int argc) { peft_activation_reserve_space_size = atoll(argv[++i]) * 1024 * 1024; continue; } - if (!strcmp(argv[i], "-peft-weight-reserve-space-size")) { - peft_weight_reserve_space_size = atoll(argv[++i]) * 1024 * 1024; - continue; - } if ((!strcmp(argv[i], "--only-data-parallel"))) { only_data_parallel = true; continue; @@ -4852,6 +4812,20 @@ void register_flexflow_internal_tasks(Runtime *runtime, registrar); } } + { + TaskVariantRegistrar registrar(LOAD_WEIGHT_TASK_ID, "load_weight_task"); + registrar.add_constraint(ProcessorConstraint(Processor::LOC_PROC)); + if (pre_register) { + Runtime::preregister_task_variant( + registrar, "load_weight_task"); + } else { + if (enable_control_replication) { + registrar.global_registration = false; + } + runtime->register_task_variant( + registrar); + } + } #endif // ElementUnary task { diff --git a/src/runtime/model.cu b/src/runtime/model.cu index 5dab73e1a4..3a250539c7 100644 --- a/src/runtime/model.cu +++ b/src/runtime/model.cu @@ -168,7 +168,7 @@ FFHandler } else { handle.batch_config_metadata = nullptr; } - + // #ifdef DEADCODE if (info->peft_activation_reserve_space_size > 0) { // allocate memory for peft activation reserve space Memory gpu_mem = Machine::MemoryQuery(Machine::get_machine()) @@ -182,33 +182,8 @@ FFHandler } else { handle.peft_activation_allocator = nullptr; } - - if (info->peft_weight_reserve_space_size > 0) { - // allocate memory for peft weight reserve space - Memory gpu_mem = Machine::MemoryQuery(Machine::get_machine()) - .only_kind(Memory::GPU_FB_MEM) - .best_affinity_to(task->target_proc) - .first(); - Realm::Rect<1, coord_t> bounds( - Realm::Point<1, coord_t>(0), - Realm::Point<1, coord_t>(info->peft_weight_reserve_space_size - 1)); - std::vector field_sizes; - field_sizes.push_back(sizeof(char)); - Realm::RegionInstance workspaceInst; - Realm::RegionInstance::create_instance(workspaceInst, - gpu_mem, - bounds, - field_sizes, - 0, - Realm::ProfilingRequestSet()) - .wait(); - void *ptr = workspaceInst.pointer_untyped(0, sizeof(char)); - handle.peft_weight_allocator = - new PEFTWeightAllocator(ptr, info->peft_weight_reserve_space_size); - } else { - handle.peft_weight_allocator = nullptr; - } - // checkCUDA(cudaMalloc(&handle.workSpace, handle.workSpaceSize)); +// #endif +// checkCUDA(cudaMalloc(&handle.workSpace, handle.workSpaceSize)); #ifdef FF_USE_NCCL handle.ncclComm = NULL; #endif diff --git a/src/runtime/peft_weight_allocator.cc b/src/runtime/peft_weight_allocator.cc new file mode 100644 index 0000000000..1fcef3678e --- /dev/null +++ b/src/runtime/peft_weight_allocator.cc @@ -0,0 +1,319 @@ +#include "flexflow/utils/peft_weight_allocator.h" + +namespace FlexFlow { +// declare legion names +using Legion::ArgumentMap; +using Legion::Context; +using Legion::coord_t; +using Legion::Domain; +using Legion::FutureMap; +using Legion::IndexLauncher; +using Legion::InlineLauncher; +using Legion::Machine; +using Legion::Memory; +using Legion::PhysicalRegion; +using Legion::Predicate; +using Legion::Rect; +using Legion::RegionRequirement; +using Legion::Runtime; +using Legion::Task; +using Legion::TaskArgument; +using Legion::TaskLauncher; + +void PEFTMemoryManager::allocate_inference_memory() { + // allocate chunk of memory for all the PEFT adapters + Realm::Rect<1, coord_t> bounds( + Realm::Point<1, coord_t>(0), + Realm::Point<1, coord_t>(max_lora_size * max_concurrent_adapters - 1)); + std::vector field_sizes; + field_sizes.push_back(sizeof(char)); + Realm::RegionInstance::create_instance(peftLegionInst, + gpu_mem, + bounds, + field_sizes, + 0, + Realm::ProfilingRequestSet()) + .wait(); + base_ptr = peftLegionInst.pointer_untyped(0, sizeof(char)); +} + +void PEFTMemoryManager::allocate_finetuning_memory() { + size_t ft_size = max_lora_size * 3; // weights, gradients, momentum values + ft_size += max_peft_tokens * (in_dim + max_rank) * + data_type_size(dt); // input, low-rank activations + // allocate chunk of memory for PEFT adapter + Realm::Rect<1, coord_t> bounds(Realm::Point<1, coord_t>(0), + Realm::Point<1, coord_t>(ft_size - 1)); + std::vector field_sizes; + field_sizes.push_back(sizeof(char)); + Realm::RegionInstance::create_instance(peftLegionInst, + gpu_mem, + bounds, + field_sizes, + 0, + Realm::ProfilingRequestSet()) + .wait(); + finetuning_ptr = peftLegionInst.pointer_untyped(0, sizeof(char)); +} + +void PEFTMemoryManager::get_finetuning_slot(PEFTModelID const &model_id, + bool *cache_miss) { + if (finetuning_ptr == nullptr) { + allocate_finetuning_memory(); + } + assert(finetuning_ptr != nullptr && + "PEFT Memory Manager finetuning_ptr is null"); + *cache_miss = (model_id.id != finetuning_model_id.id); + finetuning_model_id = model_id; +} + +int PEFTMemoryManager::get_inference_peft_slot(PEFTModelID const &model_id, + bool *cache_miss) { + assert(base_ptr != nullptr && "PEFT Memory Manager not initialized"); + assert(lru_hashtable.size() == lru_list.size() && + lru_list.size() == peft2mem_slot.size() && + "PEFT Memory Manager LRU hashtable/list and/or peft2mem_slot are out " + "of sync"); + // check for cache hit + if (lru_hashtable.find(model_id) != lru_hashtable.end()) { + int lru_list_index = lru_hashtable[model_id]; + assert(lru_list[lru_list_index] == model_id && + "PEFT Memory Manager LRU hashtable/list are out of sync"); + // move the model to the end of the LRU list + lru_list.erase(lru_list.begin() + lru_list_index); + lru_list.push_back(model_id); + // update the LRU hashtable + lru_hashtable[model_id] = lru_list.size() - 1; + // get memory slot + assert(peft2mem_slot.find(model_id) != peft2mem_slot.end() && + "PEFT Memory Manager peft2mem_slot is out of sync"); + *cache_miss = false; + } else { + // cache miss + // check if you need to evict + bool need_to_evict = lru_list.size() == max_concurrent_adapters; + int mem_slot = -1; + if (need_to_evict) { + // evict the least recently used model + PEFTModelID lru_model_id = lru_list[0]; + lru_list.erase(lru_list.begin()); + lru_hashtable.erase(lru_model_id); + mem_slot = peft2mem_slot[lru_model_id]; + peft2mem_slot.erase(lru_model_id); + } else { + mem_slot = lru_list.size(); + } + // update the LRU list and hashtable + lru_list.push_back(model_id); + lru_hashtable[model_id] = lru_list.size() - 1; + // update the memory slot + peft2mem_slot[model_id] = mem_slot; + *cache_miss = true; + } + assert(peft2mem_slot.find(model_id) != peft2mem_slot.end() && + "PEFT Memory Manager peft2mem_slot is out of sync"); + int slot = peft2mem_slot[model_id]; + assert(slot >= 0 && slot < max_concurrent_adapters && + "PEFT Memory Manager peft2mem_slot is out of bounds"); + return slot; +} + +template +void load_peft_from_file(DT *ptr, + size_t num_rows, + size_t num_columns, + int num_shards, + int shard_id, + std::string filepath) { + std::ifstream in(filepath, std::ios::in | std::ios::binary); + if (!in.good()) { + printf("Could not open file: %s\n", filepath.c_str()); + } + assert(in.good() && "incorrect weight file path"); + + // HuggingFace dims (serialized in row-major order) + // lora_A: [rank, intermediate_dim] + // lora_B: [hidden_dim, rank] + // FlexFlow dims (serialized in column-major order) + // lora_A: [intermediate_dim, rank] + // lora_B: [rank, out_dim] + // Tensor parallelism: shard lora_A along intermediate_dim, replicate lora_B + assert(num_rows % num_shards == 0); + size_t chunk_size = num_rows / num_shards; + size_t offset = (num_shards > 1) ? shard_id * chunk_size : 0; + + // Allocate memory for the weight shard + std::vector
host_array(chunk_size * num_columns); + // Read the chunk + size_t total_size_read = 0; + for (int i = 0; i < num_columns; ++i) { + in.seekg((i * num_rows + offset) * sizeof(DT)); + in.read(reinterpret_cast(host_array.data() + i * chunk_size), + chunk_size * sizeof(DT)); + total_size_read += in.gcount(); + } + // Check weight shard size + size_t expected_data_size = chunk_size * num_columns * sizeof(DT); + if (total_size_read != expected_data_size) { + printf("load weight data error: expected %lu bytes, got: %lu bytes, data " + "size: %lu\n", + expected_data_size, + total_size_read, + sizeof(DT)); + assert(false); + } + assert(host_array.size() == chunk_size * num_columns); + // Copy weight to device memory + copy_tensor_host_to_dev(ptr, host_array.data(), chunk_size * num_columns); + in.close(); +} + +void PEFTMemoryManager::load_peft_model(LoraLinearWeight &weight, + LoraLinearConfig const &lora_config) { + // Load weights + assert(weight.w0_ptr != nullptr && weight.w1_ptr != nullptr && + "PEFT Memory Manager weight ptr null"); + int w0_num_elements = lora_config.rank * in_dim; + int w1_num_elements = lora_config.rank * out_dim; + // values below represent total weight sizes before sharding. Lora B is not + // sharded. + int lora_A_num_rows = in_dim * num_shards; + int lora_A_num_cols = lora_config.rank; + int lora_B_num_rows = lora_config.rank; + int lora_B_num_cols = out_dim; + int lora_A_num_shards = num_shards; + int lora_B_num_shards = 1; + if (lora_config.init_lora_weights) { + // initialize weights randomly + int seed = 0; + init_peft_weight_wrapper( + weight, in_dim, out_dim, lora_config.rank, dt, seed); + } else { + // load weights from file + std::string weights_folder_filepath = join_path({ + lora_config.cache_folder, + "weights", + lora_config.peft_model_id, + dt == DT_FLOAT ? "full-precision" : "half-precision", + }); + std::string w0_filepath = join_path( + {weights_folder_filepath, lora_layername_substr + "_A.weight"}); + std::string w1_filepath = join_path( + {weights_folder_filepath, lora_layername_substr + "_B.weight"}); + if (dt == DT_FLOAT) { + std::cout << "Loading LORA weight " << lora_layername_substr + "_A.weight" + << ", num_rows: " << lora_A_num_rows + << ", num_cols: " << lora_A_num_cols + << ", num_shards: " << lora_A_num_shards + << ", shard_id: " << shard_id << std::endl; + load_peft_from_file((float *)weight.w0_ptr, + lora_A_num_rows, + lora_A_num_cols, + lora_A_num_shards, + shard_id, + w0_filepath); + std::cout << "Loading LORA weight " << lora_layername_substr + "_B.weight" + << ", num_rows: " << lora_B_num_rows + << ", num_cols: " << lora_B_num_cols + << ", num_shards: " << lora_B_num_shards + << ", shard_id: " << shard_id << std::endl; + load_peft_from_file((float *)weight.w1_ptr, + lora_B_num_rows, + lora_B_num_cols, + lora_B_num_shards, + shard_id, + w1_filepath); + } else if (dt == DT_HALF) { + std::cout << "Loading LORA weight " << lora_layername_substr + "_A.weight" + << ", num_rows: " << lora_A_num_rows + << ", num_cols: " << lora_A_num_cols + << ", num_shards: " << lora_A_num_shards + << ", shard_id: " << shard_id << std::endl; + load_peft_from_file((half *)weight.w0_ptr, + lora_A_num_rows, + lora_A_num_cols, + lora_A_num_shards, + shard_id, + w0_filepath); + std::cout << "Loading LORA weight " << lora_layername_substr + "_B.weight" + << ", num_rows: " << lora_B_num_rows + << ", num_cols: " << lora_B_num_cols + << ", num_shards: " << lora_B_num_shards + << ", shard_id: " << shard_id << std::endl; + load_peft_from_file((half *)weight.w1_ptr, + lora_B_num_rows, + lora_B_num_cols, + lora_B_num_shards, + shard_id, + w1_filepath); + } else { + assert(false && "Data type not supported"); + } + } +} + +LoraLinearWeight + PEFTMemoryManager::get_inference_peft(PEFTModelID const &model_id, + LoraLinearConfig const &lora_config) { + assert(model_id != PEFTModelID::NO_ID && "PEFT Model ID is not set"); + bool cache_miss; + int mem_slot = get_inference_peft_slot(model_id, &cache_miss); + int w0_num_elements = lora_config.rank * in_dim; + int data_size = data_type_size(dt); + LoraLinearWeight result; + result.w0_ptr = static_cast(base_ptr) + mem_slot * max_lora_size; + result.w1_ptr = + static_cast(result.w0_ptr) + w0_num_elements * data_size; + if (cache_miss) { + load_peft_model(result, lora_config); + } + return result; +} + +LoraLinearWeight PEFTMemoryManager::get_finetuning_peft( + PEFTModelID const &model_id, LoraLinearConfig const &lora_config) { + assert(model_id != PEFTModelID::NO_ID && "PEFT Model ID is not set"); + bool cache_miss; + get_finetuning_slot(model_id, &cache_miss); + int w0_num_elements = lora_config.rank * in_dim; + int w1_num_elements = lora_config.rank * out_dim; + int data_size = data_type_size(dt); + LoraLinearWeight result; + result.w0_ptr = finetuning_ptr; + result.w1_ptr = + static_cast(result.w0_ptr) + w0_num_elements * data_size; + result.w0_grad_ptr = + static_cast(result.w1_ptr) + w1_num_elements * data_size; + result.w1_grad_ptr = + static_cast(result.w0_grad_ptr) + w0_num_elements * data_size; + result.w0_v_values_ptr = + static_cast(result.w1_grad_ptr) + w1_num_elements * data_size; + result.w1_v_values_ptr = + static_cast(result.w0_v_values_ptr) + w0_num_elements * data_size; + result.input_activation = + static_cast(result.w1_v_values_ptr) + + w1_num_elements * data_size; // max_peft_tokens*in_dim + result.low_rank_activation = + static_cast(result.input_activation) + + max_peft_tokens * in_dim * data_size; // max_peft_tokens*rank + if (cache_miss) { + load_peft_model(result, lora_config); + } + return result; +} + +LoraLinearWeight + PEFTMemoryManager::get_peft(PEFTModelID const &model_id, + LoraLinearConfig const &lora_config) { + if (lora_config.trainable) { + return get_finetuning_peft(model_id, lora_config); + } else { + return get_inference_peft(model_id, lora_config); + } +} + +void PEFTMemoryManager::check_ft_model_id(PEFTModelID const &model_id) { + assert(finetuning_model_id == model_id && "PEFT bwd model is not in memory!"); +} + +}; // namespace FlexFlow \ No newline at end of file diff --git a/src/runtime/peft_weight_allocator.cu b/src/runtime/peft_weight_allocator.cu new file mode 100644 index 0000000000..3c4ea91db3 --- /dev/null +++ b/src/runtime/peft_weight_allocator.cu @@ -0,0 +1,80 @@ + + +#include "flexflow/ops/kernels/decompress_kernels.h" +#include "flexflow/utils/cuda_helper.h" +#include "flexflow/utils/peft_weight_allocator.h" +#include +#include +namespace FlexFlow { + +template +void lora_init_kernel(LoraLinearWeight const &weight, + int in_dim, + int out_dim, + int rank, + int seed, + cudaStream_t stream) { + // Initialize generator + std::mt19937 gen(seed); + + // Get handle to weights by iterating over m->model_state to get each + // LoraLinearWeight object + int w0_num_elements = rank * in_dim; + int w1_num_elements = rank * out_dim; + + // LoRA_A weight: [in_dim, rank] + float stdv_lora_a = 1.0f / sqrt(in_dim); + std::uniform_real_distribution dis_lora_a(-stdv_lora_a, stdv_lora_a); + std::vector
lora_a_random_init(w0_num_elements); + for (auto &num : lora_a_random_init) { + float num_float = dis_lora_a(gen); + if (std::is_same::value) { + num = __float2half(num_float); + } else { + num = num_float; + } + } + checkCUDA(cudaMemcpyAsync(static_cast
(weight.w0_ptr), + lora_a_random_init.data(), + w0_num_elements * sizeof(DT), + cudaMemcpyHostToDevice, + stream)); + + // LoRA_B weight: [rank, out_dim] + float stdv_lora_b = 1.0f / sqrt(rank); + std::uniform_real_distribution dis_lora_b(-stdv_lora_b, stdv_lora_b); + std::vector lora_b_random_init(w1_num_elements); + for (auto &num : lora_b_random_init) { + float num_float = dis_lora_b(gen); + if (std::is_same::value) { + num = __float2half(num_float); + } else { + num = num_float; + } + } + checkCUDA(cudaMemcpyAsync(static_cast
(weight.w1_ptr), + lora_b_random_init.data(), + w1_num_elements * sizeof(DT), + cudaMemcpyHostToDevice, + stream)); +} + +void init_peft_weight_wrapper(LoraLinearWeight const &weight, + int in_dim, + int out_dim, + int rank, + DataType dt, + int seed) { + cudaStream_t stream; + checkCUDA(get_legion_stream(&stream)); + + if (dt == DT_FLOAT) { + lora_init_kernel(weight, in_dim, out_dim, rank, seed, stream); + } else if (dt == DT_HALF) { + lora_init_kernel(weight, in_dim, out_dim, rank, seed, stream); + } else { + assert(false && "Unsupported data type"); + } +} + +} // namespace FlexFlow \ No newline at end of file diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index 193abbb455..fddaae09ce 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -263,6 +263,73 @@ size_t RequestManager::get_num_ssms() { return ssm_models.size(); } +void RequestManager::set_peft_config(PEFTModelID const &peft_model_id, + LoraLinearConfig const &peft_config) { + // check that peft_model_id is not already in use + assert(peft_configs.find(peft_model_id) == peft_configs.end() && + "PEFT model ID already in use"); + // LoraLinearConfig new_config = + // LoraLinearConfig::deserialize_from_json_string( + // peft_config.serialize_to_json_string()); + peft_configs[peft_model_id] = peft_config; +} + +LoraLinearConfig const & + RequestManager::get_peft_config(PEFTModelID const &peft_model_id) { + assert(peft_configs.find(peft_model_id) != peft_configs.end() && + "PEFT model ID not found"); + return peft_configs[peft_model_id]; +} + +void RequestManager::set_max_lora_rank(int max_lora_rank_) { + max_lora_rank = max_lora_rank_; +} + +void RequestManager::set_max_concurrent_adapters(int max_concurrent_adapters_) { + max_concurrent_adapters = max_concurrent_adapters_; +} + +int RequestManager::get_max_lora_rank() { + return max_lora_rank; +} + +int RequestManager::get_max_concurrent_adapters() { + return max_concurrent_adapters; +} + +PEFTModelID * + FFModel::register_peft_adapter(LoraLinearConfig const &peft_config) { + assert(config.enable_peft && + "Cannot add a LoRA layer if PEFT mode is not enabled"); + if (peft_config.target_modules.size() == 0) { + printf("PEFT config does not contain any target module\n"); + std::cout << peft_config << std::endl; + assert(false); + } + std::cout << "Registering PEFT adapter" + << peft_config.serialize_to_json_string() << std::endl; + // go over base_layer_to_peft_layer and check that you can find at least one + // match + for (int i = 0; i < peft_config.target_modules.size(); i++) { + bool found = false; + for (auto const &pair : base_layer_to_peft_layer) { + Layer *base_layer = pair.first; + if (base_layer->name != nullptr && strlen(base_layer->name) > 0 && + std::string(base_layer->name).find(peft_config.target_modules[0]) != + std::string::npos) { + found = true; + break; + } + } + assert(found && "Attempting to add LoRA to a LLM target module that does " + "not exist or does not support LoRA"); + } + PEFTModelID *peft_model_id = new PEFTModelID(peft_model_global_guid++); + RequestManager *rm = RequestManager::get_request_manager(); + rm->set_peft_config(*peft_model_id, peft_config); + return peft_model_id; +} + RequestManager::RequestGuid RequestManager::register_new_request(Request const &request_) { const std::lock_guard lock(request_queue_mutex); @@ -628,6 +695,18 @@ void RequestManager::check_batch(BatchConfig const &old_bc, } } +void RequestManager::add_peft_config_to_request_info( + BatchConfig &bc, int req_idx, LoraLinearConfig const &peft_config) { + std::memset(bc.requestsInfo[req_idx].peft_model_config_str, + 0, + BatchConfig::MAX_PEFT_CONFIG_SIZE); + std::string peft_config_str = peft_config.serialize_to_json_string(); + std::strcpy(bc.requestsInfo[req_idx].peft_model_config_str, + peft_config_str.c_str()); + // std::cout << "Added PEFT config to request info: " + // << bc.requestsInfo[req_idx].peft_model_config_str << std::endl; +} + BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, InferenceResult const &result) { const std::lock_guard lock(request_queue_mutex); @@ -666,6 +745,8 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, int inference_batch_size = BatchConfig::max_requests_per_batch() - (int)enable_peft_finetuning; + int num_concurrent_adapters = 0; + // Step 2: prepare the next batch for existing inference requests BatchConfig new_bc; for (int i = 0; i < inference_batch_size; i++) { @@ -684,6 +765,10 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, assert(processed_tokens < request.tokens.size()); bool request_completed = check_inf_req_completion(old_bc, i); if (request_completed) { + if (is_eos_token(request.tokens.back())) { + // remove the EOS token + request.tokens.pop_back(); + } std::string output = this->tokenizer_->Decode(request.tokens); // Unlike Huggingface, the sentencepiece C++ library automatically // removes the BOS token @@ -760,6 +845,11 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, old_bc.requestsInfo[i].request_guid; new_bc.requestsInfo[i].peft_model_id = old_bc.requestsInfo[i].peft_model_id; + std::strcpy(new_bc.requestsInfo[i].peft_model_config_str, + old_bc.requestsInfo[i].peft_model_config_str); + if (old_bc.requestsInfo[i].peft_model_id != PEFTModelID::NO_ID) { + num_concurrent_adapters += 1; + } new_bc.requestsInfo[i].peft_bwd = old_bc.requestsInfo[i].peft_bwd; new_bc.requestsInfo[i].max_length = old_bc.requestsInfo[i].max_length; num_active_req++; @@ -811,6 +901,9 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, } new_bc.num_generation_tokens = num_generation_tokens; + assert(num_concurrent_adapters <= get_max_concurrent_adapters() && + "Number of concurrent adapters exceeded the limit"); + // Step 3: add new inference requests to the next batch if there is space for (int i = 0; i < inference_batch_size; i++) { if (new_bc.request_completed[i]) { @@ -818,6 +911,14 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, new_bc.num_tokens < get_max_tokens_per_batch()) { Request new_request = pending_infr_request_queue.front(); assert(new_request.req_type == RequestType::REQ_INFERENCE); + + // if the request has peft adapters and we are at capacity, don't add it + // yet + if (new_request.peft_model_id != PEFTModelID::NO_ID && + num_concurrent_adapters == get_max_concurrent_adapters()) { + break; + } + pending_infr_request_queue.pop(); // all_requests[new_request.guid] = new_request; @@ -829,6 +930,10 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, (int)new_request.tokens.size()); new_bc.requestsInfo[i].max_length = new_request.max_length; new_bc.requestsInfo[i].peft_model_id = new_request.peft_model_id; + if (new_request.peft_model_id != PEFTModelID::NO_ID) { + add_peft_config_to_request_info( + new_bc, i, get_peft_config(new_request.peft_model_id)); + } new_bc.requestsInfo[i].peft_bwd = false; new_bc.request_completed[i] = false; new_bc.requestsInfo[i].prompt_phase = true; @@ -983,7 +1088,8 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, int num_peft_label_tokens = request.dataset[dataset_entry].second.size(); assert(num_peft_label_tokens == 0); - if (num_peft_tokens > 0) { + if (num_peft_tokens > 0 && + num_concurrent_adapters < get_max_concurrent_adapters()) { assert(new_bc.request_completed[inference_batch_size]); // request info new_bc.request_completed[inference_batch_size] = false; @@ -995,9 +1101,11 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, num_peft_tokens; new_bc.requestsInfo[inference_batch_size].max_length = request.max_length; new_bc.requestsInfo[inference_batch_size].request_guid = request.guid; + new_bc.requestsInfo[inference_batch_size].peft_bwd = true; new_bc.requestsInfo[inference_batch_size].peft_model_id = request.peft_model_id; - new_bc.requestsInfo[inference_batch_size].peft_bwd = true; + add_peft_config_to_request_info( + new_bc, inference_batch_size, get_peft_config(request.peft_model_id)); set_optimizer_tasks( new_bc.requestsInfo[inference_batch_size].optimizer_tasks, request.max_training_steps, @@ -1015,8 +1123,11 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, new_bc.num_tokens++; new_bc.num_peft_tokens++; } + num_concurrent_adapters += 1; } } + assert(num_concurrent_adapters <= get_max_concurrent_adapters() && + "Number of concurrent adapters exceeded the limit"); return new_bc; } @@ -2914,7 +3025,7 @@ void RequestManager::serve_incr_decoding(FFModel *llm) { assert(im->model_weights_loaders.find(llm) != im->model_weights_loaders.end()); // Load model weights - im->model_weights_loaders[llm]->load_weights(llm); + im->model_weights_loaders[llm]->load_weights_parallel(llm, ctx, runtime); // init operators im->init_operators_inference(llm); // Legion futures for inc_decoding and spec_infer @@ -2976,7 +3087,7 @@ void RequestManager::serve_spec_infer(FFModel *llm) { assert(im->model_weights_loaders.find(llm) != im->model_weights_loaders.end()); // Load model weights - im->model_weights_loaders[llm]->load_weights(llm); + im->model_weights_loaders[llm]->load_weights_parallel(llm, ctx, runtime); // init operators im->init_operators_inference(llm); } @@ -2987,7 +3098,7 @@ void RequestManager::serve_spec_infer(FFModel *llm) { assert(im->model_weights_loaders.find(llm) != im->model_weights_loaders.end()); // Load model weights - im->model_weights_loaders[ssm]->load_weights(ssm); + im->model_weights_loaders[ssm]->load_weights_parallel(ssm, ctx, runtime); // init operators im->init_operators_inference(ssm); } diff --git a/tests/inference/huggingface_inference_simple.py b/tests/inference/huggingface_inference_simple.py new file mode 100644 index 0000000000..f1cf8450b7 --- /dev/null +++ b/tests/inference/huggingface_inference_simple.py @@ -0,0 +1,51 @@ +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + AutoConfig, + GenerationConfig, +) + +model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct" +do_sample = False +max_length = 128 +model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, device_map="auto",) +hf_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) +tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) +generation_config = GenerationConfig.from_pretrained(model_name) +print(generation_config.do_sample) +generation_config.do_sample = do_sample +generation_config.num_beams=1 +generation_config.temperature = None +generation_config.top_p = None + + +def run_text_completion(): + prompt = "Help me plan a 1-week trip to Dubai" + batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) + + generated = model.generate( + batch["input_ids"], + max_new_tokens=max_length, + generation_config=generation_config, + ) + out = tokenizer.decode(generated[0]) + print(out) + +def run_chat_completion(): + messages=[ + {"role": "system", "content": "You are a helpful an honest programming assistant."}, + {"role": "user", "content": "Is Rust better than Python?"}, + ] + tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + batch = tokenizer(tokenized_chat, return_tensors="pt") + + generated = model.generate( + batch["input_ids"], + max_new_tokens=max_length, + generation_config=generation_config, + ) + out = tokenizer.decode(generated[0], skip_special_tokens=True, clean_up_tokenization_spaces=True) + prompt_length = len(tokenizer.decode(batch["input_ids"][0], skip_special_tokens=True, clean_up_tokenization_spaces=True)) + all_text = out[prompt_length:] + print(all_text) +run_chat_completion() \ No newline at end of file diff --git a/tests/inference/huggingface_pipeline.py b/tests/inference/huggingface_pipeline.py new file mode 100644 index 0000000000..95388e0a4b --- /dev/null +++ b/tests/inference/huggingface_pipeline.py @@ -0,0 +1,33 @@ +import transformers +from transformers import GenerationConfig + +model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct" +do_sample = False + +generation_config = GenerationConfig.from_pretrained(model_id) +generation_config.do_sample = do_sample +generation_config.num_beams=1 +# generation_config.max_length = 128 +generation_config.temperature = None +generation_config.top_p = None +print(generation_config) + +pipeline = transformers.pipeline( + "text-generation", + model=model_id, + # model_kwargs={"torch_dtype": torch.bfloat16}, + device_map="auto", +) + +messages=[ + {"role": "system", "content": "You are a helpful an honest programming assistant."}, + {"role": "user", "content": "Is Rust better than Python?"}, + ] + +# messages="Help me plan a 1-week trip to Dubai" +outputs = pipeline( + messages, + max_new_tokens=128, + generation_config=generation_config, +) +print(outputs[0]["generated_text"][-1]['content']) \ No newline at end of file diff --git a/tests/inference/inference_alignment_test.py b/tests/inference/inference_alignment_test.py index 8dab7ff43b..1fe2bfbaae 100644 --- a/tests/inference/inference_alignment_test.py +++ b/tests/inference/inference_alignment_test.py @@ -361,7 +361,7 @@ def compare(hf_tensor, ff_tensor, label="", additional_ff_tensor=None, tolerance hf_tensor = get_hf_tensor(hf_tensor_name, input_comparison) ff_tensor = get_ff_tensor(ff_tensor_name, input_comparison, hf_tensor.shape, tp_type=TPType.REPLICATE)[:,:,-1].squeeze() hf_tensor = hf_tensor.squeeze() - print(hf_tensor.shape, ff_tensor.shape) + # print(hf_tensor.shape, ff_tensor.shape) compare(hf_tensor, ff_tensor, label="LM head input") output_comparison = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0) hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) diff --git a/tests/inference/python_test_configs/generate_configs.py b/tests/inference/python_test_configs/generate_configs.py index 2720304d4f..afb7ffb9a7 100644 --- a/tests/inference/python_test_configs/generate_configs.py +++ b/tests/inference/python_test_configs/generate_configs.py @@ -8,8 +8,8 @@ "memory_per_gpu": 14000, "zero_copy_memory_per_node": 40000, # optional parameters - "num_cpus": 4, - "legion_utility_processors": 4, + "num_cpus": 8, + "legion_utility_processors": 8, "data_parallelism_degree": 1, "tensor_parallelism_degree": 1, "pipeline_parallelism_degree": 4, @@ -19,7 +19,6 @@ "use_8bit_quantization": False, "enable_peft": False, "peft_activation_reserve_space_size": 1024, # 1GB - "peft_weight_reserve_space_size": 1024, # 1GB "profiling": False, "benchmarking": False, "inference_debugging": False, @@ -63,15 +62,14 @@ # starcoder_models = ["bigcode/starcoderbase-7b",] parallelism_settings = [(1, 4), (2, 2), (4, 1)] -# The paths below should be with respect to the folder from which the tests are launched (FF_HOME/tests/inference) -prompt_file = "../../inference/prompt/test.json" -output_folder = "../../inference/output" - # Change working dir to folder storing this script abspath = os.path.abspath(__file__) dname = os.path.dirname(abspath) os.chdir(dname) +prompt_file = os.path.abspath("../../../inference/prompt/test.json") +output_folder = os.path.abspath("../../../inference/output") + # Generate incremental decoding configs all_models = llama_models + opt_models + falcon_models + mpt_models diff --git a/tests/peft/alignment/align_test_utils.py b/tests/peft/alignment/align_test_utils.py index f5ed8ae65b..a8a9be2f3b 100644 --- a/tests/peft/alignment/align_test_utils.py +++ b/tests/peft/alignment/align_test_utils.py @@ -430,7 +430,7 @@ def compare_loaded_tensors(hf_tensor, ff_tensor, tolerance=1e-2): print(f"HF: {hf_tensor}\nFF:{ff_tensor}") print(np.isclose(hf_tensor, ff_tensor, atol=tolerance)) mismatches = np.where(~np.isclose(hf_tensor, ff_tensor, atol=tolerance))[0] - print(mismatches) + # print(mismatches) len_hf_tensor = hf_tensor.flatten().shape[0] assert len(mismatches) <= 0.05 * len_hf_tensor print("Ok!") diff --git a/tests/peft/hf_finetune.py b/tests/peft/hf_finetune.py index a2fc5548ab..8a53ef8c9c 100644 --- a/tests/peft/hf_finetune.py +++ b/tests/peft/hf_finetune.py @@ -77,7 +77,7 @@ def main(): if args.save_peft_tensors: make_debug_dirs() register_peft_hooks(model) - save_model_weights(model, target_modules=["lora", "lm_head", "down_proj"]) + save_model_weights(model, target_modules=["lora", "lm_head", "down_proj", "up_proj"]) # Load fine-tuning dataset data = load_dataset("Abirate/english_quotes") diff --git a/tests/peft/peft_alignment_test.py b/tests/peft/peft_alignment_test.py index cc677cd51a..c4db87c099 100644 --- a/tests/peft/peft_alignment_test.py +++ b/tests/peft/peft_alignment_test.py @@ -17,7 +17,7 @@ def check_bwd_pass(self): def check_step(self, step_idx, learning_rate=0.001): raise NotImplementedError() -class LllamaAlignmentTest(AlignmentTest): +class LlamaAlignmentTest(AlignmentTest): def __init__(self, model_name, tp_degree=1): self.model_name = model_name self.peft_config = PeftConfig.from_pretrained(model_name) @@ -485,12 +485,16 @@ def compare(hf_tensor, ff_tensor, label="", additional_ff_tensor=None, tolerance hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.REPLICATE) compare(hf_tensor, ff_tensor, label=f"W2 {i} gradient output") + down_proj_grad_output_pre = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.REPLICATE, pre=True) + down_proj_grad_output = ff_tensor.clone() + compare_loaded_tensors(down_proj_grad_output, down_proj_grad_output_pre) # LoRA_B hf_tensor_name = f"layers.{i}.mlp.down_proj.lora_B.default" ff_tensor_name = convert_hf_filename_to_ff(hf_tensor_name) output_comparison = TensorComparisonIdxs(hf_tensor_type="output_gradient", ff_tensor_type="output_gradient", hf_tensor_idx=0, ff_tensor_idx=0) hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) + lora_grad_output = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.REPLICATE) ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.REPLICATE) * self.lora_scaling_factor compare(hf_tensor, ff_tensor, label=f"LoRA_B {i} gradient output") @@ -501,6 +505,7 @@ def compare(hf_tensor, ff_tensor, label="", additional_ff_tensor=None, tolerance hf_tensor = get_hf_tensor(hf_tensor_name, input_comparison) ff_tensor = get_ff_tensor(ff_tensor_name, input_comparison, hf_tensor.shape, tp_type=TPType.PARTITION) compare(hf_tensor, ff_tensor, label=f"LoRA_A {i} gradient input") + lora_a_grad_input = ff_tensor.clone() # W2 (down_proj) input hf_tensor_name = f"layers.{i}.mlp.down_proj" @@ -508,7 +513,15 @@ def compare(hf_tensor, ff_tensor, label="", additional_ff_tensor=None, tolerance input_comparison = TensorComparisonIdxs(hf_tensor_type="input_gradient", ff_tensor_type="input_gradient", hf_tensor_idx=0, ff_tensor_idx=0) hf_tensor = get_hf_tensor(hf_tensor_name, input_comparison) ff_tensor = get_ff_tensor(ff_tensor_name, input_comparison, hf_tensor.shape, tp_type=TPType.PARTITION) + down_proj_grad_input_pre = get_ff_tensor(ff_tensor_name, input_comparison, hf_tensor.shape, tp_type=TPType.PARTITION, pre=True) compare(hf_tensor, ff_tensor, label=f"W2 {i} gradient input") + + # down proj output (before/after kernel) should match output of lora_b + compare_loaded_tensors(down_proj_grad_output, lora_grad_output) + # down proj input (before kernel) should match input of lora_a + compare_loaded_tensors(down_proj_grad_input_pre, lora_a_grad_input) + # compare_loaded_tensors(down_proj_grad_input_pre.squeeze(), ff_tensor.squeeze()) + # W2 input (HF) and SigmoidSiluMulti output (FF) hf_w2_input = hf_tensor.clone() @@ -538,11 +551,47 @@ def compare(hf_tensor, ff_tensor, label="", additional_ff_tensor=None, tolerance output_comparison = TensorComparisonIdxs(hf_tensor_type="output_gradient", ff_tensor_type="output_gradient", hf_tensor_idx=0, ff_tensor_idx=0) hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.PARTITION) + # print(f"w3 {i} grad output") + # print("flexflow tensor shape:", ff_tensor.squeeze().shape) + # print(ff_tensor.squeeze()) + # print("huggingface tensor shape:", hf_tensor.squeeze().T.shape) + # print(hf_tensor.squeeze().T) compare(hf_tensor, ff_tensor, label=f"W3 {i} gradient output") + # print(f"W3 {i} output matches!") + # print(f"FF shape: {ff_tensor.shape}") + # print(f"HF shape: {hf_tensor.shape}") + + # hf_w3_output = hf_tensor.clone() + + # W3 (up_proj) input input_comparison = TensorComparisonIdxs(hf_tensor_type="input_gradient", ff_tensor_type="input_gradient", hf_tensor_idx=0, ff_tensor_idx=0) hf_tensor = get_hf_tensor(hf_tensor_name, input_comparison) ff_tensor = get_ff_tensor(ff_tensor_name, input_comparison, hf_tensor.shape, tp_type=TPType.TO_REDUCE) + + # w3_input_torch = torch.matmul(hf_tensor, torch.transpose(ff_tensor, 0, 1)) + # ff_up_proj_weight_path="/usr/.cache/flexflow/debug/flexflow/weights/step_0/shard_0/layers.11.layers.11.mlp.up_proj.weight_0" + # hf_up_proj_weight_path="/usr/.cache/flexflow/debug/huggingface/weights/step_0/layers.11.mlp.up_proj.weight" + # hf_up_proj_weight = torch.load(hf_up_proj_weight_path, map_location='cpu') + # print(hf_up_proj_weight.shape) + # ff_up_proj_weight = load_ff_tensor(ff_up_proj_weight_path, hf_up_proj_weight.shape[::-1]) + # print(ff_up_proj_weight.shape) + # ff_up_proj_weight = torch.from_numpy(ff_up_proj_weight).to(hf_up_proj_weight.dtype) + # assert torch.allclose(hf_up_proj_weight.T, ff_up_proj_weight, atol=1e-5) + + # print("HF W3 output shape:", hf_w3_output.shape) + # print("HF W3 weight shape:", hf_up_proj_weight.shape) + # print("HF W3 input shape:", hf_tensor.shape) + + # simulated_w3_input = torch.matmul(hf_w3_output.squeeze(), hf_up_proj_weight) + # print("simulated W3 input shape:", simulated_w3_input.T.shape) + # print(simulated_w3_input.T) + # print(f"w3 {i} grad input") + # print("flexflow tensor shape:", ff_tensor.squeeze().shape) + # print(ff_tensor.squeeze()) + # print("huggingface tensor shape:", hf_tensor.squeeze().T.shape) + # print(hf_tensor.squeeze().T) + compare(hf_tensor, ff_tensor, label=f"W3 {i} gradient input") # Attn O-proj @@ -606,7 +655,8 @@ def compare(hf_tensor, ff_tensor, label="", additional_ff_tensor=None, tolerance ff_tensor_name = f"layers.{i}.layers.{i}.input_layernorm" _output_comparison = TensorComparisonIdxs(hf_tensor_type="input_gradient", ff_tensor_type="output_gradient", hf_tensor_idx=0, ff_tensor_idx=1) input_layernorm_out1 = get_ff_tensor(ff_tensor_name, _output_comparison, hf_tensor.shape, tp_type=TPType.REPLICATE) - torch.testing.assert_close(attn_input, input_layernorm_out1, rtol=1.3e-6, atol=1e-5) + compare_loaded_tensors(attn_input, input_layernorm_out1, tolerance=1e-5) + # torch.testing.assert_close(attn_input, input_layernorm_out1, rtol=1.3e-6, atol=1e-5) # Input layernorm @@ -695,7 +745,24 @@ def compare(hf_tensor, ff_tensor, label="", tolerance=1e-4): torch.testing.assert_close(hf_gradient, (hf_original_weight-hf_finetuned_weight)/learning_rate, rtol=1.3e-6, atol=1e-5) ff_gradient_name = convert_hf_filename_to_ff(hf_gradient_name) ff_gradient = get_ff_tensor(ff_gradient_name, hf_gradient.shape, tp_type=TPType.REPLICATE) + + lora_low_rank_activation_fwd_path = f"/usr/.cache/flexflow/debug/flexflow/fwd/step_{step_idx}/shard_0/layers.{i}.layers.{i}.mlp.down_proj.lora.low_rank_activation" + lora_low_rank_activation_bwd_path = f"/usr/.cache/flexflow/debug/flexflow/bwd/step_{step_idx}/shard_0/layers.{i}.layers.{i}.mlp.down_proj.lora.low_rank_activation" + lora_low_rank_activation_fwd = load_ff_tensor(lora_low_rank_activation_fwd_path, [16, 128])[:,:self.num_tokens] + lora_low_rank_activation_fwd = torch.from_numpy(lora_low_rank_activation_fwd) + lora_low_rank_activation_bwd = load_ff_tensor(lora_low_rank_activation_bwd_path, [16, 24]) + lora_low_rank_activation_bwd = torch.from_numpy(lora_low_rank_activation_bwd) + torch.testing.assert_close(lora_low_rank_activation_fwd, lora_low_rank_activation_bwd, rtol=1.3e-6, atol=1e-5) + + # print(f"LoRA_B {i} gradient") + # print("FlexFlow shape: ", ff_gradient.shape) + # print(ff_gradient) + # print("HuggingFace shape: ", hf_gradient.shape) + # print(hf_gradient.squeeze().T) compare(hf_gradient, ff_gradient, label=f"LoRA_B {i} gradient") + + + # ff_out_gradient_name = f"layers.{i}.layers.{i}.mlp.down_proj.lora.output_gradient_0" # ff_fwd_folder = os.path.join(ff_path, "fwd", f"step_{step_idx}", "shard_0") # ff_bwd_folder = os.path.join(ff_path, "bwd", f"step_{step_idx}", "shard_0") @@ -737,7 +804,7 @@ def compare(hf_tensor, ff_tensor, label="", tolerance=1e-4): args = parser.parse_args() if __name__ == "__main__": - llama_alignment = LllamaAlignmentTest(args.model_name, tp_degree=args.tensor_parallelism_degree) + llama_alignment = LlamaAlignmentTest(args.model_name, tp_degree=args.tensor_parallelism_degree) # llama_alignment.check_weights_alignment() for i in range(args.num_steps): llama_alignment.check_fwd_pass(i) diff --git a/tests/peft_test.sh b/tests/peft_test.sh index 5600d57edf..e497d4224e 100755 --- a/tests/peft_test.sh +++ b/tests/peft_test.sh @@ -31,22 +31,22 @@ mkdir -p ./inference/output export LEGION_BACKTRACE=1 # Download test model -python ./inference/utils/download_peft_model.py goliaro/llama-160m-lora --base_model_name JackFram/llama-160m +python ./inference/utils/download_peft_model.py goliaro/llama-160m-lora # Run PEFT in Huggingface to get ground truth tensors -python ./tests/peft/hf_finetune.py --peft-model-id goliaro/llama-160m-lora --save-peft-tensors --use-full-precision +python ./tests/peft/hf_finetune.py --peft-model-id goliaro/llama-160m-lora --save-peft-tensors --use-full-precision -lr 0.001 # Python test echo "Python test" python ./inference/python/ff_peft.py # Check alignment -python ./tests/peft/peft_alignment_test.py -tp 2 +python ./tests/peft/peft_alignment_test.py -tp 4 -lr 0.001 # C++ test echo "C++ test" ./build/inference/peft/peft \ - -ll:gpu 2 -ll:cpu 4 -ll:util 4 \ - -tensor-parallelism-degree 2 \ + -ll:gpu 4 -ll:cpu 4 -ll:util 4 \ + -tensor-parallelism-degree 4 \ -ll:fsize 8192 -ll:zsize 12000 \ -llm-model JackFram/llama-160m \ -finetuning-dataset ./inference/prompt/peft_dataset.json \ @@ -55,7 +55,7 @@ echo "C++ test" --use-full-precision \ --inference-debugging # Check alignment -python ./tests/peft/peft_alignment_test.py -tp 2 +python ./tests/peft/peft_alignment_test.py -tp 4 -lr 0.001 # Print succeess message echo ""