Skip to content

Commit

Permalink
ADD: add fp16 matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
T-K-233 committed Jul 9, 2024
1 parent 7565bdc commit e79801b
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 115 deletions.
35 changes: 35 additions & 0 deletions nn/inc/ops/dot.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,41 @@
#include <riscv_vector.h>
#endif

#include "nn_float16.h"


static inline void NN__dot_F16(size_t n, float16_t *s, float16_t *x, float16_t *y) {
float16_t sum = 0.0;

#ifdef RVV
size_t vlmax = __riscv_vsetvlmax_e16m1();

vfloat16m1_t vec_zero = __riscv_vfmv_v_f_f16m1(0, vlmax);
vfloat16m1_t vec_s = __riscv_vfmv_v_f_f16m1(0, vlmax);

while (n > 0) {
size_t vl = __riscv_vsetvl_e16m1(n);
vfloat16m1_t vec_x = __riscv_vle16_v_f16m1(x, vl);
vfloat16m1_t vec_y = __riscv_vle16_v_f16m1(y, vl);
vec_s = __riscv_vfmacc_vv_f16m1(vec_s, vec_x, vec_y, vl);

x += vl;
y += vl;
n -= vl;
}
vec_s = __riscv_vfredusum_vs_f16m1_f16m1(vec_s, vec_zero, vlmax);
sum = __riscv_vfmv_f_s_f16m1_f16(vec_s);
#else
float sum_f32 = 0;
for (size_t i = 0; i < n; i += 1) {
sum_f32 += NN_halfToFloat(x[i]) * NN_halfToFloat(y[i]);
}
sum = NN_floatToHalf(sum_f32);
#endif

*s = sum;
}

static inline void NN__dot_F32(size_t n, float *s, float *x, float *y) {
float sum = 0.0;

Expand Down
44 changes: 44 additions & 0 deletions nn/src/nn_matmul.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,56 @@ void NN_matmul(Tensor *out, Tensor *a, Tensor *b) {
}
return;
}
if (a->dtype == DTYPE_F16 && b->dtype == DTYPE_F16 && out->dtype == DTYPE_F16) {
// currently only support 2D matrix multiplication
assert(a->ndim == 2);
assert(b->ndim == 2);
assert(a->dtype == DTYPE_F16);
assert(b->dtype == DTYPE_F16);
assert(out->dtype == DTYPE_F16);
assert(a->shape[1] == b->shape[0]);
assert(out->shape[0] == a->shape[0]);
assert(out->shape[1] == b->shape[1]);

for (size_t i = 0; i < out->shape[0]; i += 1) {
for (size_t j = 0; j < out->shape[1]; j += 1) {
float sum = 0;
for (size_t k = 0; k < a->shape[1]; k += 1) {
sum += NN_halfToFloat(((float16_t *)a->data)[i * a->shape[1] + k]) * NN_halfToFloat(((float16_t *)b->data)[k * b->shape[1] + j]);
}
((float16_t *)out->data)[i * out->shape[1] + j] = NN_floatToHalf(sum);
}
}
return;
}
printf("Unsupported operation: %s = %s @ %s\n",
NN_getDataTypeName(out->dtype), NN_getDataTypeName(a->dtype), NN_getDataTypeName(b->dtype)
);
}

