Skip to content

Commit

Permalink
better package sharing, add configurability to workflows configs
Browse files Browse the repository at this point in the history
  • Loading branch information
erhant committed Oct 9, 2024
1 parent 7da3a09 commit 438ad78
Show file tree
Hide file tree
Showing 10 changed files with 141 additions and 186 deletions.
233 changes: 67 additions & 166 deletions Cargo.lock

Large diffs are not rendered by default.

8 changes: 5 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ default-members = ["compute"]

[workspace.package]
edition = "2021"
version = "0.2.12"
version = "0.2.13"
license = "Apache-2.0"
readme = "README.md"

Expand All @@ -21,7 +21,9 @@ debug = true

[workspace.dependencies]
# async stuff
tokio-util = { version = "0.7.10", features = ["rt"] }
tokio-util = { version = "0.7.10", features = [
"rt",
] } # tokio-util provides CancellationToken
tokio = { version = "1", features = ["macros", "rt-multi-thread", "signal"] }
async-trait = "0.1.81"

Expand All @@ -32,7 +34,7 @@ serde_json = "1.0"
# http client
reqwest = "0.12.5"

# env reading
# utilities
dotenvy = "0.15.7"

# randomization
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ trace:

.PHONY: build # | Build
build:
cargo build
cargo build --workspace

.PHONY: profile-cpu # | Profile CPU usage with flamegraph
profile-cpu:
Expand Down
28 changes: 16 additions & 12 deletions compute/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,35 @@ readme = "README.md"
authors = ["Erhan Tezcan <[email protected]>"]

[dependencies]
tokio-util = { version = "0.7.10", features = ["rt"] }
tokio = { version = "1", features = ["macros", "rt-multi-thread", "signal"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
async-trait = "0.1.81"
reqwest = "0.12.5"
# async stuff
tokio-util.workspace = true
tokio.workspace = true
async-trait.workspace = true

# serialize & deserialize
serde.workspace = true
serde_json.workspace = true

# http & networking
reqwest.workspace = true
port_check = "0.2.1"
url = "2.5.0"
urlencoding = "2.1.3"

# utilities
dotenvy.workspace = true
base64 = "0.22.0"
hex = "0.4.3"
hex-literal = "0.4.1"
url = "2.5.0"
urlencoding = "2.1.3"
uuid = { version = "1.8.0", features = ["v4"] }

port_check = "0.2.1"

# logging & errors
rand.workspace = true
env_logger.workspace = true
log.workspace = true
eyre.workspace = true
tracing = { version = "0.1.40" }
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
# tracing = { version = "0.1.40" }
# tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }

# encryption (ecies) & signatures (ecdsa) & hashing & bloom-filters
ecies = { version = "0.2", default-features = false, features = ["pure"] }
Expand Down
2 changes: 1 addition & 1 deletion compute/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ impl DriaComputeNodeConfig {
log::error!("No models were provided, make sure to restart with at least one model provided within DKN_MODELS.");
panic!("No models provided.");
}
log::info!("Models: {:?}", workflows.models);
log::info!("Configured models: {:?}", workflows.models);

let p2p_listen_addr_str = env::var("DKN_P2P_LISTEN_ADDR")
.map(|addr| addr.trim_matches('"').to_string())
Expand Down
2 changes: 1 addition & 1 deletion compute/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ async fn main() -> Result<()> {
let cancellation_token = token.clone();
tokio::spawn(async move {
if let Some(timeout_str) = env::var("DKN_EXIT_TIMEOUT").ok() {
// add cancellation check
let duration_secs = timeout_str.parse().unwrap_or(120);
log::warn!("Waiting for {} seconds before exiting.", duration_secs);
tokio::time::sleep(tokio::time::Duration::from_secs(duration_secs)).await;
cancellation_token.cancel();
} else {
Expand Down
12 changes: 11 additions & 1 deletion workflows/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,26 @@ authors = ["Erhan Tezcan <[email protected]>"]
# ollama-rs is re-exported from ollama-workflows as well
ollama-workflows = { git = "https://github.com/andthattoo/ollama-workflows" }

# async stuff
tokio-util.workspace = true
tokio.workspace = true
async-trait.workspace = true

# serialize & deserialize
serde.workspace = true
serde_json.workspace = true
async-trait.workspace = true

# http & networking
reqwest.workspace = true

# utilities
rand.workspace = true

# logging & errors
log.workspace = true
eyre.workspace = true

[dev-dependencies]
# only used for tests
env_logger.workspace = true
dotenvy.workspace = true
16 changes: 15 additions & 1 deletion workflows/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub struct DriaWorkflowsConfig {
}

impl DriaWorkflowsConfig {
/// Creates a new config with the given models.
pub fn new(models: Vec<Model>) -> Self {
let models_and_providers = models
.into_iter()
Expand All @@ -28,6 +29,19 @@ impl DriaWorkflowsConfig {
ollama: OllamaConfig::new(),
}
}

/// Sets the Ollama configuration for the Workflows config.
pub fn with_ollama_config(mut self, ollama: OllamaConfig) -> Self {
self.ollama = ollama;
self
}

/// Sets the OpenAI configuration for the Workflows config.
pub fn with_openai_config(mut self, openai: OpenAIConfig) -> Self {
self.openai = openai;
self
}

/// Parses Ollama-Workflows compatible models from a comma-separated values string.
pub fn new_from_csv(input: &str) -> Self {
let models_str = split_csv_line(input);
Expand All @@ -40,7 +54,7 @@ impl DriaWorkflowsConfig {
Self::new(models)
}

/// Returns the models that belong to a given providers from the config.
/// Returns the models from the config that belongs to a given provider.
pub fn get_models_for_provider(&self, provider: ModelProvider) -> Vec<Model> {
self.models
.iter()
Expand Down
18 changes: 18 additions & 0 deletions workflows/src/providers/ollama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,24 @@ impl OllamaConfig {
}
}

/// Sets the timeout duration for checking model performance during a generation.
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}

/// Sets the minimum tokens per second (TPS) for checking model performance during a generation.
pub fn with_min_tps(mut self, min_tps: f64) -> Self {
self.min_tps = min_tps;
self
}

/// Sets the auto-pull flag for Ollama models.
pub fn with_auto_pull(mut self, auto_pull: bool) -> Self {
self.auto_pull = auto_pull;
self
}

/// Check if requested models exist in Ollama, and then tests them using a workflow.
pub async fn check(&self, external_models: Vec<Model>) -> Result<Vec<Model>> {
log::info!(
Expand Down
6 changes: 6 additions & 0 deletions workflows/src/providers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ impl OpenAIConfig {
}
}

/// Sets the API key for OpenAI.
pub fn with_api_key(mut self, api_key: String) -> Self {
self.api_key = Some(api_key);
self
}

/// Check if requested models exist & are available in the OpenAI account.
pub async fn check(&self, models: Vec<Model>) -> Result<Vec<Model>> {
log::info!("Checking OpenAI requirements");
Expand Down

0 comments on commit 438ad78

Please sign in to comment.