diff --git a/examples/cifar10/train_cifar10.py b/examples/cifar10/train_cifar10.py index de45419..921aac8 100644 --- a/examples/cifar10/train_cifar10.py +++ b/examples/cifar10/train_cifar10.py @@ -159,7 +159,7 @@ def train(argv): "optim": optim.state_dict(), "step": step, }, - savedir + f"para_cifar10_weights_step_{step}.pt", + savedir + f"cifar10_weights_step_{step}.pt", )