Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix convergence for dolly+stage3 training (microsoft#17685)
### Fix convergence for dolly+stage3 training In [ZeROOffloadSubscriber](https://github.com/microsoft/onnxruntime/blob/216214b7d302cb504d1e5a647f65b6fe49c22dbb/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py#L359C7-L359C28), we defined some PythonOp, taking input and returning it inplace, for example: https://github.com/microsoft/onnxruntime/blob/216214b7d302cb504d1e5a647f65b6fe49c22dbb/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py#L223C20-L223C20. While it is possible, when ORT runs such a PythonOp, once it completes, it will release the input OrtValue, triggered the data erasing or overridden. But the PythonOp's returned value OrtValue are still pointing to that address, reading or writting on that may introduce a wrong result or even undefined behaviors. ``` /bert_ort/pengwa/py38/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_custom_autograd_function_runner.py:28: UserWarning: .rank-0: onnxruntime.training.utils.hooks._zero_offload_subscriber.ORTZeROOffloadPreForwardFunction->Backward: ONNX Op attribute 'tensor_reuse_map' doesn't indicate 8-th output is reusing any input, but detected inplace_map indicates it is reusing some input index. A clone will be done before returning to ORT, to align with ORT's NO Buffer reuse plan. Please update inplace_map explicitly to avoid such a copy. warnings.warn(f".rank-{get_rank()}: {message}") 0%|▏ | 1/1000 [00:04<1:15:08, 4.51s/it][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,023 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 14.1406, 'learning_rate': 0, 'epoch': 0.0} 0%|▏ | 1/1000 [00:04<1:15:08, 4.51s/it]Invalidate trace cache @ step 5: expected module 6, but got module 7 0%|▍ | 2/1000 [00:04<31:53, 1.92s/it][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,124 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0} 0%|▋ | 3/1000 [00:04<18:05, 1.09s/it][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,227 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0} 0%|▋ | 3/1000 [00:04<18:05, 1.09s/it][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,326 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0} 0%|█▏ | 5/1000 [00:04<08:44, 1.90it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,419 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0} 0%|█▏ | 5/1000 [00:04<08:44, 1.90it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,505 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0} 1%|█▋ | 7/1000 [00:05<05:28, 3.02it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,597 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0} 1%|█▋ | 7/1000 [00:05<05:28, 3.02it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,690 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0} 1%|██▏ | 9/1000 [00:05<03:57, 4.17it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,791 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0} 1%|██▏ | 9/1000 [00:05<03:57, 4.17it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,889 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0} 1%|██▋ | 11/1000 [00:05<03:06, 5.32it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,981 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0} 1%|██▋ | 11/1000 [00:05<03:06, 5.32it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:45,073 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.01} 1%|███▏ | 13/1000 [00:05<02:33, 6.42it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:45,166 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.01} 1%|███▏ | 13/1000 [00:05<02:33, 6.42it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:45,256 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.01} 2%|███▌ | 15/1000 [00:05<02:12, 7.43it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:45,348 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.01} 2%|███▌ | 15/1000 [00:05<02:12, 7.43it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:45,439 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.01} 2%|████ | 17/1000 [00:06<01:59, 8.22it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:45,535 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.01} 2%|████ | 17/1000 [00:06<01:59, 8.22it/s]Traceback (most recent call last): File "examples/onnxruntime/training/language-modeling/run_clm.py", line 600, in <module> main() File "examples/onnxruntime/training/language-modeling/run_clm.py", line 548, in main train_result = trainer.train(resume_from_checkpoint=checkpoint) File "/bert_ort/pengwa/optimum/optimum/onnxruntime/trainer.py", line 457, in train return inner_training_loop( File "/bert_ort/pengwa/optimum/optimum/onnxruntime/trainer.py", line 781, in _inner_training_loop self.deepspeed.step() File "/bert_ort/pengwa/deepspeed/deepspeed/runtime/engine.py", line 2084, in step self._take_model_step(lr_kwargs) File "/bert_ort/pengwa/deepspeed/deepspeed/runtime/engine.py", line 1990, in _take_model_step self.optimizer.step() File "/bert_ort/pengwa/deepspeed/deepspeed/utils/nvtx.py", line 15, in wrapped_fn ret_val = func(*args, **kwargs) File "/bert_ort/pengwa/deepspeed/deepspeed/runtime/zero/stage3.py", line 1854, in step if self._overflow_check_and_loss_scale_update(): File "/bert_ort/pengwa/deepspeed/deepspeed/utils/nvtx.py", line 15, in wrapped_fn ret_val = func(*args, **kwargs) File "/bert_ort/pengwa/deepspeed/deepspeed/runtime/zero/stage3.py", line 1788, in _overflow_check_and_loss_scale_update self._update_scale(self.overflow) File "/bert_ort/pengwa/deepspeed/deepspeed/runtime/zero/stage3.py", line 2132, in _update_scale self.loss_scaler.update_scale(has_overflow) File "/bert_ort/pengwa/deepspeed/deepspeed/runtime/fp16/loss_scaler.py", line 175, in update_scale raise Exception( Exception: Current loss scale already at minimum - cannot decrease scale anymore. Exiting run. 2%|████ | 17/1000 [00:06<06:07, 2.67it/s] [2023-09-25 08:30:51,075] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 1065120) of binary: /bert_ort/pengwa/py38/bin/python Traceback (most recent call last): File "/bert_ort/pengwa/py38/bin/torchrun", line 8, in <module> sys.exit(main()) File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper return f(*args, **kwargs) File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/torch/distributed/run.py", line 806, in main run(args) File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/torch/distributed/run.py", line 797, in run elastic_launch( File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 134, in __call__ return launch_agent(self._config, self._entrypoint, list(args)) File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent raise ChildFailedError( torch.distributed.elastic.multiprocessing.errors.ChildFailedError: ============================================================ examples/onnxruntime/training/language-modeling/run_clm.py FAILED ------------------------------------------------------------ Failures: <NO_OTHER_FAILURES> ------------------------------------------------------------ Root Cause (first observed failure): [0]: time : 2023-09-25_08:30:51 host : orttrainingdev10.internal.cloudapp.net rank : 0 (local_rank: 0) exitcode : 1 (pid: 1065120) error_file: <N/A> traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html ============================================================ (/bert_ort/pengwa/py38) [email protected]@orttrainingdev10:/bert_ort/pengwa/optim ``` ## The Fix For those output that are reusing input, but ORT is not aware of, we detected on the fly (the first iteration, by checking the output tensor addresses with input tensor addresses) , then do implicit copy before set it as PythonOp's output tensors. With this fix: (left: PyTorch, right: ORT) ![image](https://github.com/microsoft/onnxruntime/assets/10530022/0d72f431-2abd-4e52-af99-19974b85edde)
- Loading branch information