diff --git a/README.md b/README.md index 4df6c395..7b1f6a49 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # ADI MAX78000/MAX78002 Model Training and Synthesis -July 22, 2024 +August 27, 2024 **Note: This branch requires PyTorch 2. Please see the archive-1.8 branch for PyTorch 1.8 support. [KNOWN_ISSUES](KNOWN_ISSUES.txt) contains a list of known issues.** @@ -1620,13 +1620,15 @@ When using the `-8` command line switch, all module outputs are quantized to 8-b The last layer can optionally use 32-bit output for increased precision. This is simulated by adding the parameter `wide=True` to the module function call. -##### Weights: Quantization-Aware Training (QAT) +##### Weights and Activations: Quantization-Aware Training (QAT) Quantization-aware training (QAT) is enabled by default. QAT is controlled by a policy file, specified by `--qat-policy`. -* After `start_epoch` epochs, training will learn an additional parameter that corresponds to a shift of the final sum of products. +* After `start_epoch` epochs, an intermediate epoch with no backpropagation will be realized to collect activation statistics. Each layer's activation ranges will be determined based on the range & resolution trade-off from the collected activations. Then, QAT will start and an additional parameter (`output_shift`) will be learned to shift activations for compensating weights & biases scaling down. * `weight_bits` describes the number of bits available for weights. * `overrides` allows specifying the `weight_bits` on a per-layer basis. +* `outlier_removal_z_score` defines the z-score threshold for outlier removal during activation range calculation. (default: 8.0) +* `shift_quantile` defines the quantile of the parameters distribution to be used for the `output_shift` parameter. (default: 1.0) By default, weights are quantized to 8-bits after 30 epochs as specified in `policies/qat_policy.yaml`. A more refined example that specifies weight sizes for individual layers can be seen in `policies/qat_policy_cifar100.yaml`. @@ -1745,7 +1747,7 @@ For both approaches, the `quantize.py` software quantizes an existing PyTorch ch #### Quantization-Aware Training (QAT) -Quantization-aware training is the better performing approach. It is enabled by default. QAT learns additional parameters during training that help with quantization (see [Weights: Quantization-Aware Training (QAT)](#weights-quantization-aware-training-qat). No additional arguments (other than input, output, and device) are needed for `quantize.py`. +Quantization-aware training is the better performing approach. It is enabled by default. QAT learns additional parameters during training that help with quantization (see [Weights and Activations: Quantization-Aware Training (QAT)](#weights-and-activations-quantization-aware-training-qat). No additional arguments (other than input, output, and device) are needed for `quantize.py`. The input checkpoint to `quantize.py` is either `qat_best.pth.tar`, the best QAT epoch’s checkpoint, or `qat_checkpoint.pth.tar`, the final QAT epoch’s checkpoint. @@ -2004,7 +2006,7 @@ The behavior of a training session might change when Quantization Aware Training While there can be multiple reasons for this, check two important settings that can influence the training behavior: * The initial learning rate may be set too high. Reduce LR by a factor of 10 or 100 by specifying a smaller initial `--lr` on the command line, and possibly by reducing the epoch `milestones` for further reduction of the learning rate in the scheduler file specified by `--compress`. Note that the the selected optimizer and the batch size both affect the learning rate. -* The epoch when QAT is engaged may be set too low. Increase `start_epoch` in the QAT scheduler file specified by `--qat-policy`, and increase the total number of training epochs by increasing the value specified by the `--epochs` command line argument and by editing the `ending_epoch` in the scheduler file specified by `--compress`. *See also the rule of thumb discussed in the section [Weights: Quantization-Aware Training (QAT)](#weights:-auantization-aware-training \(qat\)).* +* The epoch when QAT is engaged may be set too low. Increase `start_epoch` in the QAT scheduler file specified by `--qat-policy`, and increase the total number of training epochs by increasing the value specified by the `--epochs` command line argument and by editing the `ending_epoch` in the scheduler file specified by `--compress`. *See also the rule of thumb discussed in the section [Weights and Activations: Quantization-Aware Training (QAT)](#weights-and-activations-quantization-aware-training-qat).* @@ -2209,6 +2211,7 @@ The following table describes the most important command line arguments for `ai8 | `--no-unload` | Do not create the `cnn_unload()` function | | | `--no-kat` | Do not generate the `check_output()` function (disable known-answer test) | | | `--no-deduplicate-weights` | Do not deduplicate weights and and bias values | | +| `--scale-output` | Use scales from the checkpoint to recover output range while generating `cnn_unload()` function | | ### YAML Network Description @@ -2330,6 +2333,12 @@ The following keywords are required for each `unload` list item: `width`: Data width (optional, defaults to 8) — either 8 or 32 `write_gap`: Gap between data words (optional, defaults to 0) +When `--scale-output` is specified, scales from the checkpoint file are used to recover the output range. If there is a non-zero scale for the 8 bits output, the output will be scaled and kept in 16 bits. If the scale is zero, the output will be 8 bits. For 32 bits output, the output will be kept in 32 bits always. + +Example: + +![Unload Array](docs/unload_example.png) + ##### `layers` (Mandatory) `layers` is a list that defines the per-layer description, as shown below: @@ -2654,7 +2663,7 @@ Example: By default, the final layer is used as the output layer. Output layers are checked using the known-answer test, and they are copied from hardware memory when `cnn_unload()` is called. The tool also checks that output layer data isn’t overwritten by any later layers. When specifying `output: true`, any layer (or a combination of layers) can be used as an output layer. -*Note:* When `unload:` is used, output layers are not used for generating `cnn_unload()`. +*Note:* When `--no-unload` is used, output layers are not used for generating `cnn_unload()`. Example: `output: true` diff --git a/docs/unload_example.png b/docs/unload_example.png new file mode 100644 index 00000000..7ecacf2b Binary files /dev/null and b/docs/unload_example.png differ diff --git a/gen-demos-max78000.sh b/gen-demos-max78000.sh index a3f0ff9c..53f82f72 100755 --- a/gen-demos-max78000.sh +++ b/gen-demos-max78000.sh @@ -12,7 +12,8 @@ python ai8xize.py --test-dir $TARGET --prefix cifar-100-simplewide2x-mixed --che python ai8xize.py --test-dir $TARGET --prefix cifar-100-residual --checkpoint-file trained/ai85-cifar100-residual-qat8-q.pth.tar --config-file networks/cifar100-ressimplenet.yaml --softmax $COMMON_ARGS --boost 2.5 "$@" python ai8xize.py --test-dir $TARGET --prefix kws20_v3 --checkpoint-file trained/ai85-kws20_v3-qat8-q.pth.tar --config-file networks/kws20-v3-hwc.yaml --softmax $COMMON_ARGS "$@" python ai8xize.py --test-dir $TARGET --prefix kws20_nas --checkpoint-file trained/ai85-kws20_nas-qat8-q.pth.tar --config-file networks/kws20-nas-hwc.yaml --softmax $COMMON_ARGS "$@" -python ai8xize.py --test-dir $TARGET --prefix faceid --checkpoint-file trained/ai85-faceid-qat8-q.pth.tar --config-file networks/faceid.yaml --fifo $COMMON_ARGS "$@" +python izer/add_fake_passthrough.py --input-checkpoint-path trained/ai85-faceid_112-qat-q.pth.tar --output-checkpoint-path trained/ai85-fakepass-faceid_112-qat-q.pth.tar --layer-name fakepass --layer-depth 128 --layer-name-after-pt linear --low-memory-footprint "$@" +python ai8xize.py --test-dir $TARGET --prefix faceid_112 --checkpoint-file trained/ai85-fakepass-faceid_112-qat-q.pth.tar --config-file networks/ai85-faceid_112.yaml --fifo $COMMON_ARGS "$@" python ai8xize.py --test-dir $TARGET --prefix cats-dogs --checkpoint-file trained/ai85-catsdogs-qat8-q.pth.tar --config-file networks/cats-dogs-hwc.yaml --fifo --softmax $COMMON_ARGS "$@" python ai8xize.py --test-dir $TARGET --prefix camvid_unet --checkpoint-file trained/ai85-camvid-unet-large-fakept-q.pth.tar --config-file networks/camvid-unet-large-fakept.yaml $COMMON_ARGS --overlap-data --mlator --no-unload --max-checklines 8192 --new-kernel-loader "$@" python ai8xize.py --test-dir $TARGET --prefix aisegment_unet --checkpoint-file trained/ai85-aisegment-unet-large-fakept-q.pth.tar --config-file networks/aisegment-unet-large-fakept.yaml $COMMON_ARGS --overlap-data --mlator --no-unload --max-checklines 8192 --new-kernel-loader "$@" diff --git a/gen-demos-max78002.sh b/gen-demos-max78002.sh index 0af26bbb..a7e24576 100755 --- a/gen-demos-max78002.sh +++ b/gen-demos-max78002.sh @@ -12,7 +12,7 @@ python ai8xize.py --test-dir $TARGET --prefix cifar-100-simplewide2x-mixed --che python ai8xize.py --test-dir $TARGET --prefix cifar-100-residual --checkpoint-file trained/ai85-cifar100-residual-qat8-q.pth.tar --config-file networks/cifar100-ressimplenet.yaml --softmax $COMMON_ARGS "$@" python ai8xize.py --test-dir $TARGET --prefix kws20_v3_1 --checkpoint-file trained/ai87-kws20_v3-qat8-q.pth.tar --config-file networks/ai87-kws20-v3-hwc.yaml --softmax $COMMON_ARGS "$@" python ai8xize.py --test-dir $TARGET --prefix kws20_v2_1 --checkpoint-file trained/ai87-kws20_v2-qat8-q.pth.tar --config-file networks/ai87-kws20-v2-hwc.yaml --softmax $COMMON_ARGS "$@" -python ai8xize.py --test-dir $TARGET --prefix faceid --checkpoint-file trained/ai85-faceid-qat8-q.pth.tar --config-file networks/faceid.yaml --fifo $COMMON_ARGS "$@" +python ai8xize.py --test-dir $TARGET --prefix mobilefacenet-112 --checkpoint-file trained/ai87-mobilefacenet-112-qat-q.pth.tar --config-file networks/ai87-mobilefacenet-112.yaml --fifo $COMMON_ARGS "$@" python ai8xize.py --test-dir $TARGET --prefix cats-dogs --checkpoint-file trained/ai85-catsdogs-qat8-q.pth.tar --config-file networks/cats-dogs-hwc-no-fifo.yaml --softmax $COMMON_ARGS "$@" python ai8xize.py --test-dir $TARGET --prefix camvid_unet --checkpoint-file trained/ai85-camvid-unet-large-fakept-q.pth.tar --config-file networks/camvid-unet-large-fakept.yaml $COMMON_ARGS --overlap-data --mlator --no-unload --max-checklines 8192 "$@" python ai8xize.py --test-dir $TARGET --prefix aisegment_unet --checkpoint-file trained/ai85-aisegment-unet-large-fakept-q.pth.tar --config-file networks/aisegment-unet-large-fakept.yaml $COMMON_ARGS --overlap-data --mlator --no-unload --max-checklines 8192 "$@" @@ -21,5 +21,5 @@ python ai8xize.py --test-dir $TARGET --prefix cifar-100-effnet2 --checkpoint-fil python ai8xize.py --test-dir $TARGET --prefix cifar-100-mobilenet-v2-0.75 --checkpoint-file trained/ai87-cifar100-mobilenet-v2-0.75-qat8-q.pth.tar --config-file networks/ai87-cifar100-mobilenet-v2-0.75.yaml --softmax $COMMON_ARGS "$@" python ai8xize.py --test-dir $TARGET --prefix imagenet --checkpoint-file trained/ai87-imagenet-effnet2-q.pth.tar --config-file networks/ai87-imagenet-effnet2.yaml $COMMON_ARGS "$@" python ai8xize.py --test-dir $TARGET --prefix facedet_tinierssd --checkpoint-file trained/ai87-facedet-tinierssd-qat8-q.pth.tar --config-file networks/ai87-facedet-tinierssd.yaml --sample-input tests/sample_vggface2_facedetection.npy $COMMON_ARGS "$@" -python ai8xize.py --test-dir $TARGET --prefix pascalvoc_fpndetector --checkpoint-file trained/ai87-pascalvoc-fpndetector-qat8-q.pth.tar --config-file networks/ai87-pascalvoc-fpndetector.yaml --fifo --sample-input tests/sample_pascalvoc_256_320.npy --overwrite --no-unload $COMMON_ARGS "$@" +python ai8xize.py --test-dir $TARGET --prefix pascalvoc_fpndetector --checkpoint-file trained/ai87-pascalvoc-fpndetector-qat8-q.pth.tar --config-file networks/ai87-pascalvoc-fpndetector.yaml --fifo --sample-input tests/sample_pascalvoc_256_320.npy --no-unload $COMMON_ARGS "$@" python ai8xize.py --test-dir $TARGET --prefix kinetics --checkpoint-file trained/ai85-kinetics-qat8-q.pth.tar --config-file networks/ai85-kinetics-actiontcn.yaml --overlap-data --softmax --zero-sram $COMMON_ARGS "$@" diff --git a/izer/backend/max7800x.py b/izer/backend/max7800x.py index ceff6bee..fc410289 100644 --- a/izer/backend/max7800x.py +++ b/izer/backend/max7800x.py @@ -1,5 +1,5 @@ ################################################################################################### -# Copyright (C) 2019-2023 Maxim Integrated Products, Inc. All Rights Reserved. +# Copyright (C) 2019-2024 Maxim Integrated Products, Inc. All Rights Reserved. # # Maxim Integrated Products, Inc. Default Copyright Notice: # https://www.maximintegrated.com/en/aboutus/legal/copyrights.html @@ -69,6 +69,7 @@ def create_net(self) -> str: # pylint: disable=too-many-locals,too-many-branche fast_fifo_quad = state.fast_fifo_quad fifo = state.fifo final_layer = state.final_layer + final_scale = state.final_scale first_layer_used = state.first_layer_used flatten = state.flatten forever = state.forever @@ -136,6 +137,7 @@ def create_net(self) -> str: # pylint: disable=too-many-locals,too-many-branche riscv = state.riscv riscv_cache = state.riscv_cache riscv_flash = state.riscv_flash + scale_output = state.scale_output simple1b = state.simple1b simulated_sequence = state.simulated_sequence snoop = state.snoop @@ -1152,7 +1154,8 @@ def create_net(self) -> str: # pylint: disable=too-many-locals,too-many-branche conv_str = ', no convolution, ' apb.output(conv_str + f'{output_chan[ll]}x{output_dim_str[ll]} output\n', embedded_code) - + apb.output('\n', embedded_code) + apb.output(f'// Final Scales: {final_scale}\n', embedded_code) apb.output('\n', embedded_code) apb.header() @@ -3553,8 +3556,20 @@ def run_eltwise( elif block_mode: assets.copy('assets', 'blocklevel-ai' + str(device), base_directory, test_name) elif embedded_code: - output_count = output_chan[terminating_layer] \ - * output_dim[terminating_layer][0] * output_dim[terminating_layer][1] + output_count = 0 + for i in range(terminating_layer + 1): + if output_layer[i]: + if output_width[i] != 32: + if scale_output: + output_count += (output_chan[i] * output_dim[i][0] * output_dim[i][1] + + (32 // (2 * output_width[i]) - 1)) \ + // (32 // (2 * output_width[i])) + else: + output_count += (output_chan[i] * output_dim[i][0] * output_dim[i][1] + + (32 // output_width[i] - 1)) \ + // (32 // output_width[i]) + else: + output_count += output_chan[i] * output_dim[i][0] * output_dim[i][1] insert = summary_stats + \ '\n/* Number of outputs for this network */\n' \ f'#define CNN_NUM_OUTPUTS {output_count}' diff --git a/izer/checkpoint.py b/izer/checkpoint.py index b88bd8c8..91b6587b 100644 --- a/izer/checkpoint.py +++ b/izer/checkpoint.py @@ -1,5 +1,5 @@ ################################################################################################### -# Copyright (C) 2019-2023 Maxim Integrated Products, Inc. All Rights Reserved. +# Copyright (C) 2019-2024 Maxim Integrated Products, Inc. All Rights Reserved. # # Maxim Integrated Products, Inc. Default Copyright Notice: # https://www.maximintegrated.com/en/aboutus/legal/copyrights.html @@ -56,6 +56,7 @@ def load( bias_min = [] bias_max = [] bias_size = [] + final_scale = {} checkpoint = torch.load(checkpoint_file, map_location='cpu') print(f'Reading {checkpoint_file} to configure network weights...') @@ -251,6 +252,12 @@ def load( # Add implicit shift based on quantization output_shift[seq] += 8 - abs(quantization[seq]) + final_scale_name = '.'.join([layer, 'final_scale']) + if final_scale_name in checkpoint_state: + w = checkpoint_state[final_scale_name].numpy().astype(np.int64) + final_scale[seq] = w.item() + else: + final_scale[seq] = 0 layers += 1 seq += 1 @@ -286,4 +293,4 @@ def load( sys.exit(1) return layers, weights, bias, output_shift, \ - input_channels, output_channels + input_channels, output_channels, final_scale diff --git a/izer/commandline.py b/izer/commandline.py index fedb8d07..c1677fad 100644 --- a/izer/commandline.py +++ b/izer/commandline.py @@ -464,6 +464,8 @@ def get_parser() -> argparse.Namespace: help='GitHub repository name for update checking') group.add_argument('--yamllint', metavar='S', default='yamllint', help='name of linter for YAML files (default: yamllint)') + group.add_argument('--scale-output', action='store_true', default=False, + help="scale output with final layer scale factor (default: false)") args = parser.parse_args() @@ -691,6 +693,7 @@ def set_state(args: argparse.Namespace) -> None: state.rtl_preload_weights = args.rtl_preload_weights state.runtest_filename = args.runtest_filename state.sample_filename = args.sample_filename + state.scale_output = args.scale_output state.simple1b = args.simple1b state.sleep = args.deepsleep state.slow_load = args.slow_load diff --git a/izer/izer.py b/izer/izer.py index 483c6cb7..f9dd618c 100644 --- a/izer/izer.py +++ b/izer/izer.py @@ -1,5 +1,5 @@ ################################################################################################### -# Copyright (C) 2019-2023 Maxim Integrated Products, Inc. All Rights Reserved. +# Copyright (C) 2019-2024 Maxim Integrated Products, Inc. All Rights Reserved. # # Maxim Integrated Products, Inc. Default Copyright Notice: # https://www.maximintegrated.com/en/aboutus/legal/copyrights.html @@ -74,6 +74,7 @@ def main(): # If not using test data, load weights and biases # This also configures the network's output channels + final_scale = None if cfg['arch'] != 'test': if not args.checkpoint_file: eprint('--checkpoint-file is a required argument.') @@ -96,7 +97,7 @@ def main(): else: # PyTorch checkpoint file selected layers, weights, bias, output_shift, \ - input_channels, output_channels = \ + input_channels, output_channels, final_scale = \ checkpoint.load( args.checkpoint_file, cfg['arch'], @@ -134,6 +135,8 @@ def main(): params['bypass'], filename=args.bias_input, ) + if final_scale is None: + final_scale = {ll: 0 for ll in range(cfg_layers)} if cfg_layers > layers: # Add empty weights/biases and channel counts for layers not in checkpoint file. # The checkpoint file does not contain weights for non-convolution operations. @@ -630,6 +633,7 @@ def main(): state.eltwise = eltwise state.final_layer = final_layer state.first_layer_used = min_layer + state.final_scale = final_scale state.flatten = flatten state.in_offset = input_offset state.in_sequences = in_sequences diff --git a/izer/quantize.py b/izer/quantize.py index a1916e51..f8b2b67f 100644 --- a/izer/quantize.py +++ b/izer/quantize.py @@ -1,5 +1,5 @@ ################################################################################################### -# Copyright (C) 2019-2023 Maxim Integrated Products, Inc. All Rights Reserved. +# Copyright (C) 2019-2024 Maxim Integrated Products, Inc. All Rights Reserved. # # Maxim Integrated Products, Inc. Default Copyright Notice: # https://www.maximintegrated.com/en/aboutus/legal/copyrights.html @@ -241,6 +241,11 @@ def get_max_bit_shift(t, clamp_bits, shift_quantile, return_bit_shift=False): out_shift_name = '.'.join([layer, 'output_shift']) out_shift = torch.Tensor([-1 * get_max_bit_shift(params_r, clamp_bits, shift_quantile, True)]) + threshold_name = '.'.join([layer, 'threshold']) + if threshold_name in checkpoint_state: + threshold = checkpoint_state[threshold_name] + out_shift = (out_shift - threshold).clamp(min=-7.-clamp_bits, + max=23.-clamp_bits) new_checkpoint_state[out_shift_name] = out_shift if new_masks_dict is not None: new_masks_dict[out_shift_name] = out_shift diff --git a/izer/state.py b/izer/state.py index 443fe6bb..6ec0fbff 100644 --- a/izer/state.py +++ b/izer/state.py @@ -1,5 +1,5 @@ ################################################################################################### -# Copyright (C) 2021-2023 Maxim Integrated Products Inc. All Rights Reserved. +# Copyright (C) 2021-2024 Maxim Integrated Products Inc. All Rights Reserved. # # Maxim Integrated Products Inc. Default Copyright Notice: # https://www.maximintegrated.com/en/aboutus/legal/copyrights.html @@ -70,6 +70,7 @@ final_layer: int = -1 first_layer_used: int = 0 fixed_input: bool = False +final_scale: List[int] = [] flatten: List[bool] = [] forever: bool = False generate_kat: bool = True @@ -174,6 +175,7 @@ rtl_preload_weights: bool = False rtl_preload: bool = False runtest_filename: str = '' +scale_output: bool = False sample_filename: str = '' simple1b: bool = False simulated_sequence: List[Any] = [] diff --git a/izer/toplevel.py b/izer/toplevel.py index d778b959..e75fa4ee 100644 --- a/izer/toplevel.py +++ b/izer/toplevel.py @@ -1,5 +1,5 @@ ################################################################################################### -# Copyright (C) 2019-2023 Maxim Integrated Products, Inc. All Rights Reserved. +# Copyright (C) 2019-2024 Maxim Integrated Products, Inc. All Rights Reserved. # # Maxim Integrated Products, Inc. Default Copyright Notice: # https://www.maximintegrated.com/en/aboutus/legal/copyrights.html @@ -14,7 +14,7 @@ COPYRIGHT = \ '/*******************************************************************************\n' \ - '* Copyright (C) 2019-2023 Maxim Integrated Products, Inc., All rights Reserved.\n' \ + '* Copyright (C) 2019-2024 Maxim Integrated Products, Inc., All rights Reserved.\n' \ '*\n' \ '* This software is protected by copyright laws of the United States and\n' \ '* of foreign countries. This material may also be protected by patent laws\n' \ @@ -248,16 +248,11 @@ def function_footer( def write_ml_data( memfile: TextIO, - output_width: int, ) -> None: """ Write the ml_data variable with `output_width` to `memfile`. """ - if output_width != 32: - memfile.write('static int32_t ml_data32[(CNN_NUM_OUTPUTS + ' - f'{32 // output_width - 1}) / {32 // output_width}];\n') - else: - memfile.write('static int32_t ml_data[CNN_NUM_OUTPUTS];\n') + memfile.write('static int32_t ml_data[CNN_NUM_OUTPUTS];\n') def main( @@ -309,7 +304,7 @@ def main( unmask = ~mask & ((1 << tc.dev.P_NUMGROUPS_ALL) - 1) if unload and not softmax: - write_ml_data(memfile, output_width) + write_ml_data(memfile) memfile.write('\n') if arm_code_wrapper: @@ -347,8 +342,8 @@ def main( if embedded_code and not forever and softmax: memfile.write(' int digs, tens;\n') if output_width != 32: - memfile.write(f'int{output_width}_t *ml_data = ' - f'(int{output_width}_t *) ml_data32;\n') + memfile.write(f' int{output_width}_t *ml_data{output_width} = ' + f'(int{output_width}_t *) ml_data;\n') if embedded_code and softmax or oneshot > 0: memfile.write('\n') @@ -811,7 +806,7 @@ def main( memfile.write(' softmax_layer();\n') elif unload: memfile.write(' cnn_unload((uint32_t *) ' - f'ml_data{"32" if output_width != 32 else ""});\n') + 'ml_data);\n') if embedded_code: memfile.write('\n printf("\\n*** PASS ***\\n\\n");\n\n' @@ -911,13 +906,13 @@ def softmax_layer( Write the call to the softmax layer to `memfile`. """ memfile.write('// Classification layer:\n') - write_ml_data(memfile, output_width) + write_ml_data(memfile) memfile.write('static q15_t ml_softmax[CNN_NUM_OUTPUTS];\n\n') function_header(memfile, prefix='', function='softmax_layer', return_type='void') - memfile.write(f' cnn_unload((uint32_t *) ml_data{"32" if output_width != 32 else ""});\n') + memfile.write(' cnn_unload((uint32_t *) ml_data);\n') if output_width == 32: if shift == 0: @@ -927,7 +922,7 @@ def softmax_layer( memfile.write(' softmax_shift_q17p14_q15((q31_t *) ml_data, ' f'CNN_NUM_OUTPUTS, {shift}, ml_softmax);\n') else: - memfile.write(' arm_softmax_q7_q15((const q7_t *) ml_data32, ' + memfile.write(' arm_softmax_q7_q15((const q7_t *) ml_data, ' 'CNN_NUM_OUTPUTS, ml_softmax);\n') function_footer(memfile, return_value='void') diff --git a/izer/unload.py b/izer/unload.py index ce5d951c..9e0efd97 100644 --- a/izer/unload.py +++ b/izer/unload.py @@ -1,5 +1,5 @@ ################################################################################################### -# Copyright (C) 2019-2022 Maxim Integrated Products, Inc. All Rights Reserved. +# Copyright (C) 2019-2024 Maxim Integrated Products, Inc. All Rights Reserved. # # Maxim Integrated Products, Inc. Default Copyright Notice: # https://www.maximintegrated.com/en/aboutus/legal/copyrights.html @@ -53,6 +53,36 @@ def mlator_write_one( return f'{prefix} out_buf{"32" if out_size != 32 else ""}' \ f'[offs++] = *mlat;{comment}\n' + def scaled_mlator_write_one( + prefix: str = '', + comment: str = '', + out_size: int = 8, + scale: int = 0, + ) -> None: + """ + Print a single mlator unload line with scaling + """ + + if scale > 0: + return f'{prefix} val = *mlat;{comment}\n' \ + f'{prefix} out_buf[offs++] = (val & 0xff) << {scale};\n' \ + f'{prefix} out_buf[offs++] = ((val >> 8) & 0xff) << {scale};\n' \ + f'{prefix} out_buf[offs++] = ((val >> 16) & 0xff) << {scale};\n' \ + f'{prefix} out_buf[offs++] = ((val >> 24) & 0xff) << {scale};\n' + + if scale < 0: + return f'{prefix} val = *mlat;{comment}\n' \ + f'{prefix} out_buf[offs++] = (int16_t)((val & 0xff) << 8) >> '\ + f'{abs(scale) + 8};\n' \ + f'{prefix} out_buf[offs++] = (int16_t)(((val >> 8) & 0xff) << 8) >> '\ + f'{abs(scale) + 8};\n' \ + f'{prefix} out_buf[offs++] = (int16_t)(((val >> 16) & 0xff) << 8) >> '\ + f'{abs(scale) + 8};\n' \ + f'{prefix} out_buf[offs++] = (int16_t)(((val >> 24) & 0xff) << 8) >> '\ + f'{abs(scale) + 8};\n' + + return mlator_write_one(prefix, comment, out_size) + # Cache for faster access apb_base = state.apb_base mlator = state.mlator @@ -61,6 +91,23 @@ def mlator_write_one( wide_chunk = state.wide_chunk if state.embedded_code else 0 unload_custom = state.unload_custom mlator_warning = state.mlator_warning + final_scale = state.final_scale + scale_output = state.scale_output + final_scale_detected = False + + for layer in final_scale: + if final_scale[layer] != 0: + final_scale_detected = True + break + + if not scale_output and final_scale_detected: + wprint('Non-zero output scale detected, but --scale-output not set. ' + 'Unload operation will be realized without scaling. ' + f'Final scales are {final_scale}.') + + if scale_output and not final_scale_detected: + nprint('--scale-output set, but all output scales are zero. ' + 'Unload operation will be realized without scaling.') assert not state.block_mode or not mlator @@ -201,6 +248,7 @@ def mlator_write_one( else: # mlator def mlator_loop( num: int = 1, + ll: int = ll, ) -> None: """ Print multiple mlator unload lines using a partially unrolled loop @@ -214,13 +262,20 @@ def mlator_loop( if num >= 2 * mlator_chunk: result += f' for (i = 0; i < {num // mlator_chunk}; i++) {{\n' for _ in range(mlator_chunk): - result += mlator_write_one(' ', '', out_size) + if scale_output: + result += scaled_mlator_write_one(' ', '', out_size, + final_scale[ll]) + else: + result += mlator_write_one(' ', '', out_size) result += ' }\n' num = num % mlator_chunk # Emit single lines for all remaining statements while num > 0: - result += mlator_write_one('', '', out_size) + if scale_output: + result += scaled_mlator_write_one(' ', '', out_size, final_scale[ll]) + else: + result += mlator_write_one('', '', out_size) num -= 1 return result @@ -251,12 +306,20 @@ def mlator_loop( (proc // tc.dev.P_NUMPRO) * tc.dev.C_GROUP_OFFS // 4) + (doffs >> 2) * width + expand * out_size) \ * (write_gap[ll] + 1) * 4 - target = this_c * input_shape[ll][1] * input_shape[ll][2] \ - + row * input_shape[ll][1] + col + written + if scale_output: + target = this_c * input_shape[ll][1] * input_shape[ll][2] \ + + row * input_shape[ll][1] + col + written // 2 + else: + target = this_c * input_shape[ll][1] * input_shape[ll][2] \ + + row * input_shape[ll][1] + col + written + assert target & 3 == 0 if target != write_addr: - out_text += f' offs = 0x{target >> 2:04x};\n' + if scale_output: + out_text += f' offs = 0x{target:04x};\n' + else: + out_text += f' offs = 0x{target >> 2:04x};\n' if source != read_addr: if loop_count > 0: out_text += mlator_loop(loop_count) @@ -286,9 +349,13 @@ def mlator_loop( # FIXME: Do not write more than # `num_bytes = min(4, input_shape[2] - col)` if mlator_chunk == 1: - out_text += mlator_write_one('', - f' // {this_c},{row},{col}-{col+3}', - out_size) + if scale_output: + out_text += scaled_mlator_write_one('', f' // {this_c},{row},' + f'{col}-{col+3}', + out_size, final_scale[ll]) + else: + out_text += mlator_write_one('', f' // {this_c},{row},' + f'{col}-{col+3}', out_size) loop_count += 1 read_addr = source + 4 write_addr = target + 4 @@ -346,11 +413,16 @@ def mlator_loop( else: prefix = '' for _ in range(min(remaining, chunk)): - if delta_r == 4: - out_text += f'{prefix} *out_buf++ = *addr++;\n' + if final_scale[ll] != 0 and final_scale[ll] > 0 and scale_output: + out_text += f'{prefix} *out_buf++ = (*addr++) <<'\ + f' {final_scale[ll]};\n' + elif final_scale[ll] != 0 and final_scale[ll] < 0 and scale_output: + out_text += f'{prefix} *out_buf++ = (int{o_width}_t)(*addr++)'\ + f' >> {abs(final_scale[ll])};\n' else: - out_text += f'{prefix} *out_buf++ = *addr;\n' \ - f'{prefix} addr {"+" if delta_r >= 0 else "-"}= ' \ + out_text += f'{prefix} *out_buf++ = *addr++;\n' + if delta_r != 4: + out_text += f'{prefix} addr {"+" if delta_r >= 0 else "-"}= ' \ f'0x{abs(delta_r) // 4:04x};\n' if loop_runs > 1: out_text += ' }\n' @@ -362,10 +434,21 @@ def mlator_loop( xy_dim = input_shape[ll][1] * input_shape[ll][2] short_write = xy_dim == 1 chunk = max(1, narrow_chunk) - if not short_write: + if not short_write and out_size == 1: out_text += ' offs = 0x0000;\n' if not first_output: - out_text += f' out_buf = ((uint8_t *) out_buf32) + 0x{written:04x};\n' + if scale_output and out_size == 1: + out_text += f' out_buf = ((uint{o_width*2}_t *) out_buf32)'\ + f'+ 0x{(written // 2):04x};\n' + elif scale_output and out_size == 4: + out_text += f' temp_out_buf = ((uint32_t *) out_buf32)'\ + f'+ 0x{(written // 4):04x};\n' + elif not scale_output and out_size == 4: + out_text += f' temp_out_buf = ((uint32_t *) out_buf32)'\ + f'+ 0x{(written // 4):04x};\n' + else: + out_text += f' out_buf = ((uint{o_width}_t *) out_buf32)'\ + f'+ 0x{written:04x};\n' while idx < len(emit_list): # Find how many have the same r/w addresses with different shift, # then how many the same deltas between rs and ws with the same set of shifts. @@ -407,47 +490,82 @@ def mlator_loop( else: prefix = '' for _ in range(min(remaining, chunk)): - if delta_r == 4: - out_text += f'{prefix} val = *addr++;\n' - else: - out_text += f'{prefix} val = *addr;\n' \ - f'{prefix} addr {"+" if delta_r >= 0 else "-"}= ' \ - f'0x{abs(delta_r) // 4:04x};\n' - for shift in shift_list: - if not short_write: - out_text += f'{prefix} out_buf[offs' - if shift > 0: - out_text += f'+0x{xy_dim * shift:02x}' - out_text += '] = ' + if out_size == 4: + if final_scale[ll] != 0 and final_scale[ll] > 0 and scale_output: + out_text += f'{prefix} *temp_out_buf++ = (*addr++)'\ + f'<< {final_scale[ll]};\n' + elif final_scale[ll] != 0 and final_scale[ll] < 0 and scale_output: + out_text += f'{prefix} *temp_out_buf++ = (int32_t)(*addr++)'\ + f' >> {abs(final_scale[ll])};\n' else: - out_text += f'{prefix} *out_buf++ = ' - if shift == 0: - out_text += 'val' + out_text += f'{prefix} *temp_out_buf++ = *addr++;\n' + else: + if delta_r == 4: + out_text += f'{prefix} val = *addr++;\n' else: - out_text += f'(val >> {shift * 8})' - out_text += ' & 0xff;\n' + out_text += f'{prefix} val = *addr;\n' \ + f'{prefix} addr {"+" if delta_r >= 0 else "-"}= '\ + f' 0x{abs(delta_r) // 4:04x};\n' + for shift in shift_list: + if not short_write: + out_text += f'{prefix} out_buf[offs' + if shift > 0: + out_text += f'+0x{xy_dim * shift:02x}' + out_text += '] = ' + else: + out_text += f'{prefix} *out_buf++ = ' + if scale_output: + if shift == 0: + out_text += f'(int{o_width*2}_t)((val' + else: + out_text += f'(int{o_width*2}_t)(((val >> {shift * 8})' + else: + if shift == 0: + out_text += '(val' + else: + out_text += f'((val >> {shift * 8})' + if not scale_output or final_scale[ll] == 0: + out_text += ' & 0xff);\n' + elif final_scale[ll] > 0: + out_text += f' & 0xff)) << {final_scale[ll]};\n' + elif final_scale[ll] < 0: + out_text += ' & 0xff) << 8) >>'\ + f'{8 + abs(final_scale[ll])};\n' - if not short_write: - out_text += f'{prefix} offs++;\n' + if not short_write: + out_text += f'{prefix} offs++;\n' if loop_runs > 1: out_text += ' }\n' remaining -= loop_runs * chunk out_addr += 4 * loop_runs * chunk idx += (run + 1) * shift_count - if not short_write and idx < len(emit_list) and shift_count > 1: + if not short_write and idx < len(emit_list) and \ + shift_count > 1 and out_size == 1: out_text += f' offs += 0x{xy_dim * (shift_count - 1):04x};\n' - # Always a byte counter - written += input_shape[ll][0] * input_shape[ll][1] * input_shape[ll][2] \ - * output_width[ll] // 8 + if out_size == 4: + written += input_shape[ll][0] * input_shape[ll][1] * input_shape[ll][2] \ + * out_size + else: + if scale_output: + written += ((input_shape[ll][0] * input_shape[ll][1] * + input_shape[ll][2] + 1) // 2) * 4 + else: + written += ((input_shape[ll][0] * input_shape[ll][1] * + input_shape[ll][2] + 3) // 4) * 4 first_output = False prev_out_size = out_size - if o_width != 32 and have_non_mlator: + if o_width != 32 and have_non_mlator and not scale_output: memfile.write(f' uint{o_width}_t *out_buf = (uint{o_width}_t *) out_buf32;\n') memfile.write(' uint32_t val;\n') + if o_width != 32 and scale_output: + memfile.write(f' uint{o_width*2}_t *out_buf = (uint{o_width*2}_t *) out_buf32;\n') + memfile.write(' uint32_t val;\n') + if 32 in o_widths and o_width != 32: + memfile.write(' uint32_t *temp_out_buf;\n') if o_width == 32 or have_non_mlator: memfile.write(' volatile uint32_t *addr;\n') if mlator_layers: diff --git a/trained/ai87-pascalvoc-fpndetector-qat8-q.pth.tar b/trained/ai87-pascalvoc-fpndetector-qat8-q.pth.tar index 98a9b2b9..ec4b4ca7 100644 Binary files a/trained/ai87-pascalvoc-fpndetector-qat8-q.pth.tar and b/trained/ai87-pascalvoc-fpndetector-qat8-q.pth.tar differ diff --git a/trained/ai87-pascalvoc-fpndetector-qat8.pth.tar b/trained/ai87-pascalvoc-fpndetector-qat8.pth.tar index f7c2771a..3b380f38 100644 Binary files a/trained/ai87-pascalvoc-fpndetector-qat8.pth.tar and b/trained/ai87-pascalvoc-fpndetector-qat8.pth.tar differ