Skip to content

Commit

Permalink
Implemented and tested the backward pass for C++ interface, debugging…
Browse files Browse the repository at this point in the history
… python interface
  • Loading branch information
yingchen21 committed Aug 4, 2024
1 parent c9d0fb1 commit 6c4349d
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 67 deletions.
6 changes: 3 additions & 3 deletions include/flexflow/ops/inc_multihead_self_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,9 @@ class IncMultiHeadSelfAttention : public Op {
BatchConfig const *bc,
int shard_id,
GenericTensorAccessorW const &input_grad,
GenericTensorAccessorR const &weight,
GenericTensorAccessorR const &output_grad,
GenericTensorAccessorR const &bias);
// GenericTensorAccessorR const &weight,
GenericTensorAccessorR const &output_grad);
// GenericTensorAccessorR const &bias);
Params get_params() const;

public:
Expand Down
6 changes: 3 additions & 3 deletions src/ops/fused.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1046,9 +1046,9 @@ __host__ void FusedOp::peft_bwd_task(Task const *task,
bc,
task->index_point.point_data[0],
my_input_grad_accessor[0],
my_weight_accessor[0],
my_output_grad_accessor[0],
biases);
// my_weight_accessor[0],
my_output_grad_accessor[0]);
// biases);
break;
}
case OP_TREE_INC_MULTIHEAD_SELF_ATTENTION:
Expand Down
89 changes: 47 additions & 42 deletions src/ops/inc_multihead_self_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -767,31 +767,31 @@ FutureMap IncMultiHeadSelfAttention::peft_bwd(
EXCLUSIVE,
batch_inputs[0]->region_grad));
launcher.add_field(idx++, FID_DATA);
launcher.add_region_requirement(
RegionRequirement(weights[0]->part,
0 /*projection id*/,
READ_ONLY,
EXCLUSIVE,
weights[0]->region,
ff.cpu_offload ? MAP_TO_ZC_MEMORY : 0));
launcher.add_field(idx++, FID_DATA);
// launcher.add_region_requirement(
// RegionRequirement(weights[0]->part,
// 0 /*projection id*/,
// READ_ONLY,
// EXCLUSIVE,
// weights[0]->region,
// ff.cpu_offload ? MAP_TO_ZC_MEMORY : 0));
// launcher.add_field(idx++, FID_DATA);
launcher.add_region_requirement(
RegionRequirement(batch_outputs[0]->part_grad,
0 /*projection id*/,
READ_WRITE,
EXCLUSIVE,
batch_outputs[0]->region_grad));
launcher.add_field(idx++, FID_DATA);
if (qkv_bias || final_bias) {
launcher.add_region_requirement(
RegionRequirement(weights[1]->part,
0 /*projection id*/,
READ_ONLY,
EXCLUSIVE,
weights[1]->region,
ff.cpu_offload ? MAP_TO_ZC_MEMORY : 0));
launcher.add_field(idx++, FID_DATA);
}
// if (qkv_bias || final_bias) {
// launcher.add_region_requirement(
// RegionRequirement(weights[1]->part,
// 0 /*projection id*/,
// READ_ONLY,
// EXCLUSIVE,
// weights[1]->region,
// ff.cpu_offload ? MAP_TO_ZC_MEMORY : 0));
// launcher.add_field(idx++, FID_DATA);
// }
return runtime->execute_index_space(ctx, launcher);
}

