From 0875bd0a1beafa79b9a406c565d7aee63ac4afc7 Mon Sep 17 00:00:00 2001 From: Luca Colagrande Date: Sat, 28 Oct 2023 15:37:53 +0200 Subject: [PATCH] Extend layout utils to accept HW config as input --- sw/dnn/layernorm/layout.csv | 6 +++--- sw/dnn/layernorm/src/layernorm.h | 9 +++++++-- target/common/common.mk | 3 ++- util/trace/layout_events.py | 17 ++++++++++++----- 4 files changed, 24 insertions(+), 11 deletions(-) diff --git a/sw/dnn/layernorm/layout.csv b/sw/dnn/layernorm/layout.csv index 9fd0970bec..fed43ee3cb 100644 --- a/sw/dnn/layernorm/layout.csv +++ b/sw/dnn/layernorm/layout.csv @@ -1,3 +1,3 @@ - , setup, dma in, compute tile, dma out, dma in, compute tile, dma out -"range(0,8)", 1, , 3, , , 5, -8 , 1, 2, , 4, 5, , 7 + , setup, dma in, compute tile, dma out, dma in, compute tile, dma out +"[i*9+j+cfg['cluster']['cluster_base_hartid'] for i in range(cfg['s1_quadrant']['nr_clusters']) for j in range(8)]", 1, , 3, , , 5, +"[i*9+8+cfg['cluster']['cluster_base_hartid'] for i in range(cfg['s1_quadrant']['nr_clusters'])]" , 1, 2, , 4, 5, , 7 diff --git a/sw/dnn/layernorm/src/layernorm.h b/sw/dnn/layernorm/src/layernorm.h index f5af22f457..6294b99bb2 100644 --- a/sw/dnn/layernorm/src/layernorm.h +++ b/sw/dnn/layernorm/src/layernorm.h @@ -110,12 +110,15 @@ static inline void layernorm_fp32(float *input, float *output, // layernorm_fp32(input, input, ldI, 0, 1, seq_len, embeddings, eps); // } -// Tiles the seq_len axis +// Tiles the seq_len axis (assumes seq_len is an integer multiple of n_tiles) +// Distributes tiles to clusters (assumes n_tiles is an integer multiple of +// the number of clusters) static inline void layernorm_layer(layernorm_layer_t l) { snrt_mcycle(); // Compute the tiling parameters uint32_t n_tiles = l.n_tiles; + uint32_t n_tiles_per_cluster = l.n_tiles / snrt_cluster_num(); uint32_t tile_seq_len = l.seq_len / n_tiles; uint32_t tile_size = l.batch_size * tile_seq_len * l.embeddings; uint32_t tile_offset = tile_seq_len * l.embeddings; @@ -130,7 +133,9 @@ static inline void layernorm_layer(layernorm_layer_t l) { // Iterate tiles snrt_mcycle(); - for (int tile_idx = 0; tile_idx < n_tiles; tile_idx++) { + for (uint32_t cluster_tile_idx = 0; cluster_tile_idx < n_tiles_per_cluster; cluster_tile_idx++) { + // Calculate absolute tile index + uint32_t tile_idx = snrt_cluster_idx() * n_tiles_per_cluster + cluster_tile_idx; // Copy input tile if (snrt_is_dm_core()) { float *remote_itile = remote_ifmap + tile_idx * tile_offset; diff --git a/target/common/common.mk b/target/common/common.mk index 48b875f760..73abbdf08f 100644 --- a/target/common/common.mk +++ b/target/common/common.mk @@ -59,7 +59,8 @@ VLT_FLAGS += --unroll-count 1024 VLT_CFLAGS += -std=c++14 -pthread VLT_CFLAGS +=-I ${VLT_BUILDDIR} -I $(VLT_ROOT)/include -I $(VLT_ROOT)/include/vltstd -I $(VLT_FESVR)/include -I $(TB_DIR) -I ${MKFILE_DIR}/test -ANNOTATE_FLAGS ?= -q --keep-time +ANNOTATE_FLAGS ?= -q --keep-time +LAYOUT_EVENTS_FLAGS ?= --cfg=$(CFG) # We need a recent LLVM installation (>11) to compile Verilator. # We also need to link the binaries with LLVM's libc++. diff --git a/util/trace/layout_events.py b/util/trace/layout_events.py index ea877c53cb..0d0e914358 100755 --- a/util/trace/layout_events.py +++ b/util/trace/layout_events.py @@ -41,6 +41,7 @@ import csv import pandas as pd from math import isnan +import hjson def main(): @@ -55,10 +56,9 @@ def main(): metavar='', help='Layout CSV file') parser.add_argument( - '--num-clusters', - type=int, - default=1, - help='Number of clusters') + '--cfg', + type=str, + help='System configuration .hjson file') parser.add_argument( '-o', '--output', @@ -71,6 +71,11 @@ def main(): # Read input CSV df = pd.read_csv(args.csv) + # Read system configuration .hjson file + cfg = None + with open(args.cfg) as cfg_file: + cfg = hjson.load(cfg_file) + # Output CSV data data = [] columns = [] @@ -92,7 +97,9 @@ def main(): # which generates a list of hart IDs expr = row[0] code = compile(expr, "", "eval") - tids = eval(code, {}, {'num_clusters': args.num_clusters}) + # Symbols must be added to globals to be used in list comprehensions + # see https://bugs.python.org/issue36300 + tids = eval(code, {'cfg': cfg}, {'cfg': cfg}) if type(tids) == int: tids = [tids]