Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 584714705
  • Loading branch information
The swirl_dynamics Authors committed Nov 22, 2023
1 parent cb629d0 commit 18bc479
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
2 changes: 1 addition & 1 deletion swirl_dynamics/lib/diffusion/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ class OdeSampler:
denoise_fn: The denoising function; required to work on batched states and
noise levels.
guidance_transforms: An optional sequence of guidance transforms that
modifies the denoising funciton in a post-process fashion.
modifies the denoising function in a post-process fashion.
apply_denoise_at_end: Whether to apply the denoise function for another time
to the terminal state.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@
"from orbax import checkpoint\n",
"\n",
"from swirl_dynamics.lib.diffusion import diffusion\n",
"from swirl_dynamics.projects.probabilistic_diffusion import unconditional\n",
"from swirl_dynamics.projects.probabilistic_diffusion import trainers\n",
"from swirl_dynamics.projects.probabilistic_diffusion import models\n",
"from swirl_dynamics.templates import callbacks\n",
"from swirl_dynamics.templates import train"
]
Expand Down Expand Up @@ -177,7 +178,7 @@
" sigma=diffusion.tangent_noise_schedule(),\n",
" data_std=1.0,\n",
")\n",
"model = unconditional.DenoisingModel(\n",
"model = models.DenoisingModel(\n",
" input_shape=(64, 64, 1), # this must agree with the expected sample shape (without the batch dimension)\n",
" denoiser=denoiser_model,\n",
" noise_sampling=diffusion.log_uniform_sampling(\n",
Expand Down Expand Up @@ -243,9 +244,9 @@
},
"outputs": [],
"source": [
"# NOTE: use `unconditional.DistributedDenoisingTrainer` for multi-device\n",
"# NOTE: use `trainers.DistributedDenoisingTrainer` for multi-device\n",
"# training with data parallelism\n",
"trainer = unconditional.DenoisingTrainer(\n",
"trainer = trainers.DenoisingTrainer(\n",
" model=model,\n",
" rng=jax.random.PRNGKey(888),\n",
" optimizer=optax.adam(\n",
Expand Down Expand Up @@ -324,7 +325,7 @@
"id": "JrbCgv_P6D0n"
},
"source": [
"#### Unconditional generation"
"#### trainers generation"
]
},
{
Expand All @@ -333,7 +334,7 @@
"id": "6AURUIg5RT4m"
},
"source": [
"The trained denoiser may be used to generate unconditional samples.\n",
"The trained denoiser may be used to generate trainers samples.\n",
"\n",
"First, let's try to restore the model from checkpoint."
]
Expand All @@ -349,11 +350,11 @@
"# Restore train state from checkpoint. By default, the move recently saved\n",
"# checkpoint is restored. Alternatively, one can directly use\n",
"# `trainer.train_state` if continuing from the training section above.\n",
"trained_state = unconditional.TrainState.restore_from_orbax_ckpt(\n",
"trained_state = trainers.TrainState.restore_from_orbax_ckpt(\n",
" f\"{workdir}/checkpoints\", step=None\n",
")\n",
"# Construct the inference function\n",
"denoise_fn = unconditional.DenoisingTrainer.inference_fn_from_state_dict(\n",
"denoise_fn = trainers.DenoisingTrainer.inference_fn_from_state_dict(\n",
" trained_state, use_ema=True, denoiser=denoiser_model\n",
")"
]
Expand Down Expand Up @@ -405,7 +406,7 @@
"source": [
"# Optional: JIT compile the generate function so that it runs faster if\n",
"# repeatedly called.\n",
"generate = jax.jit(sampler.generate, static_argnums=(2,))"
"generate = jax.jit(sampler.generate, static_argnames=('num_samples',))"
]
},
{
Expand Down Expand Up @@ -576,10 +577,10 @@
" integrator=sde.EulerMaruyama(),\n",
" scheme=diffusion_scheme,\n",
" denoise_fn=denoise_fn,\n",
" guidance_fn=guidance_fn,\n",
" guidance_transforms=(guidance_fn,),\n",
")\n",
"\n",
"guided_generate = jax.jit(guided_sampler.generate, static_argnums=(2,))"
"guided_generate = jax.jit(guided_sampler.generate, static_argnames=('num_samples',))"
]
},
{
Expand All @@ -598,7 +599,7 @@
" num_samples=4,\n",
" # The shape of the guidance input must be compatible with\n",
" # `sample[guidance_fn.slices]`\n",
" guidance_input=jnp.ones((1, 8, 8, 1)),\n",
" guidance_inputs={'observed_slices': jnp.ones((1, 8, 8, 1))},\n",
")"
]
},
Expand Down

0 comments on commit 18bc479

Please sign in to comment.