diff --git a/demo/quant/quant_post/README.md b/demo/quant/quant_post/README.md index 33cbe4a6aaad1..09679745046f4 100755 --- a/demo/quant/quant_post/README.md +++ b/demo/quant/quant_post/README.md @@ -43,7 +43,7 @@ python quant_post_static.py --model_path ./inference_model/MobileNet --save_path 运行以上命令后,可在``${save_path}``下看到量化后的模型文件和参数文件。 -> 使用的量化算法为``'KL'``, 使用训练集中的160张图片进行量化参数的校正。 +> 使用的量化算法为``'hist'``, 使用训练集中的32张图片进行量化参数的校正。 ### 测试精度 @@ -67,6 +67,6 @@ python eval.py --model_path ./quant_model_train/MobileNet --model_name __model__ 精度输出为 ``` -top1_acc/top5_acc= [0.70141864 0.89086477] +top1_acc/top5_acc= [0.70328485 0.89183184] ``` -从以上精度对比可以看出,对``mobilenet``在``imagenet``上的分类模型进行离线量化后 ``top1``精度损失为``0.77%``, ``top5``精度损失为``0.46%``. +从以上精度对比可以看出,对``mobilenet``在``imagenet``上的分类模型进行离线量化后 ``top1``精度损失为``0.59%``, ``top5``精度损失为``0.36%``. diff --git a/demo/quant/quant_post/quant_post.py b/demo/quant/quant_post/quant_post.py index 48a338fc28988..0d8935a5e5159 100755 --- a/demo/quant/quant_post/quant_post.py +++ b/demo/quant/quant_post/quant_post.py @@ -19,13 +19,15 @@ parser = argparse.ArgumentParser(description=__doc__) add_arg = functools.partial(add_arguments, argparser=parser) # yapf: disable -add_arg('batch_size', int, 16, "Minibatch size.") -add_arg('batch_num', int, 10, "Batch number") +add_arg('batch_size', int, 32, "Minibatch size.") +add_arg('batch_num', int, 1, "Batch number") add_arg('use_gpu', bool, True, "Whether to use GPU or not.") add_arg('model_path', str, "./inference_model/MobileNet/", "model dir") add_arg('save_path', str, "./quant_model/MobileNet/", "model dir to save quanted model") add_arg('model_filename', str, None, "model file name") add_arg('params_filename', str, None, "params file name") +add_arg('algo', str, 'hist', "calibration algorithm") +add_arg('hist_percent', float, 0.9999, "The percentile of algo:hist") # yapf: enable @@ -46,7 +48,9 @@ def quantize(args): model_filename=args.model_filename, params_filename=args.params_filename, batch_size=args.batch_size, - batch_nums=args.batch_num) + batch_nums=args.batch_num, + algo=args.algo, + hist_percent=args.hist_percent) def main(): diff --git a/paddleslim/quant/quanter.py b/paddleslim/quant/quanter.py index 9324276b8a894..2522fed7e1d73 100755 --- a/paddleslim/quant/quanter.py +++ b/paddleslim/quant/quanter.py @@ -313,7 +313,9 @@ def quant_post_static( batch_size=16, batch_nums=None, scope=None, - algo='KL', + algo='hist', + hist_percent=0.9999, + bias_correction=False, quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], is_full_quantize=False, weight_bits=8, @@ -358,9 +360,15 @@ def quant_post_static( generated by sample_generator as calibrate data. scope(paddle.static.Scope, optional): The scope to run program, use it to load and save variables. If scope is None, will use paddle.static.global_scope(). - algo(str, optional): If algo=KL, use KL-divergenc method to - get the more precise scale factor. If algo='direct', use - abs_max method to get the scale factor. Default: 'KL'. + algo(str, optional): If algo='KL', use KL-divergenc method to + get the scale factor. If algo='hist', use the hist_percent of histogram + to get the scale factor. If algo='mse', search for the best scale factor which + makes the mse loss minimal. Use one batch of data for mse is enough. If + algo='avg', use the average of abs_max values to get the scale factor. If + algo='abs_max', use abs_max method to get the scale factor. Default: 'hist'. + hist_percent(float, optional): The percentile of histogram for algo hist.Default:0.9999. + bias_correction(bool, optional): Bias correction method of https://arxiv.org/abs/1810.05723. + Default: False. quantizable_op_type(list[str], optional): The list of op types that will be quantized. Default: ["conv2d", "depthwise_conv2d", "mul"]. @@ -397,6 +405,8 @@ def quant_post_static( batch_nums=batch_nums, scope=scope, algo=algo, + hist_percent=hist_percent, + bias_correction=bias_correction, quantizable_op_type=quantizable_op_type, is_full_quantize=is_full_quantize, weight_bits=weight_bits,