Skip to content

Commit

Permalink
fix image encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
SamitHuang committed Jun 28, 2024
1 parent d94ec5e commit 0603a16
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion examples/opensora_hpcai/opensora/pipelines/infer_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from mindone.models.modules.pos_embed import get_2d_sincos_pos_embed

from ..models.vae.vae import VideoAutoencoderKL, VideoAutoencoderPipeline
from ..models.layers.rotary_embedding import precompute_freqs_cis
from ..schedulers.iddpm import create_diffusion

Expand Down Expand Up @@ -61,7 +62,12 @@ def __init__(
# @ms.jit
def vae_encode(self, x: Tensor) -> Tensor:
# image_latents = ops.stop_gradient(self.vae.encode(x))
image_latents = ops.stop_gradient(self.vae.module.encode(x) * self.vae.scale_factor)
if isinstance(self.vae, VideoAutoencoderKL):
spatial_vae = self.vae
elif isinstance(self.vae, VideoAutoencoderPipeline):
spatial_vae = self.vae.spatial_vae
# TODO: unify scale inside vae class
image_latents = ops.stop_gradient(spatial_vae.module.encode(x) * spatial_vae.scale_factor)
return image_latents

def vae_decode(self, x: Tensor) -> Tensor:
Expand Down

0 comments on commit 0603a16

Please sign in to comment.