From 99870234a0fb717498f48b411a30237c0a0690c4 Mon Sep 17 00:00:00 2001 From: Kenny Date: Thu, 11 Jan 2024 17:46:41 -0800 Subject: [PATCH] feat(tonic): Custom codecs for generated code Broadly, this change does 2 things: 1. Allow the built-in Prost codec to have its buffer sizes customized 2. Allow users to specify custom codecs on the tonic_build::prost::Builder The Prost codec is convenient, and handles any normal use case. However, the buffer sizes today are too large in some cases - and they may grow too aggressively. By exposing BufferSettings, users can make a small custom codec with their own BufferSettings to control their memory usage - or give enormous buffers to rpc's, as their use case requires. While one can define a custom service and methods with a custom codec today explicitly in Rust, the code generator does not have a means to supply a custom codec. I've reached for .codec... on the tonic_build::prost::Builder many times and keep forgetting it's not there. This change adds .codec_path to the Builder, so people can simply add their custom buffer codec or even their own full top level codec without reaching for manual service definition. --- examples/Cargo.toml | 8 ++++ examples/build.rs | 8 ++++ examples/src/codec_buffers/client.rs | 30 ++++++++++++ examples/src/codec_buffers/common.rs | 44 +++++++++++++++++ examples/src/codec_buffers/server.rs | 51 ++++++++++++++++++++ examples/src/json-codec/common.rs | 8 ++++ tonic-build/src/client.rs | 8 ++-- tonic-build/src/compile_settings.rs | 69 ++++++++++++++++++++++++++ tonic-build/src/lib.rs | 5 +- tonic-build/src/manual.rs | 4 +- tonic-build/src/prost.rs | 30 ++++++++++-- tonic-build/src/server.rs | 8 ++-- tonic/benches/decode.rs | 4 ++ tonic/src/codec/compression.rs | 25 +++++++--- tonic/src/codec/decode.rs | 27 +++++++---- tonic/src/codec/encode.rs | 31 ++++++++---- tonic/src/codec/mod.rs | 41 ++++++++++++++++ tonic/src/codec/prost.rs | 72 +++++++++++++++++++++++++--- 18 files changed, 425 insertions(+), 48 deletions(-) create mode 100644 examples/src/codec_buffers/client.rs create mode 100644 examples/src/codec_buffers/common.rs create mode 100644 examples/src/codec_buffers/server.rs create mode 100644 tonic-build/src/compile_settings.rs diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 5c3e7e8b5..2239336e1 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -276,6 +276,14 @@ required-features = ["cancellation"] name = "cancellation-client" path = "src/cancellation/client.rs" +[[bin]] +name = "codec-buffers-server" +path = "src/codec_buffers/server.rs" + +[[bin]] +name = "codec-buffers-client" +path = "src/codec_buffers/client.rs" + [features] gcp = ["dep:prost-types", "tonic/tls"] diff --git a/examples/build.rs b/examples/build.rs index 892b0d96c..454a77214 100644 --- a/examples/build.rs +++ b/examples/build.rs @@ -33,6 +33,14 @@ fn main() { .unwrap(); build_json_codec_service(); + + let smallbuff_copy = out_dir.join("smallbuf"); + let _ = std::fs::create_dir(smallbuff_copy.clone()); // This will panic below if the directory failed to create + tonic_build::configure() + .out_dir(smallbuff_copy) + .codec_path("crate::common::SmallBufferCodec") + .compile(&["proto/helloworld/helloworld.proto"], &["proto"]) + .unwrap(); } // Manually define the json.helloworld.Greeter service which used a custom JsonCodec to use json diff --git a/examples/src/codec_buffers/client.rs b/examples/src/codec_buffers/client.rs new file mode 100644 index 000000000..267e19dbf --- /dev/null +++ b/examples/src/codec_buffers/client.rs @@ -0,0 +1,30 @@ +//! A HelloWorld example that uses a custom codec instead of the default Prost codec. +//! +//! Generated code is the output of codegen as defined in the `examples/build.rs` file. +//! The generation is the one with .codec_path("crate::common::SmallBufferCodec") +//! The generated code assumes that a module `crate::common` exists which defines +//! `SmallBufferCodec`, and `SmallBufferCodec` must have a Default implementation. + +pub mod common; + +pub mod small_buf { + include!(concat!(env!("OUT_DIR"), "/smallbuf/helloworld.rs")); +} +use small_buf::greeter_client::GreeterClient; + +use crate::small_buf::HelloRequest; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let mut client = GreeterClient::connect("http://[::1]:50051").await?; + + let request = tonic::Request::new(HelloRequest { + name: "Tonic".into(), + }); + + let response = client.say_hello(request).await?; + + println!("RESPONSE={:?}", response); + + Ok(()) +} diff --git a/examples/src/codec_buffers/common.rs b/examples/src/codec_buffers/common.rs new file mode 100644 index 000000000..bc156f83f --- /dev/null +++ b/examples/src/codec_buffers/common.rs @@ -0,0 +1,44 @@ +//! This module defines a common encoder with small buffers. This is useful +//! when you have many concurrent RPC's, and not a huge volume of data per +//! rpc normally. +//! +//! Note that you can customize your codecs per call to the code generator's +//! compile function. This lets you group services by their codec needs. +//! +//! While this codec demonstrates customizing the built-in Prost codec, you +//! can use this to implement other codecs as well, as long as they have a +//! Default implementation. + +use std::marker::PhantomData; + +use prost::Message; +use tonic::codec::{BufferSettings, Codec, ProstDecoder, ProstEncoder}; + +#[derive(Debug, Clone, Copy, Default)] +pub struct SmallBufferCodec(PhantomData<(T, U)>); + +impl Codec for SmallBufferCodec +where + T: Message + Send + 'static, + U: Message + Default + Send + 'static, +{ + type Encode = T; + type Decode = U; + + type Encoder = ProstEncoder; + type Decoder = ProstDecoder; + + fn encoder(&mut self) -> Self::Encoder { + ProstEncoder::new(BufferSettings { + buffer_size: 512, + yield_threshold: 4096, + }) + } + + fn decoder(&mut self) -> Self::Decoder { + ProstDecoder::new(BufferSettings { + buffer_size: 512, + yield_threshold: 4096, + }) + } +} diff --git a/examples/src/codec_buffers/server.rs b/examples/src/codec_buffers/server.rs new file mode 100644 index 000000000..b30c797d3 --- /dev/null +++ b/examples/src/codec_buffers/server.rs @@ -0,0 +1,51 @@ +//! A HelloWorld example that uses a custom codec instead of the default Prost codec. +//! +//! Generated code is the output of codegen as defined in the `examples/build.rs` file. +//! The generation is the one with .codec_path("crate::common::SmallBufferCodec") +//! The generated code assumes that a module `crate::common` exists which defines +//! `SmallBufferCodec`, and `SmallBufferCodec` must have a Default implementation. + +use tonic::{transport::Server, Request, Response, Status}; + +pub mod common; + +pub mod small_buf { + include!(concat!(env!("OUT_DIR"), "/smallbuf/helloworld.rs")); +} +use small_buf::{ + greeter_server::{Greeter, GreeterServer}, + HelloReply, HelloRequest, +}; + +#[derive(Default)] +pub struct MyGreeter {} + +#[tonic::async_trait] +impl Greeter for MyGreeter { + async fn say_hello( + &self, + request: Request, + ) -> Result, Status> { + println!("Got a request from {:?}", request.remote_addr()); + + let reply = HelloReply { + message: format!("Hello {}!", request.into_inner().name), + }; + Ok(Response::new(reply)) + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let addr = "[::1]:50051".parse().unwrap(); + let greeter = MyGreeter::default(); + + println!("GreeterServer listening on {}", addr); + + Server::builder() + .add_service(GreeterServer::new(greeter)) + .serve(addr) + .await?; + + Ok(()) +} diff --git a/examples/src/json-codec/common.rs b/examples/src/json-codec/common.rs index 9f0ffeb54..d53584d47 100644 --- a/examples/src/json-codec/common.rs +++ b/examples/src/json-codec/common.rs @@ -30,6 +30,10 @@ impl Encoder for JsonEncoder { fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> { serde_json::to_writer(buf.writer(), &item).map_err(|e| Status::internal(e.to_string())) } + + fn buffer_settings(&self) -> tonic::codec::BufferSettings { + Default::default() + } } #[derive(Debug)] @@ -48,6 +52,10 @@ impl Decoder for JsonDecoder { serde_json::from_reader(buf.reader()).map_err(|e| Status::internal(e.to_string()))?; Ok(Some(item)) } + + fn buffer_settings(&self) -> tonic::codec::BufferSettings { + Default::default() + } } /// A [`Codec`] that implements `application/grpc+json` via the serde library. diff --git a/tonic-build/src/client.rs b/tonic-build/src/client.rs index 4023cb64b..180ad8294 100644 --- a/tonic-build/src/client.rs +++ b/tonic-build/src/client.rs @@ -221,7 +221,7 @@ fn generate_unary( proto_path: &str, compile_well_known_types: bool, ) -> TokenStream { - let codec_name = syn::parse_str::(method.codec_path()).unwrap(); + let codec_name = syn::parse_str::(&method.codec_path()).unwrap(); let ident = format_ident!("{}", method.name()); let (request, response) = method.request_response_name(proto_path, compile_well_known_types); let service_name = format_service_name(service, emit_package); @@ -252,7 +252,7 @@ fn generate_server_streaming( proto_path: &str, compile_well_known_types: bool, ) -> TokenStream { - let codec_name = syn::parse_str::(method.codec_path()).unwrap(); + let codec_name = syn::parse_str::(&method.codec_path()).unwrap(); let ident = format_ident!("{}", method.name()); let (request, response) = method.request_response_name(proto_path, compile_well_known_types); let service_name = format_service_name(service, emit_package); @@ -283,7 +283,7 @@ fn generate_client_streaming( proto_path: &str, compile_well_known_types: bool, ) -> TokenStream { - let codec_name = syn::parse_str::(method.codec_path()).unwrap(); + let codec_name = syn::parse_str::(&method.codec_path()).unwrap(); let ident = format_ident!("{}", method.name()); let (request, response) = method.request_response_name(proto_path, compile_well_known_types); let service_name = format_service_name(service, emit_package); @@ -314,7 +314,7 @@ fn generate_streaming( proto_path: &str, compile_well_known_types: bool, ) -> TokenStream { - let codec_name = syn::parse_str::(method.codec_path()).unwrap(); + let codec_name = syn::parse_str::(&method.codec_path()).unwrap(); let ident = format_ident!("{}", method.name()); let (request, response) = method.request_response_name(proto_path, compile_well_known_types); let service_name = format_service_name(service, emit_package); diff --git a/tonic-build/src/compile_settings.rs b/tonic-build/src/compile_settings.rs new file mode 100644 index 000000000..dc088fff4 --- /dev/null +++ b/tonic-build/src/compile_settings.rs @@ -0,0 +1,69 @@ +use std::{ + marker::PhantomData, + mem::take, + sync::{Mutex, MutexGuard}, +}; + +#[derive(Debug, Clone)] +pub(crate) struct CompileSettings { + pub(crate) codec_path: String, +} + +impl Default for CompileSettings { + fn default() -> Self { + Self { + codec_path: "tonic::codec::ProstCodec".to_string(), + } + } +} + +thread_local! { + static COMPILE_SETTINGS: Mutex> = Default::default(); +} + +/// Called before compile, this installs a CompileSettings in the current thread's +/// context, so that live code generation can access the settings. +/// The previous state is restored when you drop the SettingsGuard. +pub(crate) fn set_context(new_settings: CompileSettings) -> SettingsGuard { + COMPILE_SETTINGS.with(|settings| { + let mut guard = settings + .lock() + .expect("threadlocal mutex should always succeed"); + let old_settings = guard.clone(); + *guard = Some(new_settings); + SettingsGuard { + previous_settings: old_settings, + _pd: PhantomData, + } + }) +} + +/// Access the current compile settings. This is populated only during +/// code generation compile() or compile_with_config() time. +pub(crate) fn load() -> CompileSettings { + COMPILE_SETTINGS.with(|settings| { + settings + .lock() + .expect("threadlocal mutex should always succeed") + .clone() + .unwrap_or_default() + }) +} + +type PhantomUnsend = PhantomData>; + +pub(crate) struct SettingsGuard { + previous_settings: Option, + _pd: PhantomUnsend, +} + +impl Drop for SettingsGuard { + fn drop(&mut self) { + COMPILE_SETTINGS.with(|settings| { + let mut guard = settings + .lock() + .expect("threadlocal mutex should always succeed"); + *guard = take(&mut self.previous_settings); + }) + } +} diff --git a/tonic-build/src/lib.rs b/tonic-build/src/lib.rs index ddd739c62..81ad46e34 100644 --- a/tonic-build/src/lib.rs +++ b/tonic-build/src/lib.rs @@ -97,6 +97,9 @@ pub mod server; mod code_gen; pub use code_gen::CodeGenBuilder; +mod compile_settings; +pub(crate) use compile_settings::CompileSettings; + /// Service generation trait. /// /// This trait can be implemented and consumed @@ -137,7 +140,7 @@ pub trait Method { /// Identifier used to generate type name. fn identifier(&self) -> &str; /// Path to the codec. - fn codec_path(&self) -> &str; + fn codec_path(&self) -> String; /// Method is streamed by client. fn client_streaming(&self) -> bool; /// Method is streamed by server. diff --git a/tonic-build/src/manual.rs b/tonic-build/src/manual.rs index a6876cab9..b83df7ac5 100644 --- a/tonic-build/src/manual.rs +++ b/tonic-build/src/manual.rs @@ -195,8 +195,8 @@ impl crate::Method for Method { &self.route_name } - fn codec_path(&self) -> &str { - &self.codec_path + fn codec_path(&self) -> String { + self.codec_path.clone() } fn client_streaming(&self) -> bool { diff --git a/tonic-build/src/prost.rs b/tonic-build/src/prost.rs index 3202e2730..bd561b050 100644 --- a/tonic-build/src/prost.rs +++ b/tonic-build/src/prost.rs @@ -1,4 +1,4 @@ -use crate::code_gen::CodeGenBuilder; +use crate::{code_gen::CodeGenBuilder, compile_settings, CompileSettings}; use super::Attributes; use proc_macro2::TokenStream; @@ -41,6 +41,7 @@ pub fn configure() -> Builder { disable_comments: HashSet::default(), use_arc_self: false, generate_default_stubs: false, + compile_settings: CompileSettings::default(), } } @@ -61,8 +62,6 @@ pub fn compile_protos(proto: impl AsRef) -> io::Result<()> { Ok(()) } -const PROST_CODEC_PATH: &str = "tonic::codec::ProstCodec"; - /// Non-path Rust types allowed for request/response types. const NON_PATH_TYPE_ALLOWLIST: &[&str] = &["()"]; @@ -102,8 +101,17 @@ impl crate::Method for Method { &self.proto_name } - fn codec_path(&self) -> &str { - PROST_CODEC_PATH + /// For code generation, you can override the codec. + /// + /// You should set the codec path to an import path that has a free + /// function like `fn default()`. The default value is tonic::codec::ProstCodec, + /// which returns a default-configured ProstCodec. You may wish to configure + /// the codec, e.g., with a buffer configuration. + /// + /// Though ProstCodec implements Default, it is currently only required that + /// the function match the Default trait's function spec. + fn codec_path(&self) -> String { + compile_settings::load().codec_path } fn client_streaming(&self) -> bool { @@ -252,6 +260,7 @@ pub struct Builder { pub(crate) disable_comments: HashSet, pub(crate) use_arc_self: bool, pub(crate) generate_default_stubs: bool, + pub(crate) compile_settings: CompileSettings, out_dir: Option, } @@ -524,6 +533,16 @@ impl Builder { self } + /// Override the default codec. + /// + /// If set, writes `{codec_path}::default()` in generated code wherever a codec is created. + /// + /// This defaults to `"tonic::codec::ProstCodec"` + pub fn codec_path(mut self, codec_path: impl Into) -> Self { + self.compile_settings.codec_path = codec_path.into(); + self + } + /// Compile the .proto files and execute code generation. pub fn compile( self, @@ -541,6 +560,7 @@ impl Builder { protos: &[impl AsRef], includes: &[impl AsRef], ) -> io::Result<()> { + let _compile_settings_guard = compile_settings::set_context(self.compile_settings.clone()); let out_dir = if let Some(out_dir) = self.out_dir.as_ref() { out_dir.clone() } else { diff --git a/tonic-build/src/server.rs b/tonic-build/src/server.rs index d9ab1ad6b..d343cc567 100644 --- a/tonic-build/src/server.rs +++ b/tonic-build/src/server.rs @@ -457,7 +457,7 @@ fn generate_unary( server_trait: Ident, use_arc_self: bool, ) -> TokenStream { - let codec_name = syn::parse_str::(method.codec_path()).unwrap(); + let codec_name = syn::parse_str::(&method.codec_path()).unwrap(); let service_ident = quote::format_ident!("{}Svc", method.identifier()); @@ -517,7 +517,7 @@ fn generate_server_streaming( use_arc_self: bool, generate_default_stubs: bool, ) -> TokenStream { - let codec_name = syn::parse_str::(method.codec_path()).unwrap(); + let codec_name = syn::parse_str::(&method.codec_path()).unwrap(); let service_ident = quote::format_ident!("{}Svc", method.identifier()); @@ -587,7 +587,7 @@ fn generate_client_streaming( let service_ident = quote::format_ident!("{}Svc", method.identifier()); let (request, response) = method.request_response_name(proto_path, compile_well_known_types); - let codec_name = syn::parse_str::(method.codec_path()).unwrap(); + let codec_name = syn::parse_str::(&method.codec_path()).unwrap(); let inner_arg = if use_arc_self { quote!(inner) @@ -644,7 +644,7 @@ fn generate_streaming( use_arc_self: bool, generate_default_stubs: bool, ) -> TokenStream { - let codec_name = syn::parse_str::(method.codec_path()).unwrap(); + let codec_name = syn::parse_str::(&method.codec_path()).unwrap(); let service_ident = quote::format_ident!("{}Svc", method.identifier()); diff --git a/tonic/benches/decode.rs b/tonic/benches/decode.rs index 5c7cd0159..f5d613ce5 100644 --- a/tonic/benches/decode.rs +++ b/tonic/benches/decode.rs @@ -105,6 +105,10 @@ impl Decoder for MockDecoder { buf.advance(self.message_size); Ok(Some(out)) } + + fn buffer_settings(&self) -> tonic::codec::BufferSettings { + Default::default() + } } fn make_payload(message_length: usize, message_count: usize) -> Bytes { diff --git a/tonic/src/codec/compression.rs b/tonic/src/codec/compression.rs index 70d758415..e00b8ca8f 100644 --- a/tonic/src/codec/compression.rs +++ b/tonic/src/codec/compression.rs @@ -1,4 +1,3 @@ -use super::encode::BUFFER_SIZE; use crate::{metadata::MetadataValue, Status}; use bytes::{Buf, BytesMut}; #[cfg(feature = "gzip")] @@ -70,6 +69,14 @@ impl EnabledCompressionEncodings { } } +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub(crate) struct CompressionSettings { + pub(crate) encoding: CompressionEncoding, + /// buffer_growth_interval controls memory growth for internal buffers to balance resizing cost against memory waste. + /// The default buffer growth interval is 8 kilobytes. + pub(crate) buffer_growth_interval: usize, +} + /// The compression encodings Tonic supports. #[derive(Clone, Copy, Debug, PartialEq, Eq)] #[non_exhaustive] @@ -195,20 +202,22 @@ fn split_by_comma(s: &str) -> impl Iterator { } /// Compress `len` bytes from `decompressed_buf` into `out_buf`. +/// buffer_size_increment is a hint to control the growth of out_buf versus the cost of resizing it. #[allow(unused_variables, unreachable_code)] pub(crate) fn compress( - encoding: CompressionEncoding, + settings: CompressionSettings, decompressed_buf: &mut BytesMut, out_buf: &mut BytesMut, len: usize, ) -> Result<(), std::io::Error> { - let capacity = ((len / BUFFER_SIZE) + 1) * BUFFER_SIZE; + let buffer_growth_interval = settings.buffer_growth_interval; + let capacity = ((len / buffer_growth_interval) + 1) * buffer_growth_interval; out_buf.reserve(capacity); #[cfg(any(feature = "gzip", feature = "zstd"))] let mut out_writer = bytes::BufMut::writer(out_buf); - match encoding { + match settings.encoding { #[cfg(feature = "gzip")] CompressionEncoding::Gzip => { let mut gzip_encoder = GzEncoder::new( @@ -237,19 +246,21 @@ pub(crate) fn compress( /// Decompress `len` bytes from `compressed_buf` into `out_buf`. #[allow(unused_variables, unreachable_code)] pub(crate) fn decompress( - encoding: CompressionEncoding, + settings: CompressionSettings, compressed_buf: &mut BytesMut, out_buf: &mut BytesMut, len: usize, ) -> Result<(), std::io::Error> { + let buffer_growth_interval = settings.buffer_growth_interval; let estimate_decompressed_len = len * 2; - let capacity = ((estimate_decompressed_len / BUFFER_SIZE) + 1) * BUFFER_SIZE; + let capacity = + ((estimate_decompressed_len / buffer_growth_interval) + 1) * buffer_growth_interval; out_buf.reserve(capacity); #[cfg(any(feature = "gzip", feature = "zstd"))] let mut out_writer = bytes::BufMut::writer(out_buf); - match encoding { + match settings.encoding { #[cfg(feature = "gzip")] CompressionEncoding::Gzip => { let mut gzip_decoder = GzDecoder::new(&compressed_buf[0..len]); diff --git a/tonic/src/codec/decode.rs b/tonic/src/codec/decode.rs index cb88a0649..081f6193d 100644 --- a/tonic/src/codec/decode.rs +++ b/tonic/src/codec/decode.rs @@ -1,5 +1,5 @@ -use super::compression::{decompress, CompressionEncoding}; -use super::{DecodeBuf, Decoder, DEFAULT_MAX_RECV_MESSAGE_SIZE, HEADER_SIZE}; +use super::compression::{decompress, CompressionEncoding, CompressionSettings}; +use super::{BufferSettings, DecodeBuf, Decoder, DEFAULT_MAX_RECV_MESSAGE_SIZE, HEADER_SIZE}; use crate::{body::BoxBody, metadata::MetadataMap, Code, Status}; use bytes::{Buf, BufMut, BytesMut}; use http::StatusCode; @@ -13,8 +13,6 @@ use std::{ use tokio_stream::Stream; use tracing::{debug, trace}; -const BUFFER_SIZE: usize = 8 * 1024; - /// Streaming requests and responses. /// /// This will wrap some inner [`Body`] and [`Decoder`] and provide an interface @@ -118,6 +116,7 @@ impl Streaming { B::Error: Into, D: Decoder + Send + 'static, { + let buffer_size = decoder.buffer_settings().buffer_size; Self { decoder: Box::new(decoder), inner: StreamingInner { @@ -127,7 +126,7 @@ impl Streaming { .boxed_unsync(), state: State::ReadHeader, direction, - buf: BytesMut::with_capacity(BUFFER_SIZE), + buf: BytesMut::with_capacity(buffer_size), trailers: None, decompress_buf: BytesMut::new(), encoding, @@ -138,7 +137,10 @@ impl Streaming { } impl StreamingInner { - fn decode_chunk(&mut self) -> Result>, Status> { + fn decode_chunk( + &mut self, + buffer_settings: BufferSettings, + ) -> Result>, Status> { if let State::ReadHeader = self.state { if self.buf.remaining() < HEADER_SIZE { return Ok(None); @@ -205,8 +207,15 @@ impl StreamingInner { let decode_buf = if let Some(encoding) = compression { self.decompress_buf.clear(); - if let Err(err) = decompress(encoding, &mut self.buf, &mut self.decompress_buf, len) - { + if let Err(err) = decompress( + CompressionSettings { + encoding, + buffer_growth_interval: buffer_settings.buffer_size, + }, + &mut self.buf, + &mut self.decompress_buf, + len, + ) { let message = if let Direction::Response(status) = self.direction { format!( "Error decompressing: {}, while receiving response with status: {}", @@ -364,7 +373,7 @@ impl Streaming { } fn decode_chunk(&mut self) -> Result, Status> { - match self.inner.decode_chunk()? { + match self.inner.decode_chunk(self.decoder.buffer_settings())? { Some(mut decode_buf) => match self.decoder.decode(&mut decode_buf)? { Some(msg) => { self.inner.state = State::ReadHeader; diff --git a/tonic/src/codec/encode.rs b/tonic/src/codec/encode.rs index 13eb2c96d..396f77399 100644 --- a/tonic/src/codec/encode.rs +++ b/tonic/src/codec/encode.rs @@ -1,5 +1,7 @@ -use super::compression::{compress, CompressionEncoding, SingleMessageCompressionOverride}; -use super::{EncodeBuf, Encoder, DEFAULT_MAX_SEND_MESSAGE_SIZE, HEADER_SIZE}; +use super::compression::{ + compress, CompressionEncoding, CompressionSettings, SingleMessageCompressionOverride, +}; +use super::{BufferSettings, EncodeBuf, Encoder, DEFAULT_MAX_SEND_MESSAGE_SIZE, HEADER_SIZE}; use crate::{Code, Status}; use bytes::{BufMut, Bytes, BytesMut}; use http::HeaderMap; @@ -11,9 +13,6 @@ use std::{ }; use tokio_stream::{Stream, StreamExt}; -pub(super) const BUFFER_SIZE: usize = 8 * 1024; -const YIELD_THRESHOLD: usize = 32 * 1024; - pub(crate) fn encode_server( encoder: T, source: U, @@ -90,7 +89,8 @@ where compression_override: SingleMessageCompressionOverride, max_message_size: Option, ) -> Self { - let buf = BytesMut::with_capacity(BUFFER_SIZE); + let buffer_settings = encoder.buffer_settings(); + let buf = BytesMut::with_capacity(buffer_settings.buffer_size); let compression_encoding = if compression_override == SingleMessageCompressionOverride::Disable { @@ -100,7 +100,7 @@ where }; let uncompression_buf = if compression_encoding.is_some() { - BytesMut::with_capacity(BUFFER_SIZE) + BytesMut::with_capacity(buffer_settings.buffer_size) } else { BytesMut::new() }; @@ -132,6 +132,7 @@ where buf, uncompression_buf, } = self.project(); + let buffer_settings = encoder.buffer_settings(); loop { match source.as_mut().poll_next(cx) { @@ -151,12 +152,13 @@ where uncompression_buf, *compression_encoding, *max_message_size, + buffer_settings, item, ) { return Poll::Ready(Some(Err(status))); } - if buf.len() >= YIELD_THRESHOLD { + if buf.len() >= buffer_settings.yield_threshold { return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze()))); } } @@ -174,6 +176,7 @@ fn encode_item( uncompression_buf: &mut BytesMut, compression_encoding: Option, max_message_size: Option, + buffer_settings: BufferSettings, item: T::Item, ) -> Result<(), Status> where @@ -195,8 +198,16 @@ where let uncompressed_len = uncompression_buf.len(); - compress(encoding, uncompression_buf, buf, uncompressed_len) - .map_err(|err| Status::internal(format!("Error compressing: {}", err)))?; + compress( + CompressionSettings { + encoding, + buffer_growth_interval: buffer_settings.buffer_size, + }, + uncompression_buf, + buf, + uncompressed_len, + ) + .map_err(|err| Status::internal(format!("Error compressing: {}", err)))?; } else { encoder .encode(item, &mut EncodeBuf::new(buf)) diff --git a/tonic/src/codec/mod.rs b/tonic/src/codec/mod.rs index 306621329..099f12c74 100644 --- a/tonic/src/codec/mod.rs +++ b/tonic/src/codec/mod.rs @@ -21,6 +21,41 @@ pub use self::decode::Streaming; #[cfg(feature = "prost")] #[cfg_attr(docsrs, doc(cfg(feature = "prost")))] pub use self::prost::ProstCodec; +#[cfg(feature = "prost")] +#[cfg_attr(docsrs, doc(cfg(feature = "prost")))] +pub use self::prost::ProstDecoder; +#[cfg(feature = "prost")] +#[cfg_attr(docsrs, doc(cfg(feature = "prost")))] +pub use self::prost::ProstEncoder; + +/// Unless overridden, this is the buffer size used for encoding requests. +/// This is spent per-rpc, so you may wish to adjust it. The default is +/// pretty good for most uses, but if you have a ton of concurrent rpcs +/// you may find it too expensive. +const DEFAULT_CODEC_BUFFER_SIZE: usize = 8 * 1024; +const DEFAULT_YIELD_THRESHOLD: usize = 32 * 1024; + +/// Settings for how tonic allocates and grows buffers. +#[derive(Clone, Copy, Debug)] +pub struct BufferSettings { + /// Initial buffer size, and the growth unit for cases where the size + /// is larger than the buffer's current capacity. Defaults to 8 KiB. + /// + /// Notably, this is eagerly allocated per streaming rpc. + pub buffer_size: usize, + + /// Soft maximum size for returning a stream's ready contents in a batch, + /// rather than one-by-one. Defaults to 32 KiB. + pub yield_threshold: usize, +} +impl Default for BufferSettings { + fn default() -> Self { + Self { + buffer_size: DEFAULT_CODEC_BUFFER_SIZE, + yield_threshold: DEFAULT_YIELD_THRESHOLD, + } + } +} // 5 bytes const HEADER_SIZE: usize = @@ -63,6 +98,9 @@ pub trait Encoder { /// Encodes a message into the provided buffer. fn encode(&mut self, item: Self::Item, dst: &mut EncodeBuf<'_>) -> Result<(), Self::Error>; + + /// Controls how tonic creates and expands encode buffers. + fn buffer_settings(&self) -> BufferSettings; } /// Decodes gRPC message types @@ -79,4 +117,7 @@ pub trait Decoder { /// is no need to get the length from the bytes, gRPC framing is handled /// for you. fn decode(&mut self, src: &mut DecodeBuf<'_>) -> Result, Self::Error>; + + /// Controls how tonic creates and expands decode buffers. + fn buffer_settings(&self) -> BufferSettings; } diff --git a/tonic/src/codec/prost.rs b/tonic/src/codec/prost.rs index d2f1652f4..176329a64 100644 --- a/tonic/src/codec/prost.rs +++ b/tonic/src/codec/prost.rs @@ -1,4 +1,4 @@ -use super::{Codec, DecodeBuf, Decoder, Encoder}; +use super::{BufferSettings, Codec, DecodeBuf, Decoder, Encoder}; use crate::codec::EncodeBuf; use crate::{Code, Status}; use prost::Message; @@ -8,11 +8,23 @@ use std::marker::PhantomData; #[derive(Debug, Clone)] pub struct ProstCodec { _pd: PhantomData<(T, U)>, + buffer_settings: BufferSettings, +} + +impl ProstCodec { + /// Configure a ProstCodec with encoder/decoder buffer settings. This is used to control + /// how memory is allocated and grows per RPC. + pub fn new(buffer_settings: BufferSettings) -> Self { + Self { + _pd: PhantomData, + buffer_settings, + } + } } impl Default for ProstCodec { fn default() -> Self { - Self { _pd: PhantomData } + Self::new(Default::default()) } } @@ -28,17 +40,36 @@ where type Decoder = ProstDecoder; fn encoder(&mut self) -> Self::Encoder { - ProstEncoder(PhantomData) + ProstEncoder { + _pd: PhantomData, + buffer_settings: self.buffer_settings, + } } fn decoder(&mut self) -> Self::Decoder { - ProstDecoder(PhantomData) + ProstDecoder { + _pd: PhantomData, + buffer_settings: self.buffer_settings, + } } } /// A [`Encoder`] that knows how to encode `T`. #[derive(Debug, Clone, Default)] -pub struct ProstEncoder(PhantomData); +pub struct ProstEncoder { + _pd: PhantomData, + buffer_settings: BufferSettings, +} + +impl ProstEncoder { + /// Get a new encoder with explicit buffer settings + pub fn new(buffer_settings: BufferSettings) -> Self { + Self { + _pd: PhantomData, + buffer_settings, + } + } +} impl Encoder for ProstEncoder { type Item = T; @@ -50,11 +81,28 @@ impl Encoder for ProstEncoder { Ok(()) } + + fn buffer_settings(&self) -> BufferSettings { + self.buffer_settings + } } /// A [`Decoder`] that knows how to decode `U`. #[derive(Debug, Clone, Default)] -pub struct ProstDecoder(PhantomData); +pub struct ProstDecoder { + _pd: PhantomData, + buffer_settings: BufferSettings, +} + +impl ProstDecoder { + /// Get a new decoder with explicit buffer settings + pub fn new(buffer_settings: BufferSettings) -> Self { + Self { + _pd: PhantomData, + buffer_settings, + } + } +} impl Decoder for ProstDecoder { type Item = U; @@ -67,6 +115,10 @@ impl Decoder for ProstDecoder { Ok(item) } + + fn buffer_settings(&self) -> BufferSettings { + self.buffer_settings + } } fn from_decode_error(error: prost::DecodeError) -> crate::Status { @@ -249,6 +301,10 @@ mod tests { buf.put(&item[..]); Ok(()) } + + fn buffer_settings(&self) -> crate::codec::BufferSettings { + Default::default() + } } #[derive(Debug, Clone, Default)] @@ -263,6 +319,10 @@ mod tests { buf.advance(LEN); Ok(Some(out)) } + + fn buffer_settings(&self) -> crate::codec::BufferSettings { + Default::default() + } } mod body {