diff --git a/onnxsim/onnx_simplifier.py b/onnxsim/onnx_simplifier.py index 6ac8d1f..cadb9cd 100644 --- a/onnxsim/onnx_simplifier.py +++ b/onnxsim/onnx_simplifier.py @@ -106,7 +106,9 @@ def get_np_type_from_elem_type(elem_type: int): def get_inputs(model: onnx.ModelProto) -> List[onnx.ValueInfoProto]: + # weights initializer_names = [x.name for x in model.graph.initializer] + # input not in weights return [ipt for ipt in model.graph.input if ipt.name not in initializer_names] @@ -127,9 +129,11 @@ def remove_unused_output(model: onnx.ModelProto, unused_output: Sequence[str]) - if unused_output_name not in output_names: raise RuntimeError( f'The model doesn\'t have output named "{unused_output_name}"') + for graph_output in copy.deepcopy(model.graph.output): if graph_output.name in unused_output_names: model.graph.output.remove(graph_output) + model = onnxoptimizer.optimize(model, ['eliminate_deadend'], fixed_point=True) onnx.checker.check_model(model) @@ -144,9 +148,20 @@ def generate_specific_rand_input(model, input_shapes: TensorShapes): for key, shape in input_shapes.items(): shape_np = np.array(shape) if not np.all(shape_np > 0): - # treat batch size as 1 automatically if dynamic_input_shape is True + if config.dynamic_input_shape and len(shape_np) >= 3 and np.all(shape_np[1:] > 0): + # treat batch size as 1 automatically if dynamic_input_shape is True input_shapes[key] = [1] + shape[1:] + print(f"dynamic input shape: {key} {shape} -> {input_shapes[key]}") + continue + + if config.dynamic_input_shape: + dyn_dim = np.where(shape_np==-1)[0] + shape_tmp = copy.deepcopy(shape) + for i, idx in enumerate(dyn_dim): + shape_tmp[idx.item()] = 10 + i + input_shapes[key] = shape_tmp + print(f"dynamic input shape: {key} {shape} -> {shape_tmp}") continue raise RuntimeError( @@ -362,6 +377,7 @@ def optimize(model: onnx.ModelProto, skip_fuse_bn: bool, skipped_optimizers: Opt onnx.checker.check_model(model) onnx.helper.strip_doc_string(model) optimizers_list = onnxoptimizer.get_fuse_and_elimination_passes() + print(f'opt passes: {optimizers_list}') if skip_fuse_bn: optimizers_list.remove('fuse_bn_into_conv') if skipped_optimizers is not None: