Skip to content

Commit

Permalink
Merge pull request #39 from akiradeveloper/no-send
Browse files Browse the repository at this point in the history
Support ?Send
  • Loading branch information
akiradeveloper authored Nov 26, 2021
2 parents d64fef3 + bfc821a commit cafe260
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 7 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ members = [
"norpc",
"norpc-macros",
"example/hello-world",
"example/hello-world-no-send",
"example/kvstore",
"example/concurrent-message",
"example/rate-limit",
Expand Down
12 changes: 12 additions & 0 deletions example/hello-world-no-send/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[package]
name = "hello-world-no-send"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
tokio = { version = "*", features = ["full"] }
tower = { version = "*", features = ["full"] }

norpc = { path = "../../norpc" }
38 changes: 38 additions & 0 deletions example/hello-world-no-send/tests/hello_world_no_send.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use std::rc::Rc;
use tokio::sync::mpsc;
use tower::Service;

#[norpc::service(?Send)]
trait HelloWorld {
// Rc<T> is !Send
fn hello(s: Rc<String>) -> Rc<String>;
}

#[derive(Clone)]
struct HelloWorldApp;
#[norpc::async_trait(?Send)]
impl HelloWorld for HelloWorldApp {
async fn hello(self, s: Rc<String>) -> Rc<String> {
format!("Hello, {}", s).into()
}
}
#[tokio::test(flavor = "multi_thread")]
async fn test_hello_world_no_send() {
let local = tokio::task::LocalSet::new();
let (tx, rx) = mpsc::unbounded_channel();
local.spawn_local(async move {
let app = HelloWorldApp;
let service = HelloWorldService::new(app);
let server = norpc::no_send::ServerChannel::new(rx, service);
server.serve().await
});
local.spawn_local(async move {
let chan = norpc::no_send::ClientChannel::new(tx);
let mut cli = HelloWorldClient::new(chan);
assert_eq!(
cli.hello("World".to_owned().into()).await.unwrap(),
"Hello, World".to_string().into()
);
});
local.await;
}
8 changes: 5 additions & 3 deletions norpc-macros/src/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,14 @@ impl Generator {
}
format!(
"
#[norpc::async_trait]
#[norpc::async_trait{no_send}]
pub trait {svc_name}: Clone {{
{}
}}
",
itertools::join(methods, ""),
svc_name = svc.name,
no_send = if self.no_send { "(?Send)" } else { "" },
)
}
fn generate_client_impl(&self, svc: &Service) -> String {
Expand Down Expand Up @@ -164,10 +165,10 @@ impl Generator {
Self {{ app }}
}}
}}
impl<App: {svc_name} + 'static + Send> tower::Service<{svc_name}Request> for {svc_name}Service<App> {{
impl<App: {svc_name} + 'static {no_send}> tower::Service<{svc_name}Request> for {svc_name}Service<App> {{
type Response = {svc_name}Response;
type Error = ();
type Future = std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>>;
type Future = std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> {no_send}>>;
fn poll_ready(
&mut self,
_: &mut std::task::Context<'_>,
Expand All @@ -186,6 +187,7 @@ impl Generator {
",
itertools::join(match_arms, ","),
svc_name = svc.name,
no_send = if self.no_send { "" } else { "+ Send" },
)
}
pub(super) fn generate(&self, svc: Service) -> String {
Expand Down
35 changes: 31 additions & 4 deletions norpc-macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,43 @@
use proc_macro::TokenStream;
use quote::quote;
use std::str::FromStr;
use syn::parse::{Parse, ParseStream, Result};
use syn::*;

mod generator;

struct Args {
local: bool,
}

mod kw {
syn::custom_keyword!(Send);
}

fn try_parse(input: ParseStream) -> Result<Args> {
if input.peek(Token![?]) {
input.parse::<Token![?]>()?;
input.parse::<kw::Send>()?;
Ok(Args { local: true })
} else {
Ok(Args { local: false })
}
}

impl Parse for Args {
fn parse(input: ParseStream) -> Result<Self> {
let args: Args = try_parse(input)?;
Ok(args)
}
}

#[proc_macro_attribute]
pub fn service(_: TokenStream, item: TokenStream) -> TokenStream {
pub fn service(args: TokenStream, item: TokenStream) -> TokenStream {
let args = parse_macro_input!(args as Args);
let t = syn::parse::<ItemTrait>(item).unwrap();
let svc = parse_service(&t);
let generator = generator::Generator {
no_send: false,
no_send: args.local,
};
let code = generator.generate(svc);
TokenStream::from_str(&code).unwrap()
Expand Down Expand Up @@ -80,10 +107,10 @@ fn parse_func(f: &TraitItem) -> Function {
match &sig.output {
ReturnType::Type(_, x) => {
output_ty = quote!(#x).to_string();
},
}
ReturnType::Default => {
output_ty = "()".to_string();
},
}
}
Function {
name: func_name,
Expand Down
2 changes: 2 additions & 0 deletions norpc/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tower_service::Service;

pub mod no_send;

// Re-exported for compiler
pub use async_trait::async_trait;
pub use futures::future::poll_fn;
Expand Down
73 changes: 73 additions & 0 deletions norpc/src/no_send.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tower_service::Service;

use crate::{Error, Request};

/// mpsc channel wrapper on the client-side.
pub struct ClientChannel<X, Y> {
tx: mpsc::UnboundedSender<Request<X, Y>>,
}
impl<X, Y> ClientChannel<X, Y> {
pub fn new(tx: mpsc::UnboundedSender<Request<X, Y>>) -> Self {
Self { tx }
}
}
impl<X, Y> Clone for ClientChannel<X, Y> {
fn clone(&self) -> Self {
Self {
tx: self.tx.clone(),
}
}
}
impl<X: 'static, Y: 'static> Service<X> for ClientChannel<X, Y> {
type Response = Y;
type Error = Error;
type Future = std::pin::Pin<Box<dyn std::future::Future<Output = Result<Y, Self::Error>>>>;

fn poll_ready(
&mut self,
_: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
Ok(()).into()
}

fn call(&mut self, req: X) -> Self::Future {
let tx = self.tx.clone();
Box::pin(async move {
let (tx1, rx1) = oneshot::channel::<Y>();
let req = Request {
inner: req,
tx: tx1,
};
tx.send(req).map_err(|_| Error::SendError)?;
let rep = rx1.await.map_err(|_| Error::RecvError)?;
Ok(rep)
})
}
}

/// mpsc channel wrapper on the server-side.
pub struct ServerChannel<Req, Svc: Service<Req>> {
service: Svc,
rx: mpsc::UnboundedReceiver<Request<Req, Svc::Response>>,
}
impl<Req: 'static, Svc: Service<Req> + 'static> ServerChannel<Req, Svc> {
pub fn new(rx: mpsc::UnboundedReceiver<Request<Req, Svc::Response>>, service: Svc) -> Self {
Self { service, rx }
}
pub async fn serve(mut self) {
while let Some(Request { tx, inner }) = self.rx.recv().await {
// back-pressure
futures::future::poll_fn(|ctx| self.service.poll_ready(ctx))
.await
.ok();
let fut = self.service.call(inner);
tokio::task::spawn_local(async move {
if let Ok(rep) = fut.await {
tx.send(rep).ok();
}
});
}
}
}

0 comments on commit cafe260

Please sign in to comment.