Paper: https://arxiv.org/abs/2104.06820
Repository: https://github.com/utkarshojha/few-shot-gan-adaptation
import jax
import numpy as np
import dill as pickle
from PIL import Image
import flaxmodels as fm
ckpt = pickle.load(open('sketches.pickle', 'rb'))
params = ckpt['params_ema_G']
generator = fm.few_shot_gan_adaption.Generator()
# Seed
key = jax.random.PRNGKey(0)
# Input noise
z = jax.random.normal(key, shape=(4, 512))
# Generate images
images, _ = generator.apply(params, z, truncation_psi=0.5, train=False, noise_mode='const')
# Normalize images to be in range [0, 1]
images = (images - np.min(images)) / (np.max(images) - np.min(images))
# Save images
for i in range(images.shape[0]):
Image.fromarray(np.uint8(images[i] * 255)).save(f'image_{i}.jpg')
- Sketches (357,2 MB)
- Amedeo Modigliani (357,2 MB)
- Babies (357,2 MB)
- Otto Dix (357,2 MB)
- Rafael (357,2 MB)
The documentation can be found here.
If you want to train this model in Jax/Flax, go here.