Expand All @@ -818,37 +818,42 @@ void IncMultiHeadSelfAttention::peft_bwd_task(
IncMultiHeadSelfAttentionMeta *m =
*((IncMultiHeadSelfAttentionMeta **)task->local_args);

assert(((*m->qkv_bias || *m->final_bias) ? regions.size() == 4
: regions.size() == 3));
// assert(((*m->qkv_bias || *m->final_bias) ? regions.size() == 4
// : regions.size() == 3));
assert(regions.size() == 2); // input grad, output grad

GenericTensorAccessorW input_grad = helperGetGenericTensorAccessorRW(
m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime);
GenericTensorAccessorR weight = helperGetGenericTensorAccessorRO(
m->weight_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime);
// GenericTensorAccessorR weight = helperGetGenericTensorAccessorRO(
// m->weight_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime);
// GenericTensorAccessorW output_grad = helperGetGenericTensorAccessorRW(
// m->output_type[0], regions[2], task->regions[2], FID_DATA, ctx, runtime);
GenericTensorAccessorW output_grad = helperGetGenericTensorAccessorRW(
m->output_type[0], regions[2], task->regions[2], FID_DATA, ctx, runtime);
m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime);
GenericTensorAccessorR biases;
if (*m->qkv_bias || *m->final_bias) {
biases = helperGetGenericTensorAccessorRO(m->weight_type[1],
regions[3],
task->regions[3],
FID_DATA,
ctx,
runtime);
Domain bias_domain = runtime->get_index_space_domain(
ctx, task->regions[3].region.get_index_space());
assert(bias_domain.get_dim() == 4);
}
// if (*m->qkv_bias || *m->final_bias) {
// biases = helperGetGenericTensorAccessorRO(m->weight_type[1],
// regions[3],
// task->regions[3],
// FID_DATA,
// ctx,
// runtime);
// Domain bias_domain = runtime->get_index_space_domain(
// ctx, task->regions[3].region.get_index_space());
// assert(bias_domain.get_dim() == 4);
// }

Domain input_grad_domain = runtime->get_index_space_domain(
ctx, task->regions[0].region.get_index_space());
Domain weight_domain = runtime->get_index_space_domain(
ctx, task->regions[1].region.get_index_space());
// Domain weight_domain = runtime->get_index_space_domain(
// ctx, task->regions[1].region.get_index_space());
// Domain output_grad_domain = runtime->get_index_space_domain(
// ctx, task->regions[2].region.get_index_space());
Domain output_grad_domain = runtime->get_index_space_domain(
ctx, task->regions[2].region.get_index_space());
ctx, task->regions[1].region.get_index_space());

assert(input_grad_domain.get_dim() == 4);
assert(weight_domain.get_dim() == 2);
// assert(weight_domain.get_dim() == 2);
assert(output_grad_domain.get_dim() == 4);

assert(task->index_point.get_dim() == 1);
Expand All @@ -858,15 +863,15 @@ void IncMultiHeadSelfAttention::peft_bwd_task(
bc,
task->index_point.point_data[0],
input_grad,
weight,
output_grad,
biases);
// weight,
output_grad);
// biases);

if (m->inference_debugging) {
assert(task->index_point.get_dim() == 1);
int shard_id = task->index_point.point_data[0];
IncMultiHeadSelfAttention::save_inference_tensors_to_file(
m, shard_id, bc, {input_grad}, {weight}, {output_grad}, false);
m, shard_id, bc, {input_grad}, {}, {output_grad}, false);
}
}

Expand Down
34 changes: 19 additions & 15 deletions src/ops/inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1334,7 +1334,7 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m,
// TODO: check if this transposeAdd can correctly implement gradient accumulation
transposeAdd(C, B, n_, k_, alpha, beta, stream);
printf("backward of raw attn grad: %d, %d, with redudant dimension %d\n", k_, n_, m_);
// printf("backward of raw attn grad: %d, %d, with redudant dimension %d\n", k_, n_, m_);
if (m->inference_debugging) {
std::string filename =
get_peft_dbg_folder(m, shard_id) + ".self_attn.input_gradient_0";
Expand Down Expand Up @@ -1747,9 +1747,9 @@ void IncMultiHeadSelfAttention::peft_bwd_kernel_wrapper(
BatchConfig const *bc,
int shard_id,
GenericTensorAccessorW const &input_grad,
GenericTensorAccessorR const &weight,
GenericTensorAccessorR const &output_grad,
GenericTensorAccessorR const &bias) {
// GenericTensorAccessorR const &weight,
GenericTensorAccessorR const &output_grad) {
// GenericTensorAccessorR const &bias) {
cudaStream_t stream;
checkCUDA(get_legion_stream(&stream));
bool use_bias = *m->qkv_bias || *m->final_bias;
Expand All @@ -1763,33 +1763,37 @@ void IncMultiHeadSelfAttention::peft_bwd_kernel_wrapper(
// assert(input.data_type == weight.data_type);
assert(input_grad.data_type == output_grad.data_type);
if (use_bias) {
assert(input_grad.data_type == bias.data_type);
}
// if (use_bias) {
// assert(input_grad.data_type == bias.data_type);
// }
if (input_grad.data_type == DT_HALF) {
assert(!m->offload);
half const *bias_ptr =
use_bias ? bias.get_half_ptr() : static_cast<half const *>(nullptr);
// half const *bias_ptr =
// use_bias ? bias.get_half_ptr() : static_cast<half const *>(nullptr);
Kernels::IncMultiHeadAttention::peft_bwd_kernel(m,
bc,
shard_id,
input_grad.get_half_ptr(),
weight.get_half_ptr(),
// weight.get_half_ptr(),
static_cast<half const *>(nullptr),
output_grad.get_half_ptr(),
bias_ptr,
// bias_ptr,
static_cast<half const *>(nullptr),
stream);
} else if (input_grad.data_type == DT_FLOAT) {
assert(!m->offload);
float const *bias_ptr =
use_bias ? bias.get_float_ptr() : static_cast<float const *>(nullptr);
// float const *bias_ptr =
// use_bias ? bias.get_float_ptr() : static_cast<float const *>(nullptr);
Kernels::IncMultiHeadAttention::peft_bwd_kernel(m,
bc,
shard_id,
input_grad.get_float_ptr(),
weight.get_float_ptr(),
// weight.get_float_ptr(),
static_cast<float const *>(nullptr),
output_grad.get_float_ptr(),
bias_ptr,
// bias_ptr,
static_cast<float const *>(nullptr),
stream);
} else {
assert(false && "Unspported data type");
Expand Down
2 changes: 1 addition & 1 deletion src/ops/kernels/linear_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ void peft_bwd_kernel(LinearMeta const *m,
in_dim,
compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
printf("%s: input_grad has shape %d, %d\n", m->op_name, in_dim, num_peft_tokens);
// printf("%s: input_grad has shape %d, %d\n", m->op_name, in_dim, num_peft_tokens);
}
}
Expand Down
1 change: 0 additions & 1 deletion src/runtime/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1149,7 +1149,6 @@ bool Op::check_output_input_weight_parallel_dims(bool allocate_weights) const {
bool Op::check_output_input_weight_same_parallel_is() const {
assert(numOutputs > 0);
IndexSpace parallel_is = outputs[0]->parallel_is;
printf("checking operator %s\n", name);
for (int i = 0; i < numOutputs; i++) {
if (outputs[i]->parallel_is != parallel_is) {
std::cout<<"outputs["<<i<<"] has different parallel_is "<<outputs[i]->parallel_is<<" than output[0] "<<parallel_is<<std::endl;
Expand Down
6 changes: 4 additions & 2 deletions src/runtime/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ fs::path get_dst_folder(std::string const &subdir,
char cwd[PATH_MAX];
getcwd(cwd, sizeof(cwd));

char const *ff_cache_path = std::string(std::getenv("FF_DEBUG_PATH")) == "." ?
cwd : std::getenv("FF_DEBUG_PATH");
// char const *ff_cache_path = std::string(std::getenv("FF_DEBUG_PATH")) == "." ?
// cwd : std::getenv("FF_DEBUG_PATH");

char const *ff_cache_path = std::getenv("FF_CACHE_PATH");

std::string debug_dir_ =
ff_cache_path ? std::string(ff_cache_path) + "/debug/flexflow"
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/request_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2756,6 +2756,7 @@ void RequestManager::start_background_server(FFModel *model) {
// Register callbacks for termination
{
std::set_terminate([]() {
// assert(false && "terminate");
RequestManager::terminate_background_server_at_exit();
std::abort();
});
Expand Down Expand Up @@ -3008,6 +3009,7 @@ void RequestManager::trigger_request_completion_future(
/*static*/
void RequestManager::terminate_background_server_at_exit() {
RequestManager *rm = RequestManager::get_request_manager();
// assert(false && "RM terminating bg server due to exit");
rm->terminate_background_server();
}

Expand Down

0 comments on commit 6c4349d

Please sign in to comment.