diff --git a/Cargo.toml b/Cargo.toml index c279d3e..93af6f5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,7 @@ members = [ "norpc", "norpc-macros", "example/hello-world", + "example/hello-world-no-send", "example/kvstore", "example/concurrent-message", "example/rate-limit", diff --git a/example/hello-world-no-send/Cargo.toml b/example/hello-world-no-send/Cargo.toml new file mode 100644 index 0000000..bfd21fc --- /dev/null +++ b/example/hello-world-no-send/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "hello-world-no-send" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +tokio = { version = "*", features = ["full"] } +tower = { version = "*", features = ["full"] } + +norpc = { path = "../../norpc" } \ No newline at end of file diff --git a/example/hello-world-no-send/tests/hello_world_no_send.rs b/example/hello-world-no-send/tests/hello_world_no_send.rs new file mode 100644 index 0000000..9bc1ebb --- /dev/null +++ b/example/hello-world-no-send/tests/hello_world_no_send.rs @@ -0,0 +1,38 @@ +use std::rc::Rc; +use tokio::sync::mpsc; +use tower::Service; + +#[norpc::service(?Send)] +trait HelloWorld { + // Rc is !Send + fn hello(s: Rc) -> Rc; +} + +#[derive(Clone)] +struct HelloWorldApp; +#[norpc::async_trait(?Send)] +impl HelloWorld for HelloWorldApp { + async fn hello(self, s: Rc) -> Rc { + format!("Hello, {}", s).into() + } +} +#[tokio::test(flavor = "multi_thread")] +async fn test_hello_world_no_send() { + let local = tokio::task::LocalSet::new(); + let (tx, rx) = mpsc::unbounded_channel(); + local.spawn_local(async move { + let app = HelloWorldApp; + let service = HelloWorldService::new(app); + let server = norpc::no_send::ServerChannel::new(rx, service); + server.serve().await + }); + local.spawn_local(async move { + let chan = norpc::no_send::ClientChannel::new(tx); + let mut cli = HelloWorldClient::new(chan); + assert_eq!( + cli.hello("World".to_owned().into()).await.unwrap(), + "Hello, World".to_string().into() + ); + }); + local.await; +} diff --git a/norpc-macros/src/generator.rs b/norpc-macros/src/generator.rs index 6ef0611..2576c73 100644 --- a/norpc-macros/src/generator.rs +++ b/norpc-macros/src/generator.rs @@ -76,13 +76,14 @@ impl Generator { } format!( " - #[norpc::async_trait] + #[norpc::async_trait{no_send}] pub trait {svc_name}: Clone {{ {} }} ", itertools::join(methods, ""), svc_name = svc.name, + no_send = if self.no_send { "(?Send)" } else { "" }, ) } fn generate_client_impl(&self, svc: &Service) -> String { @@ -164,10 +165,10 @@ impl Generator { Self {{ app }} }} }} - impl tower::Service<{svc_name}Request> for {svc_name}Service {{ + impl tower::Service<{svc_name}Request> for {svc_name}Service {{ type Response = {svc_name}Response; type Error = (); - type Future = std::pin::Pin> + Send>>; + type Future = std::pin::Pin> {no_send}>>; fn poll_ready( &mut self, _: &mut std::task::Context<'_>, @@ -186,6 +187,7 @@ impl Generator { ", itertools::join(match_arms, ","), svc_name = svc.name, + no_send = if self.no_send { "" } else { "+ Send" }, ) } pub(super) fn generate(&self, svc: Service) -> String { diff --git a/norpc-macros/src/lib.rs b/norpc-macros/src/lib.rs index 0634964..3a6f748 100644 --- a/norpc-macros/src/lib.rs +++ b/norpc-macros/src/lib.rs @@ -1,16 +1,43 @@ use proc_macro::TokenStream; use quote::quote; use std::str::FromStr; +use syn::parse::{Parse, ParseStream, Result}; use syn::*; mod generator; +struct Args { + local: bool, +} + +mod kw { + syn::custom_keyword!(Send); +} + +fn try_parse(input: ParseStream) -> Result { + if input.peek(Token![?]) { + input.parse::()?; + input.parse::()?; + Ok(Args { local: true }) + } else { + Ok(Args { local: false }) + } +} + +impl Parse for Args { + fn parse(input: ParseStream) -> Result { + let args: Args = try_parse(input)?; + Ok(args) + } +} + #[proc_macro_attribute] -pub fn service(_: TokenStream, item: TokenStream) -> TokenStream { +pub fn service(args: TokenStream, item: TokenStream) -> TokenStream { + let args = parse_macro_input!(args as Args); let t = syn::parse::(item).unwrap(); let svc = parse_service(&t); let generator = generator::Generator { - no_send: false, + no_send: args.local, }; let code = generator.generate(svc); TokenStream::from_str(&code).unwrap() @@ -80,10 +107,10 @@ fn parse_func(f: &TraitItem) -> Function { match &sig.output { ReturnType::Type(_, x) => { output_ty = quote!(#x).to_string(); - }, + } ReturnType::Default => { output_ty = "()".to_string(); - }, + } } Function { name: func_name, diff --git a/norpc/src/lib.rs b/norpc/src/lib.rs index 654e4c3..fd4bb4f 100644 --- a/norpc/src/lib.rs +++ b/norpc/src/lib.rs @@ -2,6 +2,8 @@ use tokio::sync::mpsc; use tokio::sync::oneshot; use tower_service::Service; +pub mod no_send; + // Re-exported for compiler pub use async_trait::async_trait; pub use futures::future::poll_fn; diff --git a/norpc/src/no_send.rs b/norpc/src/no_send.rs new file mode 100644 index 0000000..d40e5f3 --- /dev/null +++ b/norpc/src/no_send.rs @@ -0,0 +1,73 @@ +use tokio::sync::mpsc; +use tokio::sync::oneshot; +use tower_service::Service; + +use crate::{Error, Request}; + +/// mpsc channel wrapper on the client-side. +pub struct ClientChannel { + tx: mpsc::UnboundedSender>, +} +impl ClientChannel { + pub fn new(tx: mpsc::UnboundedSender>) -> Self { + Self { tx } + } +} +impl Clone for ClientChannel { + fn clone(&self) -> Self { + Self { + tx: self.tx.clone(), + } + } +} +impl Service for ClientChannel { + type Response = Y; + type Error = Error; + type Future = std::pin::Pin>>>; + + fn poll_ready( + &mut self, + _: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + Ok(()).into() + } + + fn call(&mut self, req: X) -> Self::Future { + let tx = self.tx.clone(); + Box::pin(async move { + let (tx1, rx1) = oneshot::channel::(); + let req = Request { + inner: req, + tx: tx1, + }; + tx.send(req).map_err(|_| Error::SendError)?; + let rep = rx1.await.map_err(|_| Error::RecvError)?; + Ok(rep) + }) + } +} + +/// mpsc channel wrapper on the server-side. +pub struct ServerChannel> { + service: Svc, + rx: mpsc::UnboundedReceiver>, +} +impl + 'static> ServerChannel { + pub fn new(rx: mpsc::UnboundedReceiver>, service: Svc) -> Self { + Self { service, rx } + } + pub async fn serve(mut self) { + while let Some(Request { tx, inner }) = self.rx.recv().await { + // back-pressure + futures::future::poll_fn(|ctx| self.service.poll_ready(ctx)) + .await + .ok(); + let fut = self.service.call(inner); + tokio::task::spawn_local(async move { + if let Ok(rep) = fut.await { + tx.send(rep).ok(); + } + }); + } + } +}