Skip to content

Commit

Permalink
using SwapByteOrderCopy for SetIndices
Browse files Browse the repository at this point in the history
  • Loading branch information
ranjitshs committed Jun 19, 2024
1 parent e435327 commit 28ee5ca
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 12 deletions.
12 changes: 8 additions & 4 deletions onnxruntime/core/framework/tensorprotoutils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1648,7 +1648,14 @@ static void SetIndices(gsl::span<int64_t> gathered_indices,
} else {
auto* dst = ind_dest + dest_index;
T v = static_cast<T>(src_index);
memcpy(dst, &v, sizeof(T));
if constexpr (endian::native != endian::little) {
auto src = gsl::make_span<const unsigned char>(static_cast<const unsigned char*>(reinterpret_cast<const unsigned char*>(&v)), sizeof(T));
auto dest = gsl::make_span<unsigned char>(static_cast<unsigned char*>(reinterpret_cast<unsigned char*>(dst)) , sizeof(T));
onnxruntime::utils::SwapByteOrderCopy(sizeof(T),src ,dest);
}
else {
memcpy(dst, &v, sizeof(T));
}
}
++dest_index;
}
Expand Down Expand Up @@ -1698,9 +1705,6 @@ static void SparsifyGeneric(const void* dense_raw_data, size_t n_dense_elements,
} else {
SetIndices<int64_t>(gathered_span, raw_indices, indices);
}
if constexpr (endian::native != endian::little) {
utils::ConvertRawDataInTensorProto((ONNX_NAMESPACE::TensorProto*)&indices);
}
} else {
indices.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT8);
utils::SetRawDataInTensorProto(indices,std::string());
Expand Down
14 changes: 6 additions & 8 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1194,15 +1194,13 @@ Graph::Graph(const Model& owning_model,
const gsl::not_null<TensorProto*> tensor{graph_proto_->add_initializer()};
auto status = utils::ConstantNodeProtoToTensorProto(node, model_path, *tensor);
if constexpr (endian::native != endian::little) {
const AttributeProto& attrib = node.attribute(0);
if (attrib.type() == AttributeProto_AttributeType_SPARSE_TENSOR)
{
const TensorProto& sparse_values = node.attribute(0).sparse_tensor().values();
if ((!(sparse_values.has_raw_data())) && tensor->has_raw_data())
{
const AttributeProto& attrib = node.attribute(0);
if (attrib.type() == AttributeProto_AttributeType_SPARSE_TENSOR) {
const TensorProto& sparse_values = node.attribute(0).sparse_tensor().values();
if ((!(sparse_values.has_raw_data())) && tensor->has_raw_data()) {
onnxruntime::utils::ConvertRawDataInTensorProto(tensor);
}
}
}
}
}
ORT_ENFORCE(status.IsOK(), status.ToString());
// Ensure initializers are also graph inputs.
Expand Down

0 comments on commit 28ee5ca

Please sign in to comment.