Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Training] On Device Training is not working #18168

Open
IzanCatalan opened this issue Oct 30, 2023 · 5 comments
Open

[Training] On Device Training is not working #18168

IzanCatalan opened this issue Oct 30, 2023 · 5 comments
Assignees
Labels
ep:CUDA issues related to the CUDA execution provider stale issues that have not been addressed in a while; categorized by a bot training issues related to ONNX Runtime training; typically submitted using template

Comments

@IzanCatalan
Copy link

IzanCatalan commented Oct 30, 2023

Describe the issue

I am re-training some onnx models from ONNX Model Zoo Repo, especially Resnet50. Previously, I created the artifacts according to onnx-runtime-training-examples Repo. In my case I create them with the following code:

frozen_params = []
requires_grad = []
for init in onnx_model.graph.initializer:
    if init.name.endswith("running_mean") or init.name.endswith("running_var"):
        frozen_params.append(init.name)
    elif init.name not in frozen_params:
        requires_grad.append(init.name)

print(len(requires_grad), len(frozen_params))
print(frozen_params)
# Generate the training artifacts.
artifacts.generate_artifacts(
   onnx_model,
   requires_grad=requires_grad,
   frozen_params=frozen_params,
   loss=artifacts.LossType.CrossEntropyLoss,
   optimizer=artifacts.OptimType.AdamW,
   artifact_directory=sys.argv[2]
)

I follow the same structure train-test defined in the onnx-runtime-training-examples Repo code. In my case, I load the Imagenet train and validate dataset:

imagenet_data_train = datasets.ImageNet('/home/Imagenet', split="train", transform=transform_test)
imagenet_data_val = datasets.ImageNet('/home/Imagenet', split="val", transform=transform_test)
train_loader = torch.utils.data.DataLoader(imagenet_data_train,
                                          batch_size=8,
                                          shuffle=True,
                                          num_workers=4)

val_loader = torch.utils.data.DataLoader(imagenet_data_val,
                                          batch_size=8,
                                          shuffle=True,
                                          num_workers=4)

However, every time I try to train, the accuracy I get is much lower than the accuracy obtained in ONNX Model Zoo Repo. For example, after 20 epochs, the accuracy reached is 0.53 compared to the original accuracy, which is 0.75. I don't insert any changes or modify the weights in the process, but it surprises me the fact that if I re-train a model which previously obtained 0.75 accuracy, I can barely reach 0.53. Moreover, the first epoch only gives 0.31 accuracy.

I would like to know if this is normal behaviour, and if it is so, I would like to know how I can re-train some pre-trained and pre-tested onnx models without losing a huge amount of accuracy.

To reproduce

I am running onnxruntime build from source for cuda 11.2, GCC 9.5, cmake 3.27 and python 3.8 with ubuntu 20.04.

Urgency

No response

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

onnxruntime-training 1.17.0+cu112

PyTorch Version

None

Execution Provider

CUDA

Execution Provider Library Version

Cuda 11.2

@IzanCatalan IzanCatalan added the training issues related to ONNX Runtime training; typically submitted using template label Oct 30, 2023
@github-actions github-actions bot added the ep:CUDA issues related to the CUDA execution provider label Oct 30, 2023
@askhade
Copy link
Contributor

askhade commented Oct 30, 2023

@IzanCatalan: Can you also paste your train recipe? What are your timelines - How urgent is this?

@IzanCatalan
Copy link
Author

@askhade Are you asking for my training code? It is the following and I would like to find a solution as soon as possible within this week:

----------------------------------------------------------------------------------------------------------------------------------
# ----------------------------------------Trainining methods-------------------------------------------------------------------------
# -----------------------------------------------------------------------------------------------------------------------------------

# Util function to convert logits to predictions.
def get_pred(logits):
    return np.argmax(logits, axis=1)

# Training Loop :
def train(epoch):
    print("dentro train")
    model.train()
    print("setup train")
    losses = []
    p = 0
    for i, (data, target) in enumerate(train_loader):
        forward_inputs = [data.reshape(len(data),3,224,224).numpy().astype(np.float32),target.numpy().astype(np.int64)]
        train_loss = model(*forward_inputs)
        optimizer.step()
        model.lazy_reset_grad()
        losses.append(train_loss)
        p = i
    
    print("image", p)
    print(f'Epoch: {epoch+1},Train Loss: {sum(losses)/len(losses):.4f}')


