Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaoyu-work committed Mar 21, 2024
1 parent 1f8b9a8 commit 174f3c0
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 5 deletions.
2 changes: 1 addition & 1 deletion onnxruntime/python/tools/transformers/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from onnx_model_vae import VaeOnnxModel

import onnxruntime

Check notice

Code scanning / CodeQL

Module is imported with 'import' and 'import from' Note

Module 'onnxruntime' is imported with both 'import' and 'import from'.
from onnxruntime.python.tools.transformers.optimizer_utils import extract_external_data_from_model
from onnxruntime.transformers.optimizer_utils import extract_external_data_from_model

logger = logging.getLogger(__name__)

Expand Down
10 changes: 8 additions & 2 deletions onnxruntime/python/tools/transformers/optimizer_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from fusion_utils import NumpyHelper
from onnxruntime import OrtValue
from onnx.external_data_helper import set_external_data
from onnx import ModelProto
from onnx.external_data_helper import set_external_data

from onnxruntime import OrtValue


def extract_external_data_from_model(model: ModelProto):
"""
Expand Down
9 changes: 7 additions & 2 deletions onnxruntime/test/python/transformers/test_optimizer_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import unittest

import numpy
from onnx import ModelProto, TensorProto, helper
from onnxruntime.python.tools.transformers.optimizer_utils import extract_external_data_from_model

from onnxruntime.transformers.optimizer_utils import extract_external_data_from_model


class TestOptimizerUtils(unittest.TestCase):
Expand All @@ -12,7 +18,6 @@ def test_extract_external_data_from_model(self):
self.assertEqual(len(external_values), 1)
self.assertEqual(external_values[0].numpy(), [0.0])


def _get_model_proto_with_raw_data(self) -> ModelProto:
input = helper.make_tensor_value_info("inputs", TensorProto.FLOAT, [None])
output = helper.make_tensor_value_info("outputs", TensorProto.FLOAT, [None])
Expand Down

0 comments on commit 174f3c0

Please sign in to comment.