Skip to content

Commit

Permalink
Merge branch 'main' of github.com:atong01/conditional-flow-matching
Browse files Browse the repository at this point in the history
  • Loading branch information
kilianFatras committed Mar 2, 2024
2 parents 48852ae + bcc5082 commit 2b7f946
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions examples/images/cifar10/compute_fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
flags.DEFINE_integer("step", 400000, help="training steps")
flags.DEFINE_integer("num_gen", 50000, help="number of samples to generate")
flags.DEFINE_float("tol", 1e-5, help="Integrator tolerance (absolute and relative)")
flags.DEFINE_float("batch_size_fid", 1024, help="Batch size to compute FID")

FLAGS(sys.argv)


Expand Down Expand Up @@ -70,7 +72,7 @@

def gen_1_img(unused_latent):
with torch.no_grad():
x = torch.randn(500, 3, 32, 32, device=device)
x = torch.randn(FLAGS.batch_size_fid, 3, 32, 32, device=device)
if FLAGS.integration_method == "euler":
print("Use method: ", FLAGS.integration_method)
t_span = torch.linspace(0, 1, FLAGS.integration_steps + 1, device=device)
Expand All @@ -90,7 +92,7 @@ def gen_1_img(unused_latent):
score = fid.compute_fid(
gen=gen_1_img,
dataset_name="cifar10",
batch_size=500,
batch_size=FLAGS.batch_size_fid,
dataset_res=32,
num_gen=FLAGS.num_gen,
dataset_split="train",
Expand Down

0 comments on commit 2b7f946

Please sign in to comment.