Skip to content

Commit

Permalink
Move hidden states compute to a dedicated method (#1236)
Browse files Browse the repository at this point in the history
To simplify `generate()` logic

---------

Co-authored-by: Anna Likholat <[email protected]>
  • Loading branch information
ilya-lavrenov and likholat authored Nov 20, 2024
1 parent 1753672 commit 7c4b969
Show file tree
Hide file tree
Showing 6 changed files with 271 additions and 241 deletions.
2 changes: 2 additions & 0 deletions src/cpp/src/image_generation/diffusion_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ class DiffusionPipeline {

virtual ov::Tensor prepare_latents(ov::Tensor initial_image, const ImageGenerationConfig& generation_config) const = 0;

virtual void compute_hidden_states(const std::string& positive_prompt, const ImageGenerationConfig& generation_config) = 0;

virtual ov::Tensor generate(const std::string& positive_prompt, ov::Tensor initial_image, const ov::AnyMap& properties) = 0;

virtual ov::Tensor decode(const ov::Tensor latent) = 0;
Expand Down
81 changes: 45 additions & 36 deletions src/cpp/src/image_generation/flux_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,43 @@ class FluxPipeline : public DiffusionPipeline {
m_vae->compile(device, properties);
m_transformer->compile(device, properties);
}

void compute_hidden_states(const std::string& positive_prompt, const ImageGenerationConfig& generation_config) override {
// encode_prompt
std::string prompt_2_str =
generation_config.prompt_2 != std::nullopt ? *generation_config.prompt_2 : positive_prompt;

m_clip_text_encoder->infer(positive_prompt, "", false);
ov::Tensor pooled_prompt_embeds_out = m_clip_text_encoder->get_output_tensor(1);

ov::Tensor prompt_embeds_out = m_t5_text_encoder->infer(prompt_2_str, generation_config.max_sequence_length);

ov::Tensor pooled_prompt_embeds, prompt_embeds;
if (generation_config.num_images_per_prompt == 1) {
pooled_prompt_embeds = pooled_prompt_embeds_out;
prompt_embeds = prompt_embeds_out;
} else {
pooled_prompt_embeds = numpy_utils::repeat(pooled_prompt_embeds_out, generation_config.num_images_per_prompt);
prompt_embeds = numpy_utils::repeat(prompt_embeds_out, generation_config.num_images_per_prompt);
}

// text_ids = torch.zeros(prompt_embeds.shape[1], 3)
ov::Shape text_ids_shape = {prompt_embeds.get_shape()[1], 3};
ov::Tensor text_ids(ov::element::f32, text_ids_shape);
std::fill_n(text_ids.data<float>(), text_ids_shape[0] * text_ids_shape[1], 0.0f);

const size_t num_channels_latents = m_transformer->get_config().in_channels / 4;
const size_t vae_scale_factor = m_vae->get_vae_scale_factor();
size_t height = generation_config.height / vae_scale_factor;
size_t width = generation_config.width / vae_scale_factor;

ov::Tensor latent_image_ids = prepare_latent_image_ids(generation_config.num_images_per_prompt, height / 2, width / 2);

m_transformer->set_hidden_states("pooled_projections", pooled_prompt_embeds);
m_transformer->set_hidden_states("encoder_hidden_states", prompt_embeds);
m_transformer->set_hidden_states("txt_ids", text_ids);
m_transformer->set_hidden_states("img_ids", latent_image_ids);
}

ov::Tensor prepare_latents(ov::Tensor initial_image,
const ImageGenerationConfig& generation_config) const override {
Expand All @@ -260,10 +297,14 @@ class FluxPipeline : public DiffusionPipeline {
ov::Tensor generate(const std::string& positive_prompt,
ov::Tensor initial_image,
const ov::AnyMap& properties) override {
using namespace numpy_utils;
ImageGenerationConfig generation_config = m_generation_config;
generation_config.update_generation_config(properties);

if (!initial_image) {
// in case of typical text to image generation, we need to ignore 'strength'
generation_config.strength = 1.0f;
}

const size_t vae_scale_factor = m_vae->get_vae_scale_factor();
const auto& transformer_config = m_transformer->get_config();

Expand All @@ -274,50 +315,18 @@ class FluxPipeline : public DiffusionPipeline {

check_inputs(generation_config, initial_image);

// encode_prompt
std::string prompt_2_str =
generation_config.prompt_2 != std::nullopt ? *generation_config.prompt_2 : positive_prompt;

m_clip_text_encoder->infer(positive_prompt, "", false);
ov::Tensor pooled_prompt_embeds_out = m_clip_text_encoder->get_output_tensor(1);

ov::Tensor prompt_embeds_out = m_t5_text_encoder->infer(positive_prompt, generation_config.max_sequence_length);

ov::Tensor pooled_prompt_embeds, prompt_embeds;
if (generation_config.num_images_per_prompt == 1) {
pooled_prompt_embeds = pooled_prompt_embeds_out;
prompt_embeds = prompt_embeds_out;
} else {
pooled_prompt_embeds = repeat(pooled_prompt_embeds_out, generation_config.num_images_per_prompt);
prompt_embeds = repeat(prompt_embeds_out, generation_config.num_images_per_prompt);
}

// text_ids = torch.zeros(prompt_embeds.shape[1], 3)
ov::Shape text_ids_shape = {prompt_embeds.get_shape()[1], 3};
ov::Tensor text_ids(ov::element::f32, text_ids_shape);
std::fill_n(text_ids.data<float>(), text_ids_shape[0] * text_ids_shape[1], 0.0f);

size_t num_channels_latents = m_transformer->get_config().in_channels / 4;
size_t height = generation_config.height / vae_scale_factor;
size_t width = generation_config.width / vae_scale_factor;
compute_hidden_states(positive_prompt, generation_config);

ov::Tensor latents = prepare_latents(initial_image, generation_config);
ov::Tensor latent_image_ids = prepare_latent_image_ids(generation_config.num_images_per_prompt, height / 2, width / 2);

m_transformer->set_hidden_states("pooled_projections", pooled_prompt_embeds);
m_transformer->set_hidden_states("encoder_hidden_states", prompt_embeds);
m_transformer->set_hidden_states("txt_ids", text_ids);
m_transformer->set_hidden_states("img_ids", latent_image_ids);

size_t image_seq_len = latents.get_shape()[1];
float mu = m_scheduler->calculate_shift(image_seq_len);

float linspace_end = 1.0f / generation_config.num_inference_steps;
std::vector<float> sigmas = linspace<float>(1.0f, linspace_end, generation_config.num_inference_steps, true);
std::vector<float> sigmas = numpy_utils::linspace<float>(1.0f, linspace_end, generation_config.num_inference_steps, true);

m_scheduler->set_timesteps_with_sigma(sigmas, mu);
std::vector<float> timesteps = m_scheduler->get_float_timesteps();
size_t num_inference_steps = timesteps.size();

// Use callback if defined
std::function<bool(size_t, ov::Tensor&)> callback;
Expand All @@ -331,7 +340,7 @@ class FluxPipeline : public DiffusionPipeline {
ov::Tensor timestep(ov::element::f32, {1});
float* timestep_data = timestep.data<float>();

for (size_t inference_step = 0; inference_step < num_inference_steps; ++inference_step) {
for (size_t inference_step = 0; inference_step < timesteps.size(); ++inference_step) {
timestep_data[0] = timesteps[inference_step] / 1000;

ov::Tensor noise_pred_tensor = m_transformer->infer(latents, timestep);
Expand Down
3 changes: 1 addition & 2 deletions src/cpp/src/image_generation/generation_config.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
// Copyright (C) 2023-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include "image_generation/stable_diffusion_pipeline.hpp"
#include "image_generation/stable_diffusion_xl_pipeline.hpp"
#include "openvino/genai/image_generation/generation_config.hpp"

#include <ctime>
#include <cstdlib>
Expand Down
Loading

0 comments on commit 7c4b969

Please sign in to comment.