diff --git a/applications/ColossalChat/examples/training_scripts/train_dpo.py b/applications/ColossalChat/examples/training_scripts/train_dpo.py index 351f59092328..ad81db73aa53 100755 --- a/applications/ColossalChat/examples/training_scripts/train_dpo.py +++ b/applications/ColossalChat/examples/training_scripts/train_dpo.py @@ -44,7 +44,7 @@ def train(args): Default torch ddp plugin without any acceleration, for debugging purpose acceleration, for debugging purpose """ - plugin = TorchDDPPlugin(find_unused_parameters=True if args.grad_checkpoint is False else False) + plugin = TorchDDPPlugin(find_unused_parameters=not args.grad_checkpoint) elif args.plugin == "gemini": plugin = GeminiPlugin( precision=args.mixed_precision, diff --git a/applications/ColossalChat/examples/training_scripts/train_sft.py b/applications/ColossalChat/examples/training_scripts/train_sft.py index 62acad32f66a..e319340c3f60 100755 --- a/applications/ColossalChat/examples/training_scripts/train_sft.py +++ b/applications/ColossalChat/examples/training_scripts/train_sft.py @@ -68,7 +68,7 @@ def train(args): Default torch ddp plugin without any acceleration, for debugging purpose acceleration, for debugging purpose """ - plugin = TorchDDPPlugin(find_unused_parameters=True if args.grad_checkpoint is False else False) + plugin = TorchDDPPlugin(find_unused_parameters=not args.grad_checkpoint) elif args.plugin == "gemini": plugin = GeminiPlugin( precision=args.mixed_precision,