diff --git a/code/train_mvtec.py b/code/train_mvtec.py index 3998b48..7ee61bd 100644 --- a/code/train_mvtec.py +++ b/code/train_mvtec.py @@ -100,9 +100,8 @@ def main(**args): del batch_sft current_step += 1 # torch.cuda.empty_cache() - current_step += 1 - if iter_every_epoch % 1000 == 0: - agent.save_model(args['save_path'], 0) + # if iter_every_epoch % 1000 == 0: + # agent.save_model(args['save_path'], 0) # save at the end of the training torch.distributed.barrier() agent.save_model(args['save_path'], 0)