Skip to content

Commit

Permalink
fix reset input grad for non-activated loras
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro committed Nov 8, 2024
1 parent 139b643 commit b56ebd3
Show file tree
Hide file tree
Showing 13 changed files with 111 additions and 59 deletions.
2 changes: 1 addition & 1 deletion include/flexflow/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ class Op {
// get operator name and print it
std::string op_name_without_uid = get_op_name_without_uid(m);
std::cout << (fwd_pass ? "INF " : "BWD ") << op_name_without_uid
<< std::endl;
<< (before_kernel ? " (before kernel)" : "") << std::endl;
// build the path to save the tensor
fs::path dst_filepath;
if (fwd_pass) {
Expand Down
2 changes: 2 additions & 0 deletions include/flexflow/ops/kernels/linear_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ void inference_kernel_wrapper(LinearMeta *m,
int out_dim,
int batch_size);
void peft_bwd_kernel_wrapper(LinearMeta const *m,
BatchConfig const *bc,
void *input_grad_ptr,
void *output_grad_ptr,
void const *kernel_ptr,
Expand Down Expand Up @@ -94,6 +95,7 @@ void forward_kernel(LinearMeta const *m,
ffStream_t stream);
template <typename DT>
void peft_bwd_kernel(LinearMeta const *m,
BatchConfig const *bc,
void *input_grad_ptr,
void *output_grad_ptr,
void const *kernel_ptr,
Expand Down
17 changes: 0 additions & 17 deletions include/flexflow/ops/kernels/lora_linear_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,10 @@ namespace FlexFlow {
using Legion::Context;
using Legion::Runtime;

#ifdef DEADCODE
struct LoraLinearModelState {
LoraLinearWeight weights;
LoraOptimizerConfig const *optimizer_config;
float lora_alpha;
std::string cache_folder;
// Huggingface model ID (for download and/or upload)
std::string peft_model_id;
};
#endif

class LoraLinearMeta : public OpMeta {
public:
LoraLinearMeta(FFHandler handle, LoraLinear const *li);
~LoraLinearMeta(void);
// PEFT related fields
// void *low_rank_activation;
// void *input_activation;
// std::unordeded_map<PEFTModelID, LoraLinearWeight> model_state;
// std::unordered_map<PEFTModelID, LoraLinearModelState> model_state;
// size_t allocated_peft_buffer_size1 = 0, allocated_peft_buffer_size2 = 0;
PEFTMemoryManager *peft_memory_manager;
};

Expand Down
1 change: 1 addition & 0 deletions src/ops/fused.cu
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,7 @@ __host__ void FusedOp::peft_bwd_task(Task const *task,
int num_infr_tokens = bc->num_active_infr_tokens();
int num_peft_tokens = bc->num_active_peft_tokens();
Kernels::Linear::peft_bwd_kernel_wrapper(m,
bc,
my_input_grad_accessor[0].ptr,
my_output_grad_accessor[0].ptr,
my_weight_accessor[0].ptr,
Expand Down
45 changes: 45 additions & 0 deletions src/ops/kernels/linear_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "flexflow/ffconst_utils.h"
#include "flexflow/ops/kernels/decompress_kernels.h"
#include "flexflow/ops/kernels/linear_kernels.h"
#include "flexflow/ops/lora_linear_params.h"
#include "flexflow/utils/cuda_helper.h"

namespace FlexFlow {
Expand Down Expand Up @@ -73,6 +74,17 @@ LinearMeta::~LinearMeta(void) {
}
}
bool lora_applies_to_this_layer(LinearMeta const *m,
LoraLinearConfig const &config) {
for (std::string s : config.target_modules) {
std::string n(m->op_name);
if (n.find(s) != std::string::npos) {
return true;
}
}
return false;
}
namespace Kernels {
namespace Linear {
Expand Down Expand Up @@ -285,6 +297,7 @@ void inference_kernel_wrapper(LinearMeta *m,
}
void peft_bwd_kernel_wrapper(LinearMeta const *m,
BatchConfig const *bc,
void *input_grad_ptr,
void *output_grad_ptr,
void const *weight_ptr,
Expand All @@ -302,6 +315,7 @@ void peft_bwd_kernel_wrapper(LinearMeta const *m,
}
if (m->input_type[0] == DT_FLOAT) {
Internal::peft_bwd_kernel<float>(m,
bc,
input_grad_ptr,
output_grad_ptr,
weight_ptr,
Expand All @@ -312,6 +326,7 @@ void peft_bwd_kernel_wrapper(LinearMeta const *m,
stream);
} else if (m->input_type[0] == DT_HALF) {
Internal::peft_bwd_kernel<half>(m,
bc,
input_grad_ptr,
output_grad_ptr,
weight_ptr,
Expand Down Expand Up @@ -568,6 +583,7 @@ void forward_kernel(LinearMeta const *m,
template <typename DT>
void peft_bwd_kernel(LinearMeta const *m,
BatchConfig const *bc,
void *input_grad_ptr,
void *output_grad_ptr,
void const *kernel_ptr,
Expand Down Expand Up @@ -611,6 +627,35 @@ void peft_bwd_kernel(LinearMeta const *m,
// NOTE: we use beta=1 for input_grad to accumulate gradients when needed
DT alpha = 1.0f;
DT beta = m->reset_input_grads[0] ? 0.0f : 1.0f;
// ensure that we only have one finetuning request, with a single lora
int num_peft_requests = 0;
bool lora_applies = false;
for (int i = 0; i < bc->max_requests_per_batch(); i++) {
if (bc->request_completed[i] ||
bc->requestsInfo[i].peft_model_id == PEFTModelID::NO_ID ||
!bc->requestsInfo[i].peft_bwd) {
continue;
}
num_peft_requests++;
std::string peft_model_config_str =
std::string(bc->requestsInfo[i].peft_model_config_str);
LoraLinearConfig lora_config =
LoraLinearConfig::deserialize_from_json_string(peft_model_config_str);
if (!lora_applies_to_this_layer(m, lora_config)) {
continue;
}
lora_applies = true;
}
assert(num_peft_requests == 1 &&
"Exactly one PEFT finetuning request is required");
// if the request does not have any active lora in the current layer, reset
// beta to 0 std::cout << m->op_name << " original beta: " << (float)beta << "
// lora_applies: " << lora_applies << std::endl;
if (lora_applies) {
beta = 1.0f;
}
if (input_grad_ptr != NULL) {
checkCUDA(cublasGemmEx(m->handle.blas,
CUBLAS_OP_N,
Expand Down
34 changes: 17 additions & 17 deletions src/ops/kernels/lora_linear_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,12 @@
namespace FlexFlow {

LoraLinearMeta::LoraLinearMeta(FFHandler handler, LoraLinear const *li)
: OpMeta(handler, li) {
}
: OpMeta(handler, li) {}

LoraLinearMeta::~LoraLinearMeta(void) {}

std::string get_peft_dbg_folder(LoraLinearMeta const *m,
int shard_id,
bool is_fwd) {
std::string
get_peft_dbg_folder(LoraLinearMeta const *m, int shard_id, bool is_fwd) {
std::string op_name_without_uid = LoraLinear::get_op_name_without_uid(m);
fs::path dst_filepath;
if (is_fwd) {
Expand All @@ -51,8 +49,6 @@ std::string get_peft_dbg_folder(LoraLinearMeta const *m,
namespace Kernels {
namespace LoraLinear {



void inference_kernel_wrapper(LoraLinearMeta *m,
BatchConfig const *bc,
GenericTensorAccessorR const &input,
Expand Down Expand Up @@ -174,7 +170,6 @@ bool lora_applies_to_this_layer(LoraLinearMeta *m,

namespace Internal {


template <typename DT>
void inference_kernel(LoraLinearMeta *m,
BatchConfig const *bc,
Expand Down Expand Up @@ -208,8 +203,8 @@ void inference_kernel(LoraLinearMeta *m,
if (!lora_applies_to_this_layer(m, lora_config)) {
continue;
}
std::cout << "Lora layer activated!" << std::endl;
std::cout << "Lora Config: " << peft_model_config_str << std::endl;
// std::cout << "Lora layer activated!" << std::endl;
// std::cout << "Lora Config: " << peft_model_config_str << std::endl;
assert(lora_config.trainable == bc->requestsInfo[i].peft_bwd &&
"Trainable flag mismatch");
int num_peft_tokens = bc->requestsInfo[i].num_tokens_in_batch;
Expand Down Expand Up @@ -311,7 +306,7 @@ void peft_bwd_kernel(Context ctx,
Runtime *runtime,
LoraLinearMeta *m,
BatchConfig const *bc,
int shard_id,
int shard_id,
DT *input_grad_ptr,
DT const *output_grad_ptr,
int in_dim,
Expand Down Expand Up @@ -340,8 +335,8 @@ void peft_bwd_kernel(Context ctx,
if (!lora_applies_to_this_layer(m, lora_config)) {
continue;
}
std::cout << "Lora layer activated!" << std::endl;
std::cout << "Lora Config: " << peft_model_config_str << std::endl;
// std::cout << "Lora layer activated!" << std::endl;
// std::cout << "Lora Config: " << peft_model_config_str << std::endl;
assert(lora_config.trainable == bc->requestsInfo[i].peft_bwd &&
"Trainable flag mismatch");
m->peft_memory_manager->check_ft_model_id(
Expand All @@ -359,12 +354,17 @@ void peft_bwd_kernel(Context ctx,
DT beta = (bc->requestsInfo[i].optimizer_tasks.reset_gradients_to_zero)
? 0.0f
: 1.0f;
std::cout << "Lora B gradient computation, beta = " << (float) beta << std::endl;
// std::cout << "Lora B gradient computation, beta = " << (float) beta <<
// std::endl;
if (m->inference_debugging) {
// save result to file for checking
std::string filename = get_peft_dbg_folder(m, shard_id, false) + ".low_rank_activation";
std::cout << "Save low_rank_activation (" << lora_config.rank << ", " << num_peft_tokens << ") to " << filename << std::endl;
save_tensor(static_cast<const DT*>(weight.low_rank_activation), lora_config.rank*num_peft_tokens, filename.c_str());
std::string filename =
get_peft_dbg_folder(m, shard_id, false) + ".low_rank_activation";
std::cout << "Save low_rank_activation (" << lora_config.rank << ", "
<< num_peft_tokens << ") to " << filename << std::endl;
save_tensor(static_cast<const DT *>(weight.low_rank_activation),
lora_config.rank * num_peft_tokens,
filename.c_str());
}
checkCUDA(cublasGemmEx(m->handle.blas,
CUBLAS_OP_N,
Expand Down
1 change: 1 addition & 0 deletions src/ops/linear.cc
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,7 @@ void Linear::peft_bwd_task(Task const *task,
num_peft_tokens);
}
peft_bwd_kernel_wrapper(m,
bc,
input_grad.ptr,
output_grad.ptr,
weight.ptr,
Expand Down
3 changes: 2 additions & 1 deletion src/ops/lora_linear.cc
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,8 @@ void LoraLinear::peft_bwd_task(Task const *task,
int out_dim = output_grad.domain.hi()[0] - output_grad.domain.lo()[0] + 1;
// int num_infr_tokens = bc->num_active_infr_tokens();
// int num_peft_tokens = bc->num_active_peft_tokens();
peft_bwd_kernel_wrapper(ctx, runtime, m, bc, shard_id, input_grad, output_grad);
peft_bwd_kernel_wrapper(
ctx, runtime, m, bc, shard_id, input_grad, output_grad);

save_peft_weights_if_needed(m, bc, in_dim, out_dim, shard_id);

Expand Down
3 changes: 2 additions & 1 deletion src/ops/lora_linear_params.cc
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,8 @@ LoraLinearConfig LoraLinearConfig::deserialize_from_json_string(
config.target_modules = j["target_modules"].get<std::vector<std::string>>();
config.trainable = j["trainable"].get<bool>();
config.init_lora_weights = j["init_lora_weights"].get<bool>();
config.base_model_name_or_path = j["base_model_name_or_path"].get<std::string>();
config.base_model_name_or_path =
j["base_model_name_or_path"].get<std::string>();
config.precision = j["precision"].get<std::string>();
config.optimizer_config = optimizer_config_;
return config;
Expand Down
4 changes: 3 additions & 1 deletion src/runtime/inference_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,9 @@ void InferenceManager::compile_model_and_allocate_buffer(FFModel *model) {
}
reset_inputs.insert(op->inputs[i]->region);
} else {
reset_inputs.insert(op->inputs[i]->region);
if (op->op_type != OP_LORA) {
reset_inputs.insert(op->inputs[i]->region);
}
}
}
}
Expand Down
9 changes: 5 additions & 4 deletions src/runtime/peft_weight_allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ using Legion::TaskLauncher;

void PEFTMemoryManager::allocate_inference_memory() {
// allocate chunk of memory for all the PEFT adapters
Realm::Rect<1, coord_t> bounds(Realm::Point<1, coord_t>(0),
Realm::Point<1, coord_t>(max_lora_size*max_concurrent_adapters - 1));
Realm::Rect<1, coord_t> bounds(
Realm::Point<1, coord_t>(0),
Realm::Point<1, coord_t>(max_lora_size * max_concurrent_adapters - 1));
std::vector<size_t> field_sizes;
field_sizes.push_back(sizeof(char));
Realm::RegionInstance::create_instance(peftLegionInst,
Expand All @@ -38,8 +39,8 @@ void PEFTMemoryManager::allocate_inference_memory() {

void PEFTMemoryManager::allocate_finetuning_memory() {
size_t ft_size = max_lora_size * 3; // weights, gradients, momentum values
ft_size +=
max_peft_tokens * (in_dim + max_rank) * data_type_size(dt); // input, low-rank activations
ft_size += max_peft_tokens * (in_dim + max_rank) *
data_type_size(dt); // input, low-rank activations
// allocate chunk of memory for PEFT adapter
Realm::Rect<1, coord_t> bounds(Realm::Point<1, coord_t>(0),
Realm::Point<1, coord_t>(ft_size - 1));
Expand Down
6 changes: 4 additions & 2 deletions src/runtime/request_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,8 @@ void RequestManager::set_peft_config(PEFTModelID const &peft_model_id,
// check that peft_model_id is not already in use
assert(peft_configs.find(peft_model_id) == peft_configs.end() &&
"PEFT model ID already in use");
// LoraLinearConfig new_config = LoraLinearConfig::deserialize_from_json_string(
// LoraLinearConfig new_config =
// LoraLinearConfig::deserialize_from_json_string(
// peft_config.serialize_to_json_string());
peft_configs[peft_model_id] = peft_config;
}
Expand Down Expand Up @@ -305,7 +306,8 @@ PEFTModelID *
std::cout << peft_config << std::endl;
assert(false);
}
std::cout << "Registering PEFT adapter" << peft_config.serialize_to_json_string() << std::endl;
std::cout << "Registering PEFT adapter"
<< peft_config.serialize_to_json_string() << std::endl;
// go over base_layer_to_peft_layer and check that you can find at least one
// match
for (int i = 0; i < peft_config.target_modules.size(); i++) {
Expand Down
Loading

0 comments on commit b56ebd3

Please sign in to comment.