Skip to content

Commit

Permalink
🎨 Add OAI-DID
Browse files Browse the repository at this point in the history
  • Loading branch information
luoshuijs committed Mar 26, 2024
1 parent 8438e6d commit f9f1a91
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 1 deletion.
1 change: 1 addition & 0 deletions crates/openai/src/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub(crate) const API_AUTH_SESSION_COOKIE_KEY: &str = "__Secure-next-auth.session

/// Serve
pub(crate) const PUID: &str = "_puid";
pub(crate) const OAODID: &str = "oai-did";
pub(crate) const CF_CLEARANCE: &str = "cf_clearance";
pub(crate) const MODEL: &str = "model";
pub(crate) const ARKOSE_TOKEN: &str = "arkose_token";
Expand Down
1 change: 1 addition & 0 deletions crates/openai/src/serve/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ mod router;
mod signal;
mod turnstile;
mod whitelist;
mod oaidid;

use self::proxy::ext::RequestExt;
use self::proxy::ext::SendRequestExt;
Expand Down
56 changes: 56 additions & 0 deletions crates/openai/src/serve/oaidid.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
use super::error::{ProxyError, ResponseError};
use crate::{gpt_model::GPTModel, with_context, URL_CHATGPT_API};
use moka::sync::Cache;
use std::str::FromStr;
use tokio::sync::OnceCell;

static OAIDID_CACHE: OnceCell<Cache<String, String>> = OnceCell::const_new();

pub(super) fn reduce_key(token: &str) -> Result<String, ResponseError> {
let token_profile = crate::token::check(token)
.map_err(ResponseError::Unauthorized)?
.ok_or(ResponseError::BadRequest(ProxyError::InvalidAccessToken))?;
Ok(token_profile.email().to_owned())
}

async fn cache() -> &'static Cache<String, String> {
OAIDID_CACHE
.get_or_init(|| async {
Cache::builder()
.time_to_live(std::time::Duration::from_secs(3600 * 24))
.build()
})
.await
}

pub(super) async fn get_or_init(
token: &str,
model: &str,
cache_id: String,
) -> Result<Option<String>, ResponseError> {
let token = token.trim_start_matches("Bearer ");
let oaidid_cache = cache().await;

if let Some(p) = oaidid_cache.get(&cache_id) {
return Ok(Some(p.clone()));
}

if GPTModel::from_str(model)?.is_gpt4() {
let resp = with_context!(api_client)
.get(format!("{URL_CHATGPT_API}/backend-api/models"))
.bearer_auth(token)
.send()
.await
.map_err(ResponseError::InternalServerError)?
.error_for_status()
.map_err(ResponseError::BadRequest)?;

if let Some(c) = resp.cookies().into_iter().find(|c| c.name().eq("oai-did")) {
let oaidid = c.value().to_owned();
oaidid_cache.insert(cache_id, oaidid.clone());
return Ok(Some(oaidid));
};
}

Ok(None)
}
39 changes: 38 additions & 1 deletion crates/openai/src/serve/proxy/req.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use http::{HeaderMap, Method};
use serde_json::{json, Value};

use crate::arkose::{ArkoseContext, ArkoseToken, Type};
use crate::constant::{ARKOSE_TOKEN, EMPTY, MODEL, NULL, PUID};
use crate::constant::{ARKOSE_TOKEN, EMPTY, MODEL, NULL, PUID, OAODID};
use crate::gpt_model::GPTModel;
use crate::{arkose, with_context};

Expand All @@ -19,6 +19,7 @@ use super::header_convert;
use super::toapi;
use crate::serve::error::{ProxyError, ResponseError};
use crate::serve::puid::{get_or_init, reduce_key};
use crate::serve::oaidid::{get_or_init as get_or_init_oaidid, reduce_key as reduce_key_oaidid};

#[async_trait]
impl SendRequestExt for reqwest::Client {
Expand Down Expand Up @@ -71,6 +72,15 @@ pub(super) fn has_puid(headers: &HeaderMap) -> Result<bool, ResponseError> {
}
}

pub(super) fn has_oaodid(headers: &HeaderMap) -> Result<bool, ResponseError> {
if let Some(hv) = headers.get(header::COOKIE) {
let cookie_str = hv.to_str().map_err(ResponseError::BadRequest)?;
Ok(cookie_str.contains(OAODID))
} else {
Ok(false)
}
}

/// Handle conversation request
async fn handle_conv_request(req: &mut RequestExt) -> Result<(), ResponseError> {
// Only handle POST request
Expand Down Expand Up @@ -119,6 +129,33 @@ async fn handle_conv_request(req: &mut RequestExt) -> Result<(), ResponseError>
}
}

// If puid is exist, then return
if !has_oaodid(&req.headers)? {
// Exstract the token from the Authorization header
let cache_id = reduce_key_oaidid(&token)?;

// Get or init puid
let oaidid = get_or_init_oaidid(&token, model, cache_id).await?;

if let Some(oaidid) = oaidid {
req.headers.insert(
header::COOKIE,
header::HeaderValue::from_str(&format!("{OAODID}={oaidid};"))
.map_err(ResponseError::BadRequest)?,
);
req.headers.insert(
HeaderName::from_str("Oai-Device-Id").map_err(ResponseError::BadRequest)?,
header::HeaderValue::from_str(&format!("{OAODID}={oaidid};"))
.map_err(ResponseError::BadRequest)?,
);
req.headers.insert(
HeaderName::from_str("Oai-Language").map_err(ResponseError::BadRequest)?,
header::HeaderValue::from_str(&format!("zh-Hans"))
.map_err(ResponseError::BadRequest)?,
);
}
}

// Parse model
let model = GPTModel::from_str(model).map_err(ResponseError::BadRequest)?;

Expand Down
6 changes: 6 additions & 0 deletions crates/openai/src/serve/proxy/toapi/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ use crate::{
},
uuid::uuid,
};
use crate::serve::oaidid::{get_or_init as get_or_init_oaidid, reduce_key as reduce_key_oaidid};

use super::ext::{Context, RequestExt, ResponseExt};
use super::header_convert;
Expand Down Expand Up @@ -146,6 +147,11 @@ pub(super) async fn send_request(req: RequestExt) -> Result<ResponseExt, Respons
builder = builder.header(header::COOKIE, format!("_puid={puid};"))
}

let oaidid = get_or_init_oaidid(baerer, &body.model, cache_id).await?;
if let Some(oaidid) = oaidid {
builder = builder.header(header::COOKIE, format!("oai-did={oaidid};"))
}

// Send request
let resp = builder
.json(&req_body)
Expand Down

0 comments on commit f9f1a91

Please sign in to comment.