Skip to content

Commit

Permalink
added explicit batch size FID (#110)
Browse files Browse the repository at this point in the history
  • Loading branch information
QB3 authored Mar 2, 2024
1 parent fa13e7c commit bcc5082
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions examples/images/cifar10/compute_fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
flags.DEFINE_integer("step", 400000, help="training steps")
flags.DEFINE_integer("num_gen", 50000, help="number of samples to generate")
flags.DEFINE_float("tol", 1e-5, help="Integrator tolerance (absolute and relative)")
flags.DEFINE_float("batch_size_fid", 1024, help="Batch size to compute FID")

FLAGS(sys.argv)


Expand Down Expand Up @@ -70,7 +72,7 @@

def gen_1_img(unused_latent):
with torch.no_grad():
x = torch.randn(500, 3, 32, 32, device=device)
x = torch.randn(FLAGS.batch_size_fid, 3, 32, 32, device=device)
if FLAGS.integration_method == "euler":
print("Use method: ", FLAGS.integration_method)
t_span = torch.linspace(0, 1, FLAGS.integration_steps + 1, device=device)
Expand All @@ -90,7 +92,7 @@ def gen_1_img(unused_latent):
score = fid.compute_fid(
gen=gen_1_img,
dataset_name="cifar10",
batch_size=500,
batch_size=FLAGS.batch_size_fid,
dataset_res=32,
num_gen=FLAGS.num_gen,
dataset_split="train",
Expand Down

0 comments on commit bcc5082

Please sign in to comment.