Skip to content

Commit

Permalink
ADD: move the remaining functions to new format
Browse files Browse the repository at this point in the history
  • Loading branch information
T-K-233 committed Jun 18, 2024
1 parent 12030f1 commit 6451de2
Show file tree
Hide file tree
Showing 14 changed files with 108 additions and 235 deletions.
11 changes: 6 additions & 5 deletions nn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,20 @@ add_library(nn
src/nn_print.c
src/nn_abs.c
src/nn_add.c
src/batchnorm2d/nn_batchnorm2d.c
src/conv2d/nn_conv2d.c
src/nn_batchnorm2d.c
src/nn_conv2d.c
src/nn_clip.c
src/nn_copy.c
src/nn_div.c
src/nn_elu.c
src/nn_fill.c
src/interpolate/nn_interpolate.c
src/nn_interpolate.c
src/nn_linear.c
src/nn_matmul.c
src/matrixnorm/nn_matrixnorm.c
src/nn_matrixnorm.c
src/nn_max.c
src/nn_maximum.c
src/maxpool2d/nn_maxpool2d.c
src/nn_maxpool2d.c
src/nn_min.c
src/nn_minimum.c
src/nn_mul.c
Expand Down
15 changes: 10 additions & 5 deletions nn/inc/nn_elu.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,24 @@
#include <assert.h>
#include <math.h>

#include "nn_types.h"
#include "nn_add.h"
#include "nn_matmul.h"
#include "nn_tensor.h"


/**
* Applies the rectified linear unit function element-wise: y = max(0, x)
* Applies the Exponential Linear Unit (ELU) function, element-wise.
*
* The ELU function is defined as:
*
* ELU(x) = x, if x > 0
* alpha * (exp(x) - 1), if x <= 0
*
* @param y: output tensor
* @param x: input tensor
* @param alpha: the alpha value for the ELU formulation
*/
void NN_elu_F32(Tensor *y, Tensor *x, float alpha);
void NN_ELU(Tensor *y, Tensor *x, float alpha);

void NN_ELUInplace(Tensor *x, float alpha);


#endif // __NN_RELU_H
5 changes: 2 additions & 3 deletions nn/inc/nn_matrixnorm.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
*
* @param tensor: the input tensor of shape (m, n)
*/
float NN_matrixNorm_F32(Tensor *tensor);
void NN_matrixNorm(Tensor *scalar, Tensor *x);


float NN_matrixNorm_F32_RVV(Tensor *tensor);
void NN_matrixNorm_F32(Tensor *scalar, Tensor *x);


#endif // __NN_MATRIXNORM_H
60 changes: 0 additions & 60 deletions nn/src/conv2d/nn_conv2d_gemmini.c

This file was deleted.

101 changes: 0 additions & 101 deletions nn/src/conv2d/nn_conv2d_rvv.c

This file was deleted.

17 changes: 0 additions & 17 deletions nn/src/elu/nn_elu.c

This file was deleted.

16 changes: 0 additions & 16 deletions nn/src/matrixnorm/nn_matrixnorm.c

This file was deleted.

28 changes: 0 additions & 28 deletions nn/src/matrixnorm/nn_matrixnorm_rvv.c

This file was deleted.

File renamed without changes.
File renamed without changes.
34 changes: 34 additions & 0 deletions nn/src/nn_elu.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@

#include "nn_elu.h"


void NN_ELU(Tensor *y, Tensor *x, float alpha) {
assert(y->ndim == x->ndim);
assert(y->dtype == x->dtype);
assert(y->size == x->size);

switch (y->dtype) {
case DTYPE_F32:
for (size_t i = 0; i < y->shape[0] * y->shape[1]; i += 1) {
if (((float *)x->data)[i] > 0) {
((float *)y->data)[i] = ((float *)x->data)[i];
}
else {
((float *)y->data)[i] = alpha * (expf(((float *)x->data)[i]) - 1.f);
}
}
// NN__elu_F32(y->size, (float *)y->data, (float *)x->data, 0.0f);
return;

default:
break;
}

printf("[ERROR] Unsupported operation between tensor with dtype %s = ELU(%s)\n",
NN_getDataTypeName(y->dtype), NN_getDataTypeName(x->dtype)
);
}

void NN_ELUInplace(Tensor *x, float alpha) {
NN_ELU(x, x, alpha);
}
File renamed without changes.
56 changes: 56 additions & 0 deletions nn/src/nn_matrixnorm.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@

#include "nn_matrixnorm.h"

#ifdef RVV
#include <riscv_vector.h>
#endif

void NN_matrixNorm(Tensor *scalar, Tensor *x) {
assert(x->ndim == 2);
assert(NN_isScalar(scalar));
assert(scalar->dtype == x->dtype);

switch (x->dtype) {
case DTYPE_F32:
NN_matrixNorm_F32(scalar, x);
return;

default:
break;
}

printf("[ERROR] Unsupported operation between tensor with dtype %s = ||%s||\n",
NN_getDataTypeName(scalar->dtype), NN_getDataTypeName(x->dtype)
);
}

void NN_matrixNorm_F32(Tensor *scalar, Tensor *x) {
float sum = 0;
#ifdef RVV
float *ptr = x->data;

size_t vlmax = __riscv_vsetvlmax_e32m1();
vfloat32m1_t vec_zero = __riscv_vfmv_v_f_f32m1(0, vlmax);
vfloat32m1_t vec_accumulate = __riscv_vfmv_v_f_f32m1(0, vlmax);

size_t n = x->shape[0] * x->shape[1];
while (n > 0) {
size_t vl = __riscv_vsetvl_e32m1(n);
vfloat32m1_t vec_a = __riscv_vle32_v_f32m1(ptr, vl);
vec_accumulate = __riscv_vfmacc_vv_f32m1(vec_accumulate, vec_a, vec_a, vl);
ptr += vl;
n -= vl;
}
vfloat32m1_t vec_sum = __riscv_vfredusum_vs_f32m1_f32m1(vec_accumulate, vec_zero, vlmax);
sum = __riscv_vfmv_f_s_f32m1_f32(vec_sum);
#else
for (size_t i = 0; i < x->shape[0]; i += 1) {
for (size_t j = 0; j < x->shape[1]; j += 1) {
sum += pow(((float *)x->data)[i * x->shape[1] + j], 2);
}
}
#endif

((float *)scalar->data)[0] = sqrt(sum);
return;
}
File renamed without changes.

0 comments on commit 6451de2

Please sign in to comment.