Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Drastically Different Result Across Multiple Languages Except Python #17829

Closed
MikuAuahDark opened this issue Oct 7, 2023 · 2 comments
Closed
Labels
api:Java issues related to the Java API api:Javascript issues related to the Javascript API platform:windows issues related to the Windows platform

Comments

@MikuAuahDark
Copy link

MikuAuahDark commented Oct 7, 2023

Describe the issue

I made a model that try to classify a gender based on their input names alone. For reference, here's the model in PyTorch:

class GolModel(torch.nn.Module):
    def __init__(self, nlen: int, nbits: int, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        self.upper_layer = nlen * 2
        self.lstm = torch.nn.LSTM(
            input_size=nbits,
            hidden_size=self.upper_layer,
            num_layers=1,
            bidirectional=False,
            batch_first=True,
        )
        self.fc1 = torch.nn.Linear(self.upper_layer * nlen, self.upper_layer)
        self.result = torch.nn.Linear(self.upper_layer, 2)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        lstm, _ = self.lstm(x)
        lstm_r = lstm.reshape(lstm.size(0), -1)
        fc1 = self.fc1(lstm_r)
        result = torch.nn.functional.softmax(self.result(fc1), 1)
        return result

model = GolModel(64, 21)

(that aside, I only bash random operators on it, so please do not comment on my choice of NNs.)

The model expects n x 64 x 21 tensors where n is the batch size (must be set, 1 if necessary; dimension 1). The input is a string with maximum length of 64 characters (dimension 2), which then converted to its (reversed) binary representation from bits 0 to bits 20 (dimension 3).

For example, the text "a" has code point number of 97, thus it will be converted to [1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]. Roughly, the preprocessing code is as follows in Python:

def convert_name_to_numpy(names: str):
    tensor = numpy.zeros((len(names), MAX_CODEPOINT_LEN, NUMBER_OF_BITS), numpy.float32)

    for n, name in enumerate(names):
        for i in range(min(len(name), MAX_CODEPOINT_LEN)):
            intval = ord(name[i])
            for j in range(NUMBER_OF_BITS):
                tensor[n, i, j] = bool(intval & (1 << j))

    return tensor

(to elaborate, the same code is then rewritten exactly 1:1 to JavaScript and Java)

The tensor is then passed for inference. Now the issue is, while the model is almost has same accuracy when run in Python, it returns completely different result across Java and JavaScript (see below).

Note: I'm using ONNX Runtime 1.16.0 on Python, Java, and JavaScript. The language selector doesn't allow me to pick multiple languages.

To reproduce

Model and code can be downloaded here: complete_reproduction.zip

In there, you can find these files:

  • App.java - Java version of the inference.
  • hitit_v8.py - Complete Python definition of the model, how it's trained in PyTorch and exported to ONNX.
  • hitit_v8_fastinfer.py - ONNX inference using Python.
  • namegender.html - HTML + JavaScript realtime inference.
  • modelgol2.onnx - The model.

Here's the ONNX Python inference output:

D:\omitted>python hitit_v8_fastinfer.py N Na Nav Navi Navia
(debug tensor output omitted)
N
Male: 67.00568795204163 %
Female: 32.99431502819061 %
Guessed: Unisex

Na
Male: 28.068840503692627 %
Female: 71.93116545677185 %
Guessed: Unisex

Nav
Male: 82.9777717590332 %
Female: 17.022231221199036 %
Guessed: Male

Navi
Male: 41.178399324417114 %
Female: 58.821600675582886 %
Guessed: Unisex

Navia
Male: 4.214639961719513 %
Female: 95.7853615283966 %
Guessed: Female

(Python 3.11.4)

Here's JavaScript inference output while typing "Navia", word by word logged using console.log() (added gender hint for clarity):

output Float32Array [ 0.6700569987297058, 0.3299430012702942 ] namegender.html:79:13 (UNISEX)
output Float32Array [ 0.6402597427368164, 0.359740287065506 ] namegender.html:79:13 (UNISEX)
output Float32Array [ 0.8092355132102966, 0.19076451659202576 ] namegender.html:79:13 (MALE)
output Float32Array [ 0.7302872538566589, 0.26971277594566345 ] namegender.html:79:13 (MALE)
output Float32Array [ 0.8071557283401489, 0.19284430146217346 ] namegender.html:79:13 (MALE)

(Mozilla Firefox 118.0.1)

And here's the output in Java:

D:\omitted>gradlew run --args "N Na Nav Navi Navia"
To honour the JVM settings for this build a single-use Daemon process will be forked. See https://docs.gradle.org/7.3/userguide/gradle_daemon.html#sec:disabling_the_daemon.
Daemon will be stopped at the end of the build

> Task :app:run
N
Male: 50.220966339111 %
Female: 49.779033660889 %
Guessed: Unisex

Na
Male: 15.308277308941 %
Female: 84.691721200943 %
Guessed: Female

Nav
Male: 1.0724958963692 %
Female: 98.927503824234 %
Guessed: Female

Navi
Male: 95.465511083603 %
Female: 4.5344825834036 %
Guessed: Male

Navia
Male: 98.757290840149 %
Female: 1.2427097186446 %
Guessed: Male


BUILD SUCCESSFUL in 6s
3 actionable tasks: 1 executed, 2 up-to-date

(JVM 17.0.6 Microsoft OpenJDK)

And finally, here's the output when inferencing from PyTorch directly. Should be used for baseline reference:

D:\omitted>python hitit_v8.py infer N Na Nav Navi Navia
N
Male: 67.00567603111267 %
Female: 32.99432694911957 %
Guessed: Unisex

Na
Male: 28.06881070137024 %
Female: 71.93118929862976 %
Guessed: Unisex

Nav
Male: 82.9777479171753 %
Female: 17.022255063056946 %
Guessed: Male

Navi
Male: 41.17844998836517 %
Female: 58.82155895233154 %
Guessed: Unisex

Navia
Male: 4.214644432067871 %
Female: 95.78534960746765 %
Guessed: Female

(PyTorch 2.1.0 + Python 3.11.4)

Output TL;DR: Output in Python is same as in PyTorch and ONNX. However, using other languages, the result went vastly different and they're not even equal each other.

Urgency

It's a hobby project, so it's not an urgent. However, I'd love to have it resolved as soon as possible.

Platform

Windows

OS Version

10.0.22621.2361

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.16.0

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU

Execution Provider Library Version

-

@github-actions github-actions bot added api:Java issues related to the Java API api:Javascript issues related to the Javascript API platform:windows issues related to the Windows platform labels Oct 7, 2023
@Craigacp
Copy link
Contributor

Craigacp commented Oct 7, 2023

The Java code doesn't produce a binary array, it's writing the position integer into the tensor rather than setting a bit. You want dest[i][j] = (int) ((intval & (1 << j)) > 0); which is the equivalent of passing it through Python's bool function which is what hitit_v8_fastinfer.py does.

You might also want to check the character encoding, I think in Python you're processing it in UTF-8, but Java Strings are UTF-16.

In Javascript I think your indexing into the 1d array is broken, it probably should be floatArray[i * NUMBER_OF_BITS + j], but I think there are other problems in that code too.

@MikuAuahDark
Copy link
Author

MikuAuahDark commented Oct 8, 2023

Hello, thank you. I double checked, applied the change, and indeed they're now provide consistent result.

As for the character encoding, it's not an issue. I use ord() in Python and String.codePointAt() in JavaScript and Java. The latter may not able to handle surrogate pairs correctly. However, for the sake of simplicity of the model, I trained it using ASCII character range (which means NUMBER_OF_BITS can be down to 7).

Sorry for the inconvenience.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
api:Java issues related to the Java API api:Javascript issues related to the Javascript API platform:windows issues related to the Windows platform
Projects
None yet
Development

No branches or pull requests

2 participants