Skip to content

Commit

Permalink
Merge pull request #2 from ucb-bar/gemmini-fp16
Browse files Browse the repository at this point in the history
ADD: add Gemmini FP16 support
  • Loading branch information
T-K-233 authored Jul 10, 2024
2 parents b28e112 + e79801b commit 5f29956
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 115 deletions.
28 changes: 28 additions & 0 deletions docs/Tensor-Basics.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Tensor Basics

Tensor types are resolved dynamically, such that the API is generic, does not include multiple struct definitions, and enables multiple types within a single program. That is, there is one Tensor type. The tensor may have doubles (DTYPE_F64), float (DTYPE_F32), ints, etc. This design makes it easy to write generic code.

The underlying fundamental operators will be statically typed, and hence the tensor-level API will dynamically determine which fundamental operator to use to do the computation.


## Tensor Element in Memory

The data of the tensor must be contiguous. This is for simplifying the code framework. The side effect of this design choice is that operations like transpose will be expensive, and hence it is recommended to perform such transformations during AOT compilation process.


## Using Externally Created Data

If the data of the tensor is already allocated in memory, that memory can be viewed as a Tensor:

```c
float data[] = { 1, 2, 3,
4, 5, 6 };
Tensor *tensor = nn_tensor(2, (const size_t[]){2, 3}, DTYPE_F32, data);
```
## Zero-dimensional Tensors as Scalars
A scalar is represented by a Tensor object that is zero-dimensional. These Tensors hold a single value and they can be references to a single element in a larger Tensor. They can be used anywhere a Tensor is expected.
When creating such zero-dimensional tensor, the shape will be a NULL pointer, but the size will be set to 1 and a single element worth of memory will be allocated as the data buffer.
35 changes: 35 additions & 0 deletions nn/inc/ops/dot.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,41 @@
#include <riscv_vector.h>
#endif

#include "nn_float16.h"


static inline void NN__dot_F16(size_t n, float16_t *s, float16_t *x, float16_t *y) {
float16_t sum = 0.0;

#ifdef RVV
size_t vlmax = __riscv_vsetvlmax_e16m1();

vfloat16m1_t vec_zero = __riscv_vfmv_v_f_f16m1(0, vlmax);
vfloat16m1_t vec_s = __riscv_vfmv_v_f_f16m1(0, vlmax);

while (n > 0) {
size_t vl = __riscv_vsetvl_e16m1(n);
vfloat16m1_t vec_x = __riscv_vle16_v_f16m1(x, vl);
vfloat16m1_t vec_y = __riscv_vle16_v_f16m1(y, vl);
vec_s = __riscv_vfmacc_vv_f16m1(vec_s, vec_x, vec_y, vl);

x += vl;
y += vl;
n -= vl;
}
vec_s = __riscv_vfredusum_vs_f16m1_f16m1(vec_s, vec_zero, vlmax);
sum = __riscv_vfmv_f_s_f16m1_f16(vec_s);
#else
float sum_f32 = 0;
for (size_t i = 0; i < n; i += 1) {
sum_f32 += NN_halfToFloat(x[i]) * NN_halfToFloat(y[i]);
}
sum = NN_floatToHalf(sum_f32);
#endif

*s = sum;
}

