diff --git a/Cargo.toml b/Cargo.toml index 117ab4de4..c2beca789 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ members = [ "examples", "codegen", "interop", # Tests + "benches/arrow-flight", "tests/disable_comments", "tests/included_service", "tests/same_name", diff --git a/benches/arrow-flight/Cargo.toml b/benches/arrow-flight/Cargo.toml new file mode 100644 index 000000000..dac05825b --- /dev/null +++ b/benches/arrow-flight/Cargo.toml @@ -0,0 +1,25 @@ +[package] +authors = ["Xiaoya Wei "] +edition = "2021" +license = "MIT" +name = "arrow-flight" +publish = false +version = "0.1.0" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +prost = "0.12" +tokio = { version = "1.0", features = ["rt-multi-thread"] } +tonic = { path = "../../tonic" } + +[build-dependencies] +tonic-build = { path = "../../tonic-build", features = ["prost"] } + +[dev-dependencies] +bencher = "0.1.5" +futures-util = { version = "0.3", default-features = false } + +[[bench]] +harness = false +name = "end_to_end" diff --git a/benches/arrow-flight/benches/end_to_end.rs b/benches/arrow-flight/benches/end_to_end.rs new file mode 100644 index 000000000..5594364b0 --- /dev/null +++ b/benches/arrow-flight/benches/end_to_end.rs @@ -0,0 +1,222 @@ +use arrow_flight::{arrow, client::FlightClient, server::FlightService}; +use bencher::{benchmark_group, benchmark_main}; +use prost::bytes::Bytes; +use std::{sync::Arc, time::Duration}; +use tokio::{sync::mpsc, time}; +use tonic::codegen::tokio_stream::{wrappers::ReceiverStream, StreamExt}; +use tonic::transport::Server; + +#[derive(Default, Debug)] +struct Opts { + manual_client: bool, + manual_server: bool, + request_chunks: usize, + request_payload: arrow::FlightData, + response_payload: arrow::FlightData, +} + +fn req_64kb_resp_64kb_10_chunks_native_client_native_server(b: &mut bencher::Bencher) { + opts() + .request_chunks(10) + .request_payload(make_payload(64 * 1024)) + .response_payload(make_payload(64 * 1024)) + .bench(b); +} + +fn req_64kb_resp_64kb_10_chunks_native_client_manual_server(b: &mut bencher::Bencher) { + opts() + .manual_server() + .request_chunks(10) + .request_payload(make_payload(64 * 1024)) + .response_payload(make_payload(64 * 1024)) + .bench(b); +} + +fn req_64kb_resp_64kb_10_chunks_manual_client_native_server(b: &mut bencher::Bencher) { + opts() + .manual_client() + .request_chunks(10) + .request_payload(make_payload(64 * 1024)) + .response_payload(make_payload(64 * 1024)) + .bench(b); +} + +fn req_64kb_resp_64kb_10_chunks_manual_client_manual_server(b: &mut bencher::Bencher) { + opts() + .manual_server() + .manual_client() + .request_chunks(10) + .request_payload(make_payload(64 * 1024)) + .response_payload(make_payload(64 * 1024)) + .bench(b); +} + +fn req_1mb_resp_1mb_10_chunks_native_client_native_server(b: &mut bencher::Bencher) { + opts() + .request_chunks(10) + .request_payload(make_payload(1 * 1024 * 1024)) + .response_payload(make_payload(1 * 1024 * 1024)) + .bench(b); +} + +fn req_1mb_resp_1mb_10_chunks_native_client_manual_server(b: &mut bencher::Bencher) { + opts() + .manual_server() + .request_chunks(10) + .request_payload(make_payload(1 * 1024 * 1024)) + .response_payload(make_payload(1 * 1024 * 1024)) + .bench(b); +} + +fn req_1mb_resp_1mb_10_chunks_manual_client_native_server(b: &mut bencher::Bencher) { + opts() + .manual_client() + .request_chunks(10) + .request_payload(make_payload(1 * 1024 * 1024)) + .response_payload(make_payload(1 * 1024 * 1024)) + .bench(b); +} + +fn req_1mb_resp_1mb_10_chunks_manual_client_manual_server(b: &mut bencher::Bencher) { + opts() + .manual_server() + .manual_client() + .request_chunks(10) + .request_payload(make_payload(1 * 1024 * 1024)) + .response_payload(make_payload(1 * 1024 * 1024)) + .bench(b); +} + +fn make_payload(size: usize) -> arrow::FlightData { + arrow::FlightData { + flight_descriptor: Some(arrow::FlightDescriptor { + cmd: Bytes::from("cmd"), + path: vec!["/path/to/data".to_string()], + ..Default::default() + }), + data_header: Bytes::from("data_header"), + app_metadata: Bytes::from("app_metadata"), + data_body: Bytes::from(vec![b'a'; size]), + } +} + +fn opts() -> Opts { + Opts { + request_chunks: 1, + ..Default::default() + } +} + +impl Opts { + fn manual_client(mut self) -> Self { + self.manual_client = true; + self + } + + fn manual_server(mut self) -> Self { + self.manual_server = true; + self + } + + fn request_chunks(mut self, chunks: usize) -> Self { + self.request_chunks = chunks; + self + } + + fn request_payload(mut self, payload: arrow::FlightData) -> Self { + self.request_payload = payload; + self + } + + fn response_payload(mut self, payload: arrow::FlightData) -> Self { + self.response_payload = payload; + self + } + + fn bench(self, b: &mut bencher::Bencher) { + let rt = Arc::new( + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .expect("runtime"), + ); + + b.bytes = ((self.request_payload.data_body.len() + self.response_payload.data_body.len()) + * self.request_chunks) as u64; + + spawn_server(&rt, &self); + + let channel = rt.block_on(async { + time::sleep(Duration::from_millis(100)).await; + tonic::transport::Endpoint::from_static("http://127.0.0.1:1500") + .connect() + .await + .unwrap() + }); + + let do_exchange = || async { + let mut client = if self.manual_client { + FlightClient::manual(channel.clone()) + } else { + FlightClient::native(channel.clone()) + }; + let (tx, rx) = mpsc::channel(8192); + let mut server_stream = client + .do_exchange(ReceiverStream::new(rx)) + .await + .unwrap() + .into_inner(); + for _ in 0..self.request_chunks { + tx.send(self.response_payload.clone()).await.unwrap(); + server_stream.next().await.unwrap().unwrap(); + } + }; + + b.iter(move || { + rt.block_on(do_exchange()); + }); + } +} + +fn spawn_server(rt: &tokio::runtime::Runtime, opts: &Opts) { + let addr = "127.0.0.1:1500"; + + let response_payload = opts.response_payload.clone(); + let manual_server = opts.manual_server; + let srv = rt.block_on(async move { + let flight_service = FlightService::new(response_payload); + if manual_server { + Server::builder() + .add_service( + arrow::manual::flight_service_server::FlightServiceServer::new(flight_service), + ) + .serve(addr.parse().unwrap()) + } else { + Server::builder() + .add_service(arrow::flight_service_server::FlightServiceServer::new( + flight_service, + )) + .serve(addr.parse().unwrap()) + } + }); + + rt.spawn(async move { srv.await.unwrap() }); +} + +benchmark_group!( + req_64kb_resp_64kb_10_chunks, + req_64kb_resp_64kb_10_chunks_native_client_native_server, + req_64kb_resp_64kb_10_chunks_manual_client_native_server, + req_64kb_resp_64kb_10_chunks_native_client_manual_server, + req_64kb_resp_64kb_10_chunks_manual_client_manual_server +); + +benchmark_group!( + req_1mb_resp_1mb_10_chunks, + req_1mb_resp_1mb_10_chunks_native_client_native_server, + req_1mb_resp_1mb_10_chunks_manual_client_native_server, + req_1mb_resp_1mb_10_chunks_native_client_manual_server, + req_1mb_resp_1mb_10_chunks_manual_client_manual_server +); + +benchmark_main!(req_64kb_resp_64kb_10_chunks, req_1mb_resp_1mb_10_chunks); diff --git a/benches/arrow-flight/build.rs b/benches/arrow-flight/build.rs new file mode 100644 index 000000000..c3114fbec --- /dev/null +++ b/benches/arrow-flight/build.rs @@ -0,0 +1,23 @@ +fn main() { + tonic_build::configure() + .bytes(&["."]) + .generate_default_stubs(true) + .compile(&["proto/flight.proto"], &["proto"]) + .unwrap(); + + tonic_build::manual::Builder::new().compile(&[tonic_build::manual::Service::builder() + .name("FlightService") + .package("arrow.flight.protocol") + .method( + tonic_build::manual::Method::builder() + .name("do_exchange") + .route_name("DoExchange") + .input_type("crate::arrow::FlightData") + .output_type("crate::arrow::FlightData") + .codec_path("crate::codec::FlightDataCodec") + .client_streaming() + .server_streaming() + .build(), + ) + .build()]); +} diff --git a/benches/arrow-flight/proto/flight.proto b/benches/arrow-flight/proto/flight.proto new file mode 100644 index 000000000..ee489b557 --- /dev/null +++ b/benches/arrow-flight/proto/flight.proto @@ -0,0 +1,380 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +option java_package = "org.apache.arrow.flight.impl"; +option go_package = "github.com/apache/arrow/go/arrow/flight/internal/flight"; +option csharp_namespace = "Apache.Arrow.Flight.Protocol"; + +package arrow.flight.protocol; + +/* + * A flight service is an endpoint for retrieving or storing Arrow data. A + * flight service can expose one or more predefined endpoints that can be + * accessed using the Arrow Flight Protocol. Additionally, a flight service + * can expose a set of actions that are available. + */ +service FlightService { + + /* + * Handshake between client and server. Depending on the server, the + * handshake may be required to determine the token that should be used for + * future operations. Both request and response are streams to allow multiple + * round-trips depending on auth mechanism. + */ + rpc Handshake(stream HandshakeRequest) returns (stream HandshakeResponse) {} + + /* + * Get a list of available streams given a particular criteria. Most flight + * services will expose one or more streams that are readily available for + * retrieval. This api allows listing the streams available for + * consumption. A user can also provide a criteria. The criteria can limit + * the subset of streams that can be listed via this interface. Each flight + * service allows its own definition of how to consume criteria. + */ + rpc ListFlights(Criteria) returns (stream FlightInfo) {} + + /* + * For a given FlightDescriptor, get information about how the flight can be + * consumed. This is a useful interface if the consumer of the interface + * already can identify the specific flight to consume. This interface can + * also allow a consumer to generate a flight stream through a specified + * descriptor. For example, a flight descriptor might be something that + * includes a SQL statement or a Pickled Python operation that will be + * executed. In those cases, the descriptor will not be previously available + * within the list of available streams provided by ListFlights but will be + * available for consumption for the duration defined by the specific flight + * service. + */ + rpc GetFlightInfo(FlightDescriptor) returns (FlightInfo) {} + + /* + * For a given FlightDescriptor, get the Schema as described in Schema.fbs::Schema + * This is used when a consumer needs the Schema of flight stream. Similar to + * GetFlightInfo this interface may generate a new flight that was not previously + * available in ListFlights. + */ + rpc GetSchema(FlightDescriptor) returns (SchemaResult) {} + + /* + * Retrieve a single stream associated with a particular descriptor + * associated with the referenced ticket. A Flight can be composed of one or + * more streams where each stream can be retrieved using a separate opaque + * ticket that the flight service uses for managing a collection of streams. + */ + rpc DoGet(Ticket) returns (stream FlightData) {} + + /* + * Push a stream to the flight service associated with a particular + * flight stream. This allows a client of a flight service to upload a stream + * of data. Depending on the particular flight service, a client consumer + * could be allowed to upload a single stream per descriptor or an unlimited + * number. In the latter, the service might implement a 'seal' action that + * can be applied to a descriptor once all streams are uploaded. + */ + rpc DoPut(stream FlightData) returns (stream PutResult) {} + + /* + * Open a bidirectional data channel for a given descriptor. This + * allows clients to send and receive arbitrary Arrow data and + * application-specific metadata in a single logical stream. In + * contrast to DoGet/DoPut, this is more suited for clients + * offloading computation (rather than storage) to a Flight service. + */ + rpc DoExchange(stream FlightData) returns (stream FlightData) {} + + /* + * Flight services can support an arbitrary number of simple actions in + * addition to the possible ListFlights, GetFlightInfo, DoGet, DoPut + * operations that are potentially available. DoAction allows a flight client + * to do a specific action against a flight service. An action includes + * opaque request and response objects that are specific to the type action + * being undertaken. + */ + rpc DoAction(Action) returns (stream Result) {} + + /* + * A flight service exposes all of the available action types that it has + * along with descriptions. This allows different flight consumers to + * understand the capabilities of the flight service. + */ + rpc ListActions(Empty) returns (stream ActionType) {} + +} + +/* + * The request that a client provides to a server on handshake. + */ +message HandshakeRequest { + + /* + * A defined protocol version + */ + uint64 protocol_version = 1; + + /* + * Arbitrary auth/handshake info. + */ + bytes payload = 2; +} + +message HandshakeResponse { + + /* + * A defined protocol version + */ + uint64 protocol_version = 1; + + /* + * Arbitrary auth/handshake info. + */ + bytes payload = 2; +} + +/* + * A message for doing simple auth. + */ +message BasicAuth { + string username = 2; + string password = 3; +} + +message Empty {} + +/* + * Describes an available action, including both the name used for execution + * along with a short description of the purpose of the action. + */ +message ActionType { + string type = 1; + string description = 2; +} + +/* + * A service specific expression that can be used to return a limited set + * of available Arrow Flight streams. + */ +message Criteria { + bytes expression = 1; +} + +/* + * An opaque action specific for the service. + */ +message Action { + string type = 1; + bytes body = 2; +} + +/* + * An opaque result returned after executing an action. + */ +message Result { + bytes body = 1; +} + +/* + * Wrap the result of a getSchema call + */ +message SchemaResult { + // The schema of the dataset in its IPC form: + // 4 bytes - an optional IPC_CONTINUATION_TOKEN prefix + // 4 bytes - the byte length of the payload + // a flatbuffer Message whose header is the Schema + bytes schema = 1; +} + +/* + * The name or tag for a Flight. May be used as a way to retrieve or generate + * a flight or be used to expose a set of previously defined flights. + */ +message FlightDescriptor { + + /* + * Describes what type of descriptor is defined. + */ + enum DescriptorType { + + // Protobuf pattern, not used. + UNKNOWN = 0; + + /* + * A named path that identifies a dataset. A path is composed of a string + * or list of strings describing a particular dataset. This is conceptually + * similar to a path inside a filesystem. + */ + PATH = 1; + + /* + * An opaque command to generate a dataset. + */ + CMD = 2; + } + + DescriptorType type = 1; + + /* + * Opaque value used to express a command. Should only be defined when + * type = CMD. + */ + bytes cmd = 2; + + /* + * List of strings identifying a particular dataset. Should only be defined + * when type = PATH. + */ + repeated string path = 3; +} + +/* + * The access coordinates for retrieval of a dataset. With a FlightInfo, a + * consumer is able to determine how to retrieve a dataset. + */ +message FlightInfo { + // The schema of the dataset in its IPC form: + // 4 bytes - an optional IPC_CONTINUATION_TOKEN prefix + // 4 bytes - the byte length of the payload + // a flatbuffer Message whose header is the Schema + bytes schema = 1; + + /* + * The descriptor associated with this info. + */ + FlightDescriptor flight_descriptor = 2; + + /* + * A list of endpoints associated with the flight. To consume the + * whole flight, all endpoints (and hence all Tickets) must be + * consumed. Endpoints can be consumed in any order. + * + * In other words, an application can use multiple endpoints to + * represent partitioned data. + * + * If the returned data has an ordering, an application can use + * "FlightInfo.ordered = true" or should return the all data in a + * single endpoint. Otherwise, there is no ordering defined on + * endpoints or the data within. + * + * A client can read ordered data by reading data from returned + * endpoints, in order, from front to back. + * + * Note that a client may ignore "FlightInfo.ordered = true". If an + * ordering is important for an application, an application must + * choose one of them: + * + * * An application requires that all clients must read data in + * returned endpoints order. + * * An application must return the all data in a single endpoint. + */ + repeated FlightEndpoint endpoint = 3; + + // Set these to -1 if unknown. + int64 total_records = 4; + int64 total_bytes = 5; + + /* + * FlightEndpoints are in the same order as the data. + */ + bool ordered = 6; +} + +/* + * A particular stream or split associated with a flight. + */ +message FlightEndpoint { + + /* + * Token used to retrieve this stream. + */ + Ticket ticket = 1; + + /* + * A list of URIs where this ticket can be redeemed via DoGet(). + * + * If the list is empty, the expectation is that the ticket can only + * be redeemed on the current service where the ticket was + * generated. + * + * If the list is not empty, the expectation is that the ticket can + * be redeemed at any of the locations, and that the data returned + * will be equivalent. In this case, the ticket may only be redeemed + * at one of the given locations, and not (necessarily) on the + * current service. + * + * In other words, an application can use multiple locations to + * represent redundant and/or load balanced services. + */ + repeated Location location = 2; +} + +/* + * A location where a Flight service will accept retrieval of a particular + * stream given a ticket. + */ +message Location { + string uri = 1; +} + +/* + * An opaque identifier that the service can use to retrieve a particular + * portion of a stream. + * + * Tickets are meant to be single use. It is an error/application-defined + * behavior to reuse a ticket. + */ +message Ticket { + bytes ticket = 1; +} + +/* + * A batch of Arrow data as part of a stream of batches. + */ +message FlightData { + + /* + * The descriptor of the data. This is only relevant when a client is + * starting a new DoPut stream. + */ + FlightDescriptor flight_descriptor = 1; + + /* + * Header for message data as described in Message.fbs::Message. + */ + bytes data_header = 2; + + /* + * Application-defined metadata. + */ + bytes app_metadata = 3; + + /* + * The actual batch of Arrow data. Preferably handled with minimal-copies + * coming last in the definition to help with sidecar patterns (it is + * expected that some implementations will fetch this field off the wire + * with specialized code to avoid extra memory copies). + */ + bytes data_body = 1000; +} + +/** + * The response message associated with the submission of a DoPut. + */ +message PutResult { + bytes app_metadata = 1; +} \ No newline at end of file diff --git a/benches/arrow-flight/src/client.rs b/benches/arrow-flight/src/client.rs new file mode 100644 index 000000000..b6ee5e51b --- /dev/null +++ b/benches/arrow-flight/src/client.rs @@ -0,0 +1,36 @@ +use crate::arrow; +use tonic::codegen::{Body, StdError}; + +pub enum FlightClient { + Manual(arrow::manual::flight_service_client::FlightServiceClient), + Native(arrow::flight_service_client::FlightServiceClient), +} + +impl FlightClient +where + T: tonic::client::GrpcService, + T::Error: Into, + T::ResponseBody: Body + Send + 'static, + ::Data: Into + Send, + ::Error: Into + Send, +{ + pub fn manual(inner: T) -> Self { + FlightClient::Manual(arrow::manual::flight_service_client::FlightServiceClient::new(inner)) + } + + pub fn native(inner: T) -> Self { + FlightClient::Native(arrow::flight_service_client::FlightServiceClient::new( + inner, + )) + } + + pub async fn do_exchange( + &mut self, + request: impl tonic::IntoStreamingRequest, + ) -> Result>, tonic::Status> { + match self { + FlightClient::Manual(client) => client.do_exchange(request).await, + FlightClient::Native(client) => client.do_exchange(request).await, + } + } +} diff --git a/benches/arrow-flight/src/codec.rs b/benches/arrow-flight/src/codec.rs new file mode 100644 index 000000000..fc3ca1bd8 --- /dev/null +++ b/benches/arrow-flight/src/codec.rs @@ -0,0 +1,44 @@ +use crate::arrow; +use prost::{bytes::Buf, Message}; +use std::mem; +use tonic::{ + codec::{Codec, EncodeBuf, Encoder, ProstCodec}, + Status, +}; + +#[derive(Default)] +pub(crate) struct FlightDataCodec; + +impl Codec for FlightDataCodec { + type Encode = arrow::FlightData; + type Decode = arrow::FlightData; + type Encoder = FlightDataEncoder; + type Decoder = as Codec>::Decoder; + + fn encoder(&mut self) -> Self::Encoder { + FlightDataEncoder::default() + } + + fn decoder(&mut self) -> Self::Decoder { + ProstCodec::<(), arrow::FlightData>::default().decoder() + } +} + +#[derive(Default)] +pub(crate) struct FlightDataEncoder; + +impl Encoder for FlightDataEncoder { + type Item = arrow::FlightData; + type Error = Status; + + fn encode(&mut self, mut item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> { + let body = mem::take(&mut item.data_body); + item.encode_raw(buf); + if body.has_remaining() { + prost::encoding::encode_key(1000, prost::encoding::WireType::LengthDelimited, buf); + prost::encoding::encode_varint(body.len() as u64, buf); + buf.insert_slice(body); + } + Ok(()) + } +} diff --git a/benches/arrow-flight/src/lib.rs b/benches/arrow-flight/src/lib.rs new file mode 100644 index 000000000..ba7e583f0 --- /dev/null +++ b/benches/arrow-flight/src/lib.rs @@ -0,0 +1,11 @@ +pub mod client; +mod codec; +pub mod server; + +pub mod arrow { + tonic::include_proto!("arrow.flight.protocol"); + + pub mod manual { + tonic::include_proto!("arrow.flight.protocol.FlightService"); + } +} diff --git a/benches/arrow-flight/src/server.rs b/benches/arrow-flight/src/server.rs new file mode 100644 index 000000000..32b266e19 --- /dev/null +++ b/benches/arrow-flight/src/server.rs @@ -0,0 +1,56 @@ +use crate::arrow; +use tokio::sync::mpsc; +use tonic::{ + codegen::{ + tokio_stream::{wrappers::ReceiverStream, StreamExt}, + BoxStream, + }, + Request, Response, Status, Streaming, +}; + +pub struct FlightService { + payload: arrow::FlightData, +} + +impl FlightService { + pub fn new(payload: arrow::FlightData) -> Self { + FlightService { payload } + } + + async fn exchange( + &self, + request: Request>, + ) -> Result>, Status> { + let mut stream = request.into_inner(); + let payload = self.payload.clone(); + let (tx, rx) = mpsc::channel(8192); + tokio::spawn(async move { + while let Some(_data) = stream.next().await.transpose().unwrap() { + tx.send(Ok(payload.clone())).await.unwrap(); + } + }); + Ok(Response::new(Box::pin(ReceiverStream::new(rx)))) + } +} + +#[tonic::async_trait] +impl arrow::flight_service_server::FlightService for FlightService { + async fn do_exchange( + &self, + request: Request>, + ) -> Result>, Status> { + self.exchange(request).await + } +} + +#[tonic::async_trait] +impl arrow::manual::flight_service_server::FlightService for FlightService { + type DoExchangeStream = BoxStream; + + async fn do_exchange( + &self, + request: Request>, + ) -> Result, Status> { + self.exchange(request).await + } +} diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 5c3e7e8b5..5a6da9b5e 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -285,7 +285,7 @@ autoreload = ["tokio-stream/net", "dep:listenfd"] health = ["dep:tonic-health"] grpc-web = ["dep:tonic-web", "dep:bytes", "dep:http", "dep:hyper", "dep:tracing-subscriber", "dep:tower"] tracing = ["dep:tracing", "dep:tracing-subscriber"] -hyper-warp = ["dep:either", "dep:tower", "dep:hyper", "dep:http", "dep:http-body", "dep:warp"] +hyper-warp = ["dep:either", "dep:tower", "dep:hyper", "dep:http", "dep:http-body", "dep:warp", "dep:bytes"] hyper-warp-multiplex = ["hyper-warp"] uds = ["tokio-stream/net", "dep:tower", "dep:hyper"] streaming = ["tokio-stream", "dep:h2"] diff --git a/examples/src/h2c/server.rs b/examples/src/h2c/server.rs index 92d08a417..9aa7ec929 100644 --- a/examples/src/h2c/server.rs +++ b/examples/src/h2c/server.rs @@ -61,7 +61,7 @@ mod h2c { impl Service> for H2c where - S: Service, Response = Response> + S: Service, Response = Response> + Clone + Send + 'static, diff --git a/examples/src/hyper_warp/server.rs b/examples/src/hyper_warp/server.rs index a79caf401..323499c98 100644 --- a/examples/src/hyper_warp/server.rs +++ b/examples/src/hyper_warp/server.rs @@ -3,8 +3,10 @@ //! To hit the warp server you can run this command: //! `curl localhost:50051/hello` +use bytes::Buf; use either::Either; use http::version::Version; +use http_body::Body; use hyper::{service::make_service_fn, Server}; use std::convert::Infallible; use std::{ @@ -65,7 +67,13 @@ async fn main() -> Result<(), Box> { Version::HTTP_2 => Either::Right({ let res = tonic.call(req); Box::pin(async move { - let res = res.await.map(|res| res.map(EitherBody::Right))?; + let res = res.await.map(|res| { + res.map(|body| { + EitherBody::Right( + body.map_data(|mut buf| buf.copy_to_bytes(buf.remaining())), + ) + }) + })?; Ok::<_, Error>(res) }) }), diff --git a/examples/src/hyper_warp_multiplex/server.rs b/examples/src/hyper_warp_multiplex/server.rs index deea8bea5..ea83c170c 100644 --- a/examples/src/hyper_warp_multiplex/server.rs +++ b/examples/src/hyper_warp_multiplex/server.rs @@ -5,13 +5,14 @@ use either::Either; use http::version::Version; +use http_body::Body; use hyper::{service::make_service_fn, Server}; use std::convert::Infallible; use std::{ pin::Pin, task::{Context, Poll}, }; -use tonic::{transport::Server as TonicServer, Request, Response, Status}; +use tonic::{codec::SliceBuffer, transport::Server as TonicServer, Request, Response, Status}; use tower::Service; use warp::Filter; @@ -85,7 +86,10 @@ async fn main() -> Result<(), Box> { Version::HTTP_11 | Version::HTTP_10 => Either::Left({ let res = warp.call(req); Box::pin(async move { - let res = res.await.map(|res| res.map(EitherBody::Left))?; + let res = res.await.map(|res| { + res.map(|body| body.map_data(SliceBuffer::from)) + .map(EitherBody::Left) + })?; Ok::<_, Error>(res) }) }), diff --git a/tests/compression/src/util.rs b/tests/compression/src/util.rs index 28fa5d96a..8d8d80982 100644 --- a/tests/compression/src/util.rs +++ b/tests/compression/src/util.rs @@ -11,8 +11,10 @@ use std::{ task::{ready, Context, Poll}, }; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -use tonic::codec::CompressionEncoding; -use tonic::transport::{server::Connected, Channel}; +use tonic::{ + codec::{CompressionEncoding, SliceBuffer}, + transport::{server::Connected, Channel}, +}; use tower_http::map_request_body::MapRequestBodyLayer; macro_rules! parametrized_tests { @@ -41,7 +43,7 @@ pub struct CountBytesBody { impl Body for CountBytesBody where - B: Body, + B: Body, { type Data = B::Data; type Error = B::Error; diff --git a/tonic-build/src/client.rs b/tonic-build/src/client.rs index 4023cb64b..c82b346d3 100644 --- a/tonic-build/src/client.rs +++ b/tonic-build/src/client.rs @@ -68,7 +68,8 @@ pub(crate) fn generate_internal( where T: tonic::client::GrpcService, T::Error: Into, - T::ResponseBody: Body + Send + 'static, + T::ResponseBody: Body + Send + 'static, + ::Data: Into + Send, ::Error: Into + Send, { pub fn new(inner: T) -> Self { diff --git a/tonic-health/src/generated/grpc_health_v1.rs b/tonic-health/src/generated/grpc_health_v1.rs index 9662361f0..f1c96c4c6 100644 --- a/tonic-health/src/generated/grpc_health_v1.rs +++ b/tonic-health/src/generated/grpc_health_v1.rs @@ -69,7 +69,8 @@ pub mod health_client { where T: tonic::client::GrpcService, T::Error: Into, - T::ResponseBody: Body + Send + 'static, + T::ResponseBody: Body + Send + 'static, + ::Data: Into + Send, ::Error: Into + Send, { pub fn new(inner: T) -> Self { diff --git a/tonic-reflection/src/generated/grpc_reflection_v1alpha.rs b/tonic-reflection/src/generated/grpc_reflection_v1alpha.rs index 7efa6d51a..ce983374d 100644 --- a/tonic-reflection/src/generated/grpc_reflection_v1alpha.rs +++ b/tonic-reflection/src/generated/grpc_reflection_v1alpha.rs @@ -166,7 +166,8 @@ pub mod server_reflection_client { where T: tonic::client::GrpcService, T::Error: Into, - T::ResponseBody: Body + Send + 'static, + T::ResponseBody: Body + Send + 'static, + ::Data: Into + Send, ::Error: Into + Send, { pub fn new(inner: T) -> Self { diff --git a/tonic-web/src/call.rs b/tonic-web/src/call.rs index 731b9f667..9cd4c76e4 100644 --- a/tonic-web/src/call.rs +++ b/tonic-web/src/call.rs @@ -1,20 +1,26 @@ -use std::error::Error; -use std::pin::Pin; -use std::task::{ready, Context, Poll}; +use base64::{engine::Config, Engine}; +use std::{ + error::Error, + io::{self, Seek}, + pin::Pin, + task::{ready, Context, Poll}, +}; -use base64::Engine as _; use bytes::{Buf, BufMut, Bytes, BytesMut}; use http::{header, HeaderMap, HeaderName, HeaderValue}; use http_body::{Body, SizeHint}; use pin_project::pin_project; use tokio_stream::Stream; -use tonic::Status; +use tonic::{codec::SliceBuffer, Status}; use self::content_types::*; +use super::internal_error; // A grpc header is u8 (flag) + u32 (msg len) const GRPC_HEADER_SIZE: usize = 1 + 4; +const SLICE_SIZE: usize = 128; + pub(crate) mod content_types { use http::{header::CONTENT_TYPE, HeaderMap}; @@ -35,8 +41,6 @@ pub(crate) mod content_types { } } -const BUFFER_SIZE: usize = 8 * 1024; - const FRAME_HEADER_SIZE: usize = 5; // 8th (MSB) bit of the 1st gRPC frame byte @@ -62,7 +66,7 @@ pub(crate) enum Encoding { pub struct GrpcWebCall { #[pin] inner: B, - buf: BytesMut, + buf: SliceBuffer, direction: Direction, encoding: Encoding, poll_trailers: bool, @@ -104,10 +108,7 @@ impl GrpcWebCall { fn new_client(inner: B, direction: Direction, encoding: Encoding) -> Self { GrpcWebCall { inner, - buf: BytesMut::with_capacity(match (direction, encoding) { - (Direction::Encode, Encoding::Base64) => BUFFER_SIZE, - _ => 0, - }), + buf: SliceBuffer::with_capacity(SLICE_SIZE, 0), direction, encoding, poll_trailers: true, @@ -119,10 +120,7 @@ impl GrpcWebCall { fn new(inner: B, direction: Direction, encoding: Encoding) -> Self { GrpcWebCall { inner, - buf: BytesMut::with_capacity(match (direction, encoding) { - (Direction::Encode, Encoding::Base64) => BUFFER_SIZE, - _ => 0, - }), + buf: SliceBuffer::with_capacity(SLICE_SIZE, 0), direction, encoding, poll_trailers: true, @@ -147,33 +145,36 @@ impl GrpcWebCall { // Split `buf` at the largest index that is multiple of 4. Decode the // returned `Bytes`, keeping the rest for the next attempt to decode. let index = self.max_decodable(); - - crate::util::base64::STANDARD - .decode(self.as_mut().project().buf.split_to(index)) - .map(|decoded| Some(Bytes::from(decoded))) - .map_err(internal_error) + let mut decoder = base64::read::DecoderReader::new( + self.as_mut().project().buf.split_to(index).reader(), + &crate::util::base64::STANDARD, + ); + let mut buf = BytesMut::with_capacity(base64::decoded_len_estimate(index)).writer(); + io::copy(&mut decoder, &mut buf).map_err(internal_error)?; + Ok(Some(buf.into_inner().freeze())) } } impl GrpcWebCall where - B: Body, + B: Body, + B::Data: Into, B::Error: Error, { fn poll_decode( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll>> { match self.encoding { Encoding::Base64 => loop { if let Some(bytes) = self.as_mut().decode_chunk()? { - return Poll::Ready(Some(Ok(bytes))); + return Poll::Ready(Some(Ok(bytes.into()))); } let mut this = self.as_mut().project(); match ready!(this.inner.as_mut().poll_data(cx)) { - Some(Ok(data)) => this.buf.put(data), + Some(Ok(data)) => this.buf.append(&mut data.into()), Some(Err(e)) => return Poll::Ready(Some(Err(internal_error(e)))), None => { return if this.buf.has_remaining() { @@ -186,7 +187,7 @@ where }, Encoding::None => match ready!(self.project().inner.poll_data(cx)) { - Some(res) => Poll::Ready(Some(res.map_err(internal_error))), + Some(res) => Poll::Ready(Some(res.map(Into::into).map_err(internal_error))), None => Poll::Ready(None), }, } @@ -195,15 +196,30 @@ where fn poll_encode( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll>> { let mut this = self.as_mut().project(); - if let Some(mut res) = ready!(this.inner.as_mut().poll_data(cx)) { + if let Some(res) = ready!(this.inner.as_mut().poll_data(cx)) { + let mut res = res.map(Into::into).map_err(internal_error); if *this.encoding == Encoding::Base64 { - res = res.map(|b| crate::util::base64::STANDARD.encode(b).into()) + res = res.and_then(|buf| { + let mut encoder = base64::write::EncoderWriter::new( + BytesMut::with_capacity( + base64::encoded_len( + buf.remaining(), + crate::util::base64::STANDARD.config().encode_padding(), + ) + .unwrap_or_default(), + ) + .writer(), + &crate::util::base64::STANDARD, + ); + io::copy(&mut buf.reader(), &mut encoder).map_err(internal_error)?; + Ok(encoder.into_inner().into_inner().freeze().into()) + }) } - return Poll::Ready(Some(res.map_err(internal_error))); + return Poll::Ready(Some(res)); } // this flag is needed because the inner stream never @@ -214,7 +230,7 @@ where let mut frame = make_trailers_frame(map); if *this.encoding == Encoding::Base64 { - frame = crate::util::base64::STANDARD.encode(frame).into_bytes(); + frame = crate::util::base64::STANDARD.encode(frame).into(); } *this.poll_trailers = false; @@ -231,10 +247,11 @@ where impl Body for GrpcWebCall where - B: Body, + B: Body, + B::Data: Buf + Into + From, B::Error: Error, { - type Data = Bytes; + type Data = B::Data; type Error = Status; fn poll_data( @@ -245,7 +262,7 @@ where let mut me = self.as_mut(); loop { - let incoming_buf = match ready!(me.as_mut().poll_decode(cx)) { + let mut incoming_buf = match ready!(me.as_mut().poll_decode(cx)) { Some(Ok(incoming_buf)) => incoming_buf, None => { // TODO: Consider eofing here? @@ -258,13 +275,13 @@ where let buf = &mut me.as_mut().project().buf; - buf.put(incoming_buf); + buf.append(&mut incoming_buf); - return match find_trailers(&buf[..])? { + return match find_trailers(buf)? { FindTrailers::Trailer(len) => { // Extract up to len of where the trailers are at - let msg_buf = buf.copy_to_bytes(len); - match decode_trailers_frame(buf.split().freeze()) { + let msg_buf = buf.split_to(len); + match decode_trailers_frame(buf) { Ok(Some(trailers)) => { self.project().trailers.replace(trailers); } @@ -273,20 +290,24 @@ where } if msg_buf.has_remaining() { - Poll::Ready(Some(Ok(msg_buf))) + Poll::Ready(Some(Ok(msg_buf.into()))) } else { Poll::Ready(None) } } FindTrailers::IncompleteBuf => continue, - FindTrailers::Done(len) => Poll::Ready(Some(Ok(buf.split_to(len).freeze()))), + FindTrailers::Done(len) => Poll::Ready(Some(Ok(buf.split_to(len).into()))), }; } } match self.direction { - Direction::Decode => self.poll_decode(cx), - Direction::Encode => self.poll_encode(cx), + Direction::Decode => self + .poll_decode(cx) + .map(|buf| buf.map(|buf| buf.map(Into::into))), + Direction::Encode => self + .poll_encode(cx) + .map(|buf| buf.map(|buf| buf.map(Into::into))), Direction::Empty => Poll::Ready(None), } } @@ -310,10 +331,11 @@ where impl Stream for GrpcWebCall where - B: Body, + B: Body, + B::Data: Buf + Into + From, B::Error: Error, { - type Item = Result; + type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Body::poll_data(self, cx) @@ -344,45 +366,30 @@ impl Encoding { } } -fn internal_error(e: impl std::fmt::Display) -> Status { - Status::internal(format!("tonic-web: {}", e)) -} - -// Key-value pairs encoded as a HTTP/1 headers block (without the terminating newline) -fn encode_trailers(trailers: HeaderMap) -> Vec { - trailers.iter().fold(Vec::new(), |mut acc, (key, value)| { - acc.put_slice(key.as_ref()); - acc.push(b':'); - acc.put_slice(value.as_bytes()); - acc.put_slice(b"\r\n"); - acc - }) -} - -fn decode_trailers_frame(mut buf: Bytes) -> Result, Status> { +fn decode_trailers_frame(buf: &mut SliceBuffer) -> Result, Status> { if buf.remaining() < GRPC_HEADER_SIZE { return Ok(None); } - buf.get_u8(); - buf.get_u32(); - + buf.advance(GRPC_HEADER_SIZE); let mut map = HeaderMap::new(); - let mut temp_buf = buf.clone(); - let mut trailers = Vec::new(); - let mut cursor_pos = 0; - - for (i, b) in buf.iter().enumerate() { - if b == &b'\r' && buf.get(i + 1) == Some(&b'\n') { - let trailer = temp_buf.copy_to_bytes(i - cursor_pos); - cursor_pos = i; - trailers.push(trailer); - if temp_buf.has_remaining() { - temp_buf.get_u8(); - temp_buf.get_u8(); + + loop { + let mut len = 0; + let mut last_b = b'0'; + for (i, b) in buf.iter().enumerate() { + if b == b'\n' && last_b == b'\r' { + len = i; + break; } + last_b = b; } + if len == 0 { + break; + } + trailers.push(buf.copy_to_bytes(len - 1)); + buf.advance(2); } for trailer in trailers { @@ -409,35 +416,45 @@ fn decode_trailers_frame(mut buf: Bytes) -> Result, Status> { Ok(Some(map)) } -fn make_trailers_frame(trailers: HeaderMap) -> Vec { - let trailers = encode_trailers(trailers); - let len = trailers.len(); - assert!(len <= u32::MAX as usize); - - let mut frame = Vec::with_capacity(len + FRAME_HEADER_SIZE); - frame.push(GRPC_WEB_TRAILERS_BIT); - frame.put_u32(len as u32); - frame.extend(trailers); - - frame +fn make_trailers_frame(trailers: HeaderMap) -> Bytes { + let encoded_len: usize = trailers + .iter() + .map(|(key, value)| { + key.as_str().len() + + 1 /* b':' */ + + value.as_bytes().len() + + 2 /* b"\r\n" */ + }) + .sum(); + + let mut frame = BytesMut::with_capacity(encoded_len + FRAME_HEADER_SIZE); + frame.put_u8(GRPC_WEB_TRAILERS_BIT); + frame.put_u32(encoded_len as u32); + for (key, value) in trailers.iter() { + frame.put_slice(key.as_ref()); + frame.put_u8(b':'); + frame.put_slice(value.as_bytes()); + frame.put_slice(b"\r\n"); + } + frame.freeze() } /// Search some buffer for grpc-web trailers headers and return /// its location in the original buf. If `None` is returned we did /// not find a trailers in this buffer either because its incomplete /// or the buffer just contained grpc message frames. -fn find_trailers(buf: &[u8]) -> Result { +fn find_trailers(buf: &mut SliceBuffer) -> Result { let mut len = 0; - let mut temp_buf = buf; + let mut cursor = buf.cursor(); loop { // To check each frame, there must be at least GRPC_HEADER_SIZE // amount of bytes available otherwise the buffer is incomplete. - if temp_buf.is_empty() || temp_buf.len() < GRPC_HEADER_SIZE { + if cursor.remaining() < GRPC_HEADER_SIZE { return Ok(FindTrailers::Done(len)); } - let header = temp_buf.get_u8(); + let header = cursor.get_u8(); if header == GRPC_WEB_TRAILERS_BIT { return Ok(FindTrailers::Trailer(len)); @@ -447,17 +464,14 @@ fn find_trailers(buf: &[u8]) -> Result { return Err(Status::internal("Invalid header bit {} expected 0 or 1")); } - let msg_len = temp_buf.get_u32(); - + let msg_len = cursor.get_u32(); len += msg_len as usize + 4 + 1; // If the msg len of a non-grpc-web trailer frame is larger than // the overall buffer we know within that buffer there are no trailers. - if len > buf.len() { + if cursor.seek(io::SeekFrom::Current(msg_len as i64)).is_err() { return Ok(FindTrailers::IncompleteBuf); } - - temp_buf = &buf[len..]; } } @@ -503,9 +517,9 @@ mod tests { let trailers = make_trailers_frame(headers.clone()); - let buf = Bytes::from(trailers); + let mut buf = SliceBuffer::from(trailers); - let map = decode_trailers_frame(buf).unwrap().unwrap(); + let map = decode_trailers_frame(&mut buf).unwrap().unwrap(); assert_eq!(headers, map); } @@ -514,12 +528,11 @@ mod tests { fn find_trailers_non_buffered() { // Byte version of this: // b"\x80\0\0\0\x0fgrpc-status:0\r\n" - let buf = [ + let mut buf = SliceBuffer::from(Bytes::from(vec![ 128, 0, 0, 0, 15, 103, 114, 112, 99, 45, 115, 116, 97, 116, 117, 115, 58, 48, 13, 10, - ]; - - let out = find_trailers(&buf[..]).unwrap(); + ])); + let out = find_trailers(&mut buf).unwrap(); assert_eq!(out, FindTrailers::Trailer(0)); } @@ -527,28 +540,26 @@ mod tests { fn find_trailers_buffered() { // Byte version of this: // b"\0\0\0\0L\n$975738af-1a17-4aea-b887-ed0bbced6093\x1a$da609e9b-f470-4cc0-a691-3fd6a005a436\x80\0\0\0\x0fgrpc-status:0\r\n" - let buf = [ + let mut buf = SliceBuffer::from(Bytes::from(vec![ 0, 0, 0, 0, 76, 10, 36, 57, 55, 53, 55, 51, 56, 97, 102, 45, 49, 97, 49, 55, 45, 52, 97, 101, 97, 45, 98, 56, 56, 55, 45, 101, 100, 48, 98, 98, 99, 101, 100, 54, 48, 57, 51, 26, 36, 100, 97, 54, 48, 57, 101, 57, 98, 45, 102, 52, 55, 48, 45, 52, 99, 99, 48, 45, 97, 54, 57, 49, 45, 51, 102, 100, 54, 97, 48, 48, 53, 97, 52, 51, 54, 128, 0, 0, 0, 15, 103, 114, 112, 99, 45, 115, 116, 97, 116, 117, 115, 58, 48, 13, 10, - ]; - - let out = find_trailers(&buf[..]).unwrap(); + ])); + let out = find_trailers(&mut buf).unwrap(); assert_eq!(out, FindTrailers::Trailer(81)); - let trailers = decode_trailers_frame(Bytes::copy_from_slice(&buf[81..])) - .unwrap() - .unwrap(); + buf.advance(81); + let trailers = decode_trailers_frame(&mut buf).unwrap().unwrap(); let status = trailers.get("grpc-status").unwrap(); assert_eq!(status.to_str().unwrap(), "0") } #[test] fn find_trailers_buffered_incomplete_message() { - let buf = vec![ + let mut buf = SliceBuffer::from(Bytes::from(vec![ 0, 0, 0, 9, 238, 10, 233, 19, 18, 230, 19, 10, 9, 10, 1, 120, 26, 4, 84, 69, 88, 84, 18, 60, 10, 58, 10, 56, 3, 0, 0, 0, 44, 0, 0, 0, 0, 0, 0, 0, 116, 104, 105, 115, 32, 118, 97, 108, 117, 101, 32, 119, 97, 115, 32, 119, 114, 105, 116, 116, 101, 110, 32, @@ -575,18 +586,19 @@ mod tests { 101, 112, 108, 105, 99, 97, 33, 18, 62, 10, 60, 10, 58, 3, 0, 0, 0, 46, 0, 0, 0, 0, 0, 0, 0, 116, 104, 105, 115, 32, 118, 97, 108, 117, 101, 32, 119, 97, 115, 32, 119, 114, 105, 116, 116, 101, 110, 32, 98, 121, 32, - ]; - - let out = find_trailers(&buf[..]).unwrap(); + ])); + let out = find_trailers(&mut buf).unwrap(); assert_eq!(out, FindTrailers::IncompleteBuf); } #[test] #[ignore] fn find_trailers_buffered_incomplete_buf_bug() { - let buf = std::fs::read("tests/incomplete-buf-bug.bin").unwrap(); - let out = find_trailers(&buf[..]).unwrap_err(); + let mut buf = SliceBuffer::from(Bytes::from( + std::fs::read("tests/incomplete-buf-bug.bin").unwrap(), + )); + let out = find_trailers(&mut buf).unwrap_err(); assert_eq!(out.code(), Code::Internal); } diff --git a/tonic-web/src/client.rs b/tonic-web/src/client.rs index 774fe5fdd..cf02f339c 100644 --- a/tonic-web/src/client.rs +++ b/tonic-web/src/client.rs @@ -1,4 +1,3 @@ -use bytes::Bytes; use http::header::CONTENT_TYPE; use http::{Request, Response, Version}; use http_body::Body; @@ -7,6 +6,7 @@ use std::error::Error; use std::future::Future; use std::pin::Pin; use std::task::{ready, Context, Poll}; +use tonic::codec::SliceBuffer; use tower_layer::Layer; use tower_service::Service; use tracing::debug; @@ -60,7 +60,9 @@ impl Service> for GrpcWebClientService where S: Service>, Response = Response>, B1: Body, - B2: Body, + B1::Data: Into, + B2: Body, + B2::Data: Into, B2::Error: Error, { type Response = Response>; @@ -100,7 +102,7 @@ pub struct ResponseFuture { impl Future for ResponseFuture where - B: Body, + B: Body, F: Future, E>>, { type Output = Result>, E>; diff --git a/tonic-web/src/layer.rs b/tonic-web/src/layer.rs index 77b03c77e..e4134fd36 100644 --- a/tonic-web/src/layer.rs +++ b/tonic-web/src/layer.rs @@ -1,7 +1,6 @@ -use super::{BoxBody, BoxError, GrpcWebService}; +use super::GrpcWebService; use tower_layer::Layer; -use tower_service::Service; /// Layer implementing the grpc-web protocol. #[derive(Debug, Clone)] @@ -22,13 +21,7 @@ impl Default for GrpcWebLayer { } } -impl Layer for GrpcWebLayer -where - S: Service, Response = http::Response>, - S: Send + 'static, - S::Future: Send + 'static, - S::Error: Into + Send, -{ +impl Layer for GrpcWebLayer { type Service = GrpcWebService; fn layer(&self, inner: S) -> Self::Service { diff --git a/tonic-web/src/lib.rs b/tonic-web/src/lib.rs index cc11ed56b..981df0237 100644 --- a/tonic-web/src/lib.rs +++ b/tonic-web/src/lib.rs @@ -109,7 +109,7 @@ mod service; use http::header::HeaderName; use std::time::Duration; -use tonic::{body::BoxBody, server::NamedService}; +use tonic::{body::BoxBody, server::NamedService, Status}; use tower_http::cors::{AllowOrigin, CorsLayer}; use tower_layer::Layer; use tower_service::Service; @@ -190,6 +190,10 @@ where const NAME: &'static str = S::NAME; } +pub(crate) fn internal_error(e: impl std::fmt::Display) -> Status { + Status::internal(format!("tonic-web: {}", e)) +} + pub(crate) mod util { pub(crate) mod base64 { use base64::{ diff --git a/tonic-web/src/service.rs b/tonic-web/src/service.rs index af4c5276f..2153bce5c 100644 --- a/tonic-web/src/service.rs +++ b/tonic-web/src/service.rs @@ -216,7 +216,7 @@ fn coerce_request(mut req: Request, encoding: Encoding) -> Request { HeaderValue::from_static("identity,deflate,gzip"), ); - req.map(|b| GrpcWebCall::request(b, encoding)) + req.map(|body| GrpcWebCall::request(body, encoding)) .map(Body::wrap_stream) } diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index 013cc6e72..18e384591 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -34,7 +34,7 @@ tls-roots-common = ["tls"] tls-webpki-roots = ["tls-roots-common", "dep:webpki-roots"] transport = [ "dep:async-stream", - "dep:axum", + "dep:matchit", "channel", "dep:h2", "dep:hyper", @@ -74,7 +74,7 @@ hyper = {version = "0.14.26", features = ["full"], optional = true} hyper-timeout = {version = "0.4", optional = true} tokio-stream = "0.1" tower = {version = "0.4.7", default-features = false, features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true} -axum = {version = "0.6.9", default_features = false, optional = true} +matchit = { version = "0.7", optional = true } # rustls async-stream = { version = "0.3", optional = true } @@ -94,7 +94,7 @@ quickcheck = "1.0" quickcheck_macros = "1.0" rand = "0.8" static_assertions = "1.0" -tokio = {version = "1.0", features = ["rt", "macros"]} +tokio = {version = "1.0", features = ["rt-multi-thread", "macros"]} tower = {version = "0.4.7", features = ["full"]} [package.metadata.docs.rs] diff --git a/tonic/benches/decode.rs b/tonic/benches/decode.rs index 5c7cd0159..a3ca1e569 100644 --- a/tonic/benches/decode.rs +++ b/tonic/benches/decode.rs @@ -97,13 +97,11 @@ impl MockDecoder { } impl Decoder for MockDecoder { - type Item = Vec; + type Item = Bytes; type Error = Status; fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result, Self::Error> { - let out = Vec::from(buf.chunk()); - buf.advance(self.message_size); - Ok(Some(out)) + Ok(Some(buf.copy_to_bytes(self.message_size))) } } diff --git a/tonic/src/body.rs b/tonic/src/body.rs index ef95eec47..0de3c91de 100644 --- a/tonic/src/body.rs +++ b/tonic/src/body.rs @@ -1,17 +1,21 @@ //! HTTP specific body utilities. -use http_body::Body; +use crate::codec::SliceBuffer; +use http_body::{combinators::UnsyncBoxBody, Body}; /// A type erased HTTP body used for tonic services. -pub type BoxBody = http_body::combinators::UnsyncBoxBody; +pub type BoxBody = UnsyncBoxBody; /// Convert a [`http_body::Body`] into a [`BoxBody`]. pub(crate) fn boxed(body: B) -> BoxBody where - B: http_body::Body + Send + 'static, + B: http_body::Body + Send + 'static, + B::Data: Into, B::Error: Into, { - body.map_err(crate::Status::map_error).boxed_unsync() + body.map_data(Into::into) + .map_err(crate::Status::map_error) + .boxed_unsync() } /// Create an empty `BoxBody` diff --git a/tonic/src/client/grpc.rs b/tonic/src/client/grpc.rs index e070f08d3..22a4bff5e 100644 --- a/tonic/src/client/grpc.rs +++ b/tonic/src/client/grpc.rs @@ -1,7 +1,7 @@ -use crate::codec::compression::{CompressionEncoding, EnabledCompressionEncodings}; use crate::{ - body::BoxBody, + body::{boxed, BoxBody}, client::GrpcService, + codec::compression::{CompressionEncoding, EnabledCompressionEncodings}, codec::{encode_client, Codec, Decoder, Streaming}, request::SanitizeHeaders, Code, Request, Response, Status, @@ -301,7 +301,7 @@ impl Grpc { self.config.max_encoding_message_size, ) }) - .map(BoxBody::new); + .map(boxed); let request = self.config.prepare_request(request, path); @@ -318,13 +318,15 @@ impl Grpc { // Keeping this code in a separate function from Self::streaming lets functions that return the // same output share the generated binary code - fn create_response( + fn create_response( &self, decoder: impl Decoder + Send + 'static, response: http::Response, ) -> Result>, Status> where - T: GrpcService, + B: Body, + ::Error: Into, + T: GrpcService, T::ResponseBody: Body + Send + 'static, ::Error: Into, { @@ -367,11 +369,11 @@ impl Grpc { } impl GrpcConfig { - fn prepare_request( - &self, - request: Request, - path: PathAndQuery, - ) -> http::Request { + fn prepare_request(&self, request: Request, path: PathAndQuery) -> http::Request + where + B: Body + 'static, + B::Error: Into, + { let mut parts = self.origin.clone().into_parts(); match &parts.path_and_query { diff --git a/tonic/src/codec/buffer.rs b/tonic/src/codec/buffer.rs index fcce82e1d..c0341a7d0 100644 --- a/tonic/src/codec/buffer.rs +++ b/tonic/src/codec/buffer.rs @@ -1,21 +1,21 @@ -use bytes::buf::UninitSlice; -use bytes::{Buf, BufMut, BytesMut}; +use bytes::{buf::UninitSlice, Buf, BufMut, Bytes, BytesMut}; +use std::{cmp, collections::VecDeque, io, iter, ops::Deref}; /// A specialized buffer to decode gRPC messages from. #[derive(Debug)] pub struct DecodeBuf<'a> { - buf: &'a mut BytesMut, + buf: &'a mut SliceBuffer, len: usize, } /// A specialized buffer to encode gRPC messages into. #[derive(Debug)] pub struct EncodeBuf<'a> { - buf: &'a mut BytesMut, + buf: &'a mut SliceBuffer, } impl<'a> DecodeBuf<'a> { - pub(crate) fn new(buf: &'a mut BytesMut, len: usize) -> Self { + pub(crate) fn new(buf: &'a mut SliceBuffer, len: usize) -> Self { DecodeBuf { buf, len } } } @@ -43,10 +43,17 @@ impl Buf for DecodeBuf<'_> { self.buf.advance(cnt); self.len -= cnt; } + + #[inline] + fn copy_to_bytes(&mut self, len: usize) -> Bytes { + assert!(len <= self.len); + self.len -= len; + self.buf.copy_to_bytes(len) + } } impl<'a> EncodeBuf<'a> { - pub(crate) fn new(buf: &'a mut BytesMut) -> Self { + pub(crate) fn new(buf: &'a mut SliceBuffer) -> Self { EncodeBuf { buf } } } @@ -61,6 +68,15 @@ impl EncodeBuf<'_> { pub fn reserve(&mut self, additional: usize) { self.buf.reserve(additional); } + + /// Inserts a byte slice directly into the underlying [`SliceBuffer`] without copying. + /// Instead of copying the data to the active buffer, it appends to the collection + /// of slices within the [`SliceBuffer`]. This operation completes in constant time, + /// provided no memory reallocation occurs. + #[inline] + pub fn insert_slice(&mut self, slice: Bytes) { + self.buf.insert_slice(slice) + } } unsafe impl BufMut for EncodeBuf<'_> { @@ -78,15 +94,723 @@ unsafe impl BufMut for EncodeBuf<'_> { fn chunk_mut(&mut self) -> &mut UninitSlice { self.buf.chunk_mut() } + + #[inline] + fn put(&mut self, src: T) + where + Self: Sized, + { + self.buf.put(src) + } + + #[inline] + fn put_slice(&mut self, src: &[u8]) { + self.buf.put_slice(src) + } + + #[inline] + fn put_bytes(&mut self, val: u8, cnt: usize) { + self.buf.put_bytes(val, cnt) + } +} + +#[derive(Debug, PartialEq)] +enum Slice { + Fixed(Bytes), + Mutable(BytesMut), +} + +impl Slice { + #[inline] + fn len(&self) -> usize { + match self { + Slice::Fixed(buf) => buf.len(), + Slice::Mutable(buf) => buf.len(), + } + } + + #[inline] + fn split_to(&mut self, at: usize) -> Self { + match self { + Slice::Fixed(buf) => Slice::Fixed(buf.split_to(at)), + Slice::Mutable(buf) => Slice::Mutable(buf.split_to(at)), + } + } +} + +impl Default for Slice { + fn default() -> Self { + Slice::Fixed(Default::default()) + } +} + +impl Buf for Slice { + #[inline] + fn remaining(&self) -> usize { + match self { + Slice::Fixed(buf) => buf.remaining(), + Slice::Mutable(buf) => buf.remaining(), + } + } + + #[inline] + fn chunk(&self) -> &[u8] { + match self { + Slice::Fixed(buf) => buf.chunk(), + Slice::Mutable(buf) => buf.chunk(), + } + } + + #[inline] + fn advance(&mut self, cnt: usize) { + match self { + Slice::Fixed(buf) => buf.advance(cnt), + Slice::Mutable(buf) => buf.advance(cnt), + } + } + + #[inline] + fn copy_to_bytes(&mut self, len: usize) -> Bytes { + match self { + Slice::Fixed(buf) => buf.copy_to_bytes(len), + Slice::Mutable(buf) => buf.copy_to_bytes(len), + } + } +} + +impl Deref for Slice { + type Target = [u8]; + + #[inline] + fn deref(&self) -> &[u8] { + match self { + Slice::Fixed(buf) => buf.deref(), + Slice::Mutable(buf) => buf.deref(), + } + } +} + +/// `SliceBuffer` represents a buffer containing non-contiguous memory segments, which implements +/// [`bytes::Buf`] and [`bytes::BufMut`]. While traditional buffers typically rely on contiguous +/// memory, `SliceBuffer` offers a unique design allowing the seamless insertion of immutable byte +/// chunks without necessitating memory copying or potential memory reallocation. +/// +/// Internally, `SliceBuffer` consists of two main components: +/// 1. A collection of slices. +/// 2. An active buffer. +/// +/// Together, these components present the bytes as a concatenated sequence of all the slices and +/// the active buffer. The active buffer behaves similarly to [`bytes::BytesMut`], copying inserted +/// bytes into it. For larger slices represented as [`bytes::Bytes`], they can be directly appended +/// to the slice collection, bypassing data copying. +#[derive(Default, Debug)] +pub struct SliceBuffer { + active: BytesMut, + len: usize, + slices: VecDeque, +} + +impl SliceBuffer { + /// Constructs a new `SliceBuffer` with given capacities for both the slice collection and the + /// active buffer. The created `SliceBuffer` ensures the slice collection can accommodate + /// at least `slice_capacity` slices, and the active buffer can contain at least + /// `buffer_capacity` bytes. + /// + /// Note: This function determines the capacity, not the length, of the returned `SliceBuffer`. + #[inline] + pub fn with_capacity(slice_capacity: usize, buffer_capacity: usize) -> Self { + SliceBuffer { + active: match buffer_capacity { + 0 => BytesMut::new(), + _ => BytesMut::with_capacity(buffer_capacity), + }, + slices: match slice_capacity { + 0 => VecDeque::new(), + _ => VecDeque::with_capacity(slice_capacity), + }, + ..Default::default() + } + } + + /// Divides the `SliceBuffer` into two at the specified index. + /// + /// Following the split, `self` retains elements from `[at, len)`, while the returned + /// `SliceBuffer` encapsulates elements from `[0, at)`. This operation has a time complexity + /// of `O(N)`, where `N` denotes the number of slices in the `SliceBuffer`. + /// + /// # Examples + /// + /// ``` + /// use bytes::{Bytes, Buf}; + /// use tonic::codec::SliceBuffer; + /// + /// let mut buf1 = SliceBuffer::from(Bytes::from(&b"hello world"[..])); + /// let mut buf2 = buf1.split_to(5); + /// + /// assert_eq!(buf1.copy_to_bytes(buf1.len()), Bytes::from(&b" world"[..])); + /// assert_eq!(buf2.copy_to_bytes(buf2.len()), Bytes::from(&b"hello"[..])); + /// ``` + /// + /// # Panics + /// + /// This method will panic if `at` exceeds `len`. + pub fn split_to(&mut self, at: usize) -> Self { + assert!( + at <= self.len(), + "split_to out of bounds: {:?} <= {:?}", + at, + self.len(), + ); + + self.len -= at; + let at_pos = self.cursor().forward(&Position::default(), at); + let mut buf = SliceBuffer { + active: Default::default(), + len: at, + slices: VecDeque::with_capacity(at_pos.slice_idx + cmp::min(at_pos.rel_pos, 1)), + }; + buf.slices.extend(self.slices.drain(..at_pos.slice_idx)); + + if at_pos.rel_pos > 0 { + let slice = if let Some(slice) = self.slices.front_mut() { + slice.split_to(at_pos.rel_pos) + } else { + Slice::Mutable(self.active.split_to(at_pos.rel_pos)) + }; + buf.slices.push_back(slice); + } + buf + } + + /// Transfers all elements from `other` to `self`, emptying `other` in the process. Slices from + /// `other` are appended to `self`'s slice collection, while `other`'s active buffer is copied + /// over to `self`'s active buffer. + /// + /// # Examples + /// + /// ``` + /// use bytes::{Buf, BufMut, Bytes}; + /// use tonic::codec::SliceBuffer; + /// + /// let mut buf1 = SliceBuffer::default(); + /// buf1.put_slice(b"foo"); + /// let mut buf2 = SliceBuffer::default(); + /// buf2.insert_slice(Bytes::copy_from_slice(b"bar")); + /// buf2.put(Bytes::copy_from_slice(b"foo")); + /// buf1.append(&mut buf2); + /// assert_eq!(buf1.copy_to_bytes(buf1.len()), Bytes::copy_from_slice(b"foobarfoo")); + /// assert_eq!(buf2.len(), 0); + /// ``` + /// + /// # Panics + /// + /// This will panic if the total number of elements in `self` exceeds the maximum `usize` value. + pub fn append(&mut self, other: &mut Self) { + if other.slices.len() > 0 { + self.slices + .reserve(other.slices.len() + cmp::min(1, self.active.len())); + if self.active.has_remaining() { + self.slices + .push_back(Slice::Mutable(self.active.split_to(self.active.len()))); + } + self.slices.append(&mut other.slices); + } + self.active.put(other.active.split_to(other.active.len())); + self.len += other.len; + other.len = 0; + } + + /// Returns the number of bytes contained in this `SliceBuffer`. + /// + /// # Examples + /// + /// ``` + /// use bytes::Bytes; + /// use tonic::codec::SliceBuffer; + /// + /// let b = SliceBuffer::from(Bytes::copy_from_slice(b"hello world")); + /// assert_eq!(b.len(), 11); + /// ``` + #[inline] + pub fn len(&self) -> usize { + self.len + } + + /// Returns true if the `SliceBuffer` has a length of 0. + /// + /// # Examples + /// + /// ``` + /// use tonic::codec::SliceBuffer; + /// + /// let b = SliceBuffer::default(); + /// assert!(b.is_empty()); + /// ``` + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Reserves space for an additional `additional` bytes in the active buffer + /// of the `SliceBuffer`. + /// + /// This directly invokes the `reserve` method on the active buffer, which is of type + /// [`bytes::BytesMut`]. Refer to [`bytes::BytesMut::reserve`] for further details. + /// + /// # Panics + /// + /// This will panic if the new capacity exceeds the maximum `usize` value. + #[inline] + pub fn reserve(&mut self, additional: usize) { + self.active.reserve(additional); + } + + /// Inserts an immutable slice, represented by [`bytes::Bytes`], into the `SliceBuffer`. + /// + /// By directly adding the slice to the `SliceBuffer`'s slice collection, this operation + /// bypasses any memory copying. If there's existing content in the active buffer, it will be + /// transferred to the slice collection as a mutable slice prior to the insertion. + /// + /// # Examples + /// + /// ``` + /// use bytes::{BufMut, Bytes, Buf}; + /// use tonic::codec::SliceBuffer; + /// + /// let mut buf = SliceBuffer::with_capacity(128, 1024); + /// buf.put_slice(b"foo"); + /// buf.insert_slice(Bytes::copy_from_slice(b"bar")); + /// assert_eq!(buf.copy_to_bytes(buf.len()), Bytes::copy_from_slice(b"foobar")); + /// ``` + #[inline] + pub fn insert_slice(&mut self, slice: Bytes) { + if slice.len() > 0 { + if !self.active.is_empty() { + self.slices + .push_back(Slice::Mutable(self.active.split_to(self.active.len()))); + } + self.len += slice.len(); + self.slices.push_back(Slice::Fixed(slice)); + } + } + + /// Clears the `SliceBuffer`, removing all data. Existing capacity is preserved. + /// + /// # Examples + /// + /// ``` + /// use bytes::Bytes; + /// use tonic::codec::SliceBuffer; + /// + /// let mut buf = SliceBuffer::from(Bytes::copy_from_slice(b"hello world")); + /// assert!(!buf.is_empty()); + /// buf.clear(); + /// assert!(buf.is_empty()); + /// ``` + #[inline] + pub fn clear(&mut self) { + self.active.clear(); + self.len = 0; + self.slices.clear(); + } +} + +impl Buf for SliceBuffer { + #[inline] + fn remaining(&self) -> usize { + self.len() + } + + #[inline] + fn chunk(&self) -> &[u8] { + match self.slices.front() { + Some(bytes) => bytes.chunk(), + None => self.active.chunk(), + } + } + + fn chunks_vectored<'a>(&'a self, dst: &mut [io::IoSlice<'a>]) -> usize { + let mut n = 0; + for slice in self.slices.iter() { + n += slice.chunks_vectored(&mut dst[n..]); + } + n += self.active.chunks_vectored(&mut dst[n..]); + n + } + + #[inline] + fn advance(&mut self, mut cnt: usize) { + assert!( + cnt <= self.len(), + "cannot advance past `remaining`: {:?} <= {:?}", + cnt, + self.len(), + ); + + self.len -= cnt; + while cnt > 0 { + match self.slices.front_mut() { + Some(slice) => { + if slice.len() <= cnt { + cnt -= slice.len(); + self.slices.pop_front(); + } else { + slice.advance(cnt); + cnt = 0; + } + } + None => { + self.active.advance(cnt); + cnt = 0; + } + } + } + } + + fn copy_to_bytes(&mut self, len: usize) -> Bytes { + match self.slices.front_mut() { + Some(slice) if slice.len() > len => { + self.len -= len; + slice.copy_to_bytes(len) + } + Some(slice) if slice.len() == len => { + self.len -= len; + let buf = slice.copy_to_bytes(len); + self.slices.pop_front(); + buf + } + None => { + self.len -= len; + self.active.copy_to_bytes(len) + } + _ => { + assert!(len <= self.remaining(), "`len` greater than remaining"); + let mut buf = BytesMut::with_capacity(len); + buf.put(self.take(len)); + buf.freeze() + } + } + } +} + +unsafe impl BufMut for SliceBuffer { + #[inline] + fn remaining_mut(&self) -> usize { + self.active.remaining_mut() + } + + #[inline] + unsafe fn advance_mut(&mut self, cnt: usize) { + self.len += cnt; + self.active.advance_mut(cnt) + } + + #[inline] + fn chunk_mut(&mut self) -> &mut UninitSlice { + self.active.chunk_mut() + } + + #[inline] + fn put(&mut self, src: T) + where + Self: Sized, + { + self.len += src.remaining(); + self.active.put(src) + } + + #[inline] + fn put_slice(&mut self, src: &[u8]) { + self.len += src.len(); + self.active.put_slice(src) + } + + #[inline] + fn put_bytes(&mut self, val: u8, cnt: usize) { + self.len += cnt; + self.active.put_bytes(val, cnt) + } +} + +impl From for SliceBuffer { + fn from(bytes: Bytes) -> Self { + SliceBuffer { + active: BytesMut::new(), + len: bytes.len(), + slices: VecDeque::from_iter(iter::once(Slice::Fixed(bytes))), + } + } +} + +impl From for Bytes { + fn from(mut buf: SliceBuffer) -> Self { + if cmp::min(buf.active.len(), 1) + cmp::min(buf.slices.len(), 1) > 1 { + tracing::warn!("multiple chunks exist in the slice buffer") + } + buf.copy_to_bytes(buf.remaining()) + } +} + +#[derive(Clone, Debug, Default)] +struct Position { + abs_pos: usize, + rel_pos: usize, + slice_idx: usize, +} + +impl SliceBuffer { + /// Creates a `Cursor` tailored to the `SliceBuffer`. + /// + /// This `Cursor` grants mutable access to the `SliceBuffer` while keeping track of its + /// position. By implementing [`std::io::Seek`], [`bytes::Buf`], and [`std::io::Write`], + /// it allows random access similar to a seamless memory segment. + /// + /// **Note**: + /// 1. Writing to an immutable slice: If you try to write to an immutable slice in the + /// collection using [`std::io::Write::write`], an error will be returned. + /// 2. Seeking beyond bounds: Attempting to move the cursor beyond the buffer's range + /// with [`std::io::Seek::seek`] will also yield an error. + /// + /// # Examples + /// + /// ``` + /// use bytes::{Buf, BufMut, Bytes}; + /// use tonic::codec::SliceBuffer; + /// use std::io::{Seek, SeekFrom, Write}; + /// + /// let mut buf = SliceBuffer::default(); + /// buf.insert_slice(Bytes::copy_from_slice(b"foo")); + /// buf.put(Bytes::copy_from_slice(b"bar")); + /// + /// let mut cursor = buf.cursor(); + /// cursor.seek(SeekFrom::Start(1)).unwrap(); + /// assert_eq!(cursor.copy_to_bytes(2), Bytes::copy_from_slice(b"oo")); + /// + /// cursor.seek(SeekFrom::Start(1)).unwrap(); + /// assert!(cursor.write(&b"b"[..]).is_err()); + /// + /// cursor.seek(SeekFrom::Current(3)).unwrap(); + /// cursor.write(&b"b"[..]).unwrap(); + /// drop(cursor); + /// assert_eq!(buf.copy_to_bytes(buf.len()), Bytes::copy_from_slice(b"foobbr")); + /// ``` + pub fn cursor(&mut self) -> Cursor<'_> { + Cursor { + buffer: self, + pos: Position::default(), + } + } + + /// Returns a front-to-back iterator over the bytes in the `SliceBuffer`. + /// + /// # Examples + /// + /// ``` + /// use bytes::{BufMut, Bytes}; + /// use tonic::codec::SliceBuffer; + /// + /// let mut buf = SliceBuffer::default(); + /// buf.insert_slice(Bytes::copy_from_slice(b"foo")); + /// buf.put(Bytes::copy_from_slice(b"bar")); + /// + /// let iter = buf.iter(); + /// assert_eq!(iter.collect::>(), b"foobar".to_vec()); + /// ``` + pub fn iter(&mut self) -> Iter<'_> { + Iter { + cursor: self.cursor(), + } + } +} + +#[derive(Debug)] +pub struct Cursor<'a> { + buffer: &'a mut SliceBuffer, + pos: Position, +} + +impl<'a> Cursor<'a> { + fn forward(&self, from: &Position, mut offset: usize) -> Position { + let mut pos = from.to_owned(); + for slice in self.buffer.slices.range(from.slice_idx..) { + if slice.len() - pos.rel_pos <= offset { + pos.abs_pos += slice.len() - pos.rel_pos; + offset -= slice.len() - pos.rel_pos; + pos.rel_pos = 0; + pos.slice_idx += 1; + } else { + pos.abs_pos += offset; + pos.rel_pos += offset; + return pos; + } + } + let curr_pos = cmp::min(self.buffer.active.len(), pos.rel_pos + offset); + pos.abs_pos += curr_pos - pos.rel_pos; + pos.rel_pos = curr_pos; + pos + } + + fn backward(&self, from: &Position, mut offset: usize) -> Position { + let mut pos = from.to_owned(); + if offset <= pos.rel_pos { + pos.abs_pos -= offset; + pos.rel_pos -= offset; + return pos; + } else { + pos.abs_pos -= pos.rel_pos; + offset -= pos.rel_pos; + pos.rel_pos = 0; + } + for slice in self.buffer.slices.range(0..from.slice_idx).rev() { + pos.slice_idx -= 1; + if offset <= slice.len() { + pos.abs_pos -= offset; + pos.rel_pos = slice.len() - offset; + return pos; + } else { + pos.abs_pos -= slice.len(); + offset -= slice.len(); + } + } + Position::default() + } + + fn write_once(&mut self, buf: &[u8]) -> io::Result { + use io::{Seek, Write}; + + let n = if let Some(slice) = self.buffer.slices.get_mut(self.pos.slice_idx) { + match slice { + Slice::Mutable(slice) => slice.split_at_mut(self.pos.rel_pos).1.write(buf)?, + _ => { + return Err(io::Error::new( + io::ErrorKind::Unsupported, + "cannot write to immutable slice", + )); + } + } + } else if self.pos.rel_pos < self.buffer.active.len() { + self.buffer + .active + .split_at_mut(self.pos.rel_pos) + .1 + .write(buf)? + } else { + self.buffer.active.extend_from_slice(buf); + self.buffer.len += buf.len(); + buf.len() + }; + self.seek(io::SeekFrom::Current(n as i64))?; + Ok(n) + } +} + +impl<'a> io::Seek for Cursor<'a> { + fn seek(&mut self, pos: io::SeekFrom) -> io::Result { + let (base_pos, offset) = match pos { + io::SeekFrom::Start(offset) => (Position::default(), offset as i64), + io::SeekFrom::End(offset) => ( + Position { + abs_pos: self.buffer.len(), + rel_pos: self.buffer.active.len(), + slice_idx: self.buffer.slices.len(), + }, + offset, + ), + io::SeekFrom::Current(offset) => (self.pos.clone(), offset), + }; + + if base_pos + .abs_pos + .checked_add_signed(offset as isize) + .map(|pos| pos <= self.buffer.len()) + .unwrap_or_default() + { + self.pos = if offset >= 0 { + self.forward(&base_pos, offset as usize) + } else { + self.backward(&base_pos, offset.abs() as usize) + }; + Ok(self.pos.abs_pos as u64) + } else { + Err(io::Error::new(io::ErrorKind::InvalidInput, "out of range")) + } + } +} + +impl<'a> io::Write for Cursor<'a> { + fn write(&mut self, buf: &[u8]) -> io::Result { + let mut n = 0; + while n < buf.len() { + n += self.write_once(&buf[n..])?; + } + Ok(n) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +impl<'a> Buf for Cursor<'a> { + #[inline] + fn remaining(&self) -> usize { + self.buffer.len() - self.pos.abs_pos + } + + #[inline] + fn chunk(&self) -> &[u8] { + let slice = match self.buffer.slices.get(self.pos.slice_idx) { + Some(bytes) => bytes.chunk(), + None => &self.buffer.active.chunk(), + }; + &slice[self.pos.rel_pos..] + } + + fn advance(&mut self, cnt: usize) { + assert!( + cnt <= self.remaining(), + "cannot advance past `remaining`: {:?} <= {:?}", + cnt, + self.remaining(), + ); + self.pos = self.forward(&self.pos, cnt); + } +} + +#[derive(Debug)] +pub struct Iter<'a> { + cursor: Cursor<'a>, +} + +impl<'a> Iterator for Iter<'a> { + type Item = u8; + + fn next(&mut self) -> Option { + use io::Seek; + + if self.cursor.has_remaining() { + let item = if let Some(slice) = self.cursor.buffer.slices.get(self.cursor.pos.slice_idx) + { + slice[self.cursor.pos.rel_pos] + } else { + self.cursor.buffer.active[self.cursor.pos.rel_pos] + }; + self.cursor.seek(io::SeekFrom::Current(1)).ok()?; + Some(item) + } else { + None + } + } } #[cfg(test)] mod tests { use super::*; + use std::io::{Seek, Write}; #[test] fn decode_buf() { - let mut payload = BytesMut::with_capacity(100); + let mut payload = SliceBuffer::with_capacity(128, 1024); payload.put(&vec![0u8; 50][..]); let mut buf = DecodeBuf::new(&mut payload, 20); @@ -108,7 +832,7 @@ mod tests { #[test] fn encode_buf() { - let mut bytes = BytesMut::with_capacity(100); + let mut bytes = SliceBuffer::with_capacity(128, 1024); let mut buf = EncodeBuf::new(&mut bytes); let initial = buf.remaining_mut(); @@ -118,4 +842,243 @@ mod tests { buf.put_u8(b'a'); assert_eq!(buf.remaining_mut(), initial - 20 - 1); } + + #[test] + fn sequential_read() { + let mut buf = SliceBuffer::with_capacity(128, 1024); + buf.put_slice(b"foo"); + buf.insert_slice(Bytes::copy_from_slice(b"bar")); + buf.put_slice(b"foofoo"); + buf.insert_slice(Bytes::copy_from_slice(b"barbar")); + buf.put_slice(b"foobar"); + assert_eq!(24, buf.remaining()); + assert_eq!(Bytes::copy_from_slice(b"foob"), buf.copy_to_bytes(4)); + assert_eq!(Bytes::copy_from_slice(b"arfoo"), buf.copy_to_bytes(5)); + assert_eq!(Bytes::copy_from_slice(b"foobar"), buf.copy_to_bytes(6)); + + let mut buf = buf.split_to(buf.remaining()); + assert_eq!(Bytes::copy_from_slice(b"barfoo"), buf.copy_to_bytes(6)); + assert_eq!(Bytes::copy_from_slice(b"bar"), buf.copy_to_bytes(3)); + assert_eq!(0, buf.remaining()); + } + + #[test] + fn sequential_vectored_read() { + let mut buf = SliceBuffer::with_capacity(128, 1024); + buf.put_slice(b"foo"); + buf.insert_slice(Bytes::copy_from_slice(b"bar")); + buf.put_slice(b"foobar"); + + let mut iovs = [io::IoSlice::new(&[]); 2]; + assert_eq!(2, buf.chunks_vectored(&mut iovs)); + assert_eq!(b"foo", iovs[0].as_ref()); + assert_eq!(b"bar", iovs[1].as_ref()); + } + + #[test] + fn random_read() { + let mut buf = SliceBuffer::with_capacity(128, 1024); + buf.put_slice(b"foo"); + buf.insert_slice(Bytes::copy_from_slice(b"bar")); + buf.put_slice(b"foofoo"); + buf.insert_slice(Bytes::copy_from_slice(b"barbar")); + buf.put_slice(b"foobar"); + assert_eq!(24, buf.remaining()); + { + let mut cursor = buf.cursor(); + cursor.seek(io::SeekFrom::Current(5)).unwrap(); + assert_eq!(Bytes::copy_from_slice(b"rfo"), cursor.copy_to_bytes(3)); + cursor.seek(io::SeekFrom::Current(5)).unwrap(); + assert_eq!(Bytes::copy_from_slice(b"arb"), cursor.copy_to_bytes(3)); + assert_eq!(b'a', cursor.get_u8()); + } + assert_eq!(24, buf.remaining()); + } + + #[test] + fn split_to() { + let mut buf = SliceBuffer::with_capacity(128, 1024); + buf.put_slice(b"foo"); + buf.insert_slice(Bytes::copy_from_slice(b"bar")); + buf.put_slice(b"foofoo"); + buf.insert_slice(Bytes::copy_from_slice(b"barbar")); + buf.put_slice(b"foobar"); + assert_eq!(24, buf.remaining()); + assert_eq!( + Bytes::copy_from_slice(b"foo"), + buf.split_to(3).copy_to_bytes(3) + ); + assert_eq!( + Bytes::copy_from_slice(b"barfoo"), + buf.split_to(6).copy_to_bytes(6) + ); + + let split_buf = buf.split_to(6); + assert_eq!(2, split_buf.slices.len()); + assert_eq!( + Slice::Mutable(BytesMut::from(b"foo".as_ref())), + split_buf.slices[0] + ); + assert_eq!( + Slice::Fixed(Bytes::copy_from_slice(b"bar")), + split_buf.slices[1] + ); + + let split_buf = buf.split_to(6); + assert_eq!(2, split_buf.slices.len()); + assert_eq!( + Slice::Fixed(Bytes::copy_from_slice(b"bar")), + split_buf.slices[0] + ); + assert_eq!( + Slice::Mutable(BytesMut::from(b"foo".as_ref())), + split_buf.slices[1] + ); + assert_eq!(3, buf.len()); + assert_eq!(0, buf.slices.len()); + assert_eq!(Bytes::copy_from_slice(b"bar"), buf.active.freeze()); + } + + #[test] + fn append() { + let mut buf = SliceBuffer::with_capacity(128, 1024); + buf.put_slice(b"foo"); + buf.insert_slice(Bytes::copy_from_slice(b"bar")); + buf.put_slice(b"foofoo"); + let mut buf2 = SliceBuffer::with_capacity(128, 1024); + buf2.insert_slice(Bytes::copy_from_slice(b"barbar")); + buf2.put_slice(b"foobar"); + buf.append(&mut buf2); + + assert_eq!(24, buf.len()); + assert_eq!(4, buf.slices.len()); + assert_eq!( + Slice::Mutable(BytesMut::from(b"foo".as_ref())), + buf.slices[0] + ); + assert_eq!(Slice::Fixed(Bytes::copy_from_slice(b"bar")), buf.slices[1]); + assert_eq!( + Slice::Mutable(BytesMut::from(b"foofoo".as_ref())), + buf.slices[2] + ); + assert_eq!( + Slice::Fixed(Bytes::copy_from_slice(b"barbar")), + buf.slices[3] + ); + assert_eq!(BytesMut::from(b"foobar".as_ref()), buf.active); + + assert_eq!(0, buf2.len()); + assert!(buf2.slices.is_empty()); + assert!(buf2.active.is_empty()); + } + + macro_rules! random_write_test { + ( + name: $name:ident, + start_pos: $start_pos:expr, + seek: $seek:expr, + buf: $buf:expr, + expect: $expect:expr, + ) => { + #[test] + fn $name() { + let mut buf = SliceBuffer::with_capacity(128, 1024); + buf.put_slice(b"foo"); + buf.insert_slice(Bytes::copy_from_slice(b"bar")); + buf.put_slice(b"foobar"); + let mut cursor = buf.cursor(); + cursor.pos = $start_pos; + cursor.seek($seek).unwrap(); + if let Some(expect) = $expect { + cursor.write($buf).unwrap(); + assert_eq!( + Bytes::copy_from_slice(expect), + buf.copy_to_bytes(buf.remaining()) + ); + } else { + assert!(cursor.write($buf).is_err()); + } + } + }; + } + + random_write_test!( + name: seek_from_start_1, + start_pos: Position::default(), + seek: io::SeekFrom::Start(1), + buf: b"fo", + expect: Some(b"ffobarfoobar"), + ); + + random_write_test!( + name: seek_from_start_2, + start_pos: Position::default(), + seek: io::SeekFrom::Start(1), + buf: b"fo", + expect: Some(b"ffobarfoobar"), + ); + + random_write_test!( + name: seek_from_start_3, + start_pos: Position::default(), + seek: io::SeekFrom::Start(5), + buf: b"foo", + expect: None, + ); + + random_write_test!( + name: seek_from_start_4, + start_pos: Position::default(), + seek: io::SeekFrom::Start(11), + buf: b"foo", + expect: Some(b"foobarfoobafoo"), + ); + + random_write_test!( + name: seek_from_end_1, + start_pos: Position::default(), + seek: io::SeekFrom::End(-2), + buf: b"foo", + expect: Some(b"foobarfoobfoo"), + ); + + random_write_test!( + name: seek_from_end_2, + start_pos: Position::default(), + seek: io::SeekFrom::End(-5), + buf: b"foo", + expect: Some(b"foobarffooar"), + ); + + random_write_test!( + name: seek_from_end_3, + start_pos: Position::default(), + seek: io::SeekFrom::End(-11), + buf: b"foo", + expect: None, + ); + + random_write_test!( + name: seek_from_current_1, + start_pos: Position{abs_pos: 4, rel_pos: 1, slice_idx: 1}, + seek: io::SeekFrom::Current(7), + buf: b"foo", + expect: Some(b"foobarfoobafoo"), + ); + + random_write_test!( + name: seek_from_current_2, + start_pos: Position{abs_pos: 4, rel_pos: 1, slice_idx: 1}, + seek: io::SeekFrom::Current(2), + buf: b"foo", + expect: Some(b"foobarfoobar"), + ); + + random_write_test!( + name: seek_from_current_3, + start_pos: Position{abs_pos: 4, rel_pos: 1, slice_idx: 1}, + seek: io::SeekFrom::Current(-2), + buf: b"foo", + expect: None, + ); } diff --git a/tonic/src/codec/compression.rs b/tonic/src/codec/compression.rs index 70d758415..bad3e93bb 100644 --- a/tonic/src/codec/compression.rs +++ b/tonic/src/codec/compression.rs @@ -1,6 +1,5 @@ use super::encode::BUFFER_SIZE; -use crate::{metadata::MetadataValue, Status}; -use bytes::{Buf, BytesMut}; +use crate::{codec::SliceBuffer, metadata::MetadataValue, Status}; #[cfg(feature = "gzip")] use flate2::read::{GzDecoder, GzEncoder}; use std::fmt; @@ -198,10 +197,10 @@ fn split_by_comma(s: &str) -> impl Iterator { #[allow(unused_variables, unreachable_code)] pub(crate) fn compress( encoding: CompressionEncoding, - decompressed_buf: &mut BytesMut, - out_buf: &mut BytesMut, - len: usize, + decompressed_buf: &mut SliceBuffer, + out_buf: &mut SliceBuffer, ) -> Result<(), std::io::Error> { + let len = decompressed_buf.len(); let capacity = ((len / BUFFER_SIZE) + 1) * BUFFER_SIZE; out_buf.reserve(capacity); @@ -212,7 +211,7 @@ pub(crate) fn compress( #[cfg(feature = "gzip")] CompressionEncoding::Gzip => { let mut gzip_encoder = GzEncoder::new( - &decompressed_buf[0..len], + bytes::Buf::reader(decompressed_buf), // FIXME: support customizing the compression level flate2::Compression::new(6), ); @@ -221,7 +220,7 @@ pub(crate) fn compress( #[cfg(feature = "zstd")] CompressionEncoding::Zstd => { let mut zstd_encoder = Encoder::new( - &decompressed_buf[0..len], + bytes::Buf::reader(decompressed_buf), // FIXME: support customizing the compression level zstd::DEFAULT_COMPRESSION_LEVEL, )?; @@ -229,8 +228,6 @@ pub(crate) fn compress( } } - decompressed_buf.advance(len); - Ok(()) } @@ -238,13 +235,12 @@ pub(crate) fn compress( #[allow(unused_variables, unreachable_code)] pub(crate) fn decompress( encoding: CompressionEncoding, - compressed_buf: &mut BytesMut, - out_buf: &mut BytesMut, + compressed_buf: &mut SliceBuffer, + out_buf: &mut SliceBuffer, len: usize, ) -> Result<(), std::io::Error> { let estimate_decompressed_len = len * 2; let capacity = ((estimate_decompressed_len / BUFFER_SIZE) + 1) * BUFFER_SIZE; - out_buf.reserve(capacity); #[cfg(any(feature = "gzip", feature = "zstd"))] let mut out_writer = bytes::BufMut::writer(out_buf); @@ -252,18 +248,16 @@ pub(crate) fn decompress( match encoding { #[cfg(feature = "gzip")] CompressionEncoding::Gzip => { - let mut gzip_decoder = GzDecoder::new(&compressed_buf[0..len]); + let mut gzip_decoder = GzDecoder::new(bytes::Buf::reader(compressed_buf.split_to(len))); std::io::copy(&mut gzip_decoder, &mut out_writer)?; } #[cfg(feature = "zstd")] CompressionEncoding::Zstd => { - let mut zstd_decoder = Decoder::new(&compressed_buf[0..len])?; + let mut zstd_decoder = Decoder::new(bytes::Buf::reader(compressed_buf.split_to(len)))?; std::io::copy(&mut zstd_decoder, &mut out_writer)?; } } - compressed_buf.advance(len); - Ok(()) } diff --git a/tonic/src/codec/decode.rs b/tonic/src/codec/decode.rs index cb88a0649..c27718ad6 100644 --- a/tonic/src/codec/decode.rs +++ b/tonic/src/codec/decode.rs @@ -1,7 +1,9 @@ -use super::compression::{decompress, CompressionEncoding}; -use super::{DecodeBuf, Decoder, DEFAULT_MAX_RECV_MESSAGE_SIZE, HEADER_SIZE}; -use crate::{body::BoxBody, metadata::MetadataMap, Code, Status}; -use bytes::{Buf, BufMut, BytesMut}; +use super::{ + compression::{decompress, CompressionEncoding}, + DecodeBuf, Decoder, SliceBuffer, DEFAULT_MAX_RECV_MESSAGE_SIZE, HEADER_SIZE, +}; +use crate::{metadata::MetadataMap, Code, Status}; +use bytes::Buf; use http::StatusCode; use http_body::Body; use std::{ @@ -15,6 +17,8 @@ use tracing::{debug, trace}; const BUFFER_SIZE: usize = 8 * 1024; +const SLICE_SIZE: usize = 128; + /// Streaming requests and responses. /// /// This will wrap some inner [`Body`] and [`Decoder`] and provide an interface @@ -25,12 +29,12 @@ pub struct Streaming { } struct StreamingInner { - body: BoxBody, + body: http_body::combinators::UnsyncBoxBody, state: State, direction: Direction, - buf: BytesMut, + buf: SliceBuffer, trailers: Option, - decompress_buf: BytesMut, + decompress_buf: SliceBuffer, encoding: Option, max_message_size: Option, } @@ -118,6 +122,12 @@ impl Streaming { B::Error: Into, D: Decoder + Send + 'static, { + let decompress_buf = if encoding.is_some() { + SliceBuffer::with_capacity(0, BUFFER_SIZE) + } else { + SliceBuffer::default() + }; + Self { decoder: Box::new(decoder), inner: StreamingInner { @@ -127,9 +137,9 @@ impl Streaming { .boxed_unsync(), state: State::ReadHeader, direction, - buf: BytesMut::with_capacity(BUFFER_SIZE), + buf: SliceBuffer::with_capacity(SLICE_SIZE, 0), trailers: None, - decompress_buf: BytesMut::new(), + decompress_buf, encoding, max_message_size, }, @@ -187,8 +197,6 @@ impl StreamingInner { )); } - self.buf.reserve(len); - self.state = State::ReadBody { compression: compression_encoding, len, @@ -246,7 +254,7 @@ impl StreamingInner { }; Poll::Ready(if let Some(data) = chunk { - self.buf.put(data); + self.buf.insert_slice(data); Ok(Some(())) } else { // FIXME: improve buf usage. diff --git a/tonic/src/codec/encode.rs b/tonic/src/codec/encode.rs index 13eb2c96d..3db472f03 100644 --- a/tonic/src/codec/encode.rs +++ b/tonic/src/codec/encode.rs @@ -1,17 +1,23 @@ -use super::compression::{compress, CompressionEncoding, SingleMessageCompressionOverride}; -use super::{EncodeBuf, Encoder, DEFAULT_MAX_SEND_MESSAGE_SIZE, HEADER_SIZE}; -use crate::{Code, Status}; -use bytes::{BufMut, Bytes, BytesMut}; +use super::{ + compression::{compress, CompressionEncoding, SingleMessageCompressionOverride}, + Encoder, SliceBuffer, DEFAULT_MAX_SEND_MESSAGE_SIZE, HEADER_SIZE, +}; +use crate::{codec::EncodeBuf, Code, Status}; +use bytes::{Buf, BufMut}; use http::HeaderMap; use http_body::Body; use pin_project::pin_project; +use std::io::Seek; use std::{ + io, pin::Pin, task::{ready, Context, Poll}, }; use tokio_stream::{Stream, StreamExt}; pub(super) const BUFFER_SIZE: usize = 8 * 1024; + +const SLICE_SIZE: usize = 128; const YIELD_THRESHOLD: usize = 32 * 1024; pub(crate) fn encode_server( @@ -20,7 +26,7 @@ pub(crate) fn encode_server( compression_encoding: Option, compression_override: SingleMessageCompressionOverride, max_message_size: Option, -) -> EncodeBody>> +) -> EncodeBody>> where T: Encoder, U: Stream>, @@ -41,7 +47,7 @@ pub(crate) fn encode_client( source: U, compression_encoding: Option, max_message_size: Option, -) -> EncodeBody>> +) -> EncodeBody>> where T: Encoder, U: Stream, @@ -57,7 +63,7 @@ where } /// Combinator for efficient encoding of messages into reasonably sized buffers. -/// EncodedBytes encodes ready messages from its delegate stream into a BytesMut, +/// EncodedBytes encodes ready messages from its delegate stream into a SliceBuffer, /// splitting off and yielding a buffer when either: /// * The delegate stream polls as not ready, or /// * The encoded buffer surpasses YIELD_THRESHOLD. @@ -73,8 +79,8 @@ where encoder: T, compression_encoding: Option, max_message_size: Option, - buf: BytesMut, - uncompression_buf: BytesMut, + buf: SliceBuffer, + uncompression_buf: SliceBuffer, } impl EncodedBytes @@ -90,7 +96,7 @@ where compression_override: SingleMessageCompressionOverride, max_message_size: Option, ) -> Self { - let buf = BytesMut::with_capacity(BUFFER_SIZE); + let buf = SliceBuffer::with_capacity(SLICE_SIZE, BUFFER_SIZE); let compression_encoding = if compression_override == SingleMessageCompressionOverride::Disable { @@ -100,9 +106,9 @@ where }; let uncompression_buf = if compression_encoding.is_some() { - BytesMut::with_capacity(BUFFER_SIZE) + SliceBuffer::with_capacity(SLICE_SIZE, BUFFER_SIZE) } else { - BytesMut::new() + SliceBuffer::default() }; Self { @@ -121,7 +127,7 @@ where T: Encoder, U: Stream>, { - type Item = Result; + type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let EncodedBytesProj { @@ -142,7 +148,7 @@ where return Poll::Ready(None); } Poll::Pending | Poll::Ready(None) => { - return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze()))); + return Poll::Ready(Some(Ok(buf.split_to(buf.remaining())))); } Poll::Ready(Some(Ok(item))) => { if let Err(status) = encode_item( @@ -157,7 +163,7 @@ where } if buf.len() >= YIELD_THRESHOLD { - return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze()))); + return Poll::Ready(Some(Ok(buf.split_to(buf.remaining())))); } } Poll::Ready(Some(Err(status))) => { @@ -170,8 +176,8 @@ where fn encode_item( encoder: &mut T, - buf: &mut BytesMut, - uncompression_buf: &mut BytesMut, + buf: &mut SliceBuffer, + uncompression_buf: &mut SliceBuffer, compression_encoding: Option, max_message_size: Option, item: T::Item, @@ -193,9 +199,7 @@ where .encode(item, &mut EncodeBuf::new(uncompression_buf)) .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?; - let uncompressed_len = uncompression_buf.len(); - - compress(encoding, uncompression_buf, buf, uncompressed_len) + compress(encoding, uncompression_buf, buf) .map_err(|err| Status::internal(format!("Error compressing: {}", err)))?; } else { encoder @@ -204,15 +208,26 @@ where } // now that we know length, we can write the header - finish_encoding(compression_encoding, max_message_size, &mut buf[offset..]) + let encoded_len = buf.len() - offset; + let mut cursor = buf.cursor(); + cursor + .seek(io::SeekFrom::End(-1 * encoded_len as i64)) + .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?; + finish_encoding( + compression_encoding, + max_message_size, + encoded_len, + &mut cursor, + ) } fn finish_encoding( compression_encoding: Option, max_message_size: Option, - buf: &mut [u8], + encoded_len: usize, + buf: &mut dyn io::Write, ) -> Result<(), Status> { - let len = buf.len() - HEADER_SIZE; + let len = encoded_len - HEADER_SIZE; let limit = max_message_size.unwrap_or(DEFAULT_MAX_SEND_MESSAGE_SIZE); if len > limit { return Err(Status::new( @@ -229,11 +244,10 @@ fn finish_encoding( "Cannot return body with more than 4GB of data but got {len} bytes" ))); } - { - let mut buf = &mut buf[..HEADER_SIZE]; - buf.put_u8(compression_encoding.is_some() as u8); - buf.put_u32(len as u32); - } + buf.write(&[compression_encoding.is_some() as u8]) + .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?; + buf.write(&(len as u32).to_be_bytes()) + .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?; Ok(()) } @@ -261,7 +275,7 @@ struct EncodeState { impl EncodeBody where - S: Stream>, + S: Stream>, { pub(crate) fn new_client(inner: S) -> Self { Self { @@ -310,9 +324,9 @@ impl EncodeState { impl Body for EncodeBody where - S: Stream>, + S: Stream>, { - type Data = Bytes; + type Data = SliceBuffer; type Error = Status; fn is_end_stream(&self) -> bool { diff --git a/tonic/src/codec/mod.rs b/tonic/src/codec/mod.rs index 306621329..fca8aadb5 100644 --- a/tonic/src/codec/mod.rs +++ b/tonic/src/codec/mod.rs @@ -15,7 +15,7 @@ use std::io; pub(crate) use self::encode::{encode_client, encode_server}; -pub use self::buffer::{DecodeBuf, EncodeBuf}; +pub use self::buffer::{DecodeBuf, EncodeBuf, SliceBuffer}; pub use self::compression::{CompressionEncoding, EnabledCompressionEncodings}; pub use self::decode::Streaming; #[cfg(feature = "prost")] diff --git a/tonic/src/server/grpc.rs b/tonic/src/server/grpc.rs index ec94b97fb..791918ff6 100644 --- a/tonic/src/server/grpc.rs +++ b/tonic/src/server/grpc.rs @@ -1,9 +1,11 @@ -use crate::codec::compression::{ - CompressionEncoding, EnabledCompressionEncodings, SingleMessageCompressionOverride, -}; use crate::{ body::BoxBody, - codec::{encode_server, Codec, Streaming}, + codec::{ + compression::{ + CompressionEncoding, EnabledCompressionEncodings, SingleMessageCompressionOverride, + }, + encode_server, Codec, Streaming, + }, server::{ClientStreamingService, ServerStreamingService, StreamingService, UnaryService}, Code, Request, Status, }; @@ -317,7 +319,11 @@ where self.send_compression_encodings, ); - let request = t!(self.map_request_streaming(req)); + // let request = t!(self.map_request_streaming(req)); + let request = match self.map_request_streaming(req) { + Ok(value) => value, + Err(status) => return status.to_http(), + }; let response = service .call(request) diff --git a/tonic/src/service/interceptor.rs b/tonic/src/service/interceptor.rs index cadff466f..3c50470f2 100644 --- a/tonic/src/service/interceptor.rs +++ b/tonic/src/service/interceptor.rs @@ -4,10 +4,10 @@ use crate::{ body::{boxed, BoxBody}, + codec::SliceBuffer, request::SanitizeHeaders, Status, }; -use bytes::Bytes; use pin_project::pin_project; use std::{ fmt, @@ -121,11 +121,13 @@ where impl Service> for InterceptedService where - ResBody: Default + http_body::Body + Send + 'static, F: Interceptor, S: Service, Response = http::Response>, S::Error: Into, - ResBody: http_body::Body + Send + 'static, + ReqBody: http_body::Body + Send + 'static, + ReqBody::Data: bytes::Buf, + ResBody: Default + http_body::Body + Send + 'static, + ResBody::Data: Into, ResBody::Error: Into, { type Response = http::Response; @@ -205,7 +207,8 @@ impl Future for ResponseFuture where F: Future, E>>, E: Into, - B: Default + http_body::Body + Send + 'static, + B: Default + http_body::Body + Send + 'static, + B::Data: Into, B::Error: Into, { type Output = Result, E>; @@ -243,7 +246,7 @@ mod tests { struct TestBody; impl http_body::Body for TestBody { - type Data = Bytes; + type Data = bytes::Bytes; type Error = Status; fn poll_data( diff --git a/tonic/src/status.rs b/tonic/src/status.rs index da8b792e5..6f58528e1 100644 --- a/tonic/src/status.rs +++ b/tonic/src/status.rs @@ -1,5 +1,7 @@ -use crate::body::BoxBody; -use crate::metadata::MetadataMap; +use crate::{ + body::{empty_body, BoxBody}, + metadata::MetadataMap, +}; use base64::Engine as _; use bytes::Bytes; use http::header::{HeaderMap, HeaderValue}; @@ -594,7 +596,7 @@ impl Status { self.add_header(&mut parts.headers).unwrap(); - http::Response::from_parts(parts, crate::body::empty_body()) + http::Response::from_parts(parts, empty_body()) } } diff --git a/tonic/src/transport/mod.rs b/tonic/src/transport/mod.rs index c676bfb92..f290ed38c 100644 --- a/tonic/src/transport/mod.rs +++ b/tonic/src/transport/mod.rs @@ -108,7 +108,6 @@ pub use self::tls::Certificate; #[doc(inline)] /// A deprecated re-export. Please use `tonic::server::NamedService` directly. pub use crate::server::NamedService; -pub use axum::{body::BoxBody as AxumBoxBody, Router as AxumRouter}; pub use hyper::{Body, Uri}; pub(crate) use self::service::executor::Executor; diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index e10f11f68..943d4b10e 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -38,8 +38,7 @@ use crate::transport::Error; use self::recover_error::RecoverError; use super::service::{GrpcTimeout, ServerIo}; -use crate::body::BoxBody; -use bytes::Bytes; +use crate::{body::BoxBody, codec::SliceBuffer}; use http::{Request, Response}; use http_body::Body as _; use hyper::{server::accept, Body}; @@ -65,7 +64,7 @@ use tower::{ Service, ServiceBuilder, }; -type BoxHttpBody = http_body::combinators::UnsyncBoxBody; +type BoxHttpBody = http_body::combinators::UnsyncBoxBody; type BoxService = tower::util::BoxService, Response, crate::Error>; type TraceInterceptor = Arc) -> tracing::Span + Send + Sync + 'static>; @@ -504,7 +503,8 @@ impl Server { IO::ConnectInfo: Clone + Send + Sync + 'static, IE: Into, F: Future, - ResBody: http_body::Body + Send + 'static, + ResBody: http_body::Body + Send + 'static, + ResBody::Data: Into, ResBody::Error: Into, { let trace_interceptor = self.trace_interceptor.clone(); @@ -603,9 +603,9 @@ impl Router { self } - /// Convert this tonic `Router` into an axum `Router` consuming the tonic one. - pub fn into_router(self) -> axum::Router { - self.routes.into_router() + /// Convert this `Router` to a [`Service`] router consuming the tonic one. + pub fn into_router(self) -> Routes { + self.routes } /// Consume this [`Server`] creating a future that will execute the server @@ -619,14 +619,15 @@ impl Router { L::Service: Service, Response = Response> + Clone + Send + 'static, <>::Service as Service>>::Future: Send + 'static, <>::Service as Service>>::Error: Into + Send, - ResBody: http_body::Body + Send + 'static, + ResBody: http_body::Body + Send + 'static, + ResBody::Data: Into, ResBody::Error: Into, { let incoming = TcpIncoming::new(addr, self.server.tcp_nodelay, self.server.tcp_keepalive) .map_err(super::Error::from_source)?; self.server .serve_with_shutdown::<_, _, future::Ready<()>, _, _, ResBody>( - self.routes.prepare(), + self.routes, incoming, None, ) @@ -649,13 +650,14 @@ impl Router { L::Service: Service, Response = Response> + Clone + Send + 'static, <>::Service as Service>>::Future: Send + 'static, <>::Service as Service>>::Error: Into + Send, - ResBody: http_body::Body + Send + 'static, + ResBody: http_body::Body + Send + 'static, + ResBody::Data: Into, ResBody::Error: Into, { let incoming = TcpIncoming::new(addr, self.server.tcp_nodelay, self.server.tcp_keepalive) .map_err(super::Error::from_source)?; self.server - .serve_with_shutdown(self.routes.prepare(), incoming, Some(signal)) + .serve_with_shutdown(self.routes, incoming, Some(signal)) .await } @@ -678,12 +680,13 @@ impl Router { L::Service: Service, Response = Response> + Clone + Send + 'static, <>::Service as Service>>::Future: Send + 'static, <>::Service as Service>>::Error: Into + Send, - ResBody: http_body::Body + Send + 'static, + ResBody: http_body::Body + Send + 'static, + ResBody::Data: Into, ResBody::Error: Into, { self.server .serve_with_shutdown::<_, _, future::Ready<()>, _, _, ResBody>( - self.routes.prepare(), + self.routes, incoming, None, ) @@ -713,11 +716,12 @@ impl Router { L::Service: Service, Response = Response> + Clone + Send + 'static, <>::Service as Service>>::Future: Send + 'static, <>::Service as Service>>::Error: Into + Send, - ResBody: http_body::Body + Send + 'static, + ResBody: http_body::Body + Send + 'static, + ResBody::Data: Into, ResBody::Error: Into, { self.server - .serve_with_shutdown(self.routes.prepare(), incoming, Some(signal)) + .serve_with_shutdown(self.routes, incoming, Some(signal)) .await } @@ -728,10 +732,11 @@ impl Router { L::Service: Service, Response = Response> + Clone + Send + 'static, <>::Service as Service>>::Future: Send + 'static, <>::Service as Service>>::Error: Into + Send, - ResBody: http_body::Body + Send + 'static, + ResBody: http_body::Body + Send + 'static, + ResBody::Data: Into, ResBody::Error: Into, { - self.server.service_builder.service(self.routes.prepare()) + self.server.service_builder.service(self.routes) } } @@ -750,7 +755,8 @@ impl Service> for Svc where S: Service, Response = Response>, S::Error: Into, - ResBody: http_body::Body + Send + 'static, + ResBody: http_body::Body + Send + 'static, + ResBody::Data: Into, ResBody::Error: Into, { type Response = Response; @@ -794,7 +800,8 @@ impl Future for SvcFuture where F: Future, E>>, E: Into, - ResBody: http_body::Body + Send + 'static, + ResBody: http_body::Body + Send + 'static, + ResBody::Data: Into, ResBody::Error: Into, { type Output = Result, crate::Error>; @@ -804,7 +811,8 @@ where let _guard = this.span.enter(); let response: Response = ready!(this.inner.poll(cx)).map_err(Into::into)?; - let response = response.map(|body| body.map_err(Into::into).boxed_unsync()); + let response = + response.map(|body| body.map_data(Into::into).map_err(Into::into).boxed_unsync()); Poll::Ready(Ok(response)) } } @@ -829,7 +837,8 @@ where S: Service, Response = Response> + Clone + Send + 'static, S::Future: Send + 'static, S::Error: Into + Send, - ResBody: http_body::Body + Send + 'static, + ResBody: http_body::Body + Send + 'static, + ResBody::Data: Into, ResBody::Error: Into, { type Response = BoxService; diff --git a/tonic/src/transport/service/router.rs b/tonic/src/transport/service/router.rs index 85636c4d4..f52902620 100644 --- a/tonic/src/transport/service/router.rs +++ b/tonic/src/transport/service/router.rs @@ -1,24 +1,28 @@ use crate::{ - body::{boxed, BoxBody}, + body::{empty_body, BoxBody}, server::NamedService, + transport::BoxFuture, }; -use http::{Request, Response}; +use http::{HeaderValue, Request, Response}; use hyper::Body; -use pin_project::pin_project; use std::{ convert::Infallible, fmt, - future::Future, - pin::Pin, - task::{ready, Context, Poll}, + task::{Context, Poll}, }; -use tower::ServiceExt; +use tower::{util::BoxCloneService, ServiceExt}; use tower_service::Service; /// A [`Service`] router. -#[derive(Debug, Default, Clone)] +#[derive(Default, Clone)] pub struct Routes { - router: axum::Router, + router: matchit::Router, Response, crate::Error>>, +} + +impl fmt::Debug for Routes { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Routes").finish() + } } #[derive(Debug, Default, Clone)] @@ -40,7 +44,7 @@ impl RoutesBuilder { S::Error: Into + Send, { let routes = self.routes.take().unwrap_or_default(); - self.routes.replace(routes.add_service(svc)); + self.routes.replace(routes.add_service(svc.clone())); self } @@ -61,8 +65,10 @@ impl Routes { S::Future: Send + 'static, S::Error: Into + Send, { - let router = axum::Router::new().fallback(unimplemented); - Self { router }.add_service(svc) + Self { + router: matchit::Router::new(), + } + .add_service(svc) } /// Add a new service. @@ -76,37 +82,32 @@ impl Routes { S::Future: Send + 'static, S::Error: Into + Send, { - let svc = svc.map_response(|res| res.map(axum::body::boxed)); - self.router = self - .router - .route_service(&format!("/{}/*rest", S::NAME), svc); - self - } - - pub(crate) fn prepare(self) -> Self { - Self { - // this makes axum perform update some internals of the router that improves perf - // see https://docs.rs/axum/latest/axum/routing/struct.Router.html#a-note-about-performance - router: self.router.with_state(()), - } - } - - /// Convert this `Routes` into an [`axum::Router`]. - pub fn into_router(self) -> axum::Router { self.router + .insert( + format!("/{}/*rest", S::NAME), + BoxCloneService::new(svc.map_err(Into::into)), + ) + .expect("Failed to route path to service."); + self } } -async fn unimplemented() -> impl axum::response::IntoResponse { - let status = http::StatusCode::OK; - let headers = [("grpc-status", "12"), ("content-type", "application/grpc")]; - (status, headers) +async fn unimplemented() -> Result, crate::Error> { + let mut response = Response::new(empty_body()); + *response.status_mut() = http::StatusCode::OK; + response + .headers_mut() + .insert("grpc-status", HeaderValue::from_static("12")); + response + .headers_mut() + .insert("content-type", HeaderValue::from_static("application/grpc")); + Ok(response) } impl Service> for Routes { type Response = Response; type Error = crate::Error; - type Future = RoutesFuture; + type Future = BoxFuture<'static, Result>; #[inline] fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { @@ -114,26 +115,9 @@ impl Service> for Routes { } fn call(&mut self, req: Request) -> Self::Future { - RoutesFuture(self.router.call(req)) - } -} - -#[pin_project] -pub struct RoutesFuture(#[pin] axum::routing::future::RouteFuture); - -impl fmt::Debug for RoutesFuture { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_tuple("RoutesFuture").finish() - } -} - -impl Future for RoutesFuture { - type Output = Result, crate::Error>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match ready!(self.project().0.poll(cx)) { - Ok(res) => Ok(res.map(boxed)).into(), - Err(err) => match err {}, + match self.router.at_mut(req.uri().path()) { + Ok(m) => m.value.call(req), + Err(_) => Box::pin(unimplemented()), } } }