From a1c65ef15d42675944096aada9414a8dbf205615 Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Tue, 19 Dec 2023 08:14:58 +0000 Subject: [PATCH] accept numpy as input dtype --- .../orttraining/python/training/utils/torch_io_helper.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/orttraining/orttraining/python/training/utils/torch_io_helper.py b/orttraining/orttraining/python/training/utils/torch_io_helper.py index 34cc1ca942a8c..31b25cb68e8fd 100644 --- a/orttraining/orttraining/python/training/utils/torch_io_helper.py +++ b/orttraining/orttraining/python/training/utils/torch_io_helper.py @@ -9,6 +9,7 @@ from typing import List, Mapping, Optional, Sequence, Tuple, Union import torch +import numpy from onnxruntime.training.utils.torch_profile_utils import nvtx_function_decorator @@ -43,6 +44,7 @@ def get_primitive_dtype(value): int, bool, float, + numpy.ndarray, torch.Tensor, Sequence["ORTModelInputOutputType"], Mapping[str, "ORTModelInputOutputType"], @@ -111,6 +113,7 @@ def __eq__(self, other): ORTModelInputOutputSchemaType = Union[ None, str, + numpy.ndarray, _TensorStub, Sequence["ORTModelInputOutputSchemaType"], Mapping[str, "ORTModelInputOutputSchemaType"], @@ -196,6 +199,9 @@ def _flatten_from_data(data: ORTModelInputOutputType, prefix_name: str = ""): tensor_idx[0], dtype=PrimitiveType.get_primitive_dtype(data), shape_dims=0, name=prefix_name ) return data + elif isinstance(data, numpy.ndarray): + _warn_of_constant_inputs(data) + return data # Depth first traversal to iterate over the data to replace every tensor with a stub elif isinstance(data, torch.Tensor): tensor_idx[0] += 1 @@ -290,6 +296,8 @@ def _replace_stub_with_tensor_value(data_schema: ORTModelInputOutputSchemaType, return data_schema elif PrimitiveType.is_primitive_type(data_schema): return data_schema + elif isinstance(data_schema, numpy.ndarray): + return data_schema elif isinstance(data_schema, _TensorStub): assert isinstance( data[data_schema.tensor_idx], torch.Tensor