Skip to content

Commit

Permalink
ADD: add SiLU function
Browse files Browse the repository at this point in the history
  • Loading branch information
T-K-233 committed Jul 15, 2024
1 parent ea718b8 commit 525738d
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 0 deletions.
29 changes: 29 additions & 0 deletions nn/functional/nn_silu.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@

#include "nn_silu.h"


void NN_silu(Tensor *y, const Tensor *x) {
assert(y->ndim == x->ndim);
assert(y->dtype == x->dtype);
assert(y->size == x->size);

switch (y->dtype) {
case DTYPE_F32:
for (size_t i = 0; i < y->size; i++) {
float x_i = ((float *)x->data)[i];
((float *)y->data)[i] = x_i / (1.0f + expf(-x_i));
}
return;

default:
break;
}

printf("[ERROR] Unsupported operation between tensor with dtype %s = SiLU(%s)\n",
NN_get_datatype_name(y->dtype), NN_get_datatype_name(x->dtype)
);
}

void NN_silu_inplace(Tensor *x) {
NN_silu(x, x);
}
24 changes: 24 additions & 0 deletions nn/functional/nn_silu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef __NN_SILU_H
#define __NN_SILU_H

#include <assert.h>

#include "nn_tensor.h"
#include "maximum1.h"


/**
* Applies the Sigmoid Linear Unit (SiLU) function, element-wise.
*
* The SiLU function is also known as the swish function.
*
* y = silu(x) = x * theta(x), where theta(x) is the logistic sigmoid.
*
* @param y: the output tensor
* @param x: the input tensor
*/
void NN_silu(Tensor *y, const Tensor *x);

void NN_silu_inplace(Tensor *x);

#endif // __NN_SILU_H

0 comments on commit 525738d

Please sign in to comment.