Skip to content

Commit

Permalink
feat: migrate from sqlite to postgres (#245)
Browse files Browse the repository at this point in the history
* refactor sqlite to postgres

* add missing file
  • Loading branch information
jorgeantonio21 authored Nov 21, 2024
1 parent 594cdf1 commit 19fe7aa
Show file tree
Hide file tree
Showing 11 changed files with 372 additions and 388 deletions.
24 changes: 24 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,36 @@ HF_TOKEN= # required if you want to access a gated model
# ----------------------------------------------------------------------------------
# atoma node configuration

# Postgres Configuration
POSTGRES_DB=atoma
POSTGRES_USER=atoma
POSTGRES_PASSWORD=

# Sui Configuration
SUI_CONFIG_PATH=~/.sui/sui_config

# Atoma Node Service Configuration
ATOMA_SERVICE_PORT=3000

# Currently available docker compose profiles:
#
# All possible values are:
#
# 1. chat_completions_vllm
# 2. chat_completions_mistralrs_cpu
# 3. chat_completions_vllm_cpu, running this profile requires a CPU with AVX2 support
# 4. chat_completions_vllm_rocm, running this profile requires a GPU with AMD GPU drivers installed
# 7. chat_completions_mistralrs_rocm, running this profile requires a GPU with AMD GPU drivers installed
# 8. embeddings_tei - runs text embeddings inference server in docker compose
# 9. image_generations_mistralrs - runs image generations server in docker compose
#
# Setting the COMPOSE_PROFILES environment variable will start all services listed in the value, e.g.
# COMPOSE_PROFILES=chat_completions_vllm will start the chat completions server and the postgres database
#
# Please change it accordingly to which inference services you want to run, and which database you want to use (either PostgresSQL or SQLite)
COMPOSE_PROFILES=chat_completions_vllm

# Tracing level
TRACE_LEVEL=info

# ----------------------------------------------------------------------------------
Expand Down
15 changes: 15 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,21 @@ jobs:
clippy:
name: clippy
runs-on: [ubuntu-22.04]
services:
postgres:
image: postgres:13
env:
POSTGRES_DB: atoma
POSTGRES_USER: atoma
POSTGRES_PASSWORD: atoma
ports:
- 5432:5432
# health check to ensure database is ready
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
steps:
- name: checkout
uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions atoma-service/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,6 @@ utoipa-swagger-ui = { workspace = true, features = ["axum"] }
rand = { workspace = true }
serial_test = { workspace = true }
tempfile = { workspace = true }

[dev_dependencies]
sqlx = { workspace = true, features = ["runtime-tokio", "postgres"] }
99 changes: 38 additions & 61 deletions atoma-service/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ mod middleware {
use flume::Sender;
use serde_json::json;
use serial_test::serial;
use std::{path::PathBuf, str::FromStr, sync::Arc};
use sqlx::PgPool;
use std::{str::FromStr, sync::Arc};
use sui_keys::keystore::{AccountKeystore, FileBasedKeystore};
use sui_sdk::types::{
base_types::SuiAddress,
Expand Down Expand Up @@ -83,29 +84,36 @@ mod middleware {
Tokenizer::from_str(&tokenizer_json).unwrap()
}

async fn truncate_tables() {
let db = PgPool::connect("postgres://atoma:atoma@localhost:5432/atoma")
.await
.expect("Failed to connect to database");
sqlx::query(
"TRUNCATE TABLE
tasks,
node_subscriptions,
stacks,
stack_settlement_tickets,
stack_attestation_disputes
CASCADE",
)
.execute(&db)
.await
.expect("Failed to truncate tables");
}

async fn setup_database(
public_key: PublicKey,
) -> (
PathBuf,
JoinHandle<()>,
Sender<AtomaAtomaStateManagerEvent>,
tokio::sync::watch::Sender<bool>,
Sender<AtomaEvent>,
) {
let db_path = std::path::Path::new("./db_path").to_path_buf();

std::fs::OpenOptions::new()
.create(true)
.truncate(true)
.write(true)
.read(true)
.open(&db_path)
.unwrap();

let (_event_subscriber_sender, event_subscriber_receiver) = flume::unbounded();
let (state_manager_sender, state_manager_receiver) = flume::unbounded();
let state_manager = AtomaStateManager::new_from_url(
"sqlite::memory:",
"postgres://atoma:atoma@localhost:5432/atoma",
event_subscriber_receiver,
state_manager_receiver,
)
Expand Down Expand Up @@ -151,7 +159,6 @@ mod middleware {
// but we need to return it so the tests can send events to the state manager
// otherwise the event subscriber will be dropped and the state manager shuts down
(
db_path,
state_manager_handle,
state_manager_sender,
shutdown_sender,
Expand All @@ -163,7 +170,6 @@ mod middleware {
AppState,
PublicKey,
Signature,
PathBuf,
tokio::sync::watch::Sender<bool>,
JoinHandle<()>,
Sender<AtomaEvent>,
Expand All @@ -182,13 +188,8 @@ mod middleware {
.sign_hashed(&keystore.addresses()[0], blake2b_hash.as_slice())
.expect("Failed to sign message");
let tokenizer = load_tokenizer().await;
let (
db_path,
state_manager_handle,
state_manager_sender,
shutdown_sender,
_event_subscriber_sender,
) = setup_database(public_key.clone()).await;
let (state_manager_handle, state_manager_sender, shutdown_sender, _event_subscriber_sender) =
setup_database(public_key.clone()).await;
(
AppState {
models: Arc::new(models.into_iter().map(|s| s.to_string()).collect()),
Expand All @@ -202,7 +203,6 @@ mod middleware {
},
public_key,
signature,
db_path,
shutdown_sender,
state_manager_handle,
_event_subscriber_sender,
Expand Down Expand Up @@ -238,7 +238,6 @@ mod middleware {
app_state,
_,
signature,
db_path,
shutdown_sender,
state_manager_handle,
_event_subscriber_sender,
Expand Down Expand Up @@ -279,21 +278,14 @@ mod middleware {

shutdown_sender.send(true).unwrap();
state_manager_handle.await.unwrap();
std::fs::remove_file(db_path).unwrap();
truncate_tables().await;
}

#[tokio::test]
#[serial]
async fn test_verify_stack_permissions_missing_public_key() {
let (
app_state,
_,
_,
db_path,
shutdown_sender,
state_manager_handle,
_event_subscriber_sender,
) = setup_app_state().await;
let (app_state, _, _, shutdown_sender, state_manager_handle, _event_subscriber_sender) =
setup_app_state().await;

let body = json!({
"model": "meta-llama/Llama-3.1-70B-Instruct",
Expand Down Expand Up @@ -321,7 +313,7 @@ mod middleware {
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
shutdown_sender.send(true).unwrap();
state_manager_handle.await.unwrap();
std::fs::remove_file(db_path).unwrap();
truncate_tables().await;
}

#[tokio::test]
Expand All @@ -331,7 +323,6 @@ mod middleware {
app_state,
_,
signature,
db_path,
shutdown_sender,
state_manager_handle,
_event_subscriber_sender,
Expand Down Expand Up @@ -363,7 +354,7 @@ mod middleware {
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
shutdown_sender.send(true).unwrap();
state_manager_handle.await.unwrap();
std::fs::remove_file(db_path).unwrap();
truncate_tables().await;
}

#[tokio::test]
Expand All @@ -373,7 +364,6 @@ mod middleware {
app_state,
_,
signature,
db_path,
shutdown_sender,
state_manager_handle,
_event_subscriber_sender,
Expand Down Expand Up @@ -405,7 +395,7 @@ mod middleware {
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
shutdown_sender.send(true).unwrap();
state_manager_handle.await.unwrap();
std::fs::remove_file(db_path).unwrap();
truncate_tables().await;
}

#[tokio::test]
Expand All @@ -415,7 +405,6 @@ mod middleware {
app_state,
_,
signature,
db_path,
shutdown_sender,
state_manager_handle,
_event_subscriber_sender,
Expand Down Expand Up @@ -450,21 +439,14 @@ mod middleware {
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
shutdown_sender.send(true).unwrap();
state_manager_handle.await.unwrap();
std::fs::remove_file(db_path).unwrap();
truncate_tables().await;
}

#[tokio::test]
#[serial]
async fn test_verify_stack_permissions_invalid_stack() {
let (
app_state,
_,
_,
db_path,
shutdown_sender,
state_manager_handle,
_event_subscriber_sender,
) = setup_app_state().await;
let (app_state, _, _, shutdown_sender, state_manager_handle, _event_subscriber_sender) =
setup_app_state().await;

let body = json!({
"model": "meta-llama/Llama-3.1-70B-Instruct",
Expand Down Expand Up @@ -492,7 +474,7 @@ mod middleware {
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
shutdown_sender.send(true).unwrap();
state_manager_handle.await.unwrap();
std::fs::remove_file(db_path).unwrap();
truncate_tables().await;
}

#[tokio::test]
Expand All @@ -502,7 +484,6 @@ mod middleware {
app_state,
_,
signature,
db_path,
shutdown_sender,
state_manager_handle,
_event_subscriber_sender,
Expand Down Expand Up @@ -551,7 +532,7 @@ mod middleware {
assert_eq!(response.status(), StatusCode::OK);
shutdown_sender.send(true).unwrap();
state_manager_handle.await.unwrap();
std::fs::remove_file(db_path).unwrap();
truncate_tables().await;
}

#[tokio::test]
Expand All @@ -561,7 +542,6 @@ mod middleware {
app_state,
_,
signature,
db_path,
shutdown_sender,
state_manager_handle,
_event_subscriber_sender,
Expand Down Expand Up @@ -621,7 +601,7 @@ mod middleware {
assert_eq!(response.status(), StatusCode::OK);
shutdown_sender.send(true).unwrap();
state_manager_handle.await.unwrap();
std::fs::remove_file(db_path).unwrap();
truncate_tables().await;
}

#[tokio::test]
Expand Down Expand Up @@ -909,7 +889,6 @@ mod middleware {
app_state,
_,
signature,
db_path,
shutdown_sender,
state_manager_handle,
_event_subscriber_sender,
Expand Down Expand Up @@ -975,7 +954,7 @@ mod middleware {

shutdown_sender.send(true).unwrap();
state_manager_handle.await.unwrap();
std::fs::remove_file(db_path).unwrap();
truncate_tables().await;
}

#[tokio::test]
Expand All @@ -985,7 +964,6 @@ mod middleware {
app_state,
_,
signature,
db_path,
shutdown_sender,
state_manager_handle,
_event_subscriber_sender,
Expand Down Expand Up @@ -1077,7 +1055,7 @@ mod middleware {

shutdown_sender.send(true).unwrap();
state_manager_handle.await.unwrap();
std::fs::remove_file(db_path).unwrap();
truncate_tables().await;
}

#[tokio::test]
Expand All @@ -1087,7 +1065,6 @@ mod middleware {
app_state,
_,
signature,
db_path,
shutdown_sender,
state_manager_handle,
_event_subscriber_sender,
Expand Down Expand Up @@ -1120,6 +1097,6 @@ mod middleware {

shutdown_sender.send(true).unwrap();
state_manager_handle.await.unwrap();
std::fs::remove_file(db_path).unwrap();
truncate_tables().await;
}
}
5 changes: 4 additions & 1 deletion atoma-state/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ config = { workspace = true }
flume = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
sqlx = { workspace = true, features = ["runtime-tokio-native-tls", "sqlite"] }
sqlx = { workspace = true, features = ["runtime-tokio", "postgres", "migrate"] }
thiserror = { workspace = true }
tokio = { workspace = true, features = ["full"] }
tracing = { workspace = true }

[dev-dependencies]
serial_test = { workspace = true }
Loading

0 comments on commit 19fe7aa

Please sign in to comment.