Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Dynamic shape #5625

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions docs/source/compression/dynamic_shape.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
Compression for Models with Dynamic-shape Input
==================
Compression for models with dynamic-shape input is a novel experimental feature incorporated into NNI 3.0.
This feature makes deployment more convenient. For example, when we feed multiple images to the neural network, we don't have to create multiple models decided by their height and width. And when we feed a piece of text into a neural network, the length of the text is no longer limited by a fixed format.
This feature is mainly achieved through two steps. First, create a dynamic ONNX model, and then create a dynamic TensorRT engine through the dynamic ONNX model.
.. Note::

NNI strives to ensure maximum compatibility among different compressors in dynamic-shape compression.
Nevertheless, it is impossible to avoid mutual interference in model modification between different compression algorithms in some individual scenarios.
We encourage users to integrate algorithms after acquiring a comprehensive understanding of the fundamental principles of compression methods.
If you encounter any problems or doubts that cannot be resolved while using dynamic-shape compression, you are welcome to raise an issue for discussion.

Main API
--------

To explain how dynamic-shape compression worked, we should know that each module in the model has a corresponding wrapper in the compressor.
The wrapper stores the necessary data required for compression.
``ModelSpeedupTensorRT`` append ``dummy_input`` as a parameter instead of ``input_shape``.
``dummy_input`` is an input that satisfies the torch model you want to deploy. It is used to create a onnx-model.
In addition, you should provide two parameters, ``dynamic_axes`` and ``dynamic_shape_setting``.
``dynamic_axes`` determine which dimensions of the model's input (or output) you set as dynamic.
``dynamic_shape_setting`` is to determine the specific range of the dynamic shape you set.

Example
-------
Quantize Bert and Deploy Model into ONNX&TensorRT with Dynamic-shape input

The full example can be found `here <https://github.com/microsoft/nni/examples/tutorials/quantization_bert_glue.py>`__.

The following code is a common pipeline with quantization first and then deployment.

.. code-block:: python
...
task_name = 'rte'
finetune_lr = 4e-5
quant_lr = 1e-5
quant_method = 'ptq'
...
config_list = [{
'op_types': ['Linear'],
'op_names_re': ['bert.encoder.layer.{}'.format(i) for i in range(12)],
'target_names': ['weight', '_input_','_output_'],
'quant_dtype': 'int8',
'quant_scheme': 'symmetric',#'affine''symmetric'
'granularity': 'default',
}]

The same steps as the normal quantization by nni, first set the hyperparameters of the quantizer configuration.
When the fake-quantize finished, save the parameters of the quantization node as ``calibration_config``,
and then remove the quantization node in the model by ``quantizer.unwrap_model()``.
Prepare the ``dummy_input`` required by the input model.
In order to more accurately meet the model input requirements, it is recommended to extract ``dummy_input`` directly from the training-dataset or val-dataset of the task.
Modify the ``dummy_input`` to the ``dict`` data type through the function ``transfer_dummy_input``.

.. code-block:: python
...
input_names=['input_ids','token_type_ids','attention_mask']
dummy_input = transfer_dummy_input(dummy_input,input_names)


``dynamic_axes`` is a dict. The dict keys are names of input and output whose shape is dynamic,
the dict values are dicts which specify dimensions are dynamic.
``dynamic_shape_setting`` requires you to provide three parameters, which are the smallest shape of your input, the commonly used shape, and the largest shape.
These three parameters facilitate TensorRT to allocate memory space to the model.

.. code-block:: python
...
dynamic_axes={'input_ids' : {1 : 'seq_len'},
'token_type_ids' : {1 : 'seq_len'},
'attention_mask' : {1 : 'seq_len'}}
dynamic_shape_setting ={'min_shape' : (1,18),
'opt_shape' : (1,72),
'max_shape' : (1,360)}
...
.. code-block:: python
...
engine = ModelSpeedupTensorRT(model, dummy_input=dummy_input, config=calibration_config, onnx_path='bert_rte.onnx',input_names=['input_ids','token_type_ids','attention_mask'],output_names=['output'],
dynamic_axes = dynamic_axes,
dynamic_shape_setting = dynamic_shape_setting)
engine.compress()

After ``engine.compress()``,you get a TensorRT engine of original model.
You can test model's output and inference time by ``output, time_span = engine.inference(dummy_input)``
You can test model's accuracy by ``test_Accuracy(engine)``
111 changes: 105 additions & 6 deletions examples/tutorials/quantization_bert_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,18 @@
from transformers.training_args import TrainingArguments


task_name = 'qnli'
task_name = 'rte'
finetune_lr = 4e-5
quant_lr = 1e-5
quant_method = 'lsq'
dev_mode = True
quant_method = 'ptq'
dev_mode = False

if dev_mode:
quant_max_epochs = 1
finetune_max_epochs = 1
else:
quant_max_epochs = 10
finetune_max_epochs = 10
finetune_max_epochs = 10


