Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fuse optimization #1248

Open
wants to merge 19 commits into
base: BertMLM_fixes
Choose a base branch
from
2 changes: 2 additions & 0 deletions include/flexflow/flexflow_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ void flexflow_model_compute_metrics(flexflow_model_t handle);

void flexflow_model_update(flexflow_model_t handle);

void flexflow_model_unified_update(flexflow_model_t handle);

void flexflow_model_compile(flexflow_model_t handle,
enum LossType loss_type,
int *metrics,
Expand Down
2 changes: 2 additions & 0 deletions include/flexflow/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ enum TaskIDs {
// Optimizer with NCCL
SGD_UPD_NCCL_TASK_ID,
ADAM_UPD_NCCL_TASK_ID,
ADAM_UNIFY_UPD_NCCL_TASK_ID,
// Initializer
GLOROT_INIT_TASK_ID,
ZERO_INIT_TASK_ID,
Expand Down Expand Up @@ -777,6 +778,7 @@ class FFModel {
void get_metrics();
void backward(int seq_length = -1);
void update();
void unified_update();
bool apply_fusion(std::vector<Op *> const &operators,
std::vector<Op *> &new_operators);
Op *get_final_operator() const;
Expand Down
7 changes: 7 additions & 0 deletions include/flexflow/ops/dropout.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@
#include "flexflow/node.h"
#include "flexflow/operator.h"
#include "flexflow/ops/dropout_params.h"
#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA)
#include <curand.h>
#include <curand_kernel.h>
#elif defined(FF_USE_HIP_ROCM)
#include <hiprand/hiprand.h>
#include <hiprand/hiprand_kernel.h>
#endif

namespace FlexFlow {

Expand Down
16 changes: 12 additions & 4 deletions include/flexflow/ops/kernels/dropout_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "flexflow/fftype.h"
#include "flexflow/op_meta.h"
#include "flexflow/ops/dropout.h"
#include "flexflow/accessor.h"

namespace FlexFlow {

Expand All @@ -17,33 +18,40 @@ class DropoutMeta : public OpMeta {
~DropoutMeta(void);
Realm::RegionInstance reserveInst;
#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA)
curandState *state;
cudnnTensorDescriptor_t inputTensor, outputTensor;
cudnnDropoutDescriptor_t dropoutDesc;
#else
miopenTensorDescriptor_t inputTensor, outputTensor;
miopenDropoutDescriptor_t dropoutDesc;
hiprandState *state;
#endif
void *reserveSpace, *dropoutStates;
size_t reserveSpaceSize, dropoutStateSize;
size_t num_elements;
long long seed;
float rate;
};

namespace Kernels {
namespace Dropout {
void forward_kernel_wrapper(DropoutMeta *m,
float const *input_ptr,
float *output_ptr);
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output);
void backward_kernel_wrapper(DropoutMeta *m,
float const *output_grad_ptr,
float *input_grad_ptr);
GenericTensorAccessorR const &output_grad,
GenericTensorAccessorW const &input_grad);

namespace Internal {
void forward_kernel(DropoutMeta *m,
float const *input_ptr,
float *output_ptr,
size_t num_elements,
ffStream_t stream);
void backward_kernel(DropoutMeta *m,
float const *output_grad_ptr,
float *input_grad_ptr,
size_t num_elements,
ffStream_t stream);
} // namespace Internal
} // namespace Dropout
Expand Down
23 changes: 23 additions & 0 deletions include/flexflow/optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class Optimizer {
virtual void init(void) = 0;
virtual void next(void) = 0;
virtual void update(const ParallelTensor p) = 0;
virtual void unified_update(std::vector<ParallelTensor> const parameters) = 0;
FFModel const *model;
};

Expand All @@ -43,6 +44,7 @@ class SGDOptimizer : public Optimizer {
void init(void);
void next(void);
void update(const ParallelTensor p);
void unified_update(std::vector<ParallelTensor> const parameters);
void set_weight_decay(double _weight_decay);
static void ps_update_task(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Expand All @@ -60,6 +62,11 @@ class SGDOptimizer : public Optimizer {
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
static void
nccl_unified_update_task(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
static void nccl_update_task_gpu(SGDOptimizer const *op,
OpMeta const *meta,
float const *w_grad_ptr,
Expand All @@ -85,6 +92,7 @@ class AdamOptimizer : public Optimizer {
void init(void);
void next(void);
void update(const ParallelTensor p);
void unified_update(std::vector<ParallelTensor> const parameters);
void set_weight_decay(double _weight_decay);
static void ps_update_task(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Expand All @@ -103,17 +111,32 @@ class AdamOptimizer : public Optimizer {
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
static void
nccl_unified_update_task(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
static void nccl_update_task_gpu(AdamOptimizer const *op,
OpMeta const *meta,
float const *w_grad_ptr,
size_t size,
float *w_ptr,
float *v_ptr,
float *m_ptr);
static void nccl_unified_update_task_gpu(AdamOptimizer const *op,
OpMeta const *meta,
float const *w_grad_ptr[],
size_t *size,
float *w_ptr[],
float *v_ptr[],
float *m_ptr[]);
#endif
double alpha, beta1, beta2, weight_decay, epsilon;
double alpha_t, beta1_t, beta2_t;
std::map<Legion::LogicalRegion, ParallelTensor> v_values, m_values;
size_t reservedWorkSpaceSize = 0;
int parameters_num = 0;
int processed_parameters_num = 0;
};

}; // namespace FlexFlow
Expand Down
7 changes: 7 additions & 0 deletions python/flexflow/core/flexflow_cffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2018,6 +2018,13 @@ def update(self):
:returns: None -- no returns.
"""
ffc.flexflow_model_update(self.handle)

def unified_update(self):
"""Update weights and biases of all layers.

:returns: None -- no returns.
"""
ffc.flexflow_model_unified_update(self.handle)

def compile(self, optimizer=None, loss_type=None, metrics=None, comp_mode=None):
"""Configure the model for trainting. FlexFlow uses lazy initialization,
Expand Down
5 changes: 5 additions & 0 deletions src/c/flexflow_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,11 @@ void flexflow_model_update(flexflow_model_t handle_) {
handle->update();
}

void flexflow_model_unified_update(flexflow_model_t handle_) {
FFModel *handle = FFCObjectWrapper::unwrap(handle_);
handle->unified_update();
}

void flexflow_model_compile(flexflow_model_t handle_,
enum LossType loss_type,
int *metrics,
Expand Down
38 changes: 25 additions & 13 deletions src/ops/dropout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ using PCG::Node;

using namespace FlexFlow::Kernels::Dropout;

Tensor FFModel::dropout(const Tensor input,
Tensor FFModel::dropout(Tensor const input,
float rate,
unsigned long long seed,
char const *name) {
Expand Down Expand Up @@ -86,7 +86,7 @@ bool operator==(DropoutParams const &lhs, DropoutParams const &rhs) {
}

Dropout::Dropout(FFModel &model,
const ParallelTensor _input,
ParallelTensor const _input,
float _rate,
unsigned long long _seed,
char const *name)
Expand All @@ -111,12 +111,12 @@ Dropout::Dropout(FFModel &model,

Dropout::Dropout(FFModel &model,
Dropout const &other,
const ParallelTensor input)
ParallelTensor const input)
: Dropout(model, input, other.rate, other.seed, other.name) {}

Dropout::Dropout(FFModel &model,
DropoutParams const &params,
const ParallelTensor input,
ParallelTensor const input,
char const *name)
: Dropout(model, input, params.rate, params.seed, name) {}

Expand Down Expand Up @@ -210,12 +210,12 @@ void Dropout::forward_task(Task const *task,
assert(task->regions.size() == 2);
// const Dropout* dropout = (const Dropout*) task->args;
DropoutMeta *m = *((DropoutMeta **)task->local_args);
float const *input_ptr = helperGetTensorPointerRO<float>(
regions[0], task->regions[0], FID_DATA, ctx, runtime);
float *output_ptr = helperGetTensorPointerWO<float>(
regions[1], task->regions[1], FID_DATA, ctx, runtime);

forward_kernel_wrapper(m, input_ptr, output_ptr);

GenericTensorAccessorR input = helperGetGenericTensorAccessorRO(
m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime);
GenericTensorAccessorW output = helperGetGenericTensorAccessorWO(
m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime);
forward_kernel_wrapper(m, input, output);
}

void Dropout::backward(FFModel const &ff) {
Expand Down Expand Up @@ -264,7 +264,13 @@ void Dropout::backward_task(Task const *task,
float const *output_grad_ptr = helperGetTensorPointerRO<float>(
regions[1], task->regions[1], FID_DATA, ctx, runtime);

backward_kernel_wrapper(m, output_grad_ptr, input_grad_ptr);

GenericTensorAccessorW input_grad = helperGetGenericTensorAccessorRW(
m->output_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime);
GenericTensorAccessorR output_grad = helperGetGenericTensorAccessorRO(
m->input_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime);

backward_kernel_wrapper(m, output_grad, input_grad);
}

void Dropout::serialize(Legion::Serializer &sez) const {
Expand Down Expand Up @@ -304,30 +310,36 @@ bool Dropout::measure_operator_cost(Simulator *sim,
sim->free_all();
float *input_ptr = (float *)sim->allocate(sub_input.get_volume(), DT_FLOAT);
assert(input_ptr != NULL);

GenericTensorAccessorR input_acc(m->input_type[0], sub_input.get_domain(), input_ptr);
cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset);

float *output_ptr = (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT);
assert(output_ptr != NULL);

GenericTensorAccessorW output_acc(m->output_type[0], sub_input.get_domain(), output_ptr);
cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset);

assert(m->profiling == false);

std::function<void()> forward, backward;
forward = [&] { forward_kernel_wrapper(m, input_ptr, output_ptr); };
forward = [&] { forward_kernel_wrapper(m, input_acc, output_acc); };
if (sim->computationMode == COMP_MODE_TRAINING) {
float *input_grad_ptr =
(float *)sim->allocate(sub_input.get_volume(), DT_FLOAT);
assert(input_grad_ptr != NULL);
GenericTensorAccessorW input_grad_acc(m->output_type[0], sub_input.get_domain(), input_grad_ptr);
cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset);

float *output_grad_ptr =
(float *)sim->allocate(sub_output.get_volume(), DT_FLOAT);
assert(output_grad_ptr != NULL);
GenericTensorAccessorR output_grad_acc(m->output_type[0], sub_input.get_domain(), output_grad_ptr);
cost_metrics.outputs_memory +=
cost_metrics.total_mem_diff_from(sim->offset);

backward = [&] {
backward_kernel_wrapper(m, output_grad_ptr, input_grad_ptr);
backward_kernel_wrapper(m, output_grad_acc, input_grad_acc);
};
}

Expand Down
Loading
Loading