Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

<OnnxValue>.getValue() returns non-parseable java object #19440

Open
jazzblue opened this issue Feb 6, 2024 · 3 comments
Open

<OnnxValue>.getValue() returns non-parseable java object #19440

jazzblue opened this issue Feb 6, 2024 · 3 comments
Labels
stale issues that have not been addressed in a while; categorized by a bot

Comments

@jazzblue
Copy link

jazzblue commented Feb 6, 2024

Describe the issue

I found that it is probably the same issue as #16781.

I am using ONNX to serve a scikit-learn trained model inside Java code. The output is returned as OnnxValue object and I apply getValue() to retrieve the output value. As per API documentation it is supposed to return the value as a Java object and I understand I should be able to extract the primitive value, such as float or array. At least for OnnxTensor the API doc says Either returns a boxed primitive if the Tensor is a scalar, or a multidimensional array of primitives if it has multiple dimensions. Logging the type, by applying getType() method, shows the correct type OnnxTensor(info=TensorInfo(javaType=INT64,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,shape=[1])). However, casting it into long, or int, throws exception and I do not see any other method or way to get the primitive or array. How would I extract the value from the java object?

To reproduce

  1. Java version:
openjdk version "11.0.21" 2023-10-17
OpenJDK Runtime Environment (build 11.0.21+9-post-Ubuntu-0ubuntu122.04)
OpenJDK 64-Bit Server VM (build 11.0.21+9-post-Ubuntu-0ubuntu122.04, mixed mode)
  1. Directory tree:
|-- pom.xml
|-- src
|   `-- main
|       |-- java
|       |   `-- onnx
|       |       `-- example
|       |           `-- OnnxRf.java
|       `-- resources
|           `-- rf_iris.onnx
  1. Training script (in python)
    pip install skl2onnx
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier

iris = load_iris()
X, y = iris.data, iris.target
X = X.astype(np.float32)
X_train, X_test, y_train, y_test = train_test_split(X, y)
clr = RandomForestClassifier()
clr.fit(X_train, y_train)

# Convert into ONNX format.
from skl2onnx import to_onnx

onx = to_onnx(clr, X)
with open("rf_iris.onnx", "wb") as f:
    f.write(onx.SerializeToString())

save the onnx packaged model rf_iris.onnx under src/main/resources
4. File pom.xml

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>onnx.example</groupId>
    <artifactId>onnx-example</artifactId>
    <version>1.0</version>

    <properties>
        <maven.compiler.source>11</maven.compiler.source>
        <maven.compiler.target>11</maven.compiler.target>
        <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
        <main.class>onnx.example.OnnxRf</main.class>
        <onnxruntime.version>1.16.3</onnxruntime.version>
    </properties>
    <dependencies>
        <dependency>
            <groupId>com.microsoft.onnxruntime</groupId>
            <artifactId>onnxruntime</artifactId>
            <version>${onnxruntime.version}</version>
        </dependency>
    </dependencies>
    <build>
        <plugins>
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-shade-plugin</artifactId>
                <version>3.1.1</version>
                <executions>
                    <execution>
                        <phase>package</phase>
                        <goals>
                            <goal>shade</goal>
                        </goals>
                        <configuration>
                            <artifactSet>
                                <excludes>
                                    <exclude>com.google.code.findbugs:jsr305</exclude>
                                </excludes>
                            </artifactSet>
                            <filters>
                                <filter>
                                    <!-- Do not copy the signatures in the META-INF folder.
                                    Otherwise, this might cause SecurityExceptions when using the JAR. -->
                                    <artifact>*:*</artifact>
                                    <excludes>
                                        <exclude>META-INF/*.SF</exclude>
                                        <exclude>META-INF/*.DSA</exclude>
                                        <exclude>META-INF/*.RSA</exclude>
                                    </excludes>
                                </filter>
                            </filters>
                            <transformers>
                                <transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
                                    <!-- Replace this with the main class of your job -->
                                    <mainClass>onnx.example.OnnxRf</mainClass>
                                </transformer>
                                <transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
                            </transformers>
                        </configuration>
                    </execution>
                </executions>
            </plugin>
        </plugins>
    </build>
</project>
  1. File OnnxRf.java
package onnx.example;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import java.util.*;

import ai.onnxruntime.OnnxValue;

public class OnnxRf {

    public static void main(String[] args) throws Exception {

        OrtSession session = null;
        OrtEnvironment env;
        Set<String> outputNames;

        Map<String, OnnxTensor> inputs = new HashMap();
        Iterator<String> namesIterator;
        String inFieldName;
        float[][] inputShaped;
        inputShaped = new float[1][1];

        env = OrtEnvironment.getEnvironment();
        session = env.createSession("src/main/resources/rf_iris.onnx", new OrtSession.SessionOptions());
        outputNames = session.getOutputNames();

        float[] inVal = new float[] {1.1f, 2.3f, 3.4f, 5.6f};
        inputShaped[0] = inVal;

        inputs.put("X", OnnxTensor.createTensor(env, inputShaped));
        System.out.println ("outputNames: " + outputNames);

        try (var results = session.run(inputs, outputNames)) {
            System.out.println ("----- output types -----");
            for (String fieldName : outputNames) {
                System.out.println(fieldName + ": " + results.get(fieldName).get() + " : " + results.get(fieldName).get().getType());
            }
            System.out.println ("----------");
            System.out.println("output_label, class: " + results.get("output_label").get().getValue().getClass());
            System.out.println("output_label, str: " + results.get("output_label").get().getValue().toString());
            System.out.println("output_label, getType: " + results.get("output_label").get().getType());
            System.out.println("output_label, getInfo: " + results.get("output_label").get().getInfo());

            // Trying to cast to long here since output_label output type is shown as INT64, but int did not work either
            System.out.println("output_label, long: " + (long) results.get("output_label").get().getValue());

        } catch (OrtException e) {
            // e.printStackTrace();
            System.out.println (">>>>>" + e.getCode() + ": " + e.getMessage());
        }
    }
}
  1. Build Java package
mvn package
  1. Run the Java inference code using ONNX model
java -jar target/onnx-example-1.0.jar

Urgency

No response

Platform

Windows

OS Version

Ubuntu 22.04.3 LTS

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

1.16.3

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU

Execution Provider Library Version

No response

@Craigacp
Copy link
Contributor

Craigacp commented Feb 6, 2024

Cast it to long[] not long. You can always reflectively inspect the type of the object returned by getValue, e.g. results.get("output_label").get().getValue().getClass() will return long[].class. Scalars have shape [], whereas this model produces something of shape [batch_size] so while there is only a single element in this case, it's a 1d vector which we return as a 1d array.

@jazzblue
Copy link
Author

jazzblue commented Feb 7, 2024

Cast it to long[] not long. You can always reflectively inspect the type of the object returned by getValue, e.g. results.get("output_label").get().getValue().getClass() will return long[].class. Scalars have shape [], whereas this model produces something of shape [batch_size] so while there is only a single element in this case, it's a 1d vector which we return as a 1d array.

@Craigacp casting to long[] worked, thanks!

Copy link
Contributor

github-actions bot commented Mar 8, 2024

This issue has been automatically marked as stale due to inactivity and will be closed in 30 days if no further activity occurs. If further support is needed, please provide an update and/or more details.

@github-actions github-actions bot added the stale issues that have not been addressed in a while; categorized by a bot label Mar 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale issues that have not been addressed in a while; categorized by a bot
Projects
None yet
Development

No branches or pull requests

2 participants