diff --git a/examples/images/cifar10/compute_fid.py b/examples/images/cifar10/compute_fid.py index 52ffbb4..2f79554 100644 --- a/examples/images/cifar10/compute_fid.py +++ b/examples/images/cifar10/compute_fid.py @@ -27,6 +27,8 @@ flags.DEFINE_integer("step", 400000, help="training steps") flags.DEFINE_integer("num_gen", 50000, help="number of samples to generate") flags.DEFINE_float("tol", 1e-5, help="Integrator tolerance (absolute and relative)") +flags.DEFINE_float("batch_size_fid", 1024, help="Batch size to compute FID") + FLAGS(sys.argv) @@ -70,7 +72,7 @@ def gen_1_img(unused_latent): with torch.no_grad(): - x = torch.randn(500, 3, 32, 32, device=device) + x = torch.randn(FLAGS.batch_size_fid, 3, 32, 32, device=device) if FLAGS.integration_method == "euler": print("Use method: ", FLAGS.integration_method) t_span = torch.linspace(0, 1, FLAGS.integration_steps + 1, device=device) @@ -90,7 +92,7 @@ def gen_1_img(unused_latent): score = fid.compute_fid( gen=gen_1_img, dataset_name="cifar10", - batch_size=500, + batch_size=FLAGS.batch_size_fid, dataset_res=32, num_gen=FLAGS.num_gen, dataset_split="train",