diff --git a/train_dalle.py b/train_dalle.py index 1206474e..aaa8c64d 100644 --- a/train_dalle.py +++ b/train_dalle.py @@ -178,7 +178,7 @@ def cp_path_to_dir(cp_path, tag): assert dalle_path.exists(), 'DALL-E model file does not exist' loaded_obj = torch.load(str(dalle_path), map_location='cpu') - dalle_params, vae_params, weights= loaded_obj['hparams'], loaded_obj['vae_params'], loaded_obj['weights'] + dalle_params, vae_params, weights = loaded_obj['hparams'], loaded_obj['vae_params'], loaded_obj['weights'] opt_state = loaded_obj.get('opt_state') scheduler_state = loaded_obj.get('scheduler_state')