Skip to content

Commit

Permalink
accept numpy as input dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
Prathik Rao committed Dec 19, 2023
1 parent 6d7519e commit a1c65ef
Showing 1 changed file with 8 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import List, Mapping, Optional, Sequence, Tuple, Union

import torch
import numpy

from onnxruntime.training.utils.torch_profile_utils import nvtx_function_decorator

Expand Down Expand Up @@ -43,6 +44,7 @@ def get_primitive_dtype(value):
int,
bool,
float,
numpy.ndarray,
torch.Tensor,
Sequence["ORTModelInputOutputType"],
Mapping[str, "ORTModelInputOutputType"],
Expand Down Expand Up @@ -111,6 +113,7 @@ def __eq__(self, other):
ORTModelInputOutputSchemaType = Union[
None,
str,
numpy.ndarray,
_TensorStub,
Sequence["ORTModelInputOutputSchemaType"],
Mapping[str, "ORTModelInputOutputSchemaType"],
Expand Down Expand Up @@ -196,6 +199,9 @@ def _flatten_from_data(data: ORTModelInputOutputType, prefix_name: str = ""):
tensor_idx[0], dtype=PrimitiveType.get_primitive_dtype(data), shape_dims=0, name=prefix_name
)
return data
elif isinstance(data, numpy.ndarray):
_warn_of_constant_inputs(data)
return data
# Depth first traversal to iterate over the data to replace every tensor with a stub
elif isinstance(data, torch.Tensor):
tensor_idx[0] += 1
Expand Down Expand Up @@ -290,6 +296,8 @@ def _replace_stub_with_tensor_value(data_schema: ORTModelInputOutputSchemaType,
return data_schema
elif PrimitiveType.is_primitive_type(data_schema):
return data_schema
elif isinstance(data_schema, numpy.ndarray):
return data_schema
elif isinstance(data_schema, _TensorStub):
assert isinstance(
data[data_schema.tensor_idx], torch.Tensor
Expand Down

0 comments on commit a1c65ef

Please sign in to comment.