From ee765d17e6270d58daf432597053489438929f7b Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Tue, 30 Jan 2024 16:18:48 +0100 Subject: [PATCH] lint --- .../orttraining/test/python/orttraining_test_dort.py | 8 ++++++++ .../orttraining/test/python/orttraining_test_ortvalue.py | 3 ++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/orttraining/orttraining/test/python/orttraining_test_dort.py b/orttraining/orttraining/test/python/orttraining_test_dort.py index a5e9dbc329d88..32cb3fdbf8d8c 100644 --- a/orttraining/orttraining/test/python/orttraining_test_dort.py +++ b/orttraining/orttraining/test/python/orttraining_test_dort.py @@ -519,6 +519,14 @@ def test_expand(self): with self.subTest(requires_grad=req, test_backend_backward=back, use_aot_autograd=grad): self.common_test_expand(req, back, grad) + def test_slice(self): + x = torch.arange(20, requires_grad=True, dtype=torch.float32).reshape((-1, 4)) + self.assertONNX(lambda x: x[:, 1:2], x) + + def test_slice_dynamic(self): + x = torch.rand(3, 4, requires_grad=True) + self.assertONNX(lambda x: x[x.size(0) :, x.size(1) - 3], x, opset_version=10) + if __name__ == "__main__": unittest.main() diff --git a/orttraining/orttraining/test/python/orttraining_test_ortvalue.py b/orttraining/orttraining/test/python/orttraining_test_ortvalue.py index 9dc5cfb118d01..2a98eafde1c71 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortvalue.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortvalue.py @@ -17,7 +17,8 @@ import onnxruntime as onnxrt from onnxruntime.capi import _pybind_state as C -from onnxruntime.capi.onnxruntime_pybind11_state import OrtValue as C_OrtValue, OrtDevice +from onnxruntime.capi.onnxruntime_pybind11_state import OrtDevice +from onnxruntime.capi.onnxruntime_pybind11_state import OrtValue as C_OrtValue from onnxruntime.capi.onnxruntime_pybind11_state import OrtValueVector from onnxruntime.training.ortmodule import ORTModule, _utils