Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PEFT bug fixes and alignment #1269

Merged
merged 12 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions include/flexflow/ops/add_bias_residual_layer_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,15 +124,13 @@ class AddBiasResidualLayerNorm : public Op {
T const *output_grad_ptr,
T *input_grad_ptr,
T *residual_grad_ptr,
T *attn_bias_grad_ptr,
T const *gamma_ptr,
ffStream_t stream);
static void
peft_bwd_kernel_wrapper(AddBiasResidualLayerNormMeta const *m,
GenericTensorAccessorR const &output_grad,
GenericTensorAccessorW &input_grad,
GenericTensorAccessorW const &residual_grad,
GenericTensorAccessorW const &attn_bias_grad,
GenericTensorAccessorR const &gamma);

public:
Expand Down
3 changes: 2 additions & 1 deletion include/flexflow/ops/kernels/softmax_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ void backward_kernel_wrapper(SoftmaxMeta const *m,
void inference_kernel_wrapper(SoftmaxMeta const *m,
BatchConfig const *bc,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output);
GenericTensorAccessorW const &output,
GenericTensorAccessorW const &output_grad);

void peft_bwd_kernel_wrapper(SoftmaxMeta const *m,
BatchConfig const *bc,
Expand Down
1 change: 1 addition & 0 deletions include/flexflow/ops/residual_layer_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class ResidualLayerNorm : public Op {
float _eps,
bool allocate_weights,
char const *name);
void map_output_tensors(FFModel &ff) override;
void init(FFModel const &) override;
void init_inference(FFModel const &,
std::vector<ParallelTensor> const &,
Expand Down
44 changes: 16 additions & 28 deletions inference/incr_decoding/incr_decoding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,9 @@ void FlexFlow::top_level_task(Task const *task,
bool do_sample = false;
float temperature = 0.0f;
float topp = 0.0f;
int max_requests_per_batch = 2;
int max_tokens_per_batch = 300;
int max_sequence_length = 300;
int max_requests_per_batch = 8;
int max_tokens_per_batch = 128;
int max_sequence_length = 256;

InputArgs const &command_args = HighLevelRuntime::get_input_args();
char **argv = command_args.argv;
Expand Down Expand Up @@ -272,7 +272,6 @@ void FlexFlow::top_level_task(Task const *task,

int total_num_requests = 0;
{
#ifdef DEADCODE
using json = nlohmann::json;
std::ifstream file_handle(file_paths.prompt_file_path);
assert(file_handle.good() && "Prompt file does not exist.");
Expand All @@ -286,32 +285,21 @@ void FlexFlow::top_level_task(Task const *task,
std::string text = prompt.get<std::string>();
printf("Prompt[%d]: %s\n", total_num_requests, text.c_str());
// Add inference request
Request inference_req;
inference_req.prompt = text;
inference_req.max_sequence_length = 128;
inference_req.peft_model_id = peft_model_id;
requests.push_back(inference_req);
// Request inference_req;
// inference_req.prompt = text;
// inference_req.max_sequence_length = 128;
// inference_req.peft_model_id = peft_model_id;
// requests.push_back(inference_req);
// total_num_requests++;
// Add fine-tuning request
Request fine_tuning_req;
fine_tuning_req.req_type = Request::RequestType::REQ_FINETUNING;
fine_tuning_req.max_sequence_length = 128;
fine_tuning_req.peft_model_id = peft_model_id;
fine_tuning_req.dataset_text.push_back(std::make_pair(text, ""));
requests.push_back(fine_tuning_req);
total_num_requests++;
}
#endif
std::vector<Request> requests;
for (int i = 0; i < (max_requests_per_batch - 1) * 4; i++) {
Request inference_req;
inference_req.prompt = "b";
inference_req.max_sequence_length = 40;
requests.push_back(inference_req);
total_num_requests++;
}
// Add a fine-tuning request
Request fine_tuning_req;
fine_tuning_req.req_type = Request::RequestType::REQ_FINETUNING;
fine_tuning_req.max_sequence_length = 256;
fine_tuning_req.max_training_steps = 256;
fine_tuning_req.peft_model_id = peft_model_id;
fine_tuning_req.dataset_text.push_back(std::make_pair("b", ""));
requests.push_back(fine_tuning_req);
total_num_requests++;

GenerationResult result = model.generate(requests);
}

Expand Down
10 changes: 5 additions & 5 deletions inference/models/opt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ void OPT::create_opt_model(FFModel &ff,
Tensor fc1 =
ff.dense(final_norm,
opt_config.ffn_dim,
AC_MODE_NONE,
AC_MODE_RELU,
true,
DT_NONE,
nullptr,
Expand All @@ -202,8 +202,7 @@ void OPT::create_opt_model(FFModel &ff,
REG_MODE_NONE,
0.0f,
std::string("layers_" + std::to_string(i) + "_fc1").c_str());
Tensor activation = ff.relu(fc1, false);
fc2 = ff.dense(activation,
fc2 = ff.dense(fc1,
opt_config.hidden_size,
AC_MODE_NONE,
true,
Expand All @@ -216,7 +215,7 @@ void OPT::create_opt_model(FFModel &ff,
std::string("layers_" + std::to_string(i) + "_fc2").c_str());
// Low-Rank Adapter (LoRA) for the second linear layer
ff.lora_linear(
activation,
fc1,
fc2,
OP_LORA_MLP_SECOND,
std::string("layers_" + std::to_string(i) + "_fc2_lora").c_str());
Expand Down Expand Up @@ -255,7 +254,8 @@ void OPT::create_opt_model(FFModel &ff,
output = ff.argmax(softmax, /*beam_Search*/ true);
} else {
// output = ff.arg_top_k(lm_head, /*k=*/1, false);
output = ff.argmax(lm_head, /*beam_Search*/ false);
Tensor softmax = ff.softmax(lm_head, -1);
output = ff.argmax(softmax, /*beam_Search*/ false);
}

//------------------- compile the model --------------------------------
Expand Down
27 changes: 5 additions & 22 deletions src/ops/add_bias_residual_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -931,33 +931,25 @@ Legion::FutureMap AddBiasResidualLayerNorm::peft_bwd(
launcher.add_region_requirement(
RegionRequirement(batch_inputs[0]->part_grad,
0 /*projection id*/,
READ_WRITE,
reset_input_grads[0] ? WRITE_ONLY : READ_WRITE,
EXCLUSIVE,
batch_inputs[0]->region_grad));
launcher.add_field(field_id++, FID_DATA);
// residual grad
launcher.add_region_requirement(
RegionRequirement(batch_inputs[1]->part_grad,
0 /*projection id*/,
READ_WRITE,
reset_input_grads[1] ? WRITE_ONLY : READ_WRITE,
EXCLUSIVE,
batch_inputs[1]->region_grad));
launcher.add_field(field_id++, FID_DATA);
// attn bias grad
launcher.add_region_requirement(
RegionRequirement(batch_inputs[2]->part_grad,
0 /*projection id*/,
READ_WRITE,
EXCLUSIVE,
batch_inputs[2]->region_grad));
launcher.add_field(field_id++, FID_DATA);
if (elementwise_affine) {
// gamma
launcher.add_region_requirement(RegionRequirement(weights[0]->part,
launcher.add_region_requirement(RegionRequirement(weights[1]->part,
0 /*projection id*/,
READ_ONLY,
EXCLUSIVE,
weights[0]->region));
weights[1]->region));
launcher.add_field(field_id++, FID_DATA);
}
return runtime->execute_index_space(ctx, launcher);
Expand Down Expand Up @@ -1001,14 +993,6 @@ void AddBiasResidualLayerNorm::peft_bwd_task(
ctx,
runtime);

GenericTensorAccessorW attn_bias_grad =
helperGetGenericTensorAccessorRW(m->weight_type[0],
regions[region_idx++],
task->regions[task_region_idx++],
FID_DATA,
ctx,
runtime);

GenericTensorAccessorR gamma;
if (m->elementwise_affine) {
assert(m->use_bias == (regions.size() == 6));
Expand All @@ -1020,13 +1004,12 @@ void AddBiasResidualLayerNorm::peft_bwd_task(
runtime);
}
AddBiasResidualLayerNorm::peft_bwd_kernel_wrapper(
m, output_grad, input_grad, residual_grad, attn_bias_grad, gamma);
m, output_grad, input_grad, residual_grad, gamma);

if (m->inference_debugging) {
assert(task->index_point.get_dim() == 1);
int shard_id = task->index_point.point_data[0];
std::vector<GenericTensorAccessorR> weights_accessors;
weights_accessors.push_back(attn_bias_grad);
if (m->elementwise_affine) {
weights_accessors.push_back(gamma);
}
Expand Down
72 changes: 34 additions & 38 deletions src/ops/add_bias_residual_layer_norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ __inline__ __device__ T BlockReduceSum(T val, T *shared, int max_num_threads) {
shared[wid] = val;
}
__syncthreads();
val = (threadIdx.x < min(blockDim.x, max_num_threads) / C10_WARP_SIZE)
val = (threadIdx.x < (min(blockDim.x, max_num_threads) / C10_WARP_SIZE))
? shared[lid]
: 0;
: T(0);
if (wid == 0) {
val = WarpReduceSum(val);
}
Expand Down Expand Up @@ -536,8 +536,9 @@ __device__ __inline__ void compute_gI(T const *__restrict__ dY,
T const *__restrict__ rstd,
T const *__restrict__ gamma,
T *dX,
T *dX_residual1,
T *dX_residual2,
T *dX_residual,
bool reset_input_grad,
bool reset_residual_grad,
int const N,
T *buf) {
auto const i1 = blockIdx.x;
Expand All @@ -549,9 +550,7 @@ __device__ __inline__ void compute_gI(T const *__restrict__ dY,
T const *X_i = X + i1 * N;
T const *dY_i = dY + i1 * N;
T *dX_i = dX + i1 * N;
T *dX_residual1_i = dX_residual1 + i1 * N;
T *dX_residual2_i =
(dX_residual2 != nullptr) ? dX_residual2 + i1 * N : nullptr;
T *dX_residual_i = dX_residual + i1 * N;
// vectorized reads don't improve perf, so use regular unrolling

for (; l + unroll - 1 < N; l += blockDim.x * unroll) {
Expand Down Expand Up @@ -592,10 +591,15 @@ __device__ __inline__ void compute_gI(T const *__restrict__ dY,
f_grad_input -= (x - mean_val) * rstd_val * stats_x2;
f_grad_input -= stats_x1;
f_grad_input *= term1;
dX_i[l] += f_grad_input;
dX_residual1_i[l] += f_grad_input;
if (dX_residual2 != nullptr) {
dX_residual2_i[l] += f_grad_input;
if (reset_input_grad) {
dX_i[l] = f_grad_input;
} else {
dX_i[l] += f_grad_input;
}
if (reset_residual_grad) {
dX_residual_i[l] = f_grad_input;
} else {
dX_residual_i[l] += f_grad_input;
}
}
}
Expand All @@ -607,13 +611,24 @@ __global__ void layer_norm_grad_input_kernel(T const *__restrict__ dY,
T const *__restrict__ rstd,
T const *__restrict__ gamma,
T *dX,
T *dX_residual1,
T *dX_residual2,
T *dX_residual,
bool reset_input_grad,
bool reset_residual_grad,
int const N) {
alignas(sizeof(double)) extern __shared__ char s_data1[];
T *buf = reinterpret_cast<T *>(&s_data1);

compute_gI(dY, X, mean, rstd, gamma, dX, dX_residual1, dX_residual2, N, buf);
compute_gI(dY,
X,
mean,
rstd,
gamma,
dX,
dX_residual,
reset_input_grad,
reset_residual_grad,
N,
buf);
}

/*static*/
Expand Down Expand Up @@ -661,7 +676,8 @@ void AddBiasResidualLayerNorm::backward_kernel(
gamma_ptr,
input_grad_ptr,
residual_grad_ptr,
attn_bias_grad_ptr,
m->reset_input_grads[0],
m->reset_input_grads[1],
N);

if (gamma_grad_ptr != NULL || beta_grad_ptr != NULL) {
Expand Down Expand Up @@ -764,29 +780,11 @@ void AddBiasResidualLayerNorm::peft_bwd_kernel(
T const *output_grad_ptr,
T *input_grad_ptr,
T *residual_grad_ptr,
T *attn_bias_grad_ptr,
T const *gamma_ptr,
cudaStream_t stream) {
const int64_t M = m->effective_batch_size;
const int64_t N = m->effective_num_elements;
ComputeInternalGradientsCUDAKernel<T>
<<<M, kCUDABlockReduceNumThreads, 0, stream>>>(
N,
output_grad_ptr,
static_cast<T const *>(m->input_activation),
gamma_ptr,
static_cast<T *>(m->ds_ptr),
static_cast<T *>(m->db_ptr));
const int64_t B = (M + kCUDANumThreads - 1) / kCUDANumThreads;
ComputeGradientFusedParamsCUDAKernel<T>
<<<B, kCUDANumThreads, 0, stream>>>(M,
N,
static_cast<T *>(m->mean_ptr),
static_cast<T *>(m->rstd_ptr),
static_cast<T *>(m->ds_ptr),
static_cast<T *>(m->db_ptr),
static_cast<T *>(m->scale_ptr),
static_cast<T *>(m->bias_ptr));

int const warp_size = C10_WARP_SIZE;
int const num_threads = 128;
const dim3 blocks(M);
Expand All @@ -799,7 +797,8 @@ void AddBiasResidualLayerNorm::peft_bwd_kernel(
gamma_ptr,
input_grad_ptr,
residual_grad_ptr,
attn_bias_grad_ptr,
m->reset_input_grads[0],
m->reset_input_grads[1],
N);
}

Expand All @@ -809,7 +808,6 @@ void AddBiasResidualLayerNorm::peft_bwd_kernel_wrapper(
GenericTensorAccessorR const &output_grad,
GenericTensorAccessorW &input_grad,
GenericTensorAccessorW const &residual_grad,
GenericTensorAccessorW const &attn_bias_grad,
GenericTensorAccessorR const &gamma) {
cudaStream_t stream;
checkCUDA(get_legion_stream(&stream));
Expand All @@ -825,15 +823,13 @@ void AddBiasResidualLayerNorm::peft_bwd_kernel_wrapper(
output_grad.get_float_ptr(),
input_grad.get_float_ptr(),
residual_grad.get_float_ptr(),
attn_bias_grad.get_float_ptr(),
m->elementwise_affine ? gamma.get_float_ptr() : nullptr,
stream);
} else if (m->output_type[0] == DT_HALF) {
peft_bwd_kernel(m,
output_grad.get_half_ptr(),
input_grad.get_half_ptr(),
residual_grad.get_half_ptr(),
attn_bias_grad.get_half_ptr(),
m->elementwise_affine ? gamma.get_half_ptr() : nullptr,
stream);
} else {
Expand Down
15 changes: 15 additions & 0 deletions src/ops/fused.cc
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,21 @@ FutureMap FusedOp::inference(FFModel const &ff,
batch_outputs[i]->region));
launcher.add_field(offset + i, FID_DATA);
}
offset += numOutputs;
// add softmax output grad
if (operators[numOperators - 1]->op_type == OP_SOFTMAX) {
printf("operator %i is last SOFTMAX! adding output %i\n",
numOperators - 1,
numOutputs - 1);
assert(outputs[numOutputs - 1]->region != LogicalRegion::NO_REGION);
launcher.add_region_requirement(
RegionRequirement(batch_outputs[numOutputs - 1]->part_grad,
0 /*projection id*/,
WRITE_ONLY,
EXCLUSIVE,
batch_outputs[numOutputs - 1]->region_grad));
launcher.add_field(offset, FID_DATA);
}
return runtime->execute_index_space(ctx, launcher);
}

Expand Down
Loading
Loading