From 6930884984ec7d06c9af50363eecc9ce32a8807a Mon Sep 17 00:00:00 2001 From: Wok Date: Sat, 19 Sep 2020 21:33:10 +0200 Subject: [PATCH 1/3] Add CLI argument to sample more scales for moving in a direction --- apply_factor.py | 41 +++++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/apply_factor.py b/apply_factor.py index 1a66365f..47e1c14f 100755 --- a/apply_factor.py +++ b/apply_factor.py @@ -27,6 +27,13 @@ default=2, help='channel multiplier factor. config-f = 2, else = 1', ) + parser.add_argument( + "-d_num", + "--degree_num", + type=int, + default=3, + help="number of scalar factors for moving latent vectors along eigenvector", + ) parser.add_argument("--ckpt", type=str, required=True, help="stylegan2 checkpoints") parser.add_argument( "--size", type=int, default=256, help="output image size of the generator" @@ -64,29 +71,23 @@ latent = torch.randn(args.n_sample, 512, device=args.device) latent = g.get_latent(latent) - direction = args.degree * eigvec[:, args.index].unsqueeze(0) + direction = eigvec[:, args.index].unsqueeze(0) - img, _ = g( - [latent], - truncation=args.truncation, - truncation_latent=trunc, - input_is_latent=True, - ) - img1, _ = g( - [latent + direction], - truncation=args.truncation, - truncation_latent=trunc, - input_is_latent=True, - ) - img2, _ = g( - [latent - direction], - truncation=args.truncation, - truncation_latent=trunc, - input_is_latent=True, - ) + img_list = [] + + for u in torch.linspace(- args.degree, args.degree, args.d_num): + + img_batch, _ = g( + [latent + u * direction], + truncation=args.truncation, + truncation_latent=trunc, + input_is_latent=True, + ) + + img_list.append(img_batch) grid = utils.save_image( - torch.cat([img1, img, img2], 0), + torch.cat(img_list, 0), f"{args.out_prefix}_index-{args.index}_degree-{args.degree}.png", normalize=True, range=(-1, 1), From 25b87a0e1cab229b5ce29642b1ac27df8cf96547 Mon Sep 17 00:00:00 2001 From: Wok Date: Sat, 19 Sep 2020 21:34:58 +0200 Subject: [PATCH 2/3] Transpose the grid of results (one sample per row) --- apply_factor.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/apply_factor.py b/apply_factor.py index 47e1c14f..7599f520 100755 --- a/apply_factor.py +++ b/apply_factor.py @@ -73,7 +73,7 @@ direction = eigvec[:, args.index].unsqueeze(0) - img_list = [] + img_dict = dict() for u in torch.linspace(- args.degree, args.degree, args.d_num): @@ -84,12 +84,24 @@ input_is_latent=True, ) - img_list.append(img_batch) + for j in range(img_batch.shape[0]): + + img = img_batch[j].unsqueeze(0) + + try: + img_dict[j].append(img) + except KeyError: + img_dict[j] = [img] + + img_list = [ + torch.cat(img_dict[j], 0) + for j in range(args.n_sample) + ] grid = utils.save_image( torch.cat(img_list, 0), f"{args.out_prefix}_index-{args.index}_degree-{args.degree}.png", normalize=True, range=(-1, 1), - nrow=args.n_sample, + nrow=args.d_num, ) From 5fa9f302676bc4f9297686e0e0d4d72e2480e83b Mon Sep 17 00:00:00 2001 From: Wok Date: Sat, 19 Sep 2020 22:02:52 +0200 Subject: [PATCH 3/3] Fix typo --- apply_factor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/apply_factor.py b/apply_factor.py index 7599f520..4a0874a0 100755 --- a/apply_factor.py +++ b/apply_factor.py @@ -75,7 +75,7 @@ img_dict = dict() - for u in torch.linspace(- args.degree, args.degree, args.d_num): + for u in torch.linspace(- args.degree, args.degree, args.degree_num): img_batch, _ = g( [latent + u * direction], @@ -103,5 +103,5 @@ f"{args.out_prefix}_index-{args.index}_degree-{args.degree}.png", normalize=True, range=(-1, 1), - nrow=args.d_num, + nrow=args.degree_num, )