Skip to content

Commit

Permalink
refactor: transport requires type-erased futures. improved batch ergo
Browse files Browse the repository at this point in the history
  • Loading branch information
prestwich committed Jul 25, 2023
1 parent 12f78fa commit 2be245a
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 146 deletions.
117 changes: 74 additions & 43 deletions crates/transports/src/batch.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{
collections::HashMap,
future::Future,
future::{Future, IntoFuture},
marker::PhantomData,
pin::Pin,
task::{self, ready},
Expand All @@ -10,16 +10,21 @@ use futures_channel::oneshot;
use serde_json::value::RawValue;
use tower::Service;

use crate::{error::TransportError, transports::BatchTransportFuture};
use alloy_json_rpc::{Id, JsonRpcRequest, JsonRpcResponse, RpcResult, RpcReturn};
use crate::{
error::TransportError,
transports::{BatchFutureOf, Transport},
RpcClient,
};
use alloy_json_rpc::{Id, JsonRpcRequest, RpcParam, RpcResult, RpcReturn};

type Channel = oneshot::Sender<RpcResult<Box<RawValue>, TransportError>>;
type ChannelMap = HashMap<Id, Channel>;

#[must_use = "A BatchRequest does nothing unless sent via `send_batch` or via `.await`"]
/// A Batch JSON-RPC request, awaiting dispatch.
#[derive(Debug, Default)]
pub struct BatchRequest<T> {
transport: T,
#[derive(Debug)]
pub struct BatchRequest<'a, T> {
transport: &'a RpcClient<T>,

requests: Vec<JsonRpcRequest>,

Expand All @@ -32,6 +37,15 @@ pub struct Waiter<Resp> {
_resp: PhantomData<Resp>,
}

impl<Resp> From<oneshot::Receiver<RpcResult<Box<RawValue>, TransportError>>> for Waiter<Resp> {
fn from(rx: oneshot::Receiver<RpcResult<Box<RawValue>, TransportError>>) -> Self {
Self {
rx,
_resp: PhantomData,
}
}
}

impl<Resp> std::future::Future for Waiter<Resp>
where
Resp: RpcReturn,
Expand All @@ -52,35 +66,34 @@ where
}

#[pin_project::pin_project(project = CallStateProj)]
pub enum BatchFuture<T>
pub enum BatchFuture<Conn>
where
T: Service<
Vec<JsonRpcRequest>,
Response = Vec<JsonRpcResponse>,
Error = TransportError,
Future = BatchTransportFuture,
>,
Conn: Transport,
{
Prepared(BatchRequest<T>),
Prepared {
transport: Conn,
requests: Vec<JsonRpcRequest>,
channels: ChannelMap,
},
SerError(Option<TransportError>),
AwaitingResponse {
channels: ChannelMap,
#[pin]
fut: <T as Service<Vec<JsonRpcRequest>>>::Future,
fut: BatchFutureOf<Conn>,
},
Complete,
}

impl<T> BatchRequest<T> {
pub fn new(transport: T) -> Self {
impl<'a, T> BatchRequest<'a, T> {
pub fn new(transport: &'a RpcClient<T>) -> Self {
Self {
transport,
requests: Vec::with_capacity(10),
channels: HashMap::with_capacity(10),
}
}

pub fn push_req(
fn push_raw(
&mut self,
request: JsonRpcRequest,
) -> oneshot::Receiver<RpcResult<Box<RawValue>, TransportError>> {
Expand All @@ -89,46 +102,69 @@ impl<T> BatchRequest<T> {
self.requests.push(request);
rx
}

fn push<Resp: RpcReturn>(&mut self, request: JsonRpcRequest) -> Waiter<Resp> {
self.push_raw(request).into()
}
}

impl<T> BatchRequest<T>
impl<'a, T> BatchRequest<'a, T>
where
T: Service<
Vec<JsonRpcRequest>,
Response = Vec<JsonRpcResponse>,
Error = TransportError,
Future = BatchTransportFuture,
>,
T: Transport,
{
#[must_use = "Waiters do nothing unless polled. A Waiter will never resolve unless its batch is sent."]
/// Add a call to the batch.
pub fn add_call<Params: RpcParam, Resp: RpcReturn>(
&mut self,
method: &'static str,
params: Params,
) -> Waiter<Resp> {
let request = self.transport.make_request(method, params).unwrap();
self.push(request)
}

/// Send the batch future via its connection.
pub fn send(self) -> BatchFuture<T> {
BatchFuture::Prepared(self)
pub fn send_batch(self) -> BatchFuture<T> {
BatchFuture::Prepared {
transport: self.transport.transport.clone(),
requests: self.requests,
channels: self.channels,
}
}
}

impl<'a, T> IntoFuture for BatchRequest<'a, T>
where
T: Transport,
{
type Output = <BatchFuture<T> as Future>::Output;
type IntoFuture = BatchFuture<T>;

fn into_future(self) -> Self::IntoFuture {
self.send_batch()
}
}

impl<T> BatchFuture<T>
where
T: Service<
Vec<JsonRpcRequest>,
Response = Vec<JsonRpcResponse>,
Error = TransportError,
Future = BatchTransportFuture,
>,
T: Transport,
{
fn poll_prepared(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> task::Poll<<Self as Future>::Output> {
let CallStateProj::Prepared(BatchRequest {
let CallStateProj::Prepared {
transport,
requests,
channels,
}) = self.as_mut().project()
} = self.as_mut().project()
else {
unreachable!("Called poll_prepared in incorrect state")
};

if let Err(e) = task::ready!(transport.poll_ready(cx)) {
if let Err(e) = task::ready!(<T as Service<Vec<JsonRpcRequest>>>::poll_ready(
transport, cx
)) {
self.set(BatchFuture::Complete);
return task::Poll::Ready(Err(e));
}
Expand Down Expand Up @@ -196,17 +232,12 @@ where

impl<T> Future for BatchFuture<T>
where
T: Service<
Vec<JsonRpcRequest>,
Response = Vec<JsonRpcResponse>,
Error = TransportError,
Future = BatchTransportFuture,
>,
T: Transport,
{
type Output = Result<(), TransportError>;

fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
if matches!(*self.as_mut(), BatchFuture::Prepared(_)) {
if matches!(*self.as_mut(), BatchFuture::Prepared { .. }) {
return self.poll_prepared(cx);
}

Expand Down
46 changes: 8 additions & 38 deletions crates/transports/src/call.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{error::TransportError, transports::TransportFuture};
use crate::{error::TransportError, transports::FutureOf};

use alloy_json_rpc::{JsonRpcRequest, JsonRpcResponse, RpcParam, RpcResult, RpcReturn};
use serde_json::value::RawValue;
Expand All @@ -13,33 +13,23 @@ use tower::Service;
#[pin_project::pin_project(project = CallStateProj)]
enum CallState<Conn>
where
Conn: Service<
JsonRpcRequest,
Response = JsonRpcResponse,
Error = TransportError,
Future = TransportFuture,
>,
Conn: Service<JsonRpcRequest, Response = JsonRpcResponse, Error = TransportError>,
{
Prepared {
request: Option<JsonRpcRequest>,
connection: Conn,
},
AwaitingResponse {
#[pin]
fut: <Conn as Service<JsonRpcRequest>>::Future,
fut: FutureOf<Conn>,
},
Complete,
SerError(Option<TransportError>),
}

impl<Conn> CallState<Conn>
where
Conn: Service<
JsonRpcRequest,
Response = JsonRpcResponse,
Error = TransportError,
Future = TransportFuture,
>,
Conn: Service<JsonRpcRequest, Response = JsonRpcResponse, Error = TransportError>,
{
fn poll_prepared(
mut self: Pin<&mut Self>,
Expand Down Expand Up @@ -98,12 +88,7 @@ where

impl<Conn> Future for CallState<Conn>
where
Conn: Service<
JsonRpcRequest,
Response = JsonRpcResponse,
Error = TransportError,
Future = TransportFuture,
>,
Conn: Service<JsonRpcRequest, Response = JsonRpcResponse, Error = TransportError>,
{
type Output = RpcResult<Box<RawValue>, TransportError>;

Expand All @@ -127,12 +112,7 @@ where
#[pin_project::pin_project]
pub struct RpcCall<Conn, Params, Resp>
where
Conn: Service<
JsonRpcRequest,
Response = JsonRpcResponse,
Error = TransportError,
Future = TransportFuture,
>,
Conn: Service<JsonRpcRequest, Response = JsonRpcResponse, Error = TransportError>,
Params: RpcParam,
{
#[pin]
Expand All @@ -142,12 +122,7 @@ where

impl<Conn, Params, Resp> RpcCall<Conn, Params, Resp>
where
Conn: Service<
JsonRpcRequest,
Response = JsonRpcResponse,
Error = TransportError,
Future = TransportFuture,
>,
Conn: Service<JsonRpcRequest, Response = JsonRpcResponse, Error = TransportError>,
Params: RpcParam,
{
pub fn new(request: Result<JsonRpcRequest, TransportError>, connection: Conn) -> Self {
Expand All @@ -168,12 +143,7 @@ where

impl<Conn, Params, Resp> Future for RpcCall<Conn, Params, Resp>
where
Conn: Service<
JsonRpcRequest,
Response = JsonRpcResponse,
Error = TransportError,
Future = TransportFuture,
>,
Conn: Service<JsonRpcRequest, Response = JsonRpcResponse, Error = TransportError>,
Params: RpcParam,
Resp: RpcReturn,
{
Expand Down
Loading

0 comments on commit 2be245a

Please sign in to comment.