Skip to content

Commit

Permalink
Merge pull request #10 from atoma-network/inference-models
Browse files Browse the repository at this point in the history
feat: add stable diffusion
  • Loading branch information
jorgeantonio21 authored Apr 4, 2024
2 parents eff8f1a + b6d0312 commit 3370fce
Show file tree
Hide file tree
Showing 9 changed files with 685 additions and 27 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
Cargo.lock
target/
.vscode/
15 changes: 13 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
[workspace]
resolver = "2"
edition = "2021"

members = ["atoma-event-subscribe", "atoma-inference", "atoma-networking", "atoma-json-rpc", "atoma-storage"]
members = [
"atoma-event-subscribe",
"atoma-inference",
"atoma-networking",
"atoma-json-rpc",
"atoma-storage",
]

[workspace.package]
version = "0.1.0"

[workspace.dependencies]
reqwest = "0.12.1"
async-trait = "0.1.78"
candle = { git = "https://github.com/huggingface/candle", package = "candle-core", version = "0.4.2" }
candle-flash-attn = { git = "https://github.com/huggingface/candle", package = "candle-flash-attn", version = "0.4.2" }
Expand All @@ -17,10 +25,13 @@ dotenv = "0.15.0"
ed25519-consensus = "2.1.0"
futures = "0.3.30"
hf-hub = "0.3.2"
image = { version = "0.25.0", default-features = false, features = [
"jpeg",
"png",
] }
serde = "1.0.197"
serde_json = "1.0.114"
rand = "0.8.5"
reqwest = "0.12.1"
thiserror = "1.0.58"
tokenizers = "0.15.2"
tokio = "1.36.0"
Expand Down
13 changes: 8 additions & 5 deletions atoma-inference/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
[package]
name = "inference"
version = "0.1.0"
version.workspace = true
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
async-trait.workspace = true
candle.workspace = true
Expand All @@ -19,8 +17,9 @@ hf-hub.workspace = true
reqwest = { workspace = true, features = ["json"] }
serde = { workspace = true, features = ["derive"] }
serde_json.workspace = true
image = { workspace = true }
thiserror.workspace = true
tokenizers.workspace = true
tokenizers = { workspace = true, features = ["onig"] }
tokio = { workspace = true, features = ["full", "tracing"] }
tracing.workspace = true
tracing-subscriber.workspace = true
Expand All @@ -31,7 +30,11 @@ toml.workspace = true


[features]
accelerate = ["candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
accelerate = [
"candle/accelerate",
"candle-nn/accelerate",
"candle-transformers/accelerate",
]
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
cudnn = ["candle/cudnn"]
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
Expand Down
5 changes: 2 additions & 3 deletions atoma-inference/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
pub mod apis;
pub mod model_thread;
pub mod models;
pub mod service;
pub mod specs;

pub mod apis;
pub mod models;
7 changes: 3 additions & 4 deletions atoma-inference/src/models/candle/mamba.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ impl MambaModel {
}

impl ModelTrait for MambaModel {
type Fetch = ();
type Input = TextModelInput;
type Output = String;

Expand All @@ -64,8 +65,7 @@ impl ModelTrait for MambaModel {
let tokenizer_filename = filenames[1].clone();
let weights_filenames = filenames[2..].to_vec();

let tokenizer =
Tokenizer::from_file(tokenizer_filename).map_err(ModelError::TokenizerError)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename)?;

let config: Config =
serde_json::from_slice(&std::fs::read(config_filename).map_err(ModelError::IoError)?)
Expand Down Expand Up @@ -110,8 +110,7 @@ impl ModelTrait for MambaModel {
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(ModelError::TokenizerError)?
.encode(prompt, true)?
.get_ids()
.to_vec();
let mut logits_processor =
Expand Down
83 changes: 83 additions & 0 deletions atoma-inference/src/models/candle/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,84 @@
use std::{fs::File, io::Write, path::PathBuf};

use candle::{
utils::{cuda_is_available, metal_is_available},
DType, Device, Tensor,
};
use tracing::info;

use crate::bail;

use super::ModelError;

pub mod mamba;
pub mod stable_diffusion;

pub fn device() -> Result<Device, candle::Error> {
if cuda_is_available() {
info!("Using CUDA");
Device::new_cuda(0)
} else if metal_is_available() {
info!("Using Metal");
Device::new_metal(0)
} else {
info!("Using Cpu");
Ok(Device::Cpu)
}
}

pub fn hub_load_safetensors(
repo: &hf_hub::api::sync::ApiRepo,
json_file: &str,
) -> Result<Vec<std::path::PathBuf>, ModelError> {
let json_file = repo.get(json_file).map_err(candle::Error::wrap)?;
let json_file = std::fs::File::open(json_file)?;
let json: serde_json::Value =
serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?;
let weight_map = match json.get("weight_map") {
None => bail!("no weight map in {json_file:?}"),
Some(serde_json::Value::Object(map)) => map,
Some(_) => bail!("weight map in {json_file:?} is not a map"),
};
let mut safetensors_files = std::collections::HashSet::new();
for value in weight_map.values() {
if let Some(file) = value.as_str() {
safetensors_files.insert(file.to_string());
}
}
let safetensors_files = safetensors_files
.iter()
.map(|v| repo.get(v).map_err(candle::Error::wrap))
.collect::<candle::Result<Vec<_>>>()?;
Ok(safetensors_files)
}

pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<(), ModelError> {
let p = p.as_ref();
let (channel, height, width) = img.dims3()?;
if channel != 3 {
bail!("save_image expects an input of shape (3, height, width)")
}
let img = img.permute((1, 2, 0))?.flatten_all()?;
let pixels = img.to_vec1::<u8>()?;
let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
Some(image) => image,
None => bail!("error saving image {p:?}"),
};
image.save(p).map_err(candle::Error::wrap)?;
Ok(())
}

pub fn save_tensor_to_file(tensor: &Tensor, filename: &str) -> Result<(), candle::Error> {
let json_output = serde_json::to_string(
&tensor
.to_device(&Device::Cpu)?
.flatten_all()?
.to_dtype(DType::F64)?
.to_vec1::<f64>()?,
)
.unwrap();
let mut file = File::create(PathBuf::from(filename))?;
file.write_all(json_output.as_bytes())?;
Ok(())
}
Loading

0 comments on commit 3370fce

Please sign in to comment.