# Test Loop :
def test(epoch):
    print("dentro train")
    model.eval()
    print("setup eval")
    losses = []
    metric = evaluate.load('accuracy')

    for _, (data, target) in enumerate(val_loader):
        forward_inputs = [data.reshape(len(data),3,224,224).numpy().astype(np.float32),target.numpy().astype(np.int64)]
        test_loss, logits = model(*forward_inputs)
        metric.add_batch(references=target, predictions=get_pred(logits))
        losses.append(test_loss)

    metrics = metric.compute()
    print(f'Epoch: {epoch+1}, Test Loss: {sum(losses)/len(losses):.4f}, Accuracy : {metrics["accuracy"]:.2f}')
    return metrics["accuracy"]


# --------------------------------------MAIN------------------------------------------------------
# pdb.set_trace()
print("STARTING")

# Define image transforms
normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
transform_test = transforms.Compose([
   transforms.Resize(256),
   transforms.CenterCrop(224),
   transforms.ToTensor(),
   normalize
])


print("RUNNING")
imagenet_data_train = datasets.ImageNet('/mnt/beegfs/gap/izcagal/docker/Imagenet', split="train", transform=transform_test)
imagenet_data_val = datasets.ImageNet('/mnt/beegfs/gap/izcagal/docker/Imagenet', split="val", transform=transform_test)
train_loader = torch.utils.data.DataLoader(imagenet_data_train,
                                          batch_size=8,
                                          shuffle=True,
                                          num_workers=4)

val_loader = torch.utils.data.DataLoader(imagenet_data_val,
                                          batch_size=8,
                                          shuffle=True,
                                          num_workers=4)


print("DATASET LOADED")
# Instantiate the training session by defining the checkpoint state, module, and optimizer
# The checkpoint state contains the state of the model parameters at any given time.
checkpoint_state = orttraining.CheckpointState.load_checkpoint(
    "docker/training_artifacts/checkpoint")

model = orttraining.Module(
    "docker/training_artifacts/training_model.onnx",
    checkpoint_state,
    "docker/training_artifacts/eval_model2.onnx",
    "cuda"
)

optimizer = orttraining.Optimizer(
    "docker/training_artifacts/optimizer_model.onnx", model
)
    # --------------Training Phase----------------------
while (res < 0.69):
    train(epoch)
    res = test(epoch)
    print(res)
    epoch += 1

model.export_model_for_inferencing("inference.onnx", ["resnetv17_dense0_fwd"])

@IzanCatalan
Copy link
Author

@askhade, I have an update on this issue. I tried to train with a batch size of 256 instead of 8 images, and suddenly, the first epoch gave an accuracy of 0.54 instead of 0.31 (with batch size = 8). So, am I right when supposing that the higher the batch size, the higher the accuracy per epoch?

However, I don't understand yet how it is possible when re-training a pre-trained model with an accuracy higher than 0.75 (a Resnet Model downloaded from the Official Onnx Model Zoo Repo), the first epoch degrades this accuracy to 0.54. So, my question is still the same: is this behaviour normal for On Device Training with OnnxRuntime Training? And, if it is so, how I can change it to perform a more optimal training which can give me a higher accuracy per epoch?

Copy link
Contributor

This issue has been automatically marked as stale due to inactivity and will be closed in 30 days if no further activity occurs. If further support is needed, please provide an update and/or more details.

@github-actions github-actions bot added the stale issues that have not been addressed in a while; categorized by a bot label Dec 23, 2023
@carzh
Copy link
Contributor

carzh commented Jan 19, 2024

@IzanCatalan hello, do you have an updated link to the ONNX file you're generating artifacts from, or the name of which resnet50 model you're using?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:CUDA issues related to the CUDA execution provider stale issues that have not been addressed in a while; categorized by a bot training issues related to ONNX Runtime training; typically submitted using template
Projects
None yet
Development

No branches or pull requests

3 participants