# %%
Expand Down Expand Up @@ -212,13 +212,42 @@ def build_finetuning_model(state_dict_path: str, is_quant=False):
from nni.contrib.compression.quantization import QATQuantizer, LsqQuantizer, PtqQuantizer
from nni.contrib.compression.utils import TransformersEvaluator

# dummy_input is used for torch2onnx and onnx2trt

# transfer dummy_input type into dict
def transfer_dummy_input(dummy_input,input_names):
dict_dummy_input = {}
if isinstance(dummy_input,dict):
for input_name,input_tensor in dummy_input.items():
if torch.is_tensor(input_tensor):
continue
else:
dummy_input[input_name] = torch.tensor(input_tensor)
dict_dummy_input = dummy_input
elif isinstance(dummy_input,tuple):
for i in range(len(dummy_input)):
if torch.is_tensor(dummy_input[i]):
continue
else:
temp_dummy_input = torch.tensor(dummy_input[i])
dict_dummy_input[input_names[i]] = temp_dummy_input
elif torch.is_tensor(dummy_input):
dict_dummy_input[input_names[0]] = dummy_input
else :
print('the dummy_input type is not allowed !')
return dict_dummy_input

dummy_input = ([[101, 11271, 20726, 1010, 1996, 7794, 1997, 1996, 3364, 5696, 20726, 1010, 2038, 2351, 1997, 11192, 4456, 2012, 2287, 4008, 1010, 2429, 2000, 1996, 5696, 20726, 3192, 1012, 102, 5696, 20726, 2018, 2019, 4926, 1012, 102]],[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]],[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
input_names=['input_ids','token_type_ids','attention_mask']
dummy_input = transfer_dummy_input(dummy_input,input_names)

def fake_quantize():
config_list = [{
'op_types': ['Linear'],
'op_names_re': ['bert.encoder.layer.{}'.format(i) for i in range(12)],
'target_names': ['weight', '_output_'],
'target_names': ['weight', '_input_','_output_'],
'quant_dtype': 'int8',
'quant_scheme': 'affine',
'quant_scheme': 'symmetric',#'affine''symmetric'
'granularity': 'default',
}]

Expand All @@ -243,6 +272,75 @@ def fake_quantize():
quantizer.evaluator.bind_model(model, quantizer._get_param_names_map())
print(quantizer.evaluator.evaluate())

model.eval()
model.to('cpu')
print('quantized torch-model output: ', model(**dummy_input))
model.to('cuda')
quantizer.unwrap_model()
evaluate()

# Speed up the model with TensorRT
from nni.compression.pytorch.quantization_speedup import ModelSpeedupTensorRT
engine = ModelSpeedupTensorRT(model, dummy_input=dummy_input, config=calibration_config, onnx_path='bert_rte.onnx',input_names=['input_ids','token_type_ids','attention_mask'],output_names=['output'],
dynamic_axes={'input_ids' : {1 : 'seq_len'},
'token_type_ids' : {1 : 'seq_len'},
'attention_mask' : {1 : 'seq_len'}},
dynamic_shape_setting ={'min_shape' : (1,18),
'opt_shape' : (1,72),
'max_shape' : (1,360)})
engine.compress()
import time
start_time = time.time()
output, time_span = engine.inference(dummy_input)
infer_time = time.time() - start_time
print('test dummy_input inference output: ', output)
print('test dummy_input inference time: ', time_span, infer_time)
test_Accuracy(engine)

def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

def test_Accuracy(engine):
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
_, validation_datasets = prepare_datasets(task_name, tokenizer, '')
merged_validation_dataset = ConcatDataset([d for d in validation_datasets.values()]) # type: ignore
true_cnt = 0
total_time = 0
for input_data in merged_validation_dataset:
for input_name,input_tensor in input_data.items():
if 'labels' != input_name:
input_data[input_name] = torch.tensor([input_tensor])
test_data = {key: input_data[key] for key in list(input_data.keys())[:-1]}
output, time_span = engine.inference(test_data,reset_context=True)
total_time += time_span
prediction = torch.argmax(output,-1)
if input_data['labels'] == prediction:
true_cnt +=1
Accuracy = true_cnt/len(merged_validation_dataset)
print('inference time: ', total_time /len(merged_validation_dataset))
print('Accuracy of mode #1: ', Accuracy)

def test_onnx_Accuracy(onnx_model):
import onnxruntime
ort_session = onnxruntime.InferenceSession(onnx_model)
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
_, validation_datasets = prepare_datasets(task_name, tokenizer, '')
merged_validation_dataset = ConcatDataset([d for d in validation_datasets.values()]) # type: ignore
true_cnt = 0
for input_data in merged_validation_dataset:
for input_name,input_tensor in input_data.items():
if 'labels' != input_name:
input_data[input_name] = to_numpy(torch.tensor([input_tensor]))
test_data = {key: input_data[key] for key in list(input_data.keys())[:-1]}
output = ort_session.run(None, test_data)
prediction = np.argmax(output,-1)
if input_data['labels'] == prediction:
true_cnt +=1
Accuracy = true_cnt/len(merged_validation_dataset)
print('Accuracy of mode #1: ', Accuracy)



def evaluate():
model = build_finetuning_model(f'./output/bert_finetuned/{task_name}.bin', is_quant=False)
trainer = prepare_traced_trainer(model, is_quant=False)
Expand All @@ -251,6 +349,7 @@ def evaluate():


fake_quantize()
test_onnx_Accuracy('bert_rte.onnx')
evaluate()


Expand Down
77 changes: 51 additions & 26 deletions nni/compression/pytorch/quantization_speedup/frontend_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,30 +54,52 @@ def unwrapper(model_onnx, index2name, config):
dict
The configuration of onnx model layers and calibration parameters
"""
# Support Gemm, Conv, Relu, Clip(Relu6) and Maxpool
support_op = ['Gemm', 'Conv', 'Relu', 'Clip', 'MaxP']
# Support Gemm, Conv, Relu, Clip(Relu6) and Maxpool + MatMul
support_op = ['Gemm', 'Conv', 'Relu', 'Clip', 'MaxP', 'MatMul']
idx = 0
onnx_config = {}
while idx < len(model_onnx.graph.node):
nd = model_onnx.graph.node[idx]
if nd.name[0:4] in support_op and idx > 1:
# Grad constant node and multiply node
const_nd = model_onnx.graph.node[idx-2]
mul_nd = model_onnx.graph.node[idx-1]
# Get index number which is transferred by constant node
index = int(onnx.numpy_helper.to_array(const_nd.attribute[0].t))
if index != -1:
name = index2name[index]
onnx_config[nd.name] = config[name]
nd.input[0] = mul_nd.input[0]
# Remove constant node and multiply node
model_onnx.graph.node.remove(const_nd)
model_onnx.graph.node.remove(mul_nd)
idx = idx-2
idx = idx+1
mul_name_list =[]
const_name_list = []
const_list = []
mul_list = []
#find mul node output name
for node in model_onnx.graph.node:
for op in support_op:
if op in node.name:
for node_input_name in node.input:
if 'Mul_output' in node_input_name:
mul_name_list.append(node_input_name)
#find const node output name by mul node output name
for node in model_onnx.graph.node:
if node.output[0] in mul_name_list:
for node_input_name in node.input:
if 'Constant_output' in node_input_name:
const_name_list.append(node_input_name)
# find mul node and const node
for node in model_onnx.graph.node:
for nd_name in mul_name_list:
if node.output[0] == nd_name:
mul_list.append(node)
for nd_name in const_name_list:
if node.output[0] == nd_name:
const_list.append(node)
for node in model_onnx.graph.node:
for mul_node in mul_list:
if mul_node.output[0] in node.input:
# import pdb;pdb.set_trace()
for const_node in const_list:
if const_node.output[0] in mul_node.input:
# import pdb;pdb.set_trace()
index = int(onnx.numpy_helper.to_array(const_node.attribute[0].t))
if index != -1:
name = index2name[index]
onnx_config[node.name] = config[name]
node.input[0] = mul_node.input[0]
model_onnx.graph.node.remove(const_node)
model_onnx.graph.node.remove(mul_node)
return model_onnx, onnx_config

def torch_to_onnx(model, config, input_shape, model_path, input_names, output_names):
def torch_to_onnx(model, config, dummy_input, model_path, input_names, output_names,dynamic_axes=None):
"""
Convert torch model to onnx model and get layer bits config of onnx model.

Expand All @@ -103,6 +125,8 @@ def torch_to_onnx(model, config, input_shape, model_path, input_names, output_na
dict
The configuration of onnx model layers and calibration parameters
"""
device = torch.device('cpu')
model.to(device)
# Support Gemm, Conv, Relu, Clip(Relu6) and MaxPool
support_op = [torch.nn.Conv2d, torch.nn.Linear, torch.nn.ReLU, torch.nn.ReLU6, torch.nn.MaxPool2d]
# Transfer bits number to onnx layer by using wrapper
Expand All @@ -124,14 +148,15 @@ def torch_to_onnx(model, config, input_shape, model_path, input_names, output_na
set_nested_attr(model, name, wrapper_module)
# Convert torch model to onnx model and save it in model_path
device = torch.device('cpu')
dummy_input = torch.randn(input_shape)
dummy_input = dummy_input.to(device)
model.to(device)
torch.onnx.export(model, dummy_input, model_path, verbose=False, input_names=input_names, output_names=output_names, export_params=True)
if(dynamic_axes == None):
dynamic_axes = {'input' : {2 : 'image_height',3:'image_wdith'}, #for image
'output' : {2 : 'image_height',3:'image_wdith'}}
# dummy_input = dummy_input.to(device)
# model.to(device)
torch.onnx.export(model, dummy_input, model_path, verbose=False, input_names=input_names, output_names=output_names, export_params=True,opset_version=11,dynamic_axes=dynamic_axes)
# Load onnx model
model_onnx = onnx.load(model_path)
model_onnx, onnx_config = unwrapper(model_onnx, index2name, config)
onnx.save(model_onnx, model_path)

onnx.checker.check_model(model_onnx)
return model_onnx, onnx_config
return model_onnx, onnx_config
Loading