diff --git a/gremlin-client/src/aio/client.rs b/gremlin-client/src/aio/client.rs index ad3c52f2..c7dfa1c8 100644 --- a/gremlin-client/src/aio/client.rs +++ b/gremlin-client/src/aio/client.rs @@ -10,6 +10,7 @@ use crate::ToGValue; use crate::{ConnectionOptions, GremlinError, GremlinResult}; use base64::encode; use futures::future::{BoxFuture, FutureExt}; +use futures::StreamExt; use mobc::{Connection, Pool}; use serde::Serialize; use std::collections::{HashMap, VecDeque}; @@ -162,8 +163,22 @@ impl GremlinClient { let payload = String::from("") + content_type + &message; let mut binary = payload.into_bytes(); binary.insert(0, content_type.len() as u8); - - let (response, receiver) = conn.send(id, binary).await?; + let mut receiver = conn.send(id, binary).await?; + let response = receiver + .next() + .await + .expect("It should contain the response")?; + //Prepare holding onto the connection for an auth challenge if we have credentials + //Tinkerpop performs authentication at the channel level, and if we let it go, + //a healthcheck may disrupt the challenge + //Otherwise drop the connection so it can be multiplexed + let retained_auth_context = match self.options.credentials.as_ref() { + None => { + drop(conn); + None + } + Some(credentials) => Some((credentials, conn)), + }; let (response, results) = match response.status.code { 200 | 206 => { @@ -176,8 +191,8 @@ impl GremlinClient { Ok((response, results)) } 204 => Ok((response, VecDeque::new())), - 407 => match &self.options.credentials { - Some(c) => { + 407 => match retained_auth_context { + Some((c, conn)) => { let mut args = HashMap::new(); args.insert( diff --git a/gremlin-client/src/aio/connection.rs b/gremlin-client/src/aio/connection.rs index b560d28f..23af3b10 100644 --- a/gremlin-client/src/aio/connection.rs +++ b/gremlin-client/src/aio/connection.rs @@ -43,6 +43,7 @@ use futures::{ use futures::channel::mpsc::{channel, Receiver, Sender}; use std::collections::HashMap; +use std::sync::atomic::{self, AtomicBool}; use std::sync::Arc; use url; use uuid::Uuid; @@ -64,7 +65,7 @@ pub enum Cmd { pub(crate) struct Conn { sender: Sender, - valid: bool, + valid: Arc, connection_uuid: Uuid, } @@ -170,16 +171,19 @@ impl Conn { sender_loop(connection_uuid.clone(), sink, requests.clone(), receiver); + let valid_flag = Arc::new(AtomicBool::new(true)); + receiver_loop( connection_uuid.clone(), stream, requests.clone(), sender.clone(), + valid_flag.clone(), ); Ok(Conn { sender, - valid: true, + valid: valid_flag, connection_uuid, }) } @@ -188,8 +192,8 @@ impl Conn { &mut self, id: Uuid, payload: Vec, - ) -> GremlinResult<(Response, Receiver>)> { - let (sender, mut receiver) = channel(1); + ) -> GremlinResult>> { + let (sender, receiver) = channel(1); self.sender .send(Cmd::Msg((sender, id, payload))) @@ -199,35 +203,14 @@ impl Conn { "{} Marking websocket connection invalid on send error", self.connection_uuid ); - self.valid = false; - e - })?; - - receiver - .next() - .await - .expect("It should contain the response") - .map(|r| (r, receiver)) - .map_err(|e| { - //If there's been an websocket layer error, mark the connection as invalid - match e { - GremlinError::WebSocket(_) - | GremlinError::WebSocketAsync(_) - | GremlinError::WebSocketPoolAsync(_) => { - error!( - "{} Marking websocket connection invalid on received error", - self.connection_uuid - ); - self.valid = false; - } - _ => {} - } - e + self.valid.store(false, atomic::Ordering::Release); + GremlinError::from(e) }) + .map(|_| receiver) } pub fn is_valid(&self) -> bool { - self.valid + self.valid.load(atomic::Ordering::Acquire) } } @@ -300,11 +283,14 @@ fn receiver_loop( mut stream: SplitStream, requests: Arc>>>>, mut sender: Sender, + connection_valid_flag: Arc, ) { task::spawn(async move { loop { match stream.next().await { Some(Err(error)) => { + //If there's been an websocket layer error, mark the connection as invalid + connection_valid_flag.store(false, atomic::Ordering::Release); let mut guard = requests.lock().await; let error = Arc::new(error); error!("{connection_uuid} Receiver loop error"); @@ -316,36 +302,38 @@ fn receiver_loop( } guard.clear(); } - Some(Ok(item)) => match item { - Message::Binary(data) => { - let response: Response = serde_json::from_slice(&data).unwrap(); - let mut guard = requests.lock().await; - if response.status.code != 206 { - let item = guard.remove(&response.request_id); - drop(guard); - if let Some(mut s) = item { - match s.send(Ok(response)).await { - Ok(_r) => {} - Err(_e) => {} - }; - } - } else { - let item = guard.get_mut(&response.request_id); - if let Some(s) = item { - match s.send(Ok(response)).await { - Ok(_r) => {} - Err(_e) => {} - }; + Some(Ok(item)) => { + match item { + Message::Binary(data) => { + let response: Response = serde_json::from_slice(&data).unwrap(); + let mut guard = requests.lock().await; + if response.status.code != 206 { + let item = guard.remove(&response.request_id); + drop(guard); + if let Some(mut s) = item { + match s.send(Ok(response)).await { + Ok(_r) => {} + Err(_e) => {} + }; + } + } else { + let item = guard.get_mut(&response.request_id); + if let Some(s) = item { + match s.send(Ok(response)).await { + Ok(_r) => {} + Err(_e) => {} + }; + } + drop(guard); } - drop(guard); } + Message::Ping(data) => { + info!("{connection_uuid} Received Ping"); + let _ = sender.send(Cmd::Pong(data)).await; + } + _ => {} } - Message::Ping(data) => { - info!("{connection_uuid} Received Ping"); - let _ = sender.send(Cmd::Pong(data)).await; - } - _ => {} - }, + } None => { warn!("{connection_uuid} Receiver loop breaking"); break; diff --git a/gremlin-client/src/aio/pool.rs b/gremlin-client/src/aio/pool.rs index 3d2ecfca..a9050ae0 100644 --- a/gremlin-client/src/aio/pool.rs +++ b/gremlin-client/src/aio/pool.rs @@ -1,3 +1,4 @@ +use futures::StreamExt; use mobc::Manager; use crate::aio::connection::Conn; @@ -56,7 +57,12 @@ impl Manager for GremlinConnectionManager { let mut binary = payload.into_bytes(); binary.insert(0, content_type.len() as u8); - let (response, _receiver) = conn.send(id, binary).await?; + let response = conn + .send(id, binary) + .await? + .next() + .await + .expect("Should have received response")?; match response.status.code { 200 | 206 => Ok(conn), @@ -87,7 +93,12 @@ impl Manager for GremlinConnectionManager { let mut binary = payload.into_bytes(); binary.insert(0, content_type.len() as u8); - let (response, _receiver) = conn.send(id, binary).await?; + let response = conn + .send(id, binary) + .await? + .next() + .await + .expect("Should have received response")?; match response.status.code { 200 | 206 => Ok(conn),