Skip to content

Commit

Permalink
dnn: Refactor and verify layernorm
Browse files Browse the repository at this point in the history
  • Loading branch information
colluca committed Oct 28, 2023
1 parent f869b2e commit 13acfd9
Show file tree
Hide file tree
Showing 56 changed files with 1,629 additions and 627 deletions.
3 changes: 3 additions & 0 deletions .clang-format-ignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@

# Ignore vendored third-party code
./sw/math/*
./target/snitch_cluster/sw/apps/transformer/src/transformer.c
./target/snitch_cluster/sw/apps/transformer/src/data.h
./sw/apps/transformer/src/transformer.h
1 change: 1 addition & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ jobs:
with:
flake8-version: "6.0.0"
max-line-length: "100"
exclude: "target/snitch_cluster/sw/apps/dnn/datagen.py"

######################
# Clang-Format Check #
Expand Down
14 changes: 7 additions & 7 deletions sw/blas/axpy/data/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import os

sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../util/sim/"))
from data_utils import format_scalar_definition, format_vector_definition, \
format_vector_declaration, format_ifdef_wrapper # noqa: E402
from data_utils import format_scalar_definition, format_array_definition, \
format_array_declaration, format_ifdef_wrapper # noqa: E402

MIN = -1000
MAX = +1000
Expand Down Expand Up @@ -47,16 +47,16 @@ def main():
a = np.random.uniform(MIN, MAX, 1)
x = np.random.uniform(MIN, MAX, length)
y = np.random.uniform(MIN, MAX, length)
z = np.zeros(length)
g = golden_model(a, x, y)

# Format header file
l_str = format_scalar_definition('const uint32_t', 'l', length)
a_str = format_scalar_definition('const double', 'a', a[0])
x_str = format_vector_definition('double', 'x', x, alignment=BURST_ALIGNMENT, section=section)
y_str = format_vector_definition('double', 'y', y, alignment=BURST_ALIGNMENT, section=section)
z_str = format_vector_declaration('double', 'z', z, alignment=BURST_ALIGNMENT, section=section)
g_str = format_vector_definition('double', 'g', g)
x_str = format_array_definition('double', 'x', x, alignment=BURST_ALIGNMENT, section=section)
y_str = format_array_definition('double', 'y', y, alignment=BURST_ALIGNMENT, section=section)
z_str = format_array_declaration('double', 'z', [length],
alignment=BURST_ALIGNMENT, section=section)
g_str = format_array_definition('double', 'g', g)
g_str = format_ifdef_wrapper('BIST', g_str)
f_str = '\n\n'.join([l_str, a_str, x_str, y_str, z_str, g_str])
f_str += '\n'
Expand Down
10 changes: 5 additions & 5 deletions sw/blas/axpy/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
sys.path.append(str(Path(__file__).parent / '../../../util/sim/'))
import verification # noqa: E402
from elf import Elf # noqa: E402
from data_utils import bytes_to_doubles # noqa: E402
from data_utils import bytes_to_float # noqa: E402


ERR_THRESHOLD = 1E-10
Expand All @@ -27,16 +27,16 @@ def main():
symbols_bin=args.symbols_bin,
log=args.log,
output_uids=['z'])
z_actual = np.array(bytes_to_doubles(raw_results['z']))
z_actual = np.array(bytes_to_float(raw_results['z'], prec='64'))

# Extract input operands from ELF file
if args.symbols_bin:
elf = Elf(args.symbols_bin)
else:
elf = Elf(args.snitch_bin)
a = np.array(bytes_to_doubles(elf.get_symbol_contents('a')))
x = np.array(bytes_to_doubles(elf.get_symbol_contents('x')))
y = np.array(bytes_to_doubles(elf.get_symbol_contents('y')))
a = np.array(bytes_to_float(elf.get_symbol_contents('a'), prec='64'))
x = np.array(bytes_to_float(elf.get_symbol_contents('x'), prec='64'))
y = np.array(bytes_to_float(elf.get_symbol_contents('y'), prec='64'))

# Verify results
z_golden = golden_model(a, x, y)
Expand Down
16 changes: 8 additions & 8 deletions sw/blas/gemm/data/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../util/sim/"))
from data_utils import emit_license, format_scalar_definition, \
format_vector_definition, format_ifdef_wrapper # noqa: E402
format_array_definition, format_ifdef_wrapper # noqa: E402


np.random.seed(42)
Expand Down Expand Up @@ -100,18 +100,18 @@ def emit_header(**kwargs):
data_str += [format_scalar_definition('uint32_t', 'BETA', kwargs['beta'])]
data_str += [format_scalar_definition('uint32_t', 'dtype_size', kwargs['prec']//8)]
data_str += [format_scalar_definition('uint32_t', 'expand', kwargs['expand'])]
data_str += [format_vector_definition(C_TYPES[str(kwargs['prec'])], 'a', a.flatten(),
data_str += [format_array_definition(C_TYPES[str(kwargs['prec'])], 'a', a.flatten(),
alignment=BURST_ALIGNMENT, section=kwargs['section'])]
data_str += [format_vector_definition(C_TYPES[str(kwargs['prec'])], 'b', b.flatten(),
data_str += [format_array_definition(C_TYPES[str(kwargs['prec'])], 'b', b.flatten(),
alignment=BURST_ALIGNMENT, section=kwargs['section'])]
data_str += [format_vector_definition(C_TYPES[str(kwargs['prec'])], 'c', c.flatten(),
data_str += [format_array_definition(C_TYPES[str(kwargs['prec'])], 'c', c.flatten(),
alignment=BURST_ALIGNMENT, section=kwargs['section'])]
if kwargs['prec'] == 8:
result_def = format_vector_definition(C_TYPES['64'], 'result', result.flatten())
result_def = format_array_definition(C_TYPES['64'], 'result', result.flatten())
else:
result_def = format_vector_definition(C_TYPES[str(kwargs['prec'])],
'result',
result.flatten())
result_def = format_array_definition(C_TYPES[str(kwargs['prec'])],
'result',
result.flatten())
data_str += [format_ifdef_wrapper('BIST', result_def)]
data_str = '\n\n'.join(data_str)

Expand Down
36 changes: 17 additions & 19 deletions sw/blas/gemm/src/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ typedef char v8f8 __attribute__((vector_size(8)));
dump_float(gemm, 8);
dump_uint(index, 9);


void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A,
uint32_t ldA, uint32_t ta, double* B, uint32_t ldB,
uint32_t tb, double* C, uint32_t ldC, double BETA) {
Expand Down Expand Up @@ -74,24 +73,23 @@ void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A,
}

/* params:
* M: number of rows of A and C
* N: number of columns of B and C
* K: number of columns of A and rows of B
* A: pointer to matrix A
* ldA: row stride of A
* ta: transpose A
* B: pointer to matrix B
* ldB: row stride of B
* tb: transpose B
* C: pointer to matrix C
* ldC: row stride of C
* ALPHA: scalar alpha
* A is MxK, B is KxN, C is MxN
*/
* M: number of rows of A and C
* N: number of columns of B and C
* K: number of columns of A and rows of B
* A: pointer to matrix A
* ldA: row stride of A
* ta: transpose A
* B: pointer to matrix B
* ldB: row stride of B
* tb: transpose B
* C: pointer to matrix C
* ldC: row stride of C
* ALPHA: scalar alpha
* A is MxK, B is KxN, C is MxN
*/
void gemm_fp32_baseline(uint32_t M, uint32_t N, uint32_t K, float* A,
uint32_t ldA, uint32_t ta, float* B, uint32_t ldB,
uint32_t tb, float* C, uint32_t ldC, float ALPHA) {

// float c0, c1, c2, c3 = 0;
float c0 = 0.0f;
float c1 = 0.0f;
Expand All @@ -110,7 +108,7 @@ void gemm_fp32_baseline(uint32_t M, uint32_t N, uint32_t K, float* A,
c1 = 0.0f;
c2 = 0.0f;
c3 = 0.0f;
for (uint32_t k = 0; k < K; k+=4) {
for (uint32_t k = 0; k < K; k += 4) {
c0 += A[(k + 0) + m * ldA] * B[(k + 0) * ldB + n];
c1 += A[(k + 1) + m * ldA] * B[(k + 1) * ldB + n];
c2 += A[(k + 2) + m * ldA] * B[(k + 2) * ldB + n];
Expand All @@ -131,7 +129,7 @@ void gemm_fp32_baseline(uint32_t M, uint32_t N, uint32_t K, float* A,
c1 = 0.0f;
c2 = 0.0f;
c3 = 0.0f;
for (uint32_t k = 0; k < K; k+=4) {
for (uint32_t k = 0; k < K; k += 4) {
c0 += A[(k + 0) * M * ldA + m * ldA] * B[(k + 0) * ldB + n];
c1 += A[(k + 1) * M * ldA + m * ldA] * B[(k + 1) * ldB + n];
c2 += A[(k + 2) * M * ldA + m * ldA] * B[(k + 2) * ldB + n];
Expand All @@ -152,7 +150,7 @@ void gemm_fp32_baseline(uint32_t M, uint32_t N, uint32_t K, float* A,
c1 = 0.0f;
c2 = 0.0f;
c3 = 0.0f;
for (uint32_t k = 0; k < K; k+=4) {
for (uint32_t k = 0; k < K; k += 4) {
// c0 += A[k + m * ldA] * B[k + n * ldB];
c0 += A[(k + 0) + m * ldA] * B[(k + 0) + n * ldB];
c1 += A[(k + 1) + m * ldA] * B[(k + 1) + n * ldB];
Expand Down
20 changes: 10 additions & 10 deletions sw/blas/gemm/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
sys.path.append(str(Path(__file__).parent / '../../../util/sim/'))
import verification # noqa: E402
from elf import Elf # noqa: E402
from data_utils import bytes_to_doubles, bytes_to_uint32s # noqa: E402
from data_utils import bytes_to_float, bytes_to_int # noqa: E402


ERR_THRESHOLD = 0.001
Expand All @@ -27,21 +27,21 @@ def main():
symbols_bin=args.symbols_bin,
log=args.log,
output_uids=['c'])
c_actual = np.array(bytes_to_doubles(raw_results['c']))
c_actual = np.array(bytes_to_float(raw_results['c'], prec='64'))

# Extract input operands from ELF file
if args.symbols_bin:
elf = Elf(args.symbols_bin)
else:
elf = Elf(args.snitch_bin)
a = np.array(bytes_to_doubles(elf.get_symbol_contents('a')))
b = np.array(bytes_to_doubles(elf.get_symbol_contents('b')))
c = np.array(bytes_to_doubles(elf.get_symbol_contents('c')))
beta = bytes_to_uint32s(elf.get_symbol_contents('BETA'))[0]
m = bytes_to_uint32s(elf.get_symbol_contents('M'))[0]
n = bytes_to_uint32s(elf.get_symbol_contents('N'))[0]
k = bytes_to_uint32s(elf.get_symbol_contents('K'))[0]
tb = bytes_to_uint32s(elf.get_symbol_contents('TB'))[0]
a = np.array(bytes_to_float(elf.get_symbol_contents('a'), prec='64'))
b = np.array(bytes_to_float(elf.get_symbol_contents('b'), prec='64'))
c = np.array(bytes_to_float(elf.get_symbol_contents('c'), prec='64'))
beta = bytes_to_int(elf.get_symbol_contents('BETA'), prec='32', signedness='unsigned')[0]
m = bytes_to_int(elf.get_symbol_contents('M'), prec='32', signedness='unsigned')[0]
n = bytes_to_int(elf.get_symbol_contents('N'), prec='32', signedness='unsigned')[0]
k = bytes_to_int(elf.get_symbol_contents('K'), prec='32', signedness='unsigned')[0]
tb = bytes_to_int(elf.get_symbol_contents('TB'), prec='32', signedness='unsigned')[0]
a = np.reshape(a, (m, k))
if tb:
b = np.reshape(b, (n, k))
Expand Down
1 change: 1 addition & 0 deletions sw/dnn/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*/data/data.h
136 changes: 136 additions & 0 deletions sw/dnn/batchnorm/data/datagen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
#!/usr/bin/env python3
# Copyright 2023 ETH Zurich and University of Bologna.
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
# SPDX-License-Identifier: Apache-2.0
#
# Tim Fischer <[email protected]>
# Viviane Potocnik <[email protected]>
# Luca Colagrande <[email protected]>

import argparse
import pathlib
import hjson
import sys
import os
import torch

sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../util/sim/"))
import data_utils # noqa: E402
from data_utils import emit_license, \
format_struct_definition, format_array_definition, \
format_array_declaration, format_ifdef_wrapper # noqa: E402

torch.manual_seed(42)

# AXI splits bursts crossing 4KB address boundaries. To minimize
# the occurrence of these splits the data should be aligned to 4KB
BURST_ALIGNMENT = 4096

PRECISION_T = {
'64': 'FP64',
'32': 'FP32',
'16': 'FP16',
'8': 'FP8'
}


def golden_model(ifmap):
n, ci, ih, iw = ifmap.shape
bn = torch.nn.BatchNorm2d(ci)
bn.weight.requires_grad = False
bn.bias.requires_grad = False
running_mean = torch.randn_like(bn.running_mean, requires_grad=False)
running_var = torch.rand_like(bn.running_var, requires_grad=False)
gamma = bn.weight / torch.sqrt(running_var + bn.eps)
beta = bn.bias - running_mean * bn.weight / torch.sqrt(running_var + bn.eps)
ofmap = ifmap * gamma.unsqueeze(-1).unsqueeze(-1) + beta.unsqueeze(-1).unsqueeze(-1)
return ofmap, gamma, beta


def emit_header(**kwargs):

in_channels = kwargs['input_dim']['channels']
in_height = kwargs['input_dim']['height']
in_width = kwargs['input_dim']['width']
tile_ci = kwargs['tile_ci']
prec = str(kwargs['prec'])

torch_type = data_utils.floating_point_torch_type(prec)
ctype = data_utils.floating_point_ctype(prec)

ifmap = torch.randn(1, in_channels, in_height, in_width, requires_grad=False, dtype=torch_type)
ofmap, gamma, beta = golden_model(ifmap)

# convert from CHW to HWC format
ifmap = ifmap.permute(0, 2, 3, 1)
ofmap = ofmap.permute(0, 2, 3, 1)

n, ih, iw, ci = ifmap.shape

ifmap_uid = 'ifmap'
ofmap_uid = 'ofmap'
beta_uid = 'beta'
gamma_uid = 'gamma'

layer_cfg = {
'CI': ci,
'IH': ih,
'IW': iw,
'TILE_CI': tile_ci,
'ifmap': ifmap_uid,
'ofmap': ofmap_uid,
'beta': beta_uid,
'gamma': gamma_uid
}

data_str = [emit_license()]
# Array forward declarations
data_str += [format_array_declaration(ctype, ifmap_uid, ifmap.shape)]
data_str += [format_array_declaration(ctype, ofmap_uid, ofmap.shape)]
data_str += [format_array_declaration(ctype, beta_uid, beta.shape)]
data_str += [format_array_declaration(ctype, gamma_uid, gamma.shape)]
# Layer struct
data_str += [format_struct_definition('batchnorm_layer_t', 'layer', layer_cfg)]
# Array definitions
data_str += [format_array_definition(ctype, ifmap_uid, ifmap)]
data_str += [format_array_definition(ctype, beta_uid, beta)]
data_str += [format_array_definition(ctype, gamma_uid, gamma)]
# Golden results for BIST
result_def = format_array_definition(ctype, 'golden', ofmap)
data_str += [format_ifdef_wrapper('BIST', result_def)]
data_str = '\n\n'.join(data_str)

return data_str


def main():

parser = argparse.ArgumentParser(description='Generate data for layernorm kernel')
parser.add_argument(
"-c", "--cfg",
type=pathlib.Path,
required=True,
help='Select param config file kernel'
)
parser.add_argument(
'--section',
type=str,
help='Section to store matrices in')
parser.add_argument(
'output',
type=pathlib.Path,
help='Path of the output header file')
args = parser.parse_args()

# Load param config file
with args.cfg.open() as f:
param = hjson.loads(f.read())
param['section'] = args.section

# Emit header file
with open(args.output, 'w') as f:
f.write(emit_header(**param))


if __name__ == '__main__':
main()
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,12 @@
// Solderpad Hardware License, Version 0.51, see LICENSE for details.
// SPDX-License-Identifier: SHL-0.51

// Parameters for a single BatchNorm layer

{
kernel: "BatchNorm"
channels: {
out: 32,
in: 32
}
input_dim: {
channels: 32
height: 8,
width: 8
}
tile_ci: 32
prec: 64
}
Loading

0 comments on commit 13acfd9

Please sign in to comment.