Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro committed Nov 5, 2024
1 parent 7eb953a commit 92c2c37
Show file tree
Hide file tree
Showing 10 changed files with 52 additions and 76 deletions.
9 changes: 6 additions & 3 deletions include/flexflow/ops/kernels/lora_linear_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

namespace FlexFlow {

using Legion::Context;
using Legion::Runtime;

#ifdef DEADCODE
struct LoraLinearModelState {
LoraLinearWeight weights;
Expand Down Expand Up @@ -40,7 +43,7 @@ namespace LoraLinear {
bool lora_applies_to_this_layer(LoraLinearMeta *m,
LoraLinearConfig const &config);

void init_kernel_wrapper(LoraLinearMeta *m, int seed);
// void init_kernel_wrapper(LoraLinearMeta *m, int seed);
void inference_kernel_wrapper(LoraLinearMeta *m,
BatchConfig const *bc,
GenericTensorAccessorR const &input,
Expand All @@ -53,8 +56,8 @@ void peft_bwd_kernel_wrapper(Context ctx,
GenericTensorAccessorR const &output_grad);

namespace Internal {
template <typename DT>
void init_kernel(LoraLinearMeta *m, int seed, ffStream_t stream);
// template <typename DT>
// void init_kernel(LoraLinearMeta *m, int seed, ffStream_t stream);
template <typename DT>
void inference_kernel(LoraLinearMeta *m,
BatchConfig const *bc,
Expand Down
40 changes: 7 additions & 33 deletions include/flexflow/ops/lora_linear_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ class LoraOptimizerConfig {
LoraOptimizerConfig();
virtual std::string getType() const = 0;
virtual nlohmann::json toJson() const = 0;
static std::unique_ptr<LoraOptimizerConfig> fromJson(nlohmann::json const &j);
virtual ~LoraOptimizerConfig() = default;
static LoraOptimizerConfig *fromJson(nlohmann::json const &j);
virtual ~LoraOptimizerConfig() {}
};

class LoraSGDOptimizerConfig : public LoraOptimizerConfig {
Expand All @@ -32,15 +32,11 @@ class LoraSGDOptimizerConfig : public LoraOptimizerConfig {
bool weight_decay_ = 0.0f);
friend std::ostream &operator<<(std::ostream &os,
LoraSGDOptimizerConfig const &llc);

std::string getType() const override {
return "SGD";
}

nlohmann::json toJson() const override;

static std::unique_ptr<LoraSGDOptimizerConfig>
fromJson(nlohmann::json const &j);
static LoraSGDOptimizerConfig *fromJson(nlohmann::json const &j);

public:
double lr = 0.001f;
Expand All @@ -63,11 +59,8 @@ class LoraAdamOptimizerConfig : public LoraOptimizerConfig {
std::string getType() const override {
return "Adam";
}

nlohmann::json toJson() const override;

static std::unique_ptr<LoraAdamOptimizerConfig>
fromJson(nlohmann::json const &j);
static LoraAdamOptimizerConfig *fromJson(nlohmann::json const &j);

public:
// Adam
Expand All @@ -94,29 +87,11 @@ class LoraLinearConfig {
std::vector<std::string> const &target_modules_ = {});
// constructor used to support std::unordered_map
LoraLinearConfig();

// Method to set optimizer
template <typename T>
void setOptimizer(T &&opt) {
if constexpr (std::is_base_of_v<LoraOptimizerConfig,
std::remove_reference_t<T>>) {
optimizer_config =
std::make_unique<std::remove_reference_t<T>>(std::forward<T>(opt));
} else if constexpr (std::is_same_v<std::unique_ptr<LoraOptimizerConfig>,
std::remove_reference_t<T>>) {
optimizer_config = std::move(opt);
} else {
static_assert(always_false<T>, "Unsupported optimizer type");
}
}
// Helper template for static_assert
template <typename>
static inline constexpr bool always_false = false;

friend bool operator==(LoraLinearConfig const &lhs,
LoraLinearConfig const &rhs);
friend std::ostream &operator<<(std::ostream &os,
LoraLinearConfig const &llc);

std::string serialize_to_json_string(int indent = -1) const;
void serialize_to_json_file(std::string const &filename) const;
// Deserialization method
Expand All @@ -138,8 +113,7 @@ class LoraLinearConfig {
// whether the weights are trainable (fine-tuning scenario) or not
// (inference-only). If set to true, allocate space for the gradients
bool trainable = false;
// LoraOptimizerConfig *optimizer_config;
std::unique_ptr<LoraOptimizerConfig> optimizer_config;
LoraOptimizerConfig *optimizer_config;
// whether to initialize weights randomly (instead of attempting to load them
// from file)
bool init_lora_weights;
Expand Down Expand Up @@ -170,4 +144,4 @@ struct hash<FlexFlow::LoraLinearParams> {
};
} // namespace std

#endif // _FLEXFLOW_LORA_LINEAR_PARAMS_H
#endif // _FLEXFLOW_LORA_LINEAR_PARAMS_H
4 changes: 2 additions & 2 deletions include/flexflow/request_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ class RequestManager {
std::vector<int> eos_token_ids,
std::string const &path);
void register_output_filepath(std::string const &);
void register_peft_config(PEFTModelID const &peft_model_id,
LoraLinearConfig const &peft_config);
void set_peft_config(PEFTModelID const &peft_model_id,
LoraLinearConfig const &peft_config);
LoraLinearConfig const &get_peft_config(PEFTModelID const &peft_model_id);
void set_max_lora_rank(int max_lora_rank);
void set_max_concurrent_adapters(int max_concurrent_adapters);
Expand Down
13 changes: 7 additions & 6 deletions inference/peft/peft.cc
Original file line number Diff line number Diff line change
Expand Up @@ -320,18 +320,19 @@ void FlexFlow::top_level_task(Task const *task,
assert(false && "unknow model type");
}

// Add PEFT layer
// Start background server
rm->start_background_server(&model);

// Add PEFT adapter(s)
PEFTModelID *peft_model_id = nullptr, *peft_model_id_finetuning = nullptr;
if (!peft_model_name.empty()) {
peft_model_id = model.add_lora_layer(peft_config);
peft_model_id = model.register_peft_adapter(peft_config);
if (enable_peft_finetuning) {
peft_model_id_finetuning = model.add_lora_layer(peft_config_finetuning);
peft_model_id_finetuning =
model.register_peft_adapter(peft_config_finetuning);
}
}

// Start background server
rm->start_background_server(&model);

// Run workload
{
std::vector<Request> requests;
Expand Down
8 changes: 4 additions & 4 deletions inference/peft/peft_bwd_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -304,15 +304,15 @@ void FlexFlow::top_level_task(Task const *task,
assert(false && "unknow model type");
}

// Start background server
rm->start_background_server(&model);

// Add PEFT layer
PEFTModelID *peft_model_id = nullptr;
if (!peft_model_name.empty()) {
peft_model_id = model.add_lora_layer(peft_config);
peft_model_id = model.register_peft_adapter(peft_config);
}

// Start background server
rm->start_background_server(&model);

// Warmup stage
{
std::vector<Request> requests;
Expand Down
8 changes: 4 additions & 4 deletions inference/peft/peft_fwd_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -304,15 +304,15 @@ void FlexFlow::top_level_task(Task const *task,
assert(false && "unknow model type");
}

// Start background server
rm->start_background_server(&model);

// Add PEFT layer
PEFTModelID *peft_model_id = nullptr;
if (!peft_model_name.empty()) {
peft_model_id = model.add_lora_layer(peft_config);
peft_model_id = model.register_peft_adapter(peft_config);
}

// Start background server
rm->start_background_server(&model);

// Run workload
{
std::vector<Request> requests;
Expand Down
6 changes: 3 additions & 3 deletions inference/peft/req_rate_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -366,14 +366,14 @@ void FlexFlow::top_level_task(Task const *task,
assert(false && "unknow model type");
}

rm->start_background_server(&model);

// Add PEFT layer
PEFTModelID *peft_model_id = nullptr;
if (!peft_model_name.empty()) {
peft_model_id = model.add_lora_layer(peft_config);
peft_model_id = model.register_peft_adapter(peft_config);
}

rm->start_background_server(&model);

// Warmup stage
{
std::vector<Request> requests;
Expand Down
5 changes: 3 additions & 2 deletions src/ops/kernels/lora_linear_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ LoraLinearMeta::~LoraLinearMeta(void) {}
namespace Kernels {
namespace LoraLinear {

#ifdef DEADCODE
void init_kernel_wrapper(LoraLinearMeta *m, int seed) {
cudaStream_t stream;
checkCUDA(get_legion_stream(&stream));
Expand All @@ -47,6 +48,7 @@ void init_kernel_wrapper(LoraLinearMeta *m, int seed) {
assert(false && "Unsupported data type");
}
}
#endif

void inference_kernel_wrapper(LoraLinearMeta *m,
BatchConfig const *bc,
Expand Down Expand Up @@ -314,7 +316,6 @@ void inference_kernel(LoraLinearMeta *m,
DT *output_ptr,
int in_dim,
int out_dim,
int num_shards,
ffStream_t stream) {
checkCUDA(cublasSetStream(m->handle.blas, stream));
checkCUDNN(cudnnSetStream(m->handle.dnn, stream));
Expand Down Expand Up @@ -593,7 +594,7 @@ void peft_bwd_kernel(Context ctx,
if (lora_config.optimizer_config->getType() == "SGD") {
LoraSGDOptimizerConfig const *sgd_config =
static_cast<LoraSGDOptimizerConfig const *>(
lora_config.optimizer_config.get());
lora_config.optimizer_config);
// LoRA_A weight is split in tensor parallelism, so no need to apply
// all-reduce
sgd_update<<<GET_BLOCKS(w0_num_elements),
Expand Down
29 changes: 13 additions & 16 deletions src/ops/lora_linear_params.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ namespace FlexFlow {
// empty optimizer
LoraOptimizerConfig::LoraOptimizerConfig() {}

std::unique_ptr<LoraOptimizerConfig>
LoraOptimizerConfig::fromJson(nlohmann::json const &j) {
LoraOptimizerConfig *LoraOptimizerConfig::fromJson(nlohmann::json const &j) {
std::string type = j["type"];
if (type == "SGD") {
return LoraSGDOptimizerConfig::fromJson(j);
Expand Down Expand Up @@ -50,9 +49,9 @@ nlohmann::json LoraSGDOptimizerConfig::toJson() const {
{"weight_decay", weight_decay}};
}

std::unique_ptr<LoraSGDOptimizerConfig>
LoraSGDOptimizerConfig *
LoraSGDOptimizerConfig::fromJson(nlohmann::json const &j) {
auto sgd = std::make_unique<LoraSGDOptimizerConfig>();
LoraSGDOptimizerConfig *sgd = new LoraSGDOptimizerConfig();
sgd->lr = j["lr"];
sgd->momentum = j["momentum"];
sgd->nesterov = j["nesterov"];
Expand Down Expand Up @@ -89,9 +88,9 @@ nlohmann::json LoraAdamOptimizerConfig::toJson() const {
{"epsilon", epsilon}};
}

std::unique_ptr<LoraAdamOptimizerConfig>
LoraAdamOptimizerConfig *
LoraAdamOptimizerConfig::fromJson(nlohmann::json const &j) {
auto adam = std::make_unique<LoraAdamOptimizerConfig>();
LoraAdamOptimizerConfig *adam = new LoraAdamOptimizerConfig();
adam->alpha = j["alpha"];
adam->beta1 = j["beta1"];
adam->beta2 = j["beta2"];
Expand Down Expand Up @@ -220,12 +219,11 @@ std::ostream &operator<<(std::ostream &os, LoraLinearConfig const &llc) {
os << "trainable: " << llc.trainable << ", ";
if (llc.optimizer_config != nullptr) {
os << "optimizer_config: ";
if (llc.optimizer_config.get()->getType() == "SGD") {
os << *static_cast<LoraSGDOptimizerConfig const *>(
llc.optimizer_config.get());
} else if (llc.optimizer_config.get()->getType() == "Adam") {
os << *static_cast<LoraAdamOptimizerConfig const *>(
llc.optimizer_config.get());
if (typeid(*llc.optimizer_config) == typeid(LoraSGDOptimizerConfig)) {
os << *static_cast<LoraSGDOptimizerConfig *>(llc.optimizer_config);
} else if (typeid(*llc.optimizer_config) ==
typeid(LoraAdamOptimizerConfig)) {
os << *static_cast<LoraAdamOptimizerConfig *>(llc.optimizer_config);
} else {
os << "Unknown optimizer config type";
}
Expand All @@ -248,8 +246,6 @@ std::string LoraLinearConfig::serialize_to_json_string(int indent) const {
{"init_lora_weights", init_lora_weights},
{"base_model_name_or_path", base_model_name_or_path},
{"precision", precision},
// {"optimizer_config", optimizer_config ?
// optimizer_config->toJson() : nullptr}
{"optimizer_config",
optimizer_config
? nlohmann::json(optimizer_config->toJson())
Expand Down Expand Up @@ -282,7 +278,8 @@ LoraLinearConfig LoraLinearConfig::deserialize_from_json_string(
j["lora_dropout"].get<float>(),
j["target_modules"].get<std::vector<std::string>>());
if (!j["optimizer_config"].is_null()) {
config.setOptimizer(LoraOptimizerConfig::fromJson(j["optimizer_config"]));
config.optimizer_config =
LoraOptimizerConfig::fromJson(j["optimizer_config"]);
}
return config;
}
Expand All @@ -296,4 +293,4 @@ LoraLinearConfig
return deserialize_from_json_string(j);
}

}; // namespace FlexFlow
}; // namespace FlexFlow
6 changes: 3 additions & 3 deletions src/runtime/request_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,8 @@ size_t RequestManager::get_num_ssms() {
return ssm_models.size();
}

void RequestManager::register_peft_config(PEFTModelID const &peft_model_id,
LoraLinearConfig const &peft_config) {
void RequestManager::set_peft_config(PEFTModelID const &peft_model_id,
LoraLinearConfig const &peft_config) {
// check that peft_model_id is not already in use
assert(peft_configs.find(peft_model_id) == peft_configs.end() &&
"PEFT model ID already in use");
Expand Down Expand Up @@ -322,7 +322,7 @@ PEFTModelID *
}
PEFTModelID *peft_model_id = new PEFTModelID(peft_model_global_guid++);
RequestManager *rm = RequestManager::get_request_manager();
rm->register_peft_config(*peft_model_id, peft_config);
rm->set_peft_config(*peft_model_id, peft_config);
return peft_model_id;
}

Expand Down

0 comments on commit 92c2c37

Please sign in to comment.