Skip to content

Commit

Permalink
Merge pull request #557 from dcSpark/nico/add_update_jobscope_
Browse files Browse the repository at this point in the history
add endpoint to update jobscope in a job
  • Loading branch information
nicarq authored Sep 15, 2024
2 parents 3053cbc + a3706f2 commit 47ffe3b
Show file tree
Hide file tree
Showing 6 changed files with 265 additions and 7 deletions.
4 changes: 4 additions & 0 deletions shinkai-bin/shinkai-node/src/db/db_inbox.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use shinkai_message_primitives::{
};
use tokio::sync::Mutex;

use crate::llm_provider::job::JobConfig;
use crate::network::ws_manager::WSMessageType;
use crate::network::ws_manager::WSUpdateHandler;
use crate::schemas::smart_inbox::LLMProviderSubset;
Expand Down Expand Up @@ -506,6 +507,7 @@ impl ShinkaiDB {

let mut job_scope_value: Option<Value> = None;
let mut datetime_created = String::new();
let mut job_config_value: Option<JobConfig> = None;

// Determine if the inbox is finished
let is_finished = if inbox_id.starts_with("job_inbox::") {
Expand All @@ -514,6 +516,7 @@ impl ShinkaiDB {
let job = self.get_job(&unique_id)?;
let scope_value = job.scope.to_json_value_minimal()?;
job_scope_value = Some(scope_value);
job_config_value = job.config;
datetime_created.clone_from(&job.datetime_created);
job.is_finished || job.is_hidden
}
Expand Down Expand Up @@ -555,6 +558,7 @@ impl ShinkaiDB {
is_finished,
job_scope: job_scope_value,
agent: agent_subset,
job_config: job_config_value,
};

smart_inboxes.push(smart_inbox);
Expand Down
21 changes: 21 additions & 0 deletions shinkai-bin/shinkai-node/src/network/handle_commands_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2941,6 +2941,27 @@ impl Node {
.await;
});
}
NodeCommand::V2ApiUpdateJobScope {
bearer,
job_id,
job_scope,
res,
} => {
let db_clone = Arc::clone(&self.db);
tokio::spawn(async move {
let _ = Node::v2_api_update_job_scope(db_clone, bearer, job_id, job_scope, res).await;
});
}
NodeCommand::V2ApiGetJobScope {
bearer,
job_id,
res,
} => {
let db_clone = Arc::clone(&self.db);
tokio::spawn(async move {
let _ = Node::v2_api_get_job_scope(db_clone, bearer, job_id, res).await;
});
}
_ => (),
}
}
Expand Down
13 changes: 12 additions & 1 deletion shinkai-bin/shinkai-node/src/network/node_commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use shinkai_message_primitives::{
APIVecFsSearchItems, APIWorkflowKeyname, IdentityPermissions, JobCreationInfo, JobMessage,
RegistrationCodeType, V2ChatMessage,
},
},
}, shinkai_utils::job_scope::JobScope,
};

use crate::{
Expand Down Expand Up @@ -926,4 +926,15 @@ pub enum NodeCommand {
message_id: String,
res: Sender<Result<SendResponseBodyData, APIError>>,
},
V2ApiUpdateJobScope {
bearer: String,
job_id: String,
job_scope: JobScope,
res: Sender<Result<Value, APIError>>,
},
V2ApiGetJobScope {
bearer: String,
job_id: String,
res: Sender<Result<Value, APIError>>,
},
}
103 changes: 102 additions & 1 deletion shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_jobs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use shinkai_message_primitives::{
shinkai_message::shinkai_message_schemas::{
APIChangeJobAgentRequest, JobCreationInfo, JobMessage, MessageSchemaType, V2ChatMessage,
},
shinkai_utils::job_scope::JobScope,
};

use tokio::sync::Mutex;
Expand Down Expand Up @@ -54,6 +55,7 @@ impl Node {
is_finished: smart_inbox.is_finished,
job_scope: smart_inbox.job_scope,
agent: smart_inbox.agent,
job_config: smart_inbox.job_config,
})
}

Expand Down Expand Up @@ -920,7 +922,7 @@ impl Node {
let shinkai_message = match Self::api_v2_create_shinkai_message(
sender,
recipient,
&serde_json::to_string(&job_message).unwrap(),
&serde_json::to_string(&job_message).unwrap(),
MessageSchemaType::JobMessageSchema,
node_encryption_sk,
node_signing_sk,
Expand Down Expand Up @@ -976,4 +978,103 @@ impl Node {
}
}
}

pub async fn v2_api_update_job_scope(
db: Arc<ShinkaiDB>,
bearer: String,
job_id: String,
job_scope: JobScope,
res: Sender<Result<Value, APIError>>,
) -> Result<(), NodeError> {
// Validate the bearer token
if Self::validate_bearer_token(&bearer, db.clone(), &res).await.is_err() {
return Ok(());
}

// Check if the job exists
match db.get_job(&job_id) {
Ok(_) => {
// Job exists, proceed with updating the job scope
match db.update_job_scope(job_id.clone(), job_scope.clone()) {
Ok(_) => {
match serde_json::to_value(&job_scope) {
Ok(job_scope_value) => {
let _ = res.send(Ok(job_scope_value)).await;
}
Err(err) => {
let api_error = APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Internal Server Error".to_string(),
message: format!("Failed to serialize job scope: {}", err),
};
let _ = res.send(Err(api_error)).await;
}
}
Ok(())
}
Err(err) => {
let api_error = APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Internal Server Error".to_string(),
message: format!("Failed to update job scope: {}", err),
};
let _ = res.send(Err(api_error)).await;
Ok(())
}
}
}
Err(_) => {
let api_error = APIError {
code: StatusCode::NOT_FOUND.as_u16(),
error: "Not Found".to_string(),
message: format!("Job with ID {} not found", job_id),
};
let _ = res.send(Err(api_error)).await;
Ok(())
}
}
}

pub async fn v2_api_get_job_scope(
db: Arc<ShinkaiDB>,
bearer: String,
job_id: String,
res: Sender<Result<Value, APIError>>,
) -> Result<(), NodeError> {
// Validate the bearer token
if Self::validate_bearer_token(&bearer, db.clone(), &res).await.is_err() {
return Ok(());
}

// Check if the job exists
match db.get_job(&job_id) {
Ok(job) => {
// Job exists, proceed with getting the job scope
let job_scope = job.scope();
match serde_json::to_value(&job_scope) {
Ok(job_scope_value) => {
let _ = res.send(Ok(job_scope_value)).await;
}
Err(err) => {
let api_error = APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Internal Server Error".to_string(),
message: format!("Failed to serialize job scope: {}", err),
};
let _ = res.send(Err(api_error)).await;
}
}
Ok(())
}
Err(_) => {
let api_error = APIError {
code: StatusCode::NOT_FOUND.as_u16(),
error: "Not Found".to_string(),
message: format!("Job with ID {} not found", job_id),
};
let _ = res.send(Err(api_error)).await;
Ok(())
}
}
}
}
117 changes: 114 additions & 3 deletions shinkai-bin/shinkai-node/src/network/v2_api/api_v2_handlers_jobs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ use futures::StreamExt;
use reqwest::StatusCode;
use serde::Deserialize;
use serde_json::json;
use shinkai_message_primitives::shinkai_message::shinkai_message_schemas::{
APIChangeJobAgentRequest, JobCreationInfo, JobMessage,
use shinkai_message_primitives::{
shinkai_message::shinkai_message_schemas::{APIChangeJobAgentRequest, JobCreationInfo, JobMessage},
shinkai_utils::job_scope::JobScope,
};
use utoipa::OpenApi;
use warp::multipart::FormData;
Expand Down Expand Up @@ -113,6 +114,20 @@ pub fn job_routes(
.and(warp::body::json())
.and_then(retry_message_handler);

let update_job_scope_route = warp::path("update_job_scope")
.and(warp::post())
.and(with_sender(node_commands_sender.clone()))
.and(warp::header::<String>("authorization"))
.and(warp::body::json())
.and_then(update_job_scope_handler);

let get_job_scope_route = warp::path("get_job_scope")
.and(warp::get())
.and(with_sender(node_commands_sender.clone()))
.and(warp::header::<String>("authorization"))
.and(warp::query::<GetJobScopeRequest>())
.and_then(get_job_scope_handler);

create_job_route
.or(job_message_route)
.or(get_last_messages_route)
Expand All @@ -126,6 +141,8 @@ pub fn job_routes(
.or(get_last_messages_with_branches_route)
.or(get_job_config_route)
.or(retry_message_route)
.or(update_job_scope_route)
.or(get_job_scope_route)
}

#[derive(Deserialize)]
Expand Down Expand Up @@ -798,6 +815,98 @@ pub async fn get_job_config_handler(
}
}

#[derive(Deserialize)]
pub struct UpdateJobScopeRequest {
pub job_id: String,
pub job_scope: JobScope,
}

#[utoipa::path(
post,
path = "/v2/update_job_scope",
request_body = UpdateJobScopeRequest,
responses(
(status = 200, description = "Successfully updated job scope", body = Value),
(status = 400, description = "Bad request", body = APIError),
(status = 500, description = "Internal server error", body = APIError)
)
)]
pub async fn update_job_scope_handler(
node_commands_sender: Sender<NodeCommand>,
authorization: String,
payload: UpdateJobScopeRequest,
) -> Result<impl warp::Reply, warp::Rejection> {
let bearer = authorization.strip_prefix("Bearer ").unwrap_or("").to_string();
let (res_sender, res_receiver) = async_channel::bounded(1);
node_commands_sender
.send(NodeCommand::V2ApiUpdateJobScope {
bearer,
job_id: payload.job_id,
job_scope: payload.job_scope,
res: res_sender,
})
.await
.map_err(|_| warp::reject::reject())?;
let result = res_receiver.recv().await.map_err(|_| warp::reject::reject())?;

match result {
Ok(_) => {
let response = create_success_response(json!({ "result": "Job scope updated successfully" }));
Ok(warp::reply::with_status(warp::reply::json(&response), StatusCode::OK))
}
Err(error) => Ok(warp::reply::with_status(
warp::reply::json(&error),
StatusCode::from_u16(error.code).unwrap(),
)),
}
}

#[derive(Deserialize)]
pub struct GetJobScopeRequest {
pub job_id: String,
}

#[utoipa::path(
get,
path = "/v2/get_job_scope",
params(
("job_id" = String, Query, description = "Job ID to retrieve scope for")
),
responses(
(status = 200, description = "Successfully retrieved job scope", body = JobScope),
(status = 400, description = "Bad request", body = APIError),
(status = 500, description = "Internal server error", body = APIError)
)
)]
pub async fn get_job_scope_handler(
node_commands_sender: Sender<NodeCommand>,
authorization: String,
query: GetJobScopeRequest,
) -> Result<impl warp::Reply, warp::Rejection> {
let bearer = authorization.strip_prefix("Bearer ").unwrap_or("").to_string();
let (res_sender, res_receiver) = async_channel::bounded(1);
node_commands_sender
.send(NodeCommand::V2ApiGetJobScope {
bearer,
job_id: query.job_id,
res: res_sender,
})
.await
.map_err(|_| warp::reject::reject())?;
let result = res_receiver.recv().await.map_err(|_| warp::reject::reject())?;

match result {
Ok(response) => {
let response = create_success_response(response);
Ok(warp::reply::with_status(warp::reply::json(&response), StatusCode::OK))
}
Err(error) => Ok(warp::reply::with_status(
warp::reply::json(&error),
StatusCode::from_u16(error.code).unwrap(),
)),
}
}

#[derive(OpenApi)]
#[openapi(
paths(
Expand All @@ -813,7 +922,9 @@ pub async fn get_job_config_handler(
get_last_messages_with_branches_handler,
update_job_config_handler,
get_job_config_handler,
retry_message_handler
retry_message_handler,
update_job_scope_handler,
get_job_scope_handler
),
components(
schemas(SendResponseBody, SendResponseBodyData, APIError)
Expand Down
14 changes: 12 additions & 2 deletions shinkai-bin/shinkai-node/src/schemas/smart_inbox.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
use serde::{Deserialize, Serialize};
use serde_json::Value;
use shinkai_message_primitives::{schemas::{llm_providers::serialized_llm_provider::{LLMProviderInterface, SerializedLLMProvider}, shinkai_name::ShinkaiName}, shinkai_message::{shinkai_message::ShinkaiMessage, shinkai_message_schemas::V2ChatMessage}};
use shinkai_message_primitives::{
schemas::{
llm_providers::serialized_llm_provider::{LLMProviderInterface, SerializedLLMProvider},
shinkai_name::ShinkaiName,
},
shinkai_message::{shinkai_message::ShinkaiMessage, shinkai_message_schemas::V2ChatMessage},
};

use crate::llm_provider::job::JobConfig;

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct LLMProviderSubset {
Expand Down Expand Up @@ -28,6 +36,7 @@ pub struct SmartInbox {
pub is_finished: bool,
pub job_scope: Option<Value>,
pub agent: Option<LLMProviderSubset>,
pub job_config: Option<JobConfig>,
}

#[derive(Serialize, Deserialize, Debug, Clone)]
Expand All @@ -37,6 +46,7 @@ pub struct V2SmartInbox {
pub datetime_created: String,
pub last_message: Option<V2ChatMessage>,
pub is_finished: bool,
pub job_scope: Option<Value>,
pub agent: Option<LLMProviderSubset>,
pub job_scope: Option<Value>,
pub job_config: Option<JobConfig>,
}

0 comments on commit 47ffe3b

Please sign in to comment.