diff --git a/nn/CMakeLists.txt b/nn/CMakeLists.txt index 6dfa612..5a7bab7 100644 --- a/nn/CMakeLists.txt +++ b/nn/CMakeLists.txt @@ -1,17 +1,18 @@ set(INCLUDES - ./ - ./add/ - ./matmul/ + inc ) set(SOURCES - ./add/nn_add.c - ./matmul/nn_matmul.c - - ./add/nn_add_rvv.c - ./matmul/nn_matmul_rvv.c + src/nn_tensor.c + src/nn_print.c + src/add/nn_add.c + src/copy/nn_copy.c + src/matmul/nn_matmul.c + + src/add/nn_add_rvv.c + src/matmul/nn_matmul_rvv.c ) add_library(nn ${SOURCES}) diff --git a/nn/inc/nn.h b/nn/inc/nn.h new file mode 100644 index 0000000..1a8a0b3 --- /dev/null +++ b/nn/inc/nn.h @@ -0,0 +1,39 @@ +#ifndef __NN_H +#define __NN_H + +#include + +#include "nn_tensor.h" +#include "nn_print.h" +#include "nn_add.h" +#include "nn_copy.h" +#include "nn_matmul.h" + + +// http://elm-chan.org/junk/32bit/binclude.html +#define INCLUDE_FILE(section, filename, symbol) asm (\ + ".section "#section"\n" /* Change section */\ + ".balign 4\n" /* Word alignment */\ + ".global "#symbol"_start\n" /* Export the object start address */\ + ".global "#symbol"_data\n" /* Export the object address */\ + #symbol"_start:\n" /* Define the object start address label */\ + #symbol"_data:\n" /* Define the object label */\ + ".incbin \""filename"\"\n" /* Import the file */\ + ".global "#symbol"_end\n" /* Export the object end address */\ + #symbol"_end:\n" /* Define the object end address label */\ + ".balign 4\n" /* Word alignment */\ + ".section \".text\"\n") /* Restore section */ + + + +void NN_assert(int condition, char *message) { + if (!condition) { + printf("Assertion failed: "); + printf("%s\n", message); + exit(1); + } +} + + + +#endif // __NN_H \ No newline at end of file diff --git a/nn/add/nn_add.h b/nn/inc/nn_add.h similarity index 99% rename from nn/add/nn_add.h rename to nn/inc/nn_add.h index fda312c..74fe07a 100644 --- a/nn/add/nn_add.h +++ b/nn/inc/nn_add.h @@ -5,6 +5,7 @@ #include "nn_tensor.h" + /** * Element-wise addition * @@ -25,4 +26,5 @@ void NN_add_INT(Tensor *out, Tensor *a, Tensor *b); void NN_add_F32_RVV(Tensor *out, Tensor *a, Tensor *b); + #endif // __NN_ADD_H diff --git a/nn/clip/nn_clip.h b/nn/inc/nn_clip.h similarity index 100% rename from nn/clip/nn_clip.h rename to nn/inc/nn_clip.h diff --git a/nn/inc/nn_copy.h b/nn/inc/nn_copy.h new file mode 100644 index 0000000..f235533 --- /dev/null +++ b/nn/inc/nn_copy.h @@ -0,0 +1,16 @@ +#ifndef __NN_COPY_H +#define __NN_COPY_H + +#include "nn_tensor.h" + + +/** + * Copies values from one tensor to another + * + * @param dst: destination tensor + * @param src: source tensor + */ +void NN_copy(Tensor *dst, Tensor *src); + + +#endif // __NN_COPY_H diff --git a/nn/elu/nn_elu.h b/nn/inc/nn_elu.h similarity index 100% rename from nn/elu/nn_elu.h rename to nn/inc/nn_elu.h diff --git a/nn/linear/nn_linear.h b/nn/inc/nn_linear.h similarity index 100% rename from nn/linear/nn_linear.h rename to nn/inc/nn_linear.h diff --git a/nn/matmul/nn_matmul.h b/nn/inc/nn_matmul.h similarity index 100% rename from nn/matmul/nn_matmul.h rename to nn/inc/nn_matmul.h diff --git a/nn/inc/nn_print.h b/nn/inc/nn_print.h new file mode 100644 index 0000000..31c4cbc --- /dev/null +++ b/nn/inc/nn_print.h @@ -0,0 +1,14 @@ +#ifndef __NN_PRINT_H +#define __NN_PRINT_H + +#include "nn_tensor.h" + + +void NN_printFloat(float v, int16_t num_digits); + +void NN_printShape(Tensor *t); + +void NN_printf(Tensor *t); + + +#endif // __NN_PRINT_H diff --git a/nn/relu/nn_relu.h b/nn/inc/nn_relu.h similarity index 71% rename from nn/relu/nn_relu.h rename to nn/inc/nn_relu.h index d5cb3d4..e02a1a9 100644 --- a/nn/relu/nn_relu.h +++ b/nn/inc/nn_relu.h @@ -9,7 +9,9 @@ /** - * Applies the rectified linear unit function element-wise: y = max(0, x) + * Applies the rectified linear unit function element-wise + * + * y = max(0, x) * */ void NN_relu_F32(Tensor *y, Tensor *x); diff --git a/nn/nn_tensor.h b/nn/inc/nn_tensor.h similarity index 50% rename from nn/nn_tensor.h rename to nn/inc/nn_tensor.h index b1539c2..5ff324a 100644 --- a/nn/nn_tensor.h +++ b/nn/inc/nn_tensor.h @@ -1,8 +1,9 @@ -#ifndef __NN_TYPES -#define __NN_TYPES +#ifndef __NN_TENSOR +#define __NN_TENSOR #include #include +#include #include #define MAX_DIMS 4 @@ -20,7 +21,6 @@ typedef enum { DTYPE_F64, } DataType; - typedef struct { DataType dtype; /** datatype */ size_t ndim; /** number of dimensions */ @@ -42,12 +42,18 @@ static inline size_t NN_sizeof(DataType dtype) { switch (dtype) { case DTYPE_I8: return sizeof(int8_t); + case DTYPE_I16: + return sizeof(int16_t); case DTYPE_I32: return sizeof(int32_t); + case DTYPE_I64: + return sizeof(int64_t); case DTYPE_F32: return sizeof(float); + case DTYPE_F64: + return sizeof(double); default: - printf("Unsupported data type\n"); + printf("[WARNING] Unsupported data type: %d\n", dtype); return 0; } } @@ -56,13 +62,53 @@ static inline const char *NN_getDataTypeName(DataType dtype) { switch (dtype) { case DTYPE_I8: return "INT8"; + case DTYPE_I16: + return "INT16"; case DTYPE_I32: return "INT32"; + case DTYPE_I64: + return "INT64"; case DTYPE_F32: return "FLOAT32"; + case DTYPE_F64: + return "FLOAT64"; default: return "UNKNOWN"; } } -#endif // __NN_TYPES \ No newline at end of file +static inline void NN_freeTensorData(Tensor *t) { + free(t->data); +} + +static inline void NN_deleteTensor(Tensor *t) { + free(t); +} + + +/** + * Initialize a tensor + * + * @param ndim: number of dimensions + * @param shape: shape of tensor + * @param dtype: DataType + * @param data: pointer to data, if NULL, the data will be allocated + */ +void NN_initTensor(Tensor *t, size_t ndim, size_t *shape, DataType dtype, void *data); + +Tensor *NN_tensor(size_t ndim, size_t *shape, DataType dtype, void *data); + +Tensor *NN_zeros(size_t ndim, size_t *shape, DataType dtype); + +Tensor *NN_ones(size_t ndim, size_t *shape, DataType dtype); + +/** + * Convert tensor data type + * + * @param t: input tensor + * @param dtype: target data type + */ +void NN_asType(Tensor *t, DataType dtype); + + +#endif // __NN_TENSOR \ No newline at end of file diff --git a/nn/transpose/nn_transpose.h b/nn/inc/nn_transpose.h similarity index 100% rename from nn/transpose/nn_transpose.h rename to nn/inc/nn_transpose.h diff --git a/nn/riscv_vector.h b/nn/inc/riscv_vector.h similarity index 100% rename from nn/riscv_vector.h rename to nn/inc/riscv_vector.h diff --git a/nn/nn.h b/nn/nn.h deleted file mode 100644 index 21356f0..0000000 --- a/nn/nn.h +++ /dev/null @@ -1,259 +0,0 @@ -#ifndef __NN_H -#define __NN_H - -#include - -#include "nn_tensor.h" -#include "nn_add.h" -#include "nn_matmul.h" - - - - -// http://elm-chan.org/junk/32bit/binclude.html -#define INCLUDE_FILE(section, filename, symbol) asm (\ - ".section "#section"\n" /* Change section */\ - ".balign 4\n" /* Word alignment */\ - ".global "#symbol"_start\n" /* Export the object start address */\ - ".global "#symbol"_data\n" /* Export the object address */\ - #symbol"_start:\n" /* Define the object start address label */\ - #symbol"_data:\n" /* Define the object label */\ - ".incbin \""filename"\"\n" /* Import the file */\ - ".global "#symbol"_end\n" /* Export the object end address */\ - #symbol"_end:\n" /* Define the object end address label */\ - ".balign 4\n" /* Word alignment */\ - ".section \".text\"\n") /* Restore section */ - - - -void NN_assert(int condition, char *message) { - if (!condition) { - printf("Assertion failed: "); - printf("%s\n", message); - exit(1); - } -} - - -/** - * Create a tensor - * - * @param ndim: number of dimensions - * @param shape: shape of tensor - * @param dtype: DataType - */ -void NN_initTensor(Tensor *t, size_t ndim, size_t *shape, DataType dtype, void *data) { - t->ndim = ndim; - t->dtype = dtype; - t->data = data; - - // set shape - for (size_t i = 0; i < ndim; i += 1) { - t->shape[i] = shape[i]; - } - for (size_t i = ndim; i < MAX_DIMS; i += 1) { - t->shape[i] = 0; - } - - // set strides - t->strides[ndim-1] = NN_sizeof(dtype); - for (size_t i = 0; i < ndim-1; i += 1) { - t->strides[ndim-i-2] = t->strides[ndim-i-1] * t->shape[ndim-i-1]; - } - - // calculate size (number of elements) - t->size = 1; - for (size_t i = 0; i < ndim; i += 1) { - t->size *= t->shape[i]; - } -} - -void NN_freeTensor(Tensor *t) { - free(t->data); - free(t); -} - - -/* - * ====== Print Functions ====== - * - * These functions assumes that printf is available. - */ -void NN_printFloat(float v, int16_t num_digits) { - if (v < 0) { - printf("-"); // Print the minus sign for negative numbers - v = -v; // Make the number positive for processing - } - - // Calculate the integer part of the number - long int_part = (long)v; - float fractional_part = v - int_part; - - // Count the number of digits in the integer part - long temp = int_part; - int int_digits = (int_part == 0) ? 1 : 0; // Handle zero as a special case - while (temp > 0) { - int_digits++; - temp /= 10; - } - - // Print the integer part - printf("%ld", int_part); - - // Calculate the number of fractional digits we can print - int fractional_digits = num_digits - int_digits; - if (fractional_digits > 0) { - printf("."); // Print the decimal point - - // Handle the fractional part - while (fractional_digits-- > 0) { - fractional_part *= 10; - int digit = (int)(fractional_part); - printf("%d", digit); - fractional_part -= digit; - } - } -} - -void NN_printShape(Tensor *t) { - printf("("); - for (size_t i = 0; i < t->ndim; i += 1) { - printf("%d", (int)t->shape[i]); - if (i < t->ndim-1) { - printf(", "); - } - } - printf(")"); -} - -void NN_printf(Tensor *t) { - // print data with torch.Tensor style - if (t->ndim == 1) { - printf("["); - for (size_t i=0; ishape[0]; i+=1) { - switch (t->dtype) { - case DTYPE_I8: - printf("%d", ((int8_t *)t->data)[i]); - break; - case DTYPE_I32: - printf("%ld", (size_t)((int32_t *)t->data)[i]); - break; - case DTYPE_F32: - NN_printFloat(((float *)t->data)[i], 4); - break; - } - if (i < t->shape[0]-1) { - printf(" "); - } - } - printf("]"); - printf("\n"); - return; - } - - printf("["); - for (size_t i=0; ishape[0]; i+=1) { - if (i != 0) { - printf(" "); - } - printf("["); - for (size_t j=0; jshape[1]; j+=1) { - switch (t->dtype) { - case DTYPE_I8: - printf("%d", ((int8_t *)t->data)[i*t->shape[1]+j]); - break; - case DTYPE_I32: - printf("%ld", (size_t)((int32_t *)t->data)[i*t->shape[1]+j]); - break; - case DTYPE_F32: - NN_printFloat(((float *)t->data)[i*t->shape[1]+j], 4); - break; - } - if (j < t->shape[1]-1) { - printf(" "); - } - } - printf("]"); - if (i < t->shape[0]-1) { - printf("\n"); - } - } - printf("]"); - printf("\n"); -} - -/** - * Convert tensor data type - * - * @param t: input tensor - * @param dtype: target data type - */ -void NN_asType(Tensor *t, DataType dtype) { - if (t->dtype == dtype) { - return; - } - if (t->dtype == DTYPE_I32 && dtype == DTYPE_F32) { - for (size_t i = 0; isize; i+=1) { - ((float *)t->data)[i] = (float)((int32_t *)t->data)[i]; - } - t->dtype = DTYPE_F32; - return; - } - if (t->dtype == DTYPE_I32 && dtype == DTYPE_I8) { - for (size_t i = 0; isize; i+=1) { - ((int8_t *)t->data)[i] = (int8_t)((int32_t *)t->data)[i]; - } - t->dtype = DTYPE_I8; - return; - } - - if (t->dtype == DTYPE_F32 && dtype == DTYPE_I32) { - for (size_t i = 0; isize; i+=1) { - ((int32_t *)t->data)[i] = (int32_t)((float *)t->data)[i]; - } - t->dtype = DTYPE_I32; - return; - } - - printf("Cannot convert data type from %s to %s\n", NN_getDataTypeName(t->dtype), NN_getDataTypeName(dtype)); -} - - - -/** - * Copies values from one tensor to another - * - * @param dst: destination tensor - * @param src: source tensor - */ -void NN_copyTo(Tensor *dst, Tensor *src) { - assert(dst->shape[0] == src->shape[0]); - assert(dst->shape[1] == src->shape[1]); - assert(dst->dtype == src->dtype); - - switch (dst->dtype) { - case DTYPE_I8: - for (size_t i = 0; isize; i+=1) { - ((int8_t *)dst->data)[i] = ((int8_t *)src->data)[i]; - } - break; - case DTYPE_I32: - for (size_t i = 0; isize; i+=1) { - ((int32_t *)dst->data)[i] = ((int32_t *)src->data)[i]; - } - break; - case DTYPE_F32: - for (size_t i = 0; isize; i+=1) { - ((float *)dst->data)[i] = ((float *)src->data)[i]; - } - break; - default: - printf("Unsupported data type\n"); - } -} - - - - - -#endif // __NN_H \ No newline at end of file diff --git a/nn/relu/nn_relu.c b/nn/relu/nn_relu.c deleted file mode 100644 index efe3401..0000000 --- a/nn/relu/nn_relu.c +++ /dev/null @@ -1,13 +0,0 @@ - -#include "nn_relu.h" - -void NN_relu_F32(Tensor *y, Tensor *x) { - assert(y->shape[0] == x->shape[0]); - assert(y->shape[1] == x->shape[1]); - assert(y->dtype == DTYPE_F32); - assert(x->dtype == DTYPE_F32); - - for (size_t i = 0; i < y->shape[0] * y->shape[1]; i++) { - ((float *)y->data)[i] = ((float *)x->data)[i] > 0 ? ((float *)x->data)[i] : 0; - } -} diff --git a/nn/add/nn_add.c b/nn/src/add/nn_add.c similarity index 100% rename from nn/add/nn_add.c rename to nn/src/add/nn_add.c diff --git a/nn/add/nn_add_rvv.c b/nn/src/add/nn_add_rvv.c similarity index 100% rename from nn/add/nn_add_rvv.c rename to nn/src/add/nn_add_rvv.c diff --git a/nn/clip/nn_clip.c b/nn/src/clip/nn_clip.c similarity index 100% rename from nn/clip/nn_clip.c rename to nn/src/clip/nn_clip.c diff --git a/nn/src/copy/nn_copy.c b/nn/src/copy/nn_copy.c new file mode 100644 index 0000000..1982eee --- /dev/null +++ b/nn/src/copy/nn_copy.c @@ -0,0 +1,28 @@ + +#include "nn_copy.h" + +void NN_copy(Tensor *dst, Tensor *src) { + dst->dtype = src->dtype; + dst->shape[0] = src->shape[0]; + dst->shape[1] = src->shape[1]; + + switch (dst->dtype) { + case DTYPE_I8: + for (size_t i = 0; isize; i+=1) { + ((int8_t *)dst->data)[i] = ((int8_t *)src->data)[i]; + } + break; + case DTYPE_I32: + for (size_t i = 0; isize; i+=1) { + ((int32_t *)dst->data)[i] = ((int32_t *)src->data)[i]; + } + break; + case DTYPE_F32: + for (size_t i = 0; isize; i+=1) { + ((float *)dst->data)[i] = ((float *)src->data)[i]; + } + break; + default: + printf("[ERROR] Unsupported data type: %d\n", dst->dtype); + } +} diff --git a/nn/elu/nn_elu.c b/nn/src/elu/nn_elu.c similarity index 100% rename from nn/elu/nn_elu.c rename to nn/src/elu/nn_elu.c diff --git a/nn/linear/nn_linear.c b/nn/src/linear/nn_linear.c similarity index 100% rename from nn/linear/nn_linear.c rename to nn/src/linear/nn_linear.c diff --git a/nn/matmul/nn_matmul.c b/nn/src/matmul/nn_matmul.c similarity index 100% rename from nn/matmul/nn_matmul.c rename to nn/src/matmul/nn_matmul.c diff --git a/nn/matmul/nn_matmul_rvv.c b/nn/src/matmul/nn_matmul_rvv.c similarity index 100% rename from nn/matmul/nn_matmul_rvv.c rename to nn/src/matmul/nn_matmul_rvv.c diff --git a/nn/src/nn_print.c b/nn/src/nn_print.c new file mode 100644 index 0000000..67c857c --- /dev/null +++ b/nn/src/nn_print.c @@ -0,0 +1,106 @@ + +#include "nn_print.h" + + +void NN_printFloat(float v, int16_t num_digits) { + if (v < 0) { + printf("-"); // Print the minus sign for negative numbers + v = -v; // Make the number positive for processing + } + + // Calculate the integer part of the number + long int_part = (long)v; + float fractional_part = v - int_part; + + // Count the number of digits in the integer part + long temp = int_part; + int int_digits = (int_part == 0) ? 1 : 0; // Handle zero as a special case + while (temp > 0) { + int_digits++; + temp /= 10; + } + + // Print the integer part + printf("%ld", int_part); + + // Calculate the number of fractional digits we can print + int fractional_digits = num_digits - int_digits; + if (fractional_digits > 0) { + printf("."); // Print the decimal point + + // Handle the fractional part + while (fractional_digits-- > 0) { + fractional_part *= 10; + int digit = (int)(fractional_part); + printf("%d", digit); + fractional_part -= digit; + } + } +} + +void NN_printShape(Tensor *t) { + printf("("); + for (size_t i = 0; i < t->ndim; i += 1) { + printf("%d", (int)t->shape[i]); + if (i < t->ndim-1) { + printf(", "); + } + } + printf(")"); +} + +void NN_printf(Tensor *t) { + // print data with torch.Tensor style + if (t->ndim == 1) { + printf("["); + for (size_t i=0; ishape[0]; i+=1) { + switch (t->dtype) { + case DTYPE_I8: + printf("%d", ((int8_t *)t->data)[i]); + break; + case DTYPE_I32: + printf("%ld", (size_t)((int32_t *)t->data)[i]); + break; + case DTYPE_F32: + NN_printFloat(((float *)t->data)[i], 4); + break; + } + if (i < t->shape[0]-1) { + printf(" "); + } + } + printf("]"); + printf("\n"); + return; + } + + printf("["); + for (size_t i=0; ishape[0]; i+=1) { + if (i != 0) { + printf(" "); + } + printf("["); + for (size_t j=0; jshape[1]; j+=1) { + switch (t->dtype) { + case DTYPE_I8: + printf("%d", ((int8_t *)t->data)[i*t->shape[1]+j]); + break; + case DTYPE_I32: + printf("%ld", (size_t)((int32_t *)t->data)[i*t->shape[1]+j]); + break; + case DTYPE_F32: + NN_printFloat(((float *)t->data)[i*t->shape[1]+j], 4); + break; + } + if (j < t->shape[1]-1) { + printf(" "); + } + } + printf("]"); + if (i < t->shape[0]-1) { + printf("\n"); + } + } + printf("]"); + printf("\n"); +} diff --git a/nn/src/nn_tensor.c b/nn/src/nn_tensor.c new file mode 100644 index 0000000..251cfea --- /dev/null +++ b/nn/src/nn_tensor.c @@ -0,0 +1,123 @@ + +#include "nn_tensor.h" + + +void NN_initTensor(Tensor *t, size_t ndim, size_t *shape, DataType dtype, void *data) { + t->ndim = ndim; + t->dtype = dtype; + + // set shape + for (size_t i = 0; i < ndim; i += 1) { + t->shape[i] = shape[i]; + } + for (size_t i = ndim; i < MAX_DIMS; i += 1) { + t->shape[i] = 0; + } + + // set strides + t->strides[ndim-1] = NN_sizeof(dtype); + for (size_t i = 0; i < ndim-1; i += 1) { + t->strides[ndim-i-2] = t->strides[ndim-i-1] * t->shape[ndim-i-1]; + } + + // calculate size (number of elements) + t->size = 1; + for (size_t i = 0; i < ndim; i += 1) { + t->size *= t->shape[i]; + } + + if (data == NULL) { + t->data = malloc(NN_sizeof(dtype) * t->size); + } else { + t->data = data; + } +} + + +Tensor *NN_tensor(size_t ndim, size_t *shape, DataType dtype, void *data) { + Tensor *t = (Tensor *)malloc(sizeof(Tensor)); + NN_initTensor(t, ndim, shape, dtype, data); + return t; +} + +Tensor *NN_zeros(size_t ndim, size_t *shape, DataType dtype) { + Tensor *t = NN_tensor(ndim, shape, dtype, NULL); + + switch (dtype) { + case DTYPE_I8: + for (size_t i = 0; isize; i+=1) { + ((int8_t *)t->data)[i] = 0; + } + break; + case DTYPE_I32: + for (size_t i = 0; isize; i+=1) { + ((int32_t *)t->data)[i] = 0; + } + break; + case DTYPE_F32: + for (size_t i = 0; isize; i+=1) { + ((float *)t->data)[i] = 0; + } + break; + default: + printf("[WARNING] Unsupported data type: %d\n", dtype); + } + + return t; +} + +Tensor *NN_ones(size_t ndim, size_t *shape, DataType dtype) { + Tensor *t = NN_tensor(ndim, shape, dtype, NULL); + + switch (dtype) { + case DTYPE_I8: + for (size_t i = 0; isize; i+=1) { + ((int8_t *)t->data)[i] = 1; + } + break; + case DTYPE_I32: + for (size_t i = 0; isize; i+=1) { + ((int32_t *)t->data)[i] = 1; + } + break; + case DTYPE_F32: + for (size_t i = 0; isize; i+=1) { + ((float *)t->data)[i] = 1; + } + break; + default: + printf("[WARNING] Unsupported data type: %d\n", dtype); + } + + return t; +} + +void NN_asType(Tensor *t, DataType dtype) { + if (t->dtype == dtype) { + return; + } + if (t->dtype == DTYPE_I32 && dtype == DTYPE_F32) { + for (size_t i = 0; isize; i+=1) { + ((float *)t->data)[i] = (float)((int32_t *)t->data)[i]; + } + t->dtype = DTYPE_F32; + return; + } + if (t->dtype == DTYPE_I32 && dtype == DTYPE_I8) { + for (size_t i = 0; isize; i+=1) { + ((int8_t *)t->data)[i] = (int8_t)((int32_t *)t->data)[i]; + } + t->dtype = DTYPE_I8; + return; + } + + if (t->dtype == DTYPE_F32 && dtype == DTYPE_I32) { + for (size_t i = 0; isize; i+=1) { + ((int32_t *)t->data)[i] = (int32_t)((float *)t->data)[i]; + } + t->dtype = DTYPE_I32; + return; + } + + printf("[ERROR] Cannot convert data type from %s to %s\n", NN_getDataTypeName(t->dtype), NN_getDataTypeName(dtype)); +} diff --git a/nn/src/relu/nn_relu.c b/nn/src/relu/nn_relu.c new file mode 100644 index 0000000..723a91a --- /dev/null +++ b/nn/src/relu/nn_relu.c @@ -0,0 +1,17 @@ + +#include "nn_relu.h" + +void NN_relu_F32(Tensor *y, Tensor *x) { + assert(x->dtype == DTYPE_F32); + + y->dtype = DTYPE_F32; + y->shape[0] == x->shape[0]; + y->shape[1] == x->shape[1]; + + float *y_data = (float *)y->data; + float *x_data = (float *)x->data; + + for (size_t i = 0; i < y->shape[0] * y->shape[1]; i+=1) { + y_data[i] = x_data[i] > 0 ? x_data[i] : 0; + } +} diff --git a/nn/transpose/nn_transpose.c b/nn/src/transpose/nn_transpose.c similarity index 91% rename from nn/transpose/nn_transpose.c rename to nn/src/transpose/nn_transpose.c index f43831a..705a217 100644 --- a/nn/transpose/nn_transpose.c +++ b/nn/src/transpose/nn_transpose.c @@ -12,12 +12,6 @@ * @param a: input tensor of shape (m, n) */ void NN_transpose(Tensor *out, Tensor *a) { - assert(out->shape[0] == a->shape[1]); - assert(out->shape[1] == a->shape[0]); - assert(a->ndim == 2); - assert(out->dtype == a->dtype); - assert(out->data != a->data); - if (a->dtype == DTYPE_F32) { NN_transpose_F32(out, a); return;