void NN_matmulT(Tensor *out, Tensor *a, Tensor *b) {
if (a->dtype == DTYPE_F16 && b->dtype == DTYPE_F16 && out->dtype == DTYPE_F16) {
// currently only support 2D matrix multiplication
assert(a->ndim == 2);
assert(b->ndim == 2);
assert(a->dtype == DTYPE_F16);
assert(b->dtype == DTYPE_F16);
assert(out->dtype == DTYPE_F16);
assert(a->shape[1] == b->shape[1]);
assert(out->shape[0] == a->shape[0]);
assert(out->shape[1] == b->shape[0]);

for (size_t i = 0; i < out->shape[0]; i += 1) {
for (size_t j = 0; j < out->shape[1]; j += 1) {
NN__dot_F16(a->shape[1],
(float16_t *)out->data + i * out->shape[1] + j,
(float16_t *)a->data + i * a->shape[1],
(float16_t *)b->data + j * b->shape[1]
);
}
}
return;
}
if (a->dtype == DTYPE_F32 && b->dtype == DTYPE_F32 && out->dtype == DTYPE_F32) {
// currently only support 2D matrix multiplication
assert(a->ndim == 2);
Expand Down
34 changes: 18 additions & 16 deletions tests/src/generate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,26 +63,28 @@ def rand16(shape):
# ", 1" ),
# ("ReLU6", lambda x: torch.nn.functional.relu6(x),
# [("x", rand((7, 7))) ]),
("Conv2d", lambda x, w, b: torch.nn.functional.conv2d(x.permute((0, 3, 1, 2)), w.permute((3, 2, 0, 1)), b, stride=1, padding=0, dilation=1, groups=1).permute((0, 2, 3, 1)),
[("x", rand((1, 16, 16, 3))), ("w", rand((3, 3, 3, 6))), ("b", rand((6, )))],
", (size_t[]){1, 1}, (size_t[]){0, 0}, (size_t[]){1, 1}, 1" ),
("Conv2d", lambda x, w, b: torch.nn.functional.conv2d(x.permute((0, 3, 1, 2)), w.permute((3, 2, 0, 1)), b, stride=1, padding=1, dilation=1, groups=1).permute((0, 2, 3, 1)),
[("x", rand((1, 16, 16, 3))), ("w", rand((3, 3, 3, 71))), ("b", rand((71, )))],
", (size_t[]){1, 1}, (size_t[]){1, 1}, (size_t[]){1, 1}, 1" ),
("NCHWToNHWC", lambda x: x.permute((0, 2, 3, 1)), [("x", rand((1, 2, 3, 3))) ]),
("NHWCToNCHW", lambda x: x.permute((0, 3, 1, 2)), [("x", rand((1, 3, 3, 2))) ]),
("Conv2d", lambda x, w, b: torch.nn.functional.conv2d(x.permute((0, 3, 1, 2)), w.permute((3, 2, 0, 1)), b, stride=1, padding=1, dilation=1, groups=16).permute((0, 2, 3, 1)),
[("x", rand((1, 12, 12, 16))), ("w", rand((3, 3, 1, 16))), ("b", rand((16, )))],
", (size_t[]){1, 1}, (size_t[]){1, 1}, (size_t[]){1, 1}, 16" ),
("Conv2d", lambda x, w, b: torch.nn.functional.conv2d(x.permute((0, 3, 1, 2)), w.permute((3, 2, 0, 1)), b, stride=1, padding=1, dilation=1, groups=1).permute((0, 2, 3, 1)),
[("x", rand((1, 12, 12, 16))), ("w", rand((3, 3, 16, 56))), ("b", rand((56, )))],
", (size_t[]){1, 1}, (size_t[]){1, 1}, (size_t[]){1, 1}, 1" ),
# ("Conv2d", lambda x, w, b: torch.nn.functional.conv2d(x.permute((0, 3, 1, 2)), w.permute((3, 2, 0, 1)), b, stride=1, padding=0, dilation=1, groups=1).permute((0, 2, 3, 1)),
# [("x", rand((1, 16, 16, 3))), ("w", rand((3, 3, 3, 6))), ("b", rand((6, )))],
# ", (size_t[]){1, 1}, (size_t[]){0, 0}, (size_t[]){1, 1}, 1" ),
# ("Conv2d", lambda x, w, b: torch.nn.functional.conv2d(x.permute((0, 3, 1, 2)), w.permute((3, 2, 0, 1)), b, stride=1, padding=1, dilation=1, groups=1).permute((0, 2, 3, 1)),
# [("x", rand((1, 16, 16, 3))), ("w", rand((3, 3, 3, 71))), ("b", rand((71, )))],
# ", (size_t[]){1, 1}, (size_t[]){1, 1}, (size_t[]){1, 1}, 1" ),
# ("NCHWToNHWC", lambda x: x.permute((0, 2, 3, 1)), [("x", rand((1, 2, 3, 3))) ]),
# ("NHWCToNCHW", lambda x: x.permute((0, 3, 1, 2)), [("x", rand((1, 3, 3, 2))) ]),
# ("Conv2d", lambda x, w, b: torch.nn.functional.conv2d(x.permute((0, 3, 1, 2)), w.permute((3, 2, 0, 1)), b, stride=1, padding=1, dilation=1, groups=16).permute((0, 2, 3, 1)),
# [("x", rand((1, 12, 12, 16))), ("w", rand((3, 3, 1, 16))), ("b", rand((16, )))],
# ", (size_t[]){1, 1}, (size_t[]){1, 1}, (size_t[]){1, 1}, 16" ),
# ("Conv2d", lambda x, w, b: torch.nn.functional.conv2d(x.permute((0, 3, 1, 2)), w.permute((3, 2, 0, 1)), b, stride=1, padding=1, dilation=1, groups=1).permute((0, 2, 3, 1)),
# [("x", rand((1, 12, 12, 16))), ("w", rand((3, 3, 16, 56))), ("b", rand((56, )))],
# ", (size_t[]){1, 1}, (size_t[]){1, 1}, (size_t[]){1, 1}, 1" ),
# ("LayerNorm", lambda x, w, b: torch.nn.functional.layer_norm(x, x.shape, w, b, eps=1e-05),
# [("x", rand((6, 5))), ("w", rand((6, 5))), ("b", rand((6, 5))) ],
# ", 1e-05" ),

# ("abs", lambda a: torch.abs(a), [("a", rand16((1, 4))), ]),
# ("add", lambda a, b: a + b, [("a", rand16((6, 7))), ("b", rand16((6, 7))) ]),
("abs", lambda a: torch.abs(a), [("a", rand16((1, 4))), ]),
("add", lambda a, b: a + b, [("a", rand16((6, 7))), ("b", rand16((6, 7))) ]),
("matmulT", lambda a, b: a @ b.T, [("a", rand16((6, 7))), ("b", rand16((5, 7))) ]),
("matmul", lambda a, b: a @ b, [("a", rand16((6, 7))), ("b", rand16((7, 5))) ]),
]


Expand Down
Loading

0 comments on commit e79801b

Please sign in to comment.