Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro committed Oct 6, 2024
1 parent 1691100 commit 7ff96d7
Show file tree
Hide file tree
Showing 21 changed files with 673 additions and 521 deletions.
8 changes: 5 additions & 3 deletions include/flexflow/flexflow_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -600,10 +600,12 @@ flexflow_tensor_t flexflow_model_add_argmax(flexflow_model_t handle_,
bool beam_search,
char const *name);

void flexflow_model_add_lora_layers(flexflow_model_t handle_, int num_target_modules, char const **target_modules_);
void flexflow_model_add_lora_layers(flexflow_model_t handle_,
int num_target_modules,
char const **target_modules_);


flexflow_peft_model_id_t flexflow_model_register_peft_adapter(flexflow_model_t handle_, const flexflow_lora_linear_config_t peft_config_);
flexflow_peft_model_id_t flexflow_model_register_peft_adapter(
flexflow_model_t handle_, const flexflow_lora_linear_config_t peft_config_);

void flexflow_model_set_sgd_optimizer(flexflow_model_t handle,
flexflow_sgd_optimizer_t optimizer);
Expand Down
12 changes: 6 additions & 6 deletions include/flexflow/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -845,9 +845,9 @@ class FFModel {
// ========================================
// PEFT Layers
// ========================================
// PEFTModelID *add_lora_layer(LoraLinearConfig const peft_config);
void add_lora_layers(std::vector<std::string> target_modules);
PEFTModelID *register_peft_adapter(LoraLinearConfig const &peft_config);
// PEFTModelID *add_lora_layer(LoraLinearConfig const peft_config);
void add_lora_layers(std::vector<std::string> target_modules);
PEFTModelID *register_peft_adapter(LoraLinearConfig const &peft_config);
// ========================================
// Inference APIs
// ========================================
Expand Down Expand Up @@ -1182,9 +1182,9 @@ class FFModel {
std::vector<ParallelTensor> parameters;
// PEFT related
std::unordered_map<Layer *, Layer *> base_layer_to_peft_layer;
// std::unordered_map<Layer *, std::vector<PEFTModelID>> peft_layer_to_peft_id;
// std::unordered_map<PEFTModelID, LoraLinearConfig> peft_configs;
// std::vector<Op *> peft_operators;
// std::unordered_map<Layer *, std::vector<PEFTModelID>>
// peft_layer_to_peft_id; std::unordered_map<PEFTModelID, LoraLinearConfig>
// peft_configs; std::vector<Op *> peft_operators;

FFHandler handlers[MAX_NUM_WORKERS];
Legion::Future current_metrics;
Expand Down
3 changes: 2 additions & 1 deletion include/flexflow/ops/kernels/lora_linear_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ class LoraLinearMeta : public OpMeta {
namespace Kernels {
namespace LoraLinear {

bool lora_applies_to_this_layer(LoraLinearMeta *m, LoraLinearConfig const &config);
bool lora_applies_to_this_layer(LoraLinearMeta *m,
LoraLinearConfig const &config);

void init_kernel_wrapper(LoraLinearMeta *m, int seed);
void inference_kernel_wrapper(LoraLinearMeta *m,
Expand Down
15 changes: 7 additions & 8 deletions include/flexflow/ops/lora_linear.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@ class LoraLinear : public Op {
using Params = LoraLinearParams;
using Input = std::pair<ParallelTensor, ParallelTensor>;

LoraLinear(
FFModel &model,
LayerID const &layer_guid,
ParallelTensor const input,
ParallelTensor const output,
int max_rank,
int max_concurrent_adapters,
char const *name = nullptr);
LoraLinear(FFModel &model,
LayerID const &layer_guid,
ParallelTensor const input,
ParallelTensor const output,
int max_rank,
int max_concurrent_adapters,
char const *name = nullptr);
LoraLinear(FFModel &model,
LoraLinear const &other,
ParallelTensor const input,
Expand Down
130 changes: 30 additions & 100 deletions include/flexflow/ops/lora_linear_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class LoraOptimizerConfig {
LoraOptimizerConfig();
virtual std::string getType() const = 0;
virtual nlohmann::json toJson() const = 0;
static std::unique_ptr<LoraOptimizerConfig> fromJson(const nlohmann::json& j);
static std::unique_ptr<LoraOptimizerConfig> fromJson(nlohmann::json const &j);
virtual ~LoraOptimizerConfig() = default;
};

Expand All @@ -32,26 +32,16 @@ class LoraSGDOptimizerConfig : public LoraOptimizerConfig {
bool weight_decay_ = 0.0f);
friend std::ostream &operator<<(std::ostream &os,
LoraSGDOptimizerConfig const &llc);

std::string getType() const override { return "SGD"; }

nlohmann::json toJson() const override {
return {{"type", "SGD"},
{"lr", lr},
{"momentum", momentum},
{"nesterov", nesterov},
{"weight_decay", weight_decay}};
}

static std::unique_ptr<LoraSGDOptimizerConfig> fromJson(const nlohmann::json& j) {
auto sgd = std::make_unique<LoraSGDOptimizerConfig>();
sgd->lr = j["lr"];
sgd->momentum = j["momentum"];
sgd->nesterov = j["nesterov"];
sgd->weight_decay = j["weight_decay"];
return sgd;
std::string getType() const override {
return "SGD";
}

nlohmann::json toJson() const override;

static std::unique_ptr<LoraSGDOptimizerConfig>
fromJson(nlohmann::json const &j);

public:
double lr = 0.001f;
double momentum = 0.0f;
Expand All @@ -69,28 +59,16 @@ class LoraAdamOptimizerConfig : public LoraOptimizerConfig {
double epsilon_ = 1e-8);
friend std::ostream &operator<<(std::ostream &os,
LoraAdamOptimizerConfig const &llc);

std::string getType() const override { return "Adam"; }

nlohmann::json toJson() const override {
return {{"type", "Adam"},
{"alpha", alpha},
{"beta1", beta1},
{"beta2", beta2},
{"weight_decay", weight_decay},
{"epsilon", epsilon}};
}

static std::unique_ptr<LoraAdamOptimizerConfig> fromJson(const nlohmann::json& j) {
auto adam = std::make_unique<LoraAdamOptimizerConfig>();
adam->alpha = j["alpha"];
adam->beta1 = j["beta1"];
adam->beta2 = j["beta2"];
adam->weight_decay = j["weight_decay"];
adam->epsilon = j["epsilon"];
return adam;
std::string getType() const override {
return "Adam";
}

nlohmann::json toJson() const override;

static std::unique_ptr<LoraAdamOptimizerConfig>
fromJson(nlohmann::json const &j);

public:
// Adam
double alpha = 0.001f;
Expand All @@ -100,14 +78,6 @@ class LoraAdamOptimizerConfig : public LoraOptimizerConfig {
double epsilon = 1e-8;
};

std::unique_ptr<LoraOptimizerConfig> LoraOptimizerConfig::fromJson(const nlohmann::json& j) {
std::string type = j["type"];
if (type == "SGD") return LoraSGDOptimizerConfig::fromJson(j);
if (type == "Adam") return LoraAdamOptimizerConfig::fromJson(j);
throw std::runtime_error("Unknown optimizer type");
}


class LoraLinearConfig {
public:
static const LoraLinearConfig EmptyConfig;
Expand All @@ -126,11 +96,14 @@ class LoraLinearConfig {
LoraLinearConfig();

// Method to set optimizer
template<typename T>
void setOptimizer(T&& opt) {
if constexpr (std::is_base_of_v<LoraOptimizerConfig, std::remove_reference_t<T>>) {
optimizer_config = std::make_unique<std::remove_reference_t<T>>(std::forward<T>(opt));
} else if constexpr (std::is_same_v<std::unique_ptr<LoraOptimizerConfig>, std::remove_reference_t<T>>) {
template <typename T>
void setOptimizer(T &&opt) {
if constexpr (std::is_base_of_v<LoraOptimizerConfig,
std::remove_reference_t<T>>) {
optimizer_config =
std::make_unique<std::remove_reference_t<T>>(std::forward<T>(opt));
} else if constexpr (std::is_same_v<std::unique_ptr<LoraOptimizerConfig>,
std::remove_reference_t<T>>) {
optimizer_config = std::move(opt);
} else {
static_assert(always_false<T>, "Unsupported optimizer type");
Expand All @@ -139,62 +112,19 @@ class LoraLinearConfig {
// Helper template for static_assert
template <typename>
static inline constexpr bool always_false = false;

friend bool operator==(LoraLinearConfig const &lhs,
LoraLinearConfig const &rhs);
friend std::ostream &operator<<(std::ostream &os,
LoraLinearConfig const &llc);
std::string serialize_to_json_string(int indent=-1) const {
nlohmann::json j = {
{"cache_folder", cache_folder},
{"peft_model_id", peft_model_id},
{"rank", rank},
{"lora_alpha", lora_alpha},
{"lora_dropout", lora_dropout},
{"target_modules", target_modules},
{"trainable", trainable},
{"init_lora_weights", init_lora_weights},
{"base_model_name_or_path", base_model_name_or_path},
{"precision", precision},
// {"optimizer_config", optimizer_config ? optimizer_config->toJson() : nullptr}
{"optimizer_config", optimizer_config ? nlohmann::json(optimizer_config->toJson()) : nlohmann::json()}
};

return j.dump(indent); // No indentation
}
void serialize_to_json_file(const std::string& filename) const {
std::string j = serialize_to_json_string(4);
std::ofstream file(filename);
file << j;
}
std::string serialize_to_json_string(int indent = -1) const;
void serialize_to_json_file(std::string const &filename) const;
// Deserialization method
static LoraLinearConfig deserialize_from_json_string(const std::string& json_string) {
nlohmann::json j = nlohmann::json::parse(json_string);
LoraLinearConfig config(
j["cache_folder"].get<std::string>(),
j["peft_model_id"].get<std::string>(),
j["trainable"].get<bool>(),
nullptr, // optimizer_config will be set later if present
j["init_lora_weights"].get<bool>(),
j["base_model_name_or_path"].get<std::string>(),
j["precision"].get<std::string>(),
j["rank"].get<int>(),
j["lora_alpha"].get<float>(),
j["lora_dropout"].get<float>(),
j["target_modules"].get<std::vector<std::string>>()
);
if (!j["optimizer_config"].is_null()) {
config.setOptimizer(LoraOptimizerConfig::fromJson(j["optimizer_config"]));
}
return config;
}
static LoraLinearConfig
deserialize_from_json_string(std::string const &json_string);
// Deserialization method
static LoraLinearConfig deserialize_from_json_file(const std::string& filename) {
std::ifstream file(filename);
std::string j;
file >> j;
return deserialize_from_json_string(j);
}
static LoraLinearConfig
deserialize_from_json_file(std::string const &filename);

std::string cache_folder;
// Huggingface model ID (for download and/or upload)
Expand Down
4 changes: 2 additions & 2 deletions include/flexflow/request_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class RequestManager {
void register_output_filepath(std::string const &);
void register_peft_config(PEFTModelID const &peft_model_id,
LoraLinearConfig const &peft_config);
LoraLinearConfig get_peft_config(PEFTModelID peft_model_id);
LoraLinearConfig const &get_peft_config(PEFTModelID const &peft_model_id);
void set_max_lora_rank(int max_lora_rank);
void set_max_concurrent_adapters(int max_concurrent_adapters);
int get_max_lora_rank();
Expand Down Expand Up @@ -295,7 +295,7 @@ class RequestManager {
int max_spec_tree_token_num;
int max_sequence_length;
Status request_manager_status;

// peft
std::unordered_map<PEFTModelID, LoraLinearConfig> peft_configs;
int max_lora_rank;
Expand Down
94 changes: 63 additions & 31 deletions include/flexflow/utils/peft_weight_allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,52 +101,83 @@ struct LoraLinearWeight {
void *low_rank_activation;
// v values for SGD optimizer (when using momentum)
void *w0_v_values_ptr, *w1_v_values_ptr;
LoraLinearWeight(void *w0=nullptr, void *w1=nullptr, void *w0_grad=nullptr, void *w1_grad=nullptr,
void *w0_v_values=nullptr, void *w1_v_values=nullptr, void *low_rank_activation_=nullptr, void *input_activation_=nullptr)
: w0_ptr(w0), w1_ptr(w1),
w0_grad_ptr(w0_grad), w1_grad_ptr(w1_grad),
w0_v_values_ptr(w0_v_values), w1_v_values_ptr(w1_v_values),
low_rank_activation(low_rank_activation_), input_activation(input_activation_) {}
LoraLinearWeight(void *w0 = nullptr,
void *w1 = nullptr,
void *w0_grad = nullptr,
void *w1_grad = nullptr,
void *w0_v_values = nullptr,
void *w1_v_values = nullptr,
void *low_rank_activation_ = nullptr,
void *input_activation_ = nullptr)
: w0_ptr(w0), w1_ptr(w1), w0_grad_ptr(w0_grad), w1_grad_ptr(w1_grad),
w0_v_values_ptr(w0_v_values), w1_v_values_ptr(w1_v_values),
low_rank_activation(low_rank_activation_),
input_activation(input_activation_) {}
};

void init_peft_weight_wrapper(LoraLinearWeight const &weight, int in_dim, int out_dim, int rank, DataType dt, int seed);
void init_peft_weight_wrapper(LoraLinearWeight const &weight,
int in_dim,
int out_dim,
int rank,
DataType dt,
int seed);

class PEFTMemoryManager {
public:
PEFTMemoryManager(Legion::Memory gpu_mem_, int max_rank_, int max_concurrent_adapters_, int max_peft_tokens_, int in_dim_, int out_dim_, int num_shards_, int shard_id_, std::string const &lora_layername_substr_, DataType dt_)
: gpu_mem(gpu_mem_),
max_concurrent_adapters(max_concurrent_adapters_),
max_rank(max_rank_),
in_dim(in_dim_), out_dim(out_dim_), num_shards(num_shards_), shard_id(shard_id_),
max_peft_tokens(max_peft_tokens_),
lora_layername_substr(lora_layername_substr_), dt(dt_),
base_ptr(nullptr),
finetuning_ptr(nullptr),
finetuning_model_id(PEFTModelID::NO_ID) {
max_lora_size = data_type_size(dt) * (max_rank * in_dim + max_rank * out_dim);
assert(max_concurrent_adapters > 0 && "PEFT Memory Manager max_concurrent_adapters must be > 0");
assert(max_lora_size > 0 && "PEFT Memory Manager max_lora_size must be > 0");
PEFTMemoryManager(Legion::Memory gpu_mem_,
int max_rank_,
int max_concurrent_adapters_,
int max_peft_tokens_,
int in_dim_,
int out_dim_,
int num_shards_,
int shard_id_,
std::string const &lora_layername_substr_,
DataType dt_)
: gpu_mem(gpu_mem_), max_concurrent_adapters(max_concurrent_adapters_),
max_rank(max_rank_), in_dim(in_dim_), out_dim(out_dim_),
num_shards(num_shards_), shard_id(shard_id_),
max_peft_tokens(max_peft_tokens_),
lora_layername_substr(lora_layername_substr_), dt(dt_),
base_ptr(nullptr), finetuning_ptr(nullptr),
finetuning_model_id(PEFTModelID::NO_ID) {
max_lora_size =
data_type_size(dt) * (max_rank * in_dim + max_rank * out_dim);
assert(max_concurrent_adapters > 0 &&
"PEFT Memory Manager max_concurrent_adapters must be > 0");
assert(max_lora_size > 0 &&
"PEFT Memory Manager max_lora_size must be > 0");
allocate_inference_memory();
// finetuning memory is allocated upon the first finetuning request, so we can skip for inference-only workloads
// finetuning memory is allocated upon the first finetuning request, so we
// can skip for inference-only workloads
}

// allocate memory for all the PEFT adapters for a given layer on a given shard
// allocate memory for all the PEFT adapters for a given layer on a given
// shard
void allocate_inference_memory();
// allocate memory for the PEFT adapter for a finetuning request for a given layer and shard
// allocate memory for the PEFT adapter for a finetuning request for a given
// layer and shard
void allocate_finetuning_memory();

LoraLinearWeight get_peft(PEFTModelID const &model_id, LoraLinearConfig const &lora_config);
LoraLinearWeight get_peft(PEFTModelID const &model_id,
LoraLinearConfig const &lora_config);
void check_ft_model_id(PEFTModelID const &model_id);

private:
// Check if the PEFT adapter for the given model is in memory. If not, sets the cache_miss flag to true. If this is the first finetuning request, allocate memory for the finetuning adapter.
// Check if the PEFT adapter for the given model is in memory. If not, sets
// the cache_miss flag to true. If this is the first finetuning request,
// allocate memory for the finetuning adapter.
void get_finetuning_slot(PEFTModelID const &model_id, bool *cache_miss);
// Returns the slot in memory where the peft model weights are/will be stored.
// If the model is not in memory (cache miss), set the cache_miss flag to true.
// Returns the slot in memory where the peft model weights are/will be stored.
// If the model is not in memory (cache miss), set the cache_miss flag to
// true.
int get_inference_peft_slot(PEFTModelID const &model_id, bool *cache_miss);
void load_peft_model(LoraLinearWeight &weight, LoraLinearConfig const &lora_config);
LoraLinearWeight get_inference_peft(PEFTModelID const &model_id, LoraLinearConfig const &lora_config);
LoraLinearWeight get_finetuning_peft(PEFTModelID const &model_id, LoraLinearConfig const &lora_config);
void load_peft_model(LoraLinearWeight &weight,
LoraLinearConfig const &lora_config);
LoraLinearWeight get_inference_peft(PEFTModelID const &model_id,
LoraLinearConfig const &lora_config);
LoraLinearWeight get_finetuning_peft(PEFTModelID const &model_id,
LoraLinearConfig const &lora_config);

// Legion memory management apparatus
Legion::Memory gpu_mem;
Expand All @@ -160,7 +191,8 @@ class PEFTMemoryManager {
int max_peft_tokens;
// LRU cache apparatus
std::unordered_map<PEFTModelID, int> lru_hashtable;
std::vector<PEFTModelID> lru_list; // head = least recently used, tail=most recently used
std::vector<PEFTModelID>
lru_list; // head = least recently used, tail=most recently used
std::unordered_map<PEFTModelID, int> peft2mem_slot;
// Miscellanea
std::string lora_layername_substr;
Expand Down
Loading

0 comments on commit 7ff96d7

Please sign in to comment.