Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ONNX Support for Decision Transformer Model #2038

Open
wants to merge 18 commits into
base: main
Choose a base branch
from

Conversation

ra9hur
Copy link

@ra9hur ra9hur commented Oct 1, 2024

Fixes #2032

@IlyasMoutawwakil ,

As per your suggestions, I have made necessary changes to support decision transformer model to ONNX. I ran the below command:

optimum-cli export onnx -m 'edbeeching/decision-transformer-gym-hopper-medium' --task 'reinforcement-learning' --framework 'pt' onnx/

Attaching the terminal output:
output_log.txt

The model was exported successfully with a warning:

The ONNX export succeeded with the warning: The maximum absolute difference between the output of the reference model and the ONNX exported model is not within the set tolerance 1e-05:
- state_preds: max diff = nan
- action_preds: max diff = nan
- return_preds: max diff = nan
- last_hidden_state: max diff = nan.

since max diff is nan, printed the results of the reference and ONNX models. Have attached the results in the below text file.
Note: Have extracted results from output_log.txt file to a separate text file for clarity.
output.txt

Excluding nan, if we closely observe, the difference seems to be well within set tolerance limits.

Please review and let me know your suggestions.

@ra9hur
Copy link
Author

ra9hur commented Oct 3, 2024

Thanks @IlyasMoutawwakil for the detail review comments, will check and re-submit again by next week.

@ra9hur
Copy link
Author

ra9hur commented Oct 11, 2024

@IlyasMoutawwakil ,
I have made necessary changes as per your suggestions.

optimum-cli command runs successfully.

optimum-cli export onnx -m 'edbeeching/decision-transformer-gym-hopper-medium' --task 'reinforcement-learning' --framework 'pt' onnx/

Here is the terminal output for your reference.
success_output_log.txt

Please let me know, if it is good for approval.

@IlyasMoutawwakil
Copy link
Member

cool addition ! I committed a couple of suggestions, mainly about keeping the naming of dimensions closer to the config attributes.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ra9hur
Copy link
Author

ra9hur commented Oct 11, 2024

Thanks @IlyasMoutawwakil , your inputs were very valuable, it was a great learning experience while working on this PR !!

@ra9hur
Copy link
Author

ra9hur commented Oct 25, 2024

@IlyasMoutawwakil ,

Got to test this again and there were a few issues as in the attached log.

OrderedDict_Issue.txt

Further analyzing, changes had to be done to base.py. Output was sensitive to the order of dict elements.
Please review the changes.

@ra9hur
Copy link
Author

ra9hur commented Nov 6, 2024

@IlyasMoutawwakil ,

Can you please review the changes.

@IlyasMoutawwakil
Copy link
Member

triggering tests and checking today !

@ra9hur
Copy link
Author

ra9hur commented Nov 20, 2024

Sure, thanks @IlyasMoutawwakil

The branch might be out of sync with the main repo. Please let me know if I should re-sync, copy the PR changes and then commit again.

@IlyasMoutawwakil
Copy link
Member

@ra9hur yes please, a rebase or merge would be great ! also a make style for code quality.

@ra9hur
Copy link
Author

ra9hur commented Nov 20, 2024

@IlyasMoutawwakil , I have merged changes in main to this feature branch.

Pardon my ignorance, couldn't quite get what you meant by make style. In one of previous commits, you had made changes for variable names consistency. Is that what you meant ?

Last change that I made was only ordering of outputs in OrderDict in base.py, so nothing major.

Any references so that I can follow ?

@IlyasMoutawwakil
Copy link
Member

IlyasMoutawwakil commented Nov 20, 2024

The export seems to fail due to some outputs mismatch https://github.com/huggingface/optimum/actions/runs/11931260970/job/33256099991?pr=2038#step:5:7057
does it work for you locally ?

Pardon my ignorance, couldn't quite get what you meant by make style

I mean doing pip install -e .[quality] and then make style in terminal, so that ruff and black apply the code quality changes necessary.

@ra9hur
Copy link
Author

ra9hur commented Nov 20, 2024

Thanks for your guidance !!
make style made changes to 3 files, have pushed those changes. Hope, its all good now.

Regarding output mismatch, it is working fine locally.

Validating ONNX model /home/raghu/DL/topics/RL/Godot/onnx/model.onnx...
	-[✓] ONNX model output names match reference model (state_preds, last_hidden_state, return_preds, action_preds)
	-[✓] ONNX model output names match reference model (state_preds, last_hidden_state, return_preds, action_preds)
	- Validating ONNX Model output "state_preds":
	- Validating ONNX Model output "state_preds":
		-[✓] (2, 16, 11) matches (2, 16, 11)
		-[✓] (2, 16, 11) matches (2, 16, 11)
		-[✓] all values close (atol: 1e-05)
		-[✓] all values close (atol: 1e-05)
	- Validating ONNX Model output "action_preds":
	- Validating ONNX Model output "action_preds":
		-[✓] (2, 16, 3) matches (2, 16, 3)
		-[✓] (2, 16, 3) matches (2, 16, 3)
		-[✓] all values close (atol: 1e-05)
		-[✓] all values close (atol: 1e-05)
	- Validating ONNX Model output "return_preds":
	- Validating ONNX Model output "return_preds":
		-[✓] (2, 16, 1) matches (2, 16, 1)
		-[✓] (2, 16, 1) matches (2, 16, 1)
		-[✓] all values close (atol: 1e-05)
		-[✓] all values close (atol: 1e-05)
	- Validating ONNX Model output "last_hidden_state":
	- Validating ONNX Model output "last_hidden_state":
		-[✓] (2, 48, 128) matches (2, 48, 128)
		-[✓] (2, 48, 128) matches (2, 48, 128)
		-[✓] all values close (atol: 1e-05)
		-[✓] all values close (atol: 1e-05)
The ONNX export succeeded and the exported model was saved at: /home/raghu/DL/topics/RL/Godot/onnx

@IlyasMoutawwakil
Copy link
Member

@ra9hur the export does work for reinforcement-learning task but not for feature-extraction, because the outputs defined for reinforcement learning are specific to this architecture, and it still expects the common single feature extraction output "feature-extraction": OrderedDict({"last_hidden_state": {0: "batch_size", 1: "sequence_length"}}),.
I tried changing things directly on your fork, moved the outputs to the model config directly and cleaned up some minor things.
I also ran tests locally:

tests/exporters/onnx/test_exporters_onnx_cli.py::OnnxCLIExportTestCase::test_exporters_cli_pytorch_cpu_136_decision_transformer_feature_extraction_default_edbeeching_decision_transformer_gym_hopper_medium PASSED
tests/exporters/onnx/test_exporters_onnx_cli.py::OnnxCLIExportTestCase::test_exporters_cli_pytorch_cpu_137_decision_transformer_no_task_edbeeching_decision_transformer_gym_hopper_medium PASSED
tests/exporters/onnx/test_exporters_onnx_cli.py::OnnxCLIExportTestCase::test_exporters_cli_pytorch_cpu_138_decision_transformer_reinforcement_learning_default_edbeeching_decision_transformer_gym_hopper_medium PASSED

we get 3/3 tests that are passing, for both supported tasks and no_task which infers it automatically.
tell me if this works for you.

@ra9hur
Copy link
Author

ra9hur commented Nov 20, 2024

@IlyasMoutawwakil , I tried your changes locally and executes without errors. Good to go from my side.
Thanks for helping me to troubleshoot and fix those issues !!

Checking new build errors, none of those seem to be specific to this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ONNX support for decision transformers
3 participants