Skip to content

Commit

Permalink
Merge branch 'repo-refactor' into repo-refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Bob-Chen222 authored Mar 14, 2024
2 parents ffa7f79 + 3237169 commit 502b41f
Show file tree
Hide file tree
Showing 17 changed files with 455 additions and 864 deletions.
4 changes: 2 additions & 2 deletions lib/compiler/test/test_dp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ TEST_CASE("optimal_cost") {

Node n0 = g.add_node(InputAttrs());
Node n1 = g.add_node(RepartitionAttrs(ff_dim_t(0), 2));
Node n2 = g.add_node(ElementScalarUnaryAttrs(OP_SCALAR_ADD, 0));
Node n3 = g.add_node(ElementScalarUnaryAttrs(OP_SCALAR_ADD, 1));
Node n2 = g.add_node(ElementUnaryAttrs(OP_SCALAR_ADD, 0));
Node n3 = g.add_node(ElementUnaryAttrs(OP_SCALAR_ADD, 1));
Node n4 = g.add_node(ConcatAttrs(ff_dim_t(1)));
Node n5 = g.add_node(CombineAttrs(ff_dim_t(0), 2));

Expand Down
41 changes: 23 additions & 18 deletions lib/kernels/include/kernels/element_unary_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,42 +3,47 @@

#include "kernels/accessor.h"
#include "kernels/device.h"
#include "legion.h"
#include "kernels/ff_handle.h"
#include "op-attrs/ops/element_unary.h"
#include <cstddef>

namespace FlexFlow {

class ElementUnaryPerDeviceState : public PerDeviceOpState {
public:
ElementUnaryPerDeviceState(FFHandler handle);
using ElementUnaryUnifiedAttrs =
variant<ElementUnaryAttrs, ElementScalarUnaryAttrs>;

struct ElementUnaryPerDeviceState {
ffTensorDescriptor_t inputTensor, outputTensor;
ffActivationDescriptor_t actiDesc;

OperatorType op_type;
DataType data_type;
bool inplace;
float scalar;
char op_name[MAX_OPNAME];
};

FF_VISITABLE_STRUCT_NO_EQ(ElementUnaryPerDeviceState,
inputTensor,
outputTensor,
actiDesc);

namespace Kernels {
namespace ElementUnary {

void init_kernel(ElementUnaryPerDeviceState *m,
Legion::Domain const &input_domain,
Legion::Domain const &output_domain);
ElementUnaryPerDeviceState init_kernel(ArrayShape const &input_shape,
ArrayShape const &output_shape,
ElementUnaryUnifiedAttrs const &attrs);

void forward_kernel(ffStream_t stream,
ElementUnaryPerDeviceState const *m,
ElementUnaryPerDeviceState const &device_state,
ElementUnaryUnifiedAttrs const &attrs,
PerDeviceFFHandle &handle,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output);

void backward_kernel(ffStream_t stream,
ElementUnaryPerDeviceState const *m,
ElementUnaryPerDeviceState const &device_state,
ElementUnaryUnifiedAttrs const &attrs,
PerDeviceFFHandle &handle,
GenericTensorAccessorR const &input,
GenericTensorAccessorR const &input_grad,
GenericTensorAccessorW const &output,
GenericTensorAccessorW const &output_grad);
GenericTensorAccessorW const &input_grad,
GenericTensorAccessorR const &output,
GenericTensorAccessorR const &output_grad);

} // namespace ElementUnary
} // namespace Kernels
Expand Down
175 changes: 103 additions & 72 deletions lib/kernels/src/cuda/element_unary_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,6 @@
#include "kernels/element_unary_kernels.h"

namespace FlexFlow {

// declare Legion names
using Legion::coord_t;
using Legion::Domain;

ElementUnaryPerDeviceState::ElementUnaryPerDeviceState(FFHandler handler)
: PerDeviceOpState(handler) {
checkCUDNN(cudnnCreateTensorDescriptor(&inputTensor));
checkCUDNN(cudnnCreateTensorDescriptor(&outputTensor));
checkCUDNN(cudnnCreateActivationDescriptor(&actiDesc));
}

namespace Kernels {
namespace ElementUnary {

Expand All @@ -45,13 +33,31 @@ static bool use_cudnn(OperatorType op_type) {
}
}

void init_kernel(ElementUnaryPerDeviceState *m,
Domain const &input_domain,
Domain const &output_domain) {
template <T>
optional<T> get_scalar(ElementUnaryAttrs const &attrs) {}

template <T>
optional<T> get_scalar(ElementScalarUnaryAttrs const &attrs) {
return (T)attrs.scalar;
}

ElementUnaryPerDeviceState init_kernel(ArrayShape const &input_shape,
ArrayShape const &output_shape,
ElementUnaryUnifiedAttrs const &attrs) {

ffTensorDescriptor_t inputTensor;
ffTensorDescriptor_t outputTensor;
ffActivationDescriptor_t actiDesc;

if (use_cudnn(m->op_type)) {
checkCUDNN(cudnnCreateTensorDescriptor(&inputTensor));
checkCUDNN(cudnnCreateTensorDescriptor(&outputTensor));
checkCUDNN(cudnnCreateActivationDescriptor(&actiDesc));

Op op_type = std::visit([](auto &&arg) { get_op_type(arg); }, attrs);

if (use_cudnn(op_type)) {
cudnnActivationMode_t mode;
switch (m->op_type) {
switch (op_type) {
case OP_SIGMOID:
mode = CUDNN_ACTIVATION_SIGMOID;
break;
Expand All @@ -67,78 +73,89 @@ void init_kernel(ElementUnaryPerDeviceState *m,
default:
assert(false);
}
checkCUDNN(cudnnSetActivationDescriptor(
m->actiDesc, mode, CUDNN_PROPAGATE_NAN, 0.0));
checkCUDNN(
cudnnSetTensorDescriptorFromDomain(m->inputTensor, input_domain));
// input_domain == output_domain
cudnnSetActivationDescriptor(actiDesc, mode, CUDNN_PROPAGATE_NAN, 0.0));
checkCUDNN(
cudnnSetTensorDescriptorFromArrayShape(inputTensor, input_shape));
checkCUDNN(
cudnnSetTensorDescriptorFromDomain(m->outputTensor, output_domain));
cudnnSetTensorDescriptorFromArrayShape(outputTensor, output_shape));
}

ElementUnaryPerDeviceState per_device_state = {
inputTensor, outputTensor, actiDesc};

return per_device_state;
}

template <DataType T>
struct ForwardKernel {
void operator()(ffStream_t stream,
ElementUnaryPerDeviceState const *m,
ElementUnaryPerDeviceState const &m,
ElementUnaryUnifiedAttrs const &attrs,
PerDeviceFFHandle const &handle,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output) const {
checkCUDNN(cudnnSetStream(m->handle.dnn, stream));
if (use_cudnn(m->op_type)) {
checkCUDNN(cudnnSetStream(handle.dnn, stream));
Op op_type = std::visit([](auto &&arg) { get_op_type(arg); }, attrs);
if (use_cudnn(op_type)) {
float alpha = 1.0f, beta = 0.0f;
checkCUDNN(cudnnActivationForward(m->handle.dnn,
m->actiDesc,
checkCUDNN(cudnnActivationForward(handle.dnn,
m.actiDesc,
&alpha,
m->inputTensor,
m.inputTensor,
input.get<T>(),
&beta,
m->outputTensor,
m.outputTensor,
output.get<T>()));
} else {
optional<T> scalar =
std::visit([](auto &&arg) { get_scalar<T>(arg); }, attrs);
size_t num_elements = input.shape.num_elements();
elewise_unary_forward_kernel<<<GET_BLOCKS(num_elements),
CUDA_NUM_THREADS,
0,
stream>>>(num_elements,
(T)m->scalar,
m->op_type,
input.get<T>(),
output.get<T>());
stream>>>(
num_elements, scalar, op_type, input.get<T>(), output.get<T>());
}
}
}

template <DataType T>
struct BackwardKernel {
void operator()(ffStream_t stream,
ElementUnaryPerDeviceState const *m,
ElementUnaryPerDeviceState const &m,
ElementUnaryUnifiedAttrs const &attrs,
PerDeviceFFHandle const &handle,
GenericTensorAccessorR const &input,
GenericTensorAccessorR const &input_grad,
GenericTensorAccessorW const &output,
GenericTensorAccessorW const &output_grad) {
checkCUDNN(cudnnSetStream(m->handle.dnn, stream));
GenericTensorAccessorW const &input_grad,
GenericTensorAccessorR const &output,
GenericTensorAccessorR const &output_grad) {
checkCUDNN(cudnnSetStream(handle.dnn, stream));

if (use_cudnn(m->op_type)) {
Op op_type = std::visit([](auto &&arg) { get_op_type(arg); }, attrs);
if (use_cudnn(op_type)) {
float alpha = 1.0f;
checkCUDNN(cudnnActivationBackward(m->handle.dnn,
m->actiDesc,
checkCUDNN(cudnnActivationBackward(handle.dnn,
m.actiDesc,
&alpha,
m->outputTensor,
m.outputTensor,
output.get<T>(),
m->outputTensor,
m.outputTensor,
output_grad.get<T>()),
m->inputTensor,
m.inputTensor,
input.get<T>(),
&alpha,
m->inputTensor,
m.inputTensor,
input_grad.get<T>()));
} else {
optional<T> scalar =
std::visit([](auto &&arg) { get_scalar<T>(arg); }, attrs);
size_t num_elements = input.shape.num_elements();
elewise_unary_backward_kernel<T>
<<<GET_BLOCKS(num_elements), CUDA_NUM_THREADS, 0, stream>>>(
num_elements,
m->scalar,
m->op_type,
scalar,
op_type,
output.get<T>(),
output_grad.get<T>(),
input.get<T>(),
Expand All @@ -148,26 +165,40 @@ struct BackwardKernel {
}

void forward_kernel(ffStream_t stream,
ElementUnaryPerDeviceState const *m,
ElementUnaryPerDeviceState const &device_state,
ElementUnaryUnifiedAttrs const &attrs,
PerDeviceFFHandle const &handle,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output) {
{
DataTypeDispatch1<ForwardKernel>{}(m->data_type, stream, m, input, output);
}
DataTypeDispatch1<ForwardKernel>{}(
input.data_type, stream, m, attrs, handle, input, output);
}

void backward_kernel(ffStream_t stream,
ElementUnaryPerDeviceState const *m,
GenericTensorAccessorR const &input,
GenericTensorAccessorR const &input_grad,
GenericTensorAccessorW const &output,
GenericTensorAccessorW const &output_grad)
DataTypeDispatch1<BackwardKernel>{}(
m->data_type, stream, m, input, input_grad, output, output_grad);
void backward_kernel(ffStream_t stream,
ElementUnaryPerDeviceState const &device_state,
ElementUnaryUnifiedAttrs const &attrs,
PerDeviceFFHandle const &handle,
GenericTensorAccessorR const &input,
GenericTensorAccessorR const &input_grad,
GenericTensorAccessorW const &output,
GenericTensorAccessorW const &output_grad) {
DataTypeDispatch1<BackwardKernel>{}(input.data_type,
stream,
m,
attrs,
handle,
input,
input_grad,
output,
output_grad);
}

template <typename T>
__global__ void elewise_unary_forward_kernel(
coord_t volume, const T scalar, OperatorType type, T const *in, T *out) {
__global__ void elewise_unary_forward_kernel(coord_t volume,
optional<T> const scalar,
OperatorType type,
T const *in,
T *out) {
CUDA_KERNEL_LOOP(i, volume) {
switch (type) {
case OP_EXP: {
Expand All @@ -179,19 +210,19 @@ __global__ void elewise_unary_forward_kernel(
break;
}
case OP_SCALAR_MULTIPLY: {
out[i] = in[i] * scalar;
out[i] = in[i] * scalar.value();
break;
}
case OP_SCALAR_ADD: {
out[i] = in[i] + scalar;
out[i] = in[i] + scalar.value();
break;
}
case OP_SCALAR_SUB: {
out[i] = in[i] - scalar;
out[i] = in[i] - scalar.value();
break;
}
case OP_SCALAR_TRUE_DIV: {
out[i] = in[i] / scalar;
out[i] = in[i] / scalar.value();
break;
}
case OP_GELU: {
Expand All @@ -203,7 +234,7 @@ __global__ void elewise_unary_forward_kernel(
break;
}
case OP_POW: {
out[i] = (T)(powf(in[i], scalar));
out[i] = (T)(powf(in[i], scalar.value()));
break;
}
case OP_SIN: {
Expand All @@ -222,7 +253,7 @@ __global__ void elewise_unary_forward_kernel(

template <typename T>
__global__ void elewise_unary_backward_kernel(coord_t volume,
const T scalar,
optional<T> const scalar,
OperatorType type,
T const *output,
T const *output_grad,
Expand All @@ -240,7 +271,7 @@ __global__ void elewise_unary_backward_kernel(coord_t volume,
break;
}
case OP_SCALAR_MULTIPLY: {
input_grad[i] += output_grad[i] * scalar;
input_grad[i] += output_grad[i] * scalar.value();
break;
}
case OP_SCALAR_ADD: {
Expand All @@ -252,7 +283,7 @@ __global__ void elewise_unary_backward_kernel(coord_t volume,
break;
}
case OP_SCALAR_TRUE_DIV: {
input_grad[i] += output_grad[i] / scalar;
input_grad[i] += output_grad[i] / scalar.value();
break;
}
case OP_GELU: {
Expand All @@ -268,8 +299,8 @@ __global__ void elewise_unary_backward_kernel(coord_t volume,
break;
}
case OP_POW: {
input_grad[i] =
(T)(output_grad[i] * scalar * powf(input[i], scalar - 1));
input_grad[i] = (T)(output_grad[i] * scalar.value() *
powf(input[i], scalar.value() - 1));
break;
}
case OP_SIN: {
Expand Down
Loading

0 comments on commit 502b41f

Please sign in to comment.