Skip to content

Commit

Permalink
[ElementUnary] add support for rsqrt
Browse files Browse the repository at this point in the history
  • Loading branch information
jiazhihao committed Sep 30, 2021
1 parent ed31405 commit ec7d783
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 1 deletion.
1 change: 1 addition & 0 deletions include/ffconst.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ enum OperatorType {
OP_GELU,
OP_MULTIHEAD_ATTENTION,
OP_FUSED, // Fused operator type for internal fusion optimizations
OP_RSQRT, //https://pytorch.org/docs/stable/generated/torch.rsqrt.html
};

#endif // _FLEXFLOW_CONST_H_
4 changes: 4 additions & 0 deletions include/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,10 @@ class FFModel {
const Tensor& y,
bool inplace_a = false,
char const *name = NULL);
// Add a rsqrt layer
Tensor rsqrt(const Tensor& x,
bool inplace = true,
char const *name = NULL);
// Add a scalar operation layer
Tensor scalar_multiply(const Tensor& x,
const float scalar,
Expand Down
18 changes: 17 additions & 1 deletion src/ops/element_unary.cu
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ Tensor FFModel::elu(const Tensor& x, bool inplace, const char *name)
return this->unary(OP_ELU, x, inplace, name);
}

Tensor FFModel::rsqrt(const Tensor& x, bool inplace, const char *name)
{
return this->unary(OP_RSQRT, x, inplace, name);
}

ElementUnary::ElementUnary(FFModel& model,
OperatorType _op_type,
const Tensor& x,
Expand Down Expand Up @@ -342,6 +347,11 @@ void elewise_unary_forward_kernel(coord_t volume,
out[i] = in[i] * 0.5 * erfc(-in[i]*M_SQRT1_2);
break;
}
case OP_RSQRT:
{
out[i] = 1.0f / sqrt(in[i]);
break;
}
default:
assert(false);
}
Expand Down Expand Up @@ -459,6 +469,7 @@ void elewise_unary_backward_kernel(coord_t volume,
const float beta,
const float scalar,
OperatorType type,
const float* output,
const float* output_grad,
const float* input,
float* input_grad)
Expand Down Expand Up @@ -502,6 +513,11 @@ void elewise_unary_backward_kernel(coord_t volume,
input_grad[i] = output_grad[i]*(0.5 * erfc(-input[i]*M_SQRT1_2)-0.5*M_SQRT1_2*input[i]*exp(-input[i]*input[i]*0.5));
break;
}
case OP_RSQRT:
{
input_grad[i] = -1.0f * output_grad[i] * output[i] * output[i] * output[i];
break;
}
default:
assert(false);
}
Expand All @@ -526,7 +542,7 @@ void ElementUnary::backward_kernel(const ElementUnaryMeta* m,
m->inputTensor, input_ptr, &alpha, m->inputTensor, input_grad_ptr));
} else {
elewise_unary_backward_kernel<<<GET_BLOCKS(num_elements), CUDA_NUM_THREADS, 0, stream>>>(
num_elements, alpha, alpha, m->scalar, m->op_type, output_grad_ptr, input_ptr, input_grad_ptr);
num_elements, alpha, alpha, m->scalar, m->op_type, output_ptr, output_grad_ptr, input_ptr, input_grad_ptr);
}
}

Expand Down

0 comments on commit ec7d783

Please sign in to comment.