Skip to content

Commit

Permalink
checkpointing
Browse files Browse the repository at this point in the history
  • Loading branch information
nicarq committed Oct 16, 2023
1 parent 94b6aaf commit 9534cba
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 21 deletions.
6 changes: 6 additions & 0 deletions src/db/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ pub enum Topic {
MessageBoxSymmetricKeysTimes,
TempFilesInbox,
JobQueues,
ProxyIdentities,
MyProxy
}

impl Topic {
Expand Down Expand Up @@ -66,6 +68,8 @@ impl Topic {
Self::MessageBoxSymmetricKeysTimes => "message_box_symmetric_keys_times",
Self::TempFilesInbox => "temp_files_inbox",
Self::JobQueues => "job_queues",
Self::ProxyIdentities => "proxy_identities",
Self::MyProxy => "my_proxy"
}
}
}
Expand Down Expand Up @@ -158,6 +162,8 @@ impl ShinkaiDB {
Topic::MessageBoxSymmetricKeysTimes.as_str().to_string(),
Topic::TempFilesInbox.as_str().to_string(),
Topic::JobQueues.as_str().to_string(),
Topic::ProxyIdentities.as_str().to_string(),
Topic::MyProxy.as_str().to_string()
]
};

Expand Down
3 changes: 2 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ fn main() {
node_commands_receiver,
db_path,
node_env.first_device_needs_registration_code,
initial_agent
initial_agent,
None // TODO: Add a way to pass proxy settings from env
)
.await
}),
Expand Down
73 changes: 53 additions & 20 deletions src/network/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use shinkai_message_primitives::shinkai_utils::encryption::{
};
use shinkai_message_primitives::shinkai_utils::shinkai_logging::{shinkai_log, ShinkaiLogLevel, ShinkaiLogOption};
use shinkai_message_primitives::shinkai_utils::signatures::clone_signature_secret_key;
use std::collections::HashMap;
use std::sync::Arc;
use std::{io, net::SocketAddr, time::Duration};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
Expand Down Expand Up @@ -221,17 +222,40 @@ pub enum NodeCommand {
// A type alias for a string that represents a profile name.
type ProfileName = String;

pub enum NodeProxyMode {
// Node acts as a proxy, holds identities it proxies for
// and a flag indicating if it allows new identities
// if the flag is also then it will also clean up saved identities
IsProxy(ProxyMode),
// Node is being proxied, holds its proxy's identity
IsProxied(ProxyIdentity),
// Node is not using a proxy
NoProxy,
}

#[derive(Clone, Debug)]
pub struct ProxyMode {
// If it should be strict to the given identities
pub strict: bool,
// Flag indicating if new identities can be added
pub allow_new_identities: bool,
// Starting node identities
pub proxy_node_identities: Vec<String>
pub proxy_node_identities: HashMap<String, ProxyIdentity>,
}

#[derive(Clone, Debug)]
pub struct ProxyIdentity {
// Address of the API proxy
pub api_peer: SocketAddr,
// Address of the TCP proxy
pub tcp_peer: SocketAddr,
// Name of the proxied node
// Or the name of my identity proxied
pub shinkai_name: ShinkaiName,
}

// The `Node` struct represents a single node in the network.
pub struct Node {
// Is the node in proxy mode?
pub proxy_mode: Option<ProxyMode>,
// The mode of the node
pub proxy_mode: NodeProxyMode,
// The profile name of the node.
pub node_profile_name: ShinkaiName,
// The secret key used for signing operations.
Expand Down Expand Up @@ -275,7 +299,7 @@ impl Node {
db_path: String,
first_device_needs_registration_code: bool,
initial_agent: Option<SerializedAgent>,
proxy_mode: Option<ProxyMode>,
proxy_mode: NodeProxyMode,
) -> Node {
// if is_valid_node_identity_name_and_no_subidentities is false panic
match ShinkaiName::new(node_profile_name.to_string().clone()) {
Expand All @@ -286,7 +310,11 @@ impl Node {
let identity_public_key = SignaturePublicKey::from(&identity_secret_key);
let encryption_public_key = EncryptionPublicKey::from(&encryption_secret_key);
let db = ShinkaiDB::new(&db_path).unwrap_or_else(|e| {
eprintln!("Error: {:?}", e);
shinkai_log(
ShinkaiLogOption::Node,
ShinkaiLogLevel::Error,
&format!("Failed to open database: {}", db_path).as_str(),
);
panic!("Failed to open database: {}", db_path)
});
let db_arc = Arc::new(Mutex::new(db));
Expand Down Expand Up @@ -724,7 +752,8 @@ impl Node {
shinkai_log(
ShinkaiLogOption::Node,
ShinkaiLogLevel::Debug,
&format!("save_to_db> message_to_save: {:?}", message_to_save.clone()));
&format!("save_to_db> message_to_save: {:?}", message_to_save.clone()),
);
let mut db = db.lock().await;
let db_result = db.unsafe_insert_inbox_message(&message_to_save);
match db_result {
Expand Down Expand Up @@ -753,13 +782,12 @@ impl Node {
maybe_db: Arc<Mutex<ShinkaiDB>>,
maybe_identity_manager: Arc<Mutex<IdentityManager>>,
) -> Result<(), NodeError> {
// TODO: it should check mode we are in and handle accordingly

shinkai_log(
ShinkaiLogOption::Node,
ShinkaiLogLevel::Info,
&format!(
"{} > Got message from {:?}",
receiver_address, unsafe_sender_address
),
&format!("{} > Got message from {:?}", receiver_address, unsafe_sender_address),
);

// Extract and validate the message
Expand All @@ -786,18 +814,29 @@ impl Node {
shinkai_log(
ShinkaiLogOption::Node,
ShinkaiLogLevel::Debug,
&format!("{} > Sender Profile Name: {:?}", receiver_address, sender_profile_name_string),
&format!(
"{} > Sender Profile Name: {:?}",
receiver_address, sender_profile_name_string
),
);
shinkai_log(
ShinkaiLogOption::Node,
ShinkaiLogLevel::Debug,
&format!("{} > Node Sender Identity: {}", receiver_address, sender_identity));
&format!("{} > Node Sender Identity: {}", receiver_address, sender_identity),
);
shinkai_log(
ShinkaiLogOption::Node,
ShinkaiLogLevel::Debug,
&format!("{} > Verified message signature", receiver_address),
);
shinkai_log(
ShinkaiLogOption::Node,
ShinkaiLogLevel::Debug,
&format!("{} > Sender Identity: {}", receiver_address, sender_identity),
);

// TODO(Nico): split this part depending on Proxy Mode

// Save to db
{
Node::save_to_db(
Expand All @@ -810,12 +849,6 @@ impl Node {
.await?;
}

shinkai_log(
ShinkaiLogOption::Node,
ShinkaiLogLevel::Debug,
&format!("{} > Sender Identity: {}", receiver_address, sender_identity),
);

handle_based_on_message_content_and_encryption(
message.clone(),
sender_identity.node_encryption_public_key,
Expand Down
1 change: 1 addition & 0 deletions tests/agent_integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ fn node_agent_registration() {
node1_db_path,
true,
Some(agent),
None,
);

let node1_handler = tokio::spawn(async move {
Expand Down
2 changes: 2 additions & 0 deletions tests/node_integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ fn subidentity_registration() {
node1_commands_receiver,
node1_db_path,
true,
None,
None
);

Expand All @@ -110,6 +111,7 @@ fn subidentity_registration() {
node2_commands_receiver,
node2_db_path,
true,
None,
None
);

Expand Down
2 changes: 2 additions & 0 deletions tests/node_retrying_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ fn node_retrying_test() {
node1_commands_receiver,
node1_db_path,
true,
None,
None
);

Expand All @@ -104,6 +105,7 @@ fn node_retrying_test() {
node2_commands_receiver,
node2_db_path,
true,
None,
None
);

Expand Down
1 change: 1 addition & 0 deletions tests/utils/test_boilerplate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ where
node1_commands_receiver.clone(),
node1_db_path,
false,
None,
None
);

Expand Down

0 comments on commit 9534cba

Please sign in to comment.