static inline void NN__dot_F32(size_t n, float *s, float *x, float *y) {
float sum = 0.0;

Expand Down
44 changes: 44 additions & 0 deletions nn/src/nn_matmul.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,56 @@ void NN_matmul(Tensor *out, Tensor *a, Tensor *b) {
}
return;
}
if (a->dtype == DTYPE_F16 && b->dtype == DTYPE_F16 && out->dtype == DTYPE_F16) {
// currently only support 2D matrix multiplication
assert(a->ndim == 2);
assert(b->ndim == 2);
assert(a->dtype == DTYPE_F16);
assert(b->dtype == DTYPE_F16);
assert(out->dtype == DTYPE_F16);
assert(a->shape[1] == b->shape[0]);
assert(out->shape[0] == a->shape[0]);
assert(out->shape[1] == b->shape[1]);

for (size_t i = 0; i < out->shape[0]; i += 1) {
for (size_t j = 0; j < out->shape[1]; j += 1) {
float sum = 0;
for (size_t k = 0; k < a->shape[1]; k += 1) {
sum += NN_halfToFloat(((float16_t *)a->data)[i * a->shape[1] + k]) * NN_halfToFloat(((float16_t *)b->data)[k * b->shape[1] + j]);
}
((float16_t *)out->data)[i * out->shape[1] + j] = NN_floatToHalf(sum);
}
}
return;
}
printf("Unsupported operation: %s = %s @ %s\n",
NN_getDataTypeName(out->dtype), NN_getDataTypeName(a->dtype), NN_getDataTypeName(b->dtype)
);
}

void NN_matmulT(Tensor *out, Tensor *a, Tensor *b) {
if (a->dtype == DTYPE_F16 && b->dtype == DTYPE_F16 && out->dtype == DTYPE_F16) {
// currently only support 2D matrix multiplication
assert(a->ndim == 2);
assert(b->ndim == 2);
assert(a->dtype == DTYPE_F16);
assert(b->dtype == DTYPE_F16);
assert(out->dtype == DTYPE_F16);
assert(a->shape[1] == b->shape[1]);
assert(out->shape[0] == a->shape[0]);
assert(out->shape[1] == b->shape[0]);

for (size_t i = 0; i < out->shape[0]; i += 1) {
for (size_t j = 0; j < out->shape[1]; j += 1) {
NN__dot_F16(a->shape[1],
(float16_t *)out->data + i * out->shape[1] + j,
(float16_t *)a->data + i * a->shape[1],
(float16_t *)b->data + j * b->shape[1]
);
}
}
return;
}
if (a->dtype == DTYPE_F32 && b->dtype == DTYPE_F32 && out->dtype == DTYPE_F32) {
// currently only support 2D matrix multiplication
assert(a->ndim == 2);
Expand Down
34 changes: 18 additions & 16 deletions tests/src/generate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,26 +63,28 @@ def rand16(shape):
# ", 1" ),
# ("ReLU6", lambda x: torch.nn.functional.relu6(x),
# [("x", rand((7, 7))) ]),
("Conv2d", lambda x, w, b: torch.nn.functional.conv2d(x.permute((0, 3, 1, 2)), w.permute((3, 2, 0, 1)), b, stride=1, padding=0, dilation=1, groups=1).permute((0, 2, 3, 1)),
[("x", rand((1, 16, 16, 3))), ("w", rand((3, 3, 3, 6))), ("b", rand((6, )))],
", (size_t[]){1, 1}, (size_t[]){0, 0}, (size_t[]){1, 1}, 1" ),
("Conv2d", lambda x, w, b: torch.nn.functional.conv2d(x.permute((0, 3, 1, 2)), w.permute((3, 2, 0, 1)), b, stride=1, padding=1, dilation=1, groups=1).permute((0, 2, 3, 1)),
[("x", rand((1, 16, 16, 3))), ("w", rand((3, 3, 3, 71))), ("b", rand((71, )))],
", (size_t[]){1, 1}, (size_t[]){1, 1}, (size_t[]){1, 1}, 1" ),
("NCHWToNHWC", lambda x: x.permute((0, 2, 3, 1)), [("x", rand((1, 2, 3, 3))) ]),
("NHWCToNCHW", lambda x: x.permute((0, 3, 1, 2)), [("x", rand((1, 3, 3, 2))) ]),
("Conv2d", lambda x, w, b: torch.nn.functional.conv2d(x.permute((0, 3, 1, 2)), w.permute((3, 2, 0, 1)), b, stride=1, padding=1, dilation=1, groups=16).permute((0, 2, 3, 1)),
[("x", rand((1, 12, 12, 16))), ("w", rand((3, 3, 1, 16))), ("b", rand((16, )))],
", (size_t[]){1, 1}, (size_t[]){1, 1}, (size_t[]){1, 1}, 16" ),
("Conv2d", lambda x, w, b: torch.nn.functional.conv2d(x.permute((0, 3, 1, 2)), w.permute((3, 2, 0, 1)), b, stride=1, padding=1, dilation=1, groups=1).permute((0, 2, 3, 1)),
[("x", rand((1, 12, 12, 16))), ("w", rand((3, 3, 16, 56))), ("b", rand((56, )))],
", (size_t[]){1, 1}, (size_t[]){1, 1}, (size_t[]){1, 1}, 1" ),
# ("Conv2d", lambda x, w, b: torch.nn.functional.conv2d(x.permute((0, 3, 1, 2)), w.permute((3, 2, 0, 1)), b, stride=1, padding=0, dilation=1, groups=1).permute((0, 2, 3, 1)),
# [("x", rand((1, 16, 16, 3))), ("w", rand((3, 3, 3, 6))), ("b", rand((6, )))],
# ", (size_t[]){1, 1}, (size_t[]){0, 0}, (size_t[]){1, 1}, 1" ),
# ("Conv2d", lambda x, w, b: torch.nn.functional.conv2d(x.permute((0, 3, 1, 2)), w.permute((3, 2, 0, 1)), b, stride=1, padding=1, dilation=1, groups=1).permute((0, 2, 3, 1)),
# [("x", rand((1, 16, 16, 3))), ("w", rand((3, 3, 3, 71))), ("b", rand((71, )))],
# ", (size_t[]){1, 1}, (size_t[]){1, 1}, (size_t[]){1, 1}, 1" ),
# ("NCHWToNHWC", lambda x: x.permute((0, 2, 3, 1)), [("x", rand((1, 2, 3, 3))) ]),
# ("NHWCToNCHW", lambda x: x.permute((0, 3, 1, 2)), [("x", rand((1, 3, 3, 2))) ]),
# ("Conv2d", lambda x, w, b: torch.nn.functional.conv2d(x.permute((0, 3, 1, 2)), w.permute((3, 2, 0, 1)), b, stride=1, padding=1, dilation=1, groups=16).permute((0, 2, 3, 1)),
# [("x", rand((1, 12, 12, 16))), ("w", rand((3, 3, 1, 16))), ("b", rand((16, )))],
# ", (size_t[]){1, 1}, (size_t[]){1, 1}, (size_t[]){1, 1}, 16" ),
# ("Conv2d", lambda x, w, b: torch.nn.functional.conv2d(x.permute((0, 3, 1, 2)), w.permute((3, 2, 0, 1)), b, stride=1, padding=1, dilation=1, groups=1).permute((0, 2, 3, 1)),
# [("x", rand((1, 12, 12, 16))), ("w", rand((3, 3, 16, 56))), ("b", rand((56, )))],
# ", (size_t[]){1, 1}, (size_t[]){1, 1}, (size_t[]){1, 1}, 1" ),
# ("LayerNorm", lambda x, w, b: torch.nn.functional.layer_norm(x, x.shape, w, b, eps=1e-05),
# [("x", rand((6, 5))), ("w", rand((6, 5))), ("b", rand((6, 5))) ],
# ", 1e-05" ),

# ("abs", lambda a: torch.abs(a), [("a", rand16((1, 4))), ]),
# ("add", lambda a, b: a + b, [("a", rand16((6, 7))), ("b", rand16((6, 7))) ]),
("abs", lambda a: torch.abs(a), [("a", rand16((1, 4))), ]),
("add", lambda a, b: a + b, [("a", rand16((6, 7))), ("b", rand16((6, 7))) ]),
("matmulT", lambda a, b: a @ b.T, [("a", rand16((6, 7))), ("b", rand16((5, 7))) ]),
("matmul", lambda a, b: a @ b, [("a", rand16((6, 7))), ("b", rand16((7, 5))) ]),
]


Expand Down
Loading

0 comments on commit 5f29956

Please sign in to comment.