Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Web] Error: Tensor's size(512) does not match data length(1024) #21454

Closed
kabyanil opened this issue Jul 23, 2024 · 2 comments
Closed

[Web] Error: Tensor's size(512) does not match data length(1024) #21454

kabyanil opened this issue Jul 23, 2024 · 2 comments
Labels
platform:web issues related to ONNX Runtime web; typically submitted using template

Comments

@kabyanil
Copy link

Describe the issue

I am trying to run web inference on a transformer model trained on pytorch and exported to onnx. The encoder outputs a tensor of shape [batch_size, sequence_length, embed_dim] which in my case is [1, 32, 512] during inference. The output of the decoder starts with [1, 1, 512] which grows in the second dimension such as [1, 2, 512], [1, 3, 512] until EOS is hit. The first decoder iteration runs perfectly fine. However, in the second iteration, the output of the decoder's embedding layer becomes of size 1024 with dimensions [1, 2, 512] as described in the image below -

Screenshot 2024-07-23 at 10 00 02 AM

When this tensor passes through the positional encoding layer, I get the following error -

Error: Tensor's size(512) does not match data length(1024).
    at new Je (tensor-impl.ts:229:13)
    at main (index.html:118:50)

I am able to run inference using the same onnx modules in python, with identical inference code.

Any help to resolve this issue would be appreciated.

To reproduce

            // decoder initial input
            let tgt_input = new ort.Tensor('int32', [eng_tokenizer.encode('<')[0]], [1, 1])

            let i = 0
            while (true) {
               console.log("inference step: ", i)
               if (tgt_input.dims[1] === max_len) break

               // causal mask
               let tgt_mask = causal_mask(tgt_input.dims[1])

               // decoder embedding
               let tgt_embed_out = await session.tgt_embed.run({ l_x_: tgt_input }).then((res) => res.mul)

               // ERROR OCCURS HERE IN THE SECOND ITERATION
               // decoder positional encoding
               let tgt_pos_out = await session.tgt_pos.run({ input_1: tgt_embed_out }).then((res) => res.input)

               // decoder
               let tgt_decoder_out = await session.tgt_decode.run({
                  'decoder_input': tgt_pos_out,
                  'encoder_output': src_encoder_out,
                  'src_mask': src_mask,
                  'tgt_mask': tgt_mask
               })
                  .then((res) => res.output)

               // decoder projection
               let tgt_decoder_output_last_dim = new ort.Tensor(tgt_decoder_out.data, [1, 512])
               const tgt_decoder_projection_out = await session.tgt_projection.run({ 'l_x_': tgt_decoder_output_last_dim }).then((res) => res.proj_1)
               let tgt_decoder_proj_out_last_dim = new ort.Tensor(tgt_decoder_projection_out.data, [75])

               // get next word token
               const next_word = argMax([...Array.from(tgt_decoder_proj_out_last_dim.data)])

               // append next word token to decoder input
               tgt_input = new ort.Tensor('int32',
                  new Int32Array([...Array.from(tgt_input.data).concat(next_word)]),
                  [1, tgt_input.dims[1] + 1]
               )

               console.log("new tgt_input: ", tgt_input)

               i = i + 1
            }

Urgency

No response

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

ONNX Runtime Web v1.18.0

Execution Provider

'wasm'/'cpu' (WebAssembly CPU)

@kabyanil kabyanil added the platform:web issues related to ONNX Runtime web; typically submitted using template label Jul 23, 2024
@fs-eire
Copy link
Contributor

fs-eire commented Jul 29, 2024

usually this kind of error is model specific. could you share more info (including the model and code of how session is created)

@kabyanil
Copy link
Author

kabyanil commented Aug 2, 2024

Thanks for the heads up, I have fixed the issue. Turns out it was a shape mismatch of the input tensor.

@kabyanil kabyanil closed this as completed Aug 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
platform:web issues related to ONNX Runtime web; typically submitted using template
Projects
None yet
Development

No branches or pull requests

2 participants