You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I Created a simple 2 linear layer architecture model with Relu activation, i converted the model to onnx format with all the essesntial parameters required. When i am trying to infer with onnxruntime using DnnlExecutionProvider for say batch of 8 (8X764) where 764 is the fixed embed dimenison and total sentences to be processed are 25, so uptil for the batch of 8(8X764) each, it is working fine i.e for 24 sentences, when it comes to process the 25th sentence which has a shape of (1, 764) it returns the error stating
[E:onnxruntime:Default, dnnl_subgraph_primitive.cc:561 GetMemoryAndReshape] fc.0.weight, Dims From: 784 1024 , To: 784 784]
[E:onnxruntime:, sequential_executor.cc:516 ExecuteKernel] Non-zero status code returned while running DNNL_15662691735485739911_0 node. Name:'DnnlExecutionProvider_DNNL_15662691735485739911_0_0' Status Message: not a valid reshape, inconsistent dim product]
This is the actual stacktrace :
/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 220, in run
return self._sess.run(output_names, input_feed, run_options)
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running DNNL_15662691735485739911_0 node. Name:'DnnlExecutionProvider_DNNL_15662691735485739911_0_0' Status Message: not a valid reshape, inconsistent dim product
Also i tried running the same inference using CPUExecutionProvider and CUDAExecutionProvider as well, but they are working absolutely fine. with same batch size and same number of sentences.
Possible Approach
It seems like something related to memory allocation in GetMemoryAndReshape function is causing the issue.
To reproduce
Here's the Code for model creation, converting to onnx, and inferencing the model
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
#Creating your Own custom model
class CustomModel(nn.Module):
def __init__(self, input_size, output_size):
super(CustomModel, self).__init__()
self.fc = nn.Sequential(
nn.Linear(input_size,1024),
nn.ReLU(),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Linear(512, output_size)
)
def forward(self, x):
return self.fc(x)
# Set random seed for reproducibility
torch.manual_seed(42)
# Define input and output dimensions
input_size = 28 * 28
output_size = 10
# Instantiate your custom model
model = CustomModel(input_size, output_size)
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss() # Example loss function
optimizer = optim.Adam(model.parameters(), lr=0.001) # Example optimizer
# Load the dataset and create dataloaders
train_dataset = MNIST(root='data/', train=True, download=True, transform=ToTensor())
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
# Train your model
num_epochs = 10 # Example: number of training epochs
for epoch in range(num_epochs):
# Loop over the dataset and perform forward pass, backward pass, and optimization
for batch in train_loader:
images, labels = batch
images = images.view(images.size(0), -1) # Flatten the input images
optimizer.zero_grad() # Clear gradients
outputs = model(images) # Forward pass
loss = criterion(outputs, labels) # Calculate loss
loss.backward() # Backward pass
optimizer.step() # Update weights
# Save the trained model weights to a file
torch.save(model.state_dict(), 'custom_model.pth')
# Loading the weights
model.load_state_dict(torch.load('custom_model.pth'))
model.eval()
input_names = ["input"]
output_names = ["output"]
MODEL_MAX_LENGTH = 28 * 28
# Export the loaded model to ONNX format with dynamic axes
dummy_input = {key: torch.ones(1, MODEL_MAX_LENGTH, dtype=torch.float32) for key in input_names}
dynamic_axes = {'input': {0: 'batch_size'}} # Specify dynamic axes for input tensor
torch.onnx.export(model,
dummy_input["input"],
"custom_model.onnx",
verbose=False,
input_names=input_names,
output_names=output_names,
export_params=True,
dynamic_axes=dynamic_axes
)
#Inference of the Converted Onnx Model
import onnx
import torch
import os
import torch
import onnxruntime as rt
import onnx
import numpy as np
import time
def inference_onnx():
CONVERTED_COMET_MODEL_PATH = os.path.join(os.getcwd() ,'custom_model.onnx')
onnx.checker.check_model(CONVERTED_COMET_MODEL_PATH, full_check=True)
# ort_sess = rt.InferenceSession(CONVERTED_COMET_MODEL_PATH, providers=['CUDAExecutionProvider'])
ort_sess = rt.InferenceSession(CONVERTED_COMET_MODEL_PATH, providers=['DnnlExecutionProvider'])
# ort_sess = rt.InferenceSession(CONVERTED_COMET_MODEL_PATH, providers=['CPUExecutionProvider'])
batch_size=4
shape = (5, 28 * 28)
concatenated_data = {"input": torch.rand(*shape)}
for idx in range(0, shape[0], batch_size):
inp = {"input": np.array(concatenated_data["input"][idx: idx+batch_size])}
start_time = time.time()
print("inp = ", inp, inp["input"].shape)
outputs = ort_sess.run(None, inp)
print("output_sentence = ", outputs, outputs[0].shape)
endtime = time.time()
print("Time Taken : ", round(endtime-start_time, 2))
# outputs = ort_sess.run(["score"], inp)
inference_onnx()
Urgency
within the month end
Platform
Linux
OS Version
Ubuntu 22.04.4 LTS
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
1.18.0
ONNX Runtime API
Python
Architecture
X64
Execution Provider
oneDNN
Execution Provider Library Version
No response
The text was updated successfully, but these errors were encountered:
This issue has been automatically marked as stale due to inactivity and will be closed in 30 days if no further activity occurs. If further support is needed, please provide an update and/or more details.
varunkatiyar819
changed the title
Dnnl Execution Provider GetMemoryAndReshape function issues with Status Message: not a valid reshape, inconsistent dim product
Dnnl Execution Provider GetMemoryAndReshape function issues with Status Message: not a valid reshape, inconsistent dim product.
Jun 19, 2024
Describe the issue
I Created a simple 2 linear layer architecture model with Relu activation, i converted the model to onnx format with all the essesntial parameters required. When i am trying to infer with onnxruntime using DnnlExecutionProvider for say batch of 8 (8X764) where 764 is the fixed embed dimenison and total sentences to be processed are 25, so uptil for the batch of 8(8X764) each, it is working fine i.e for 24 sentences, when it comes to process the 25th sentence which has a shape of (1, 764) it returns the error stating
[E:onnxruntime:Default, dnnl_subgraph_primitive.cc:561 GetMemoryAndReshape] fc.0.weight, Dims From: 784 1024 , To: 784 784]
[E:onnxruntime:, sequential_executor.cc:516 ExecuteKernel] Non-zero status code returned while running DNNL_15662691735485739911_0 node. Name:'DnnlExecutionProvider_DNNL_15662691735485739911_0_0' Status Message: not a valid reshape, inconsistent dim product]
This is the actual stacktrace :
/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 220, in run
return self._sess.run(output_names, input_feed, run_options)
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running DNNL_15662691735485739911_0 node. Name:'DnnlExecutionProvider_DNNL_15662691735485739911_0_0' Status Message: not a valid reshape, inconsistent dim product
Also i tried running the same inference using CPUExecutionProvider and CUDAExecutionProvider as well, but they are working absolutely fine. with same batch size and same number of sentences.
Possible Approach
It seems like something related to memory allocation in GetMemoryAndReshape function is causing the issue.
To reproduce
Here's the Code for model creation, converting to onnx, and inferencing the model
Urgency
within the month end
Platform
Linux
OS Version
Ubuntu 22.04.4 LTS
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
1.18.0
ONNX Runtime API
Python
Architecture
X64
Execution Provider
oneDNN
Execution Provider Library Version
No response
The text was updated successfully, but these errors were encountered: