Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance Issue: Onnx Runtime Train Loss Reduction is very Less in every epoch and gets saturated at after 10 epochs #19185

Open
Leaner23 opened this issue Jan 17, 2024 · 5 comments
Labels
stale issues that have not been addressed in a while; categorized by a bot

Comments

@Leaner23
Copy link

Leaner23 commented Jan 17, 2024

Describe the issue

I am trying to train a onnx model on device. The loss is reducing very less in each epoch. I tried with different batch size but the problem remains same.
Although when i tried to train the same model on keras , I was able to see the training loss was reducing in every epoch.
Example of training loss on Onnx Runtime:-
Epoch 1 Loss [[26.37851]]
Epoch 2 Loss [[24.919254]]
Epoch 3 Loss [[24.84161]]
Epoch 4 Loss [[24.851688]]
Epoch 5 Loss [[24.845762]]
Epoch 6 Loss [[24.842438]]
Epoch 7 Loss [[24.838167]]
Epoch 8 Loss [[24.836271]]
Epoch 9 Loss [[24.83929]]
Epoch 10 Loss [[24.839489]]
Epoch 11 Loss [[24.850527]]
Epoch 12 Loss [[24.865587]]
Epoch 13 Loss [[24.867554]]
Epoch 14 Loss [[24.873014]]
Epoch 15 Loss [[24.880104]]
Epoch 16 Loss [[24.879396]]
Epoch 17 Loss [[24.882072]]
Epoch 18 Loss [[24.835163]]
Epoch 19 Loss [[24.87151]]
Epoch 20 Loss [[24.835596]]

To reproduce

Code for generating the Training Artifacts:-

import onnx
from onnxruntime.training import artifacts
model_name =r'transformer_Jan_16_3.onnx'
# Load the onnx model.
onnx_model  = onnx.load(f"{model_name}")


requires_grad ='''model_7/transformer_encoder_47/layer_normalization_1/batchnorm/mul/ReadVariableOp:0
model_7/transformer_encoder_47/layer_normalization_1/batchnorm/ReadVariableOp:0
model_7/output/MatMul/ReadVariableOp:0
model_7/output/BiasAdd/ReadVariableOp:0
model_7/dense_9/MatMul/ReadVariableOp:0
model_7/dense_9/BiasAdd/ReadVariableOp:0
model_7/dense_8/MatMul/ReadVariableOp:0
model_7/dense_8/BiasAdd/ReadVariableOp:0
model_7/dense_11/MatMul/ReadVariableOp:0
model_7/dense_11/BiasAdd/ReadVariableOp:0
model_7/dense_10/MatMul/ReadVariableOp:0
model_7/dense_10/BiasAdd/ReadVariableOp:0'''.split("\n")
frozen_params = [
   param.name
   for param in onnx_model.graph.initializer if param.name not in requires_grad
    
   
]
# Generate the training artifacts.
artifacts.generate_artifacts(
   onnx_model,
   requires_grad=requires_grad,
   frozen_params=frozen_params,
   loss=artifacts.LossType.BCEWithLogitsLoss,
   optimizer=artifacts.OptimType.AdamW,
   artifact_directory="training_artifacts_16jan"
)

Code for training the Model and inferencing:

import pandas as pd
import numpy as np
from onnxruntime import InferenceSession
from onnxruntime.capi import _pybind_state as C
df = pd.read_csv('imdb_128.csv', dtype=np.float32)
label = df['label'].values
df.drop(columns=['label'], inplace=True)
data = df.values
print(f'Data Shape: {data.shape}, Label Shape: {label.shape}, Data Type: {data.dtype}, Label Type: {label.dtype}')
import onnxruntime.training.api as orttraining

# Instantiate the training session by defining the checkpoint state, module, and optimizer
# The checkpoint state contains the state of the model parameters at any given time.
checkpoint_state = orttraining.CheckpointState.load_checkpoint(
    'training_artifacts_17jan/checkpoint')
model = orttraining.Module(
    r"training_artifacts_17jan/training_model.onnx",
    checkpoint_state,
    r"training_artifacts_17jan/eval_model.onnx",
)
optimizer = orttraining.Optimizer(
    "training_artifacts_17jan/optimizer_model.onnx", model
)

label =label.reshape((len(label),1))
num_epoch = 20
for epoch in range(num_epoch):
    model.train()
    loss = 0
    batchsize = 128
    for i in range(0,data.shape[0],batchsize):
        
        

        # ort training api - training model execution outputs the training loss and the parameter gradients
        loss += model(data[i:i+batchsize].astype(np.float32), label[i:i+batchsize].astype(np.float32))
        # ort training api - update the model parameters by taking a step in the direction of the gradients
        optimizer.step()
        # ort training api - reset the gradients to zero so that new gradients can be computed in the next run
        model.lazy_reset_grad()

    print(f"Epoch {epoch+1} Loss {loss}")

#model.export_model_for_inferencing("inference_artifacts/inference.onnx", ["output"])
session = InferenceSession("inference_artifacts/inference.onnx", providers=C.get_available_providers())
for i in range(5,10):
    result = session.run(["output"], {'encoder_input':data[i].reshape(1,128) })
    print("result is " ,1 if result[0][0]>0.5 else 0 ,'label is ' ,label[i])

Urgency

it is very urgent

Platform

Linux

OS Version

ubuntu 20.2

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

1.16

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU

Execution Provider Library Version

No response

Model File

No response

Is this a quantized model?

No

@xadupre
Copy link
Member

xadupre commented Jan 17, 2024

I assume you trained with the same list of weights with onnxruntime and keras and the same data. + @baijumeswani

@baijumeswani
Copy link
Contributor

Leaner23 Could you please share your model and data to reproduce the behavior you're seeing?

@Leaner23
Copy link
Author

Leaner23 commented Jan 18, 2024

Here, I am sharing the Code for the model generation:-

import tf2onnx
import onnx
import tensorflow as tf
from tensorflow import keras
import keras_nlp

NUM_LAYERS = 2
EMBD_DIM = 128
FF_DIM = 128
NUM_HEADS = 8
DROPOUT = 0.1
NORM_EPSILON = 1e-9

encoder_input = keras.Input(shape=(MAX_SEQ_LENGTH,), dtype=tf.float32, name='encoder_input')
encoder_embedding_layer = keras_nlp.layers.TokenAndPositionEmbedding(vocabulary_size=VOCAB_SIZE, sequence_length=MAX_SEQ_LENGTH, embedding_dim=EMBD_DIM, mask_zero=True)
encoder_output = encoder_embedding_layer(encoder_input)
encoder_output = keras.layers.LayerNormalization(epsilon=NORM_EPSILON)(encoder_output)
encoder_output = keras.layers.Dropout(rate=DROPOUT)(encoder_output)
for i in range(NUM_LAYERS):
    encoder_output = keras_nlp.layers.TransformerEncoder(
        intermediate_dim=FF_DIM,
        num_heads=NUM_HEADS,
        activation=keras.activations.gelu
    )(encoder_output)
outputs = keras.layers.GlobalAveragePooling1D()(encoder_output)
outputs = keras.layers.Dense(128, activation="relu")(outputs)
outputs = keras.layers.Dense(1, activation='sigmoid', name='output')(outputs)

transformer = keras.Model(inputs = encoder_input, outputs = outputs)

learning_rate = 3e-5
optimizer = tf.keras.optimizers.experimental.AdamW(learning_rate=learning_rate)
loss = tf.keras.losses.BinaryCrossentropy()
metrics = tf.keras.metrics.BinaryAccuracy()

transformer.compile(loss=loss, metrics=metrics, optimizer=optimizer)


onnx_model, _ = tf2onnx.convert.from_keras(transformer)
onnx.save(onnx_model, 'transformer_Jan_16_3.onnx')

you can use these datapoints:-

data = [[1.000e+00 7.780e+02 1.280e+02 7.400e+01 1.200e+01 6.300e+02 1.630e+02
  1.500e+01 4.000e+00 1.766e+03 7.982e+03 1.051e+03 2.000e+00 3.200e+01
  8.500e+01 1.560e+02 4.500e+01 4.000e+01 1.480e+02 1.390e+02 1.210e+02
  6.640e+02 6.650e+02 1.000e+01 1.000e+01 1.361e+03 1.730e+02 4.000e+00
  7.490e+02 2.000e+00 1.600e+01 3.804e+03 8.000e+00 4.000e+00 2.260e+02
  6.500e+01 1.200e+01 4.300e+01 1.270e+02 2.400e+01 2.000e+00 1.000e+01
  1.000e+01 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
  0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
  0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
  0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
  0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
  0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
  0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
  0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
  0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
  0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
  0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
  0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
  0.000e+00 0.000e+00]
 [1.000e+00 6.740e+03 3.650e+02 1.234e+03 5.000e+00 1.156e+03 3.540e+02
  1.100e+01 1.400e+01 5.327e+03 6.638e+03 7.000e+00 1.016e+03 2.000e+00
  5.940e+03 3.560e+02 4.400e+01 4.000e+00 1.349e+03 5.000e+02 7.460e+02
  5.000e+00 2.000e+02 4.000e+00 4.132e+03 1.100e+01 2.000e+00 9.363e+03
  1.117e+03 1.831e+03 7.485e+03 5.000e+00 4.831e+03 2.600e+01 6.000e+00
  2.000e+00 4.183e+03 1.700e+01 3.690e+02 3.700e+01 2.150e+02 1.345e+03
  1.430e+02 2.000e+00 5.000e+00 1.838e+03 8.000e+00 1.974e+03 1.500e+01
  3.600e+01 1.190e+02 2.570e+02 8.500e+01 5.200e+01 4.860e+02 9.000e+00
  6.000e+00 2.000e+00 8.564e+03 6.300e+01 2.710e+02 6.000e+00 1.960e+02
  9.600e+01 9.490e+02 4.121e+03 4.000e+00 2.000e+00 7.000e+00 4.000e+00
  2.212e+03 2.436e+03 8.190e+02 6.300e+01 4.700e+01 7.700e+01 7.175e+03
  1.800e+02 6.000e+00 2.270e+02 1.100e+01 9.400e+01 2.494e+03 2.000e+00
  1.300e+01 4.230e+02 4.000e+00 1.680e+02 7.000e+00 4.000e+00 2.200e+01
  5.000e+00 8.900e+01 6.650e+02 7.100e+01 2.700e+02 5.600e+01 5.000e+00
  1.300e+01 1.970e+02 1.200e+01 1.610e+02 5.390e+03 9.900e+01 7.600e+01
  2.300e+01 2.000e+00 7.000e+00 4.190e+02 6.650e+02 4.000e+01 9.100e+01
  8.500e+01 1.080e+02 7.000e+00 4.000e+00 2.084e+03 5.000e+00 4.773e+03
  8.100e+01 5.500e+01 5.200e+01 1.901e+03 0.000e+00 0.000e+00 0.000e+00
  0.000e+00 0.000e+00]
 [1.000e+00 5.400e+01 1.300e+01 1.610e+03 1.400e+01 2.000e+01 1.300e+01
  6.900e+01 5.500e+01 3.640e+02 1.398e+03 2.100e+01 5.400e+01 1.300e+01
  2.190e+02 1.200e+01 1.300e+01 1.706e+03 1.500e+01 4.000e+00 2.000e+01
  1.600e+01 3.290e+02 6.000e+00 1.760e+02 3.290e+02 7.400e+01 5.100e+01
  1.300e+01 8.730e+02 4.000e+00 1.560e+02 7.100e+01 7.800e+01 4.000e+00
  7.412e+03 3.220e+02 1.600e+01 3.100e+01 7.000e+00 4.000e+00 2.490e+02
  4.000e+00 6.500e+01 1.600e+01 3.800e+01 3.790e+02 1.200e+01 1.000e+02
  1.570e+02 1.800e+01 6.000e+00 9.100e+02 2.000e+01 5.490e+02 1.800e+01
  4.000e+00 1.496e+03 2.100e+01 1.400e+01 3.100e+01 9.000e+00 2.400e+01
  6.000e+00 2.120e+02 1.200e+01 9.000e+00 6.000e+00 1.322e+03 9.910e+02
  7.000e+00 3.002e+03 4.000e+00 4.250e+02 9.000e+00 7.300e+01 2.218e+03
  5.490e+02 1.800e+01 3.100e+01 1.550e+02 3.600e+01 1.000e+02 7.630e+02
  3.790e+02 2.000e+01 1.030e+02 3.510e+02 5.308e+03 1.300e+01 2.020e+02
  1.200e+01 2.241e+03 5.000e+00 6.000e+00 3.200e+02 4.600e+01 7.000e+00
  4.570e+02 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
  0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
  0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
  0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
  0.000e+00 0.000e+00]
 [1.000e+00 1.300e+01 1.190e+02 9.540e+02 1.890e+02 1.554e+03 1.300e+01
  9.200e+01 4.590e+02 4.800e+01 4.000e+00 1.160e+02 9.000e+00 1.492e+03
  2.291e+03 4.200e+01 7.260e+02 4.000e+00 1.939e+03 1.680e+02 2.031e+03
  1.300e+01 4.230e+02 1.400e+01 2.000e+01 5.490e+02 1.800e+01 4.000e+00
  2.000e+00 5.470e+02 3.200e+01 4.000e+00 9.600e+01 3.900e+01 4.000e+00
  4.540e+02 7.000e+00 4.000e+00 2.200e+01 8.000e+00 4.000e+00 5.500e+01
  1.300e+02 1.680e+02 1.300e+01 9.200e+01 3.590e+02 6.000e+00 1.580e+02
  1.511e+03 2.000e+00 4.200e+01 6.000e+00 1.913e+03 1.900e+01 1.940e+02
  4.455e+03 4.121e+03 6.000e+00 1.140e+02 8.000e+00 7.200e+01 2.100e+01
  4.650e+02 9.667e+03 3.040e+02 4.000e+00 5.100e+01 9.000e+00 1.400e+01
  2.000e+01 4.400e+01 1.550e+02 8.000e+00 6.000e+00 2.260e+02 1.620e+02
  6.160e+02 6.510e+02 5.100e+01 9.000e+00 1.400e+01 2.000e+01 4.400e+01
  1.000e+01 1.000e+01 1.400e+01 2.180e+02 4.843e+03 6.290e+02 4.200e+01
  3.017e+03 2.100e+01 4.800e+01 2.500e+01 2.800e+01 3.500e+01 5.340e+02
  5.000e+00 6.000e+00 3.200e+02 8.000e+00 5.160e+02 5.000e+00 4.200e+01
  2.500e+01 1.810e+02 8.000e+00 1.300e+02 5.600e+01 5.470e+02 3.571e+03
  5.000e+00 1.471e+03 8.510e+02 1.400e+01 2.286e+03 0.000e+00 0.000e+00
  0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
  0.000e+00 0.000e+00]
 [1.000e+00 5.030e+02 2.000e+01 3.300e+01 1.180e+02 4.810e+02 3.020e+02
  2.600e+01 1.840e+02 5.200e+01 8.350e+02 1.120e+03 5.420e+02 2.603e+03
  1.300e+01 1.408e+03 4.500e+01 6.000e+00 2.364e+03 1.000e+01 1.000e+01
  2.500e+01 2.760e+02 4.900e+01 2.000e+00 3.239e+03 1.100e+01 1.290e+02
  1.642e+03 8.000e+00 6.070e+02 2.500e+01 3.900e+01 8.520e+02 5.226e+03
  2.000e+00 2.500e+01 6.050e+02 8.520e+02 3.925e+03 5.000e+00 2.777e+03
  4.600e+01 8.520e+02 2.000e+00 2.500e+01 2.146e+03 3.000e+01 6.080e+02
  4.044e+03 1.000e+01 1.000e+01 2.500e+01 7.890e+02 3.400e+01 4.000e+00
  2.000e+00 5.400e+01 1.544e+03 2.173e+03 2.018e+03 2.500e+01 7.900e+01
  7.200e+01 2.020e+02 7.200e+01 6.000e+00 9.680e+02 2.000e+00 1.000e+01
  1.000e+01 2.872e+03 7.500e+01 3.590e+02 2.872e+03 6.214e+03 4.000e+00
  2.000e+00 3.200e+01 7.500e+01 2.800e+01 9.000e+00 1.400e+01 2.000e+00
  1.000e+01 1.000e+01 8.840e+02 1.866e+03 9.000e+00 4.000e+00 4.017e+03
  2.809e+03 1.000e+01 1.000e+01 7.190e+02 2.000e+00 7.000e+01 2.885e+03
  4.000e+00 2.552e+03 2.000e+00 4.430e+03 1.750e+02 6.640e+03 1.100e+01
  4.000e+00 2.000e+00 5.430e+02 1.609e+03 0.000e+00 0.000e+00 0.000e+00
  0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
  0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
  0.000e+00 0.000e+00]]

label : [0., 1., 0., 0., 0.]

Copy link
Contributor

This issue has been automatically marked as stale due to inactivity and will be closed in 30 days if no further activity occurs. If further support is needed, please provide an update and/or more details.

@github-actions github-actions bot added the stale issues that have not been addressed in a while; categorized by a bot label Feb 17, 2024
@Whadup
Copy link

Whadup commented Nov 29, 2024

I have noticed a similar behavior. However, when I also return the predictions and compute the loss in numpy, the numpy loss decreases. Very odd...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale issues that have not been addressed in a while; categorized by a bot
Projects
None yet
Development

No branches or pull requests

4 participants