Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ?Send #39

Merged
merged 6 commits into from
Nov 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ members = [
"norpc",
"norpc-macros",
"example/hello-world",
"example/hello-world-no-send",
"example/kvstore",
"example/concurrent-message",
"example/rate-limit",
Expand Down
12 changes: 12 additions & 0 deletions example/hello-world-no-send/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[package]
name = "hello-world-no-send"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
tokio = { version = "*", features = ["full"] }
tower = { version = "*", features = ["full"] }

norpc = { path = "../../norpc" }
38 changes: 38 additions & 0 deletions example/hello-world-no-send/tests/hello_world_no_send.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use std::rc::Rc;
use tokio::sync::mpsc;
use tower::Service;

#[norpc::service(?Send)]
trait HelloWorld {
// Rc<T> is !Send
fn hello(s: Rc<String>) -> Rc<String>;
}

#[derive(Clone)]
struct HelloWorldApp;
#[norpc::async_trait(?Send)]
impl HelloWorld for HelloWorldApp {
async fn hello(self, s: Rc<String>) -> Rc<String> {
format!("Hello, {}", s).into()
}
}
#[tokio::test(flavor = "multi_thread")]
async fn test_hello_world_no_send() {
let local = tokio::task::LocalSet::new();
let (tx, rx) = mpsc::unbounded_channel();
local.spawn_local(async move {
let app = HelloWorldApp;
let service = HelloWorldService::new(app);
let server = norpc::no_send::ServerChannel::new(rx, service);
server.serve().await
});
local.spawn_local(async move {
let chan = norpc::no_send::ClientChannel::new(tx);
let mut cli = HelloWorldClient::new(chan);
assert_eq!(
cli.hello("World".to_owned().into()).await.unwrap(),
"Hello, World".to_string().into()
);
});
local.await;
}
8 changes: 5 additions & 3 deletions norpc-macros/src/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,14 @@ impl Generator {
}
format!(
"
#[norpc::async_trait]
#[norpc::async_trait{no_send}]
pub trait {svc_name}: Clone {{
{}
}}
",
itertools::join(methods, ""),
svc_name = svc.name,
no_send = if self.no_send { "(?Send)" } else { "" },
)
}
fn generate_client_impl(&self, svc: &Service) -> String {
Expand Down Expand Up @@ -164,10 +165,10 @@ impl Generator {
Self {{ app }}
}}
}}
impl<App: {svc_name} + 'static + Send> tower::Service<{svc_name}Request> for {svc_name}Service<App> {{
impl<App: {svc_name} + 'static {no_send}> tower::Service<{svc_name}Request> for {svc_name}Service<App> {{
type Response = {svc_name}Response;
type Error = ();
type Future = std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>>;
type Future = std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> {no_send}>>;
fn poll_ready(
&mut self,
_: &mut std::task::Context<'_>,
Expand All @@ -186,6 +187,7 @@ impl Generator {
",
itertools::join(match_arms, ","),
svc_name = svc.name,
no_send = if self.no_send { "" } else { "+ Send" },
)
}
pub(super) fn generate(&self, svc: Service) -> String {
Expand Down
35 changes: 31 additions & 4 deletions norpc-macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,43 @@
use proc_macro::TokenStream;
use quote::quote;
use std::str::FromStr;
use syn::parse::{Parse, ParseStream, Result};
use syn::*;

mod generator;

struct Args {
local: bool,
}

mod kw {
syn::custom_keyword!(Send);
}

fn try_parse(input: ParseStream) -> Result<Args> {
if input.peek(Token![?]) {
input.parse::<Token![?]>()?;
input.parse::<kw::Send>()?;
Ok(Args { local: true })
} else {
Ok(Args { local: false })
}
}

impl Parse for Args {
fn parse(input: ParseStream) -> Result<Self> {
let args: Args = try_parse(input)?;
Ok(args)
}
}

#[proc_macro_attribute]
pub fn service(_: TokenStream, item: TokenStream) -> TokenStream {
pub fn service(args: TokenStream, item: TokenStream) -> TokenStream {
let args = parse_macro_input!(args as Args);
let t = syn::parse::<ItemTrait>(item).unwrap();
let svc = parse_service(&t);
let generator = generator::Generator {
no_send: false,
no_send: args.local,
};
let code = generator.generate(svc);
TokenStream::from_str(&code).unwrap()
Expand Down Expand Up @@ -80,10 +107,10 @@ fn parse_func(f: &TraitItem) -> Function {
match &sig.output {
ReturnType::Type(_, x) => {
output_ty = quote!(#x).to_string();
},
}
ReturnType::Default => {
output_ty = "()".to_string();
},
}
}
Function {
name: func_name,
Expand Down
2 changes: 2 additions & 0 deletions norpc/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tower_service::Service;

pub mod no_send;

// Re-exported for compiler
pub use async_trait::async_trait;
pub use futures::future::poll_fn;
Expand Down
73 changes: 73 additions & 0 deletions norpc/src/no_send.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tower_service::Service;

use crate::{Error, Request};

/// mpsc channel wrapper on the client-side.
pub struct ClientChannel<X, Y> {
tx: mpsc::UnboundedSender<Request<X, Y>>,
}
impl<X, Y> ClientChannel<X, Y> {
pub fn new(tx: mpsc::UnboundedSender<Request<X, Y>>) -> Self {
Self { tx }
}
}
impl<X, Y> Clone for ClientChannel<X, Y> {
fn clone(&self) -> Self {
Self {
tx: self.tx.clone(),
}
}
}
impl<X: 'static, Y: 'static> Service<X> for ClientChannel<X, Y> {
type Response = Y;
type Error = Error;
type Future = std::pin::Pin<Box<dyn std::future::Future<Output = Result<Y, Self::Error>>>>;

fn poll_ready(
&mut self,
_: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
Ok(()).into()
}

fn call(&mut self, req: X) -> Self::Future {
let tx = self.tx.clone();
Box::pin(async move {
let (tx1, rx1) = oneshot::channel::<Y>();
let req = Request {
inner: req,
tx: tx1,
};
tx.send(req).map_err(|_| Error::SendError)?;
let rep = rx1.await.map_err(|_| Error::RecvError)?;
Ok(rep)
})
}
}

/// mpsc channel wrapper on the server-side.
pub struct ServerChannel<Req, Svc: Service<Req>> {
service: Svc,
rx: mpsc::UnboundedReceiver<Request<Req, Svc::Response>>,
}
impl<Req: 'static, Svc: Service<Req> + 'static> ServerChannel<Req, Svc> {
pub fn new(rx: mpsc::UnboundedReceiver<Request<Req, Svc::Response>>, service: Svc) -> Self {
Self { service, rx }
}
pub async fn serve(mut self) {
while let Some(Request { tx, inner }) = self.rx.recv().await {
// back-pressure
futures::future::poll_fn(|ctx| self.service.poll_ready(ctx))
.await
.ok();
let fut = self.service.call(inner);
tokio::task::spawn_local(async move {
if let Ok(rep) = fut.await {
tx.send(rep).ok();
}
});
}
}
}