Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Jan 30, 2024
1 parent a5f61f3 commit ee765d1
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
8 changes: 8 additions & 0 deletions orttraining/orttraining/test/python/orttraining_test_dort.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,14 @@ def test_expand(self):
with self.subTest(requires_grad=req, test_backend_backward=back, use_aot_autograd=grad):
self.common_test_expand(req, back, grad)

def test_slice(self):
x = torch.arange(20, requires_grad=True, dtype=torch.float32).reshape((-1, 4))
self.assertONNX(lambda x: x[:, 1:2], x)

def test_slice_dynamic(self):
x = torch.rand(3, 4, requires_grad=True)
self.assertONNX(lambda x: x[x.size(0) :, x.size(1) - 3], x, opset_version=10)


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

import onnxruntime as onnxrt
from onnxruntime.capi import _pybind_state as C
from onnxruntime.capi.onnxruntime_pybind11_state import OrtValue as C_OrtValue, OrtDevice
from onnxruntime.capi.onnxruntime_pybind11_state import OrtDevice
from onnxruntime.capi.onnxruntime_pybind11_state import OrtValue as C_OrtValue
from onnxruntime.capi.onnxruntime_pybind11_state import OrtValueVector
from onnxruntime.training.ortmodule import ORTModule, _utils

Expand Down

0 comments on commit ee765d1

Please sign in to comment.