Skip to content

Commit

Permalink
Add the flux model for image generation. (#2390)
Browse files Browse the repository at this point in the history
* Add the flux autoencoder.

* Add the encoder down-blocks.

* Upsampling in the decoder.

* Sketch the flow matching model.

* More flux model.

* Add some of the positional embeddings.

* Add the rope embeddings.

* Add the sampling functions.

* Add the flux example.

* Fix the T5 bits.

* Proper T5 tokenizer.

* Clip encoder path fix.

* Get the clip embeddings.

* No configurable weights in layer norm.

* More weights related fixes.

* Yet another shape fix.

* DType fix.

* Fix a couple more shape issues.

* DType fixes.

* Fix the latent dims.

* Fix more shape issues.

* Autoencoder fixes.

* Get some generations out.

* Bugfix.

* T5 padding.

* Clippy fix.

* Add the decode only mode.

* Fix.

* More fixes.

* Finally get some generations to work.

* Add readme.
  • Loading branch information
LaurentMazare authored Aug 4, 2024
1 parent 0fcb40b commit 19db6b9
Show file tree
Hide file tree
Showing 8 changed files with 1,346 additions and 0 deletions.
19 changes: 19 additions & 0 deletions candle-examples/examples/flux/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# candle-flux: image generation with latent rectified flow transformers

![rusty robot holding a candle](./assets/flux-robot.jpg)

Flux is a 12B rectified flow transformer capable of generating images from text
descriptions,
[huggingface](https://huggingface.co/black-forest-labs/FLUX.1-schnell),
[github](https://github.com/black-forest-labs/flux),
[blog post](https://blackforestlabs.ai/announcing-black-forest-labs/).


## Running the model

```bash
cargo run --features cuda --example flux -r -- \
--height 1024 --width 1024
--prompt "a rusty robot walking on a beach holding a small torch, the robot has the word "rust" written on it, high quality, 4k"
```

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
182 changes: 182 additions & 0 deletions candle-examples/examples/flux/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
#[cfg(feature = "accelerate")]
extern crate accelerate_src;

#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

use candle_transformers::models::{clip, flux, t5};

use anyhow::{Error as E, Result};
use candle::{IndexOp, Module, Tensor};
use candle_nn::VarBuilder;
use clap::Parser;
use tokenizers::Tokenizer;

#[derive(Parser)]
#[command(author, version, about, long_about = None)]
struct Args {
/// The prompt to be used for image generation.
#[arg(long, default_value = "A rusty robot walking on a beach")]
prompt: String,

/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,

/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,

/// The height in pixels of the generated image.
#[arg(long)]
height: Option<usize>,

/// The width in pixels of the generated image.
#[arg(long)]
width: Option<usize>,

#[arg(long)]
decode_only: Option<String>,
}

fn run(args: Args) -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;

let Args {
prompt,
cpu,
height,
width,
tracing,
decode_only,
} = args;
let width = width.unwrap_or(1360);
let height = height.unwrap_or(768);

let _guard = if tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};

let api = hf_hub::api::sync::Api::new()?;
let bf_repo = api.repo(hf_hub::Repo::model(
"black-forest-labs/FLUX.1-schnell".to_string(),
));
let device = candle_examples::device(cpu)?;
let dtype = device.bf16_default_to_f32();
let img = match decode_only {
None => {
let t5_emb = {
let repo = api.repo(hf_hub::Repo::with_revision(
"google/t5-v1_1-xxl".to_string(),
hf_hub::RepoType::Model,
"refs/pr/2".to_string(),
));
let model_file = repo.get("model.safetensors")?;
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
let config_filename = repo.get("config.json")?;
let config = std::fs::read_to_string(config_filename)?;
let config: t5::Config = serde_json::from_str(&config)?;
let mut model = t5::T5EncoderModel::load(vb, &config)?;
let tokenizer_filename = api
.model("lmz/mt5-tokenizers".to_string())
.get("t5-v1_1-xxl.tokenizer.json")?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let mut tokens = tokenizer
.encode(prompt.as_str(), true)
.map_err(E::msg)?
.get_ids()
.to_vec();
tokens.resize(256, 0);
let input_token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?;
println!("{input_token_ids}");
model.forward(&input_token_ids)?
};
println!("T5\n{t5_emb}");
let clip_emb = {
let repo = api.repo(hf_hub::Repo::model(
"openai/clip-vit-large-patch14".to_string(),
));
let model_file = repo.get("model.safetensors")?;
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
// https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
let config = clip::text_model::ClipTextConfig {
vocab_size: 49408,
projection_dim: 768,
activation: clip::text_model::Activation::QuickGelu,
intermediate_size: 3072,
embed_dim: 768,
max_position_embeddings: 77,
pad_with: None,
num_hidden_layers: 12,
num_attention_heads: 12,
};
let model =
clip::text_model::ClipTextTransformer::new(vb.pp("text_model"), &config)?;
let tokenizer_filename = repo.get("tokenizer.json")?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let tokens = tokenizer
.encode(prompt.as_str(), true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let input_token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?;
println!("{input_token_ids}");
model.forward(&input_token_ids)?
};
println!("CLIP\n{clip_emb}");
let img = {
let model_file = bf_repo.get("flux1-schnell.sft")?;
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
let cfg = flux::model::Config::schnell();
let model = flux::model::Flux::new(&cfg, vb)?;

let img = flux::sampling::get_noise(1, height, width, &device)?.to_dtype(dtype)?;
let state = flux::sampling::State::new(&t5_emb, &clip_emb, &img)?;
println!("{state:?}");
let timesteps = flux::sampling::get_schedule(4, None); // no shift for flux-schnell
println!("{timesteps:?}");
flux::sampling::denoise(
&model,
&state.img,
&state.img_ids,
&state.txt,
&state.txt_ids,
&state.vec,
&timesteps,
4.,
)?
};
flux::sampling::unpack(&img, height, width)?
}
Some(file) => {
let mut st = candle::safetensors::load(file, &device)?;
st.remove("img").unwrap().to_dtype(dtype)?
}
};
println!("latent img\n{img}");

let img = {
let model_file = bf_repo.get("ae.sft")?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
let cfg = flux::autoencoder::Config::schnell();
let model = flux::autoencoder::AutoEncoder::new(&cfg, vb)?;
model.decode(&img)?
};
println!("img\n{img}");
let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(candle::DType::U8)?;
candle_examples::save_image(&img.i(0)?, "out.jpg")?;
Ok(())
}

fn main() -> Result<()> {
let args = Args::parse();
run(args)
}
Loading

0 comments on commit 19db6b9

Please sign in to comment.