Skip to content

Commit

Permalink
feat(tonic): Custom codecs for generated code
Browse files Browse the repository at this point in the history
Broadly, this change does 2 things:
1. Allow the built-in Prost codec to have its buffer sizes customized
2. Allow users to specify custom codecs on the tonic_build::prost::Builder

The Prost codec is convenient, and handles any normal use case. However,
the buffer sizes today are too large in some cases - and they may grow too
aggressively. By exposing BufferSettings, users can make a small custom
codec with their own BufferSettings to control their memory usage - or give
enormous buffers to rpc's, as their use case requires.

While one can define a custom service and methods with a custom codec today
explicitly in Rust, the code generator does not have a means to supply a
custom codec. I've reached for .codec... on the tonic_build::prost::Builder
many times and keep forgetting it's not there. This change adds .codec_path
to the Builder, so people can simply add their custom buffer codec or even
their own full top level codec without reaching for manual service definition.
  • Loading branch information
kvcache committed Jan 12, 2024
1 parent 177c1f3 commit 9987023
Show file tree
Hide file tree
Showing 18 changed files with 425 additions and 48 deletions.
8 changes: 8 additions & 0 deletions examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,14 @@ required-features = ["cancellation"]
name = "cancellation-client"
path = "src/cancellation/client.rs"

[[bin]]
name = "codec-buffers-server"
path = "src/codec_buffers/server.rs"

[[bin]]
name = "codec-buffers-client"
path = "src/codec_buffers/client.rs"


[features]
gcp = ["dep:prost-types", "tonic/tls"]
Expand Down
8 changes: 8 additions & 0 deletions examples/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ fn main() {
.unwrap();

build_json_codec_service();

let smallbuff_copy = out_dir.join("smallbuf");
let _ = std::fs::create_dir(smallbuff_copy.clone()); // This will panic below if the directory failed to create
tonic_build::configure()
.out_dir(smallbuff_copy)
.codec_path("crate::common::SmallBufferCodec")
.compile(&["proto/helloworld/helloworld.proto"], &["proto"])
.unwrap();
}

// Manually define the json.helloworld.Greeter service which used a custom JsonCodec to use json
Expand Down
30 changes: 30 additions & 0 deletions examples/src/codec_buffers/client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
//! A HelloWorld example that uses a custom codec instead of the default Prost codec.
//!
//! Generated code is the output of codegen as defined in the `examples/build.rs` file.
//! The generation is the one with .codec_path("crate::common::SmallBufferCodec")
//! The generated code assumes that a module `crate::common` exists which defines
//! `SmallBufferCodec`, and `SmallBufferCodec` must have a Default implementation.
pub mod common;

pub mod small_buf {
include!(concat!(env!("OUT_DIR"), "/smallbuf/helloworld.rs"));
}
use small_buf::greeter_client::GreeterClient;

use crate::small_buf::HelloRequest;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let mut client = GreeterClient::connect("http://[::1]:50051").await?;

let request = tonic::Request::new(HelloRequest {
name: "Tonic".into(),
});

let response = client.say_hello(request).await?;

println!("RESPONSE={:?}", response);

Ok(())
}
44 changes: 44 additions & 0 deletions examples/src/codec_buffers/common.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
//! This module defines a common encoder with small buffers. This is useful
//! when you have many concurrent RPC's, and not a huge volume of data per
//! rpc normally.
//!
//! Note that you can customize your codecs per call to the code generator's
//! compile function. This lets you group services by their codec needs.
//!
//! While this codec demonstrates customizing the built-in Prost codec, you
//! can use this to implement other codecs as well, as long as they have a
//! Default implementation.
use std::marker::PhantomData;

use prost::Message;
use tonic::codec::{BufferSettings, Codec, ProstDecoder, ProstEncoder};

#[derive(Debug, Clone, Copy, Default)]
pub struct SmallBufferCodec<T, U>(PhantomData<(T, U)>);

impl<T, U> Codec for SmallBufferCodec<T, U>
where
T: Message + Send + 'static,
U: Message + Default + Send + 'static,
{
type Encode = T;
type Decode = U;

type Encoder = ProstEncoder<T>;
type Decoder = ProstDecoder<U>;

fn encoder(&mut self) -> Self::Encoder {
ProstEncoder::new(BufferSettings {
buffer_size: 512,
yield_threshold: 4096,
})
}

fn decoder(&mut self) -> Self::Decoder {
ProstDecoder::new(BufferSettings {
buffer_size: 512,
yield_threshold: 4096,
})
}
}
51 changes: 51 additions & 0 deletions examples/src/codec_buffers/server.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
//! A HelloWorld example that uses a custom codec instead of the default Prost codec.
//!
//! Generated code is the output of codegen as defined in the `examples/build.rs` file.
//! The generation is the one with .codec_path("crate::common::SmallBufferCodec")
//! The generated code assumes that a module `crate::common` exists which defines
//! `SmallBufferCodec`, and `SmallBufferCodec` must have a Default implementation.
use tonic::{transport::Server, Request, Response, Status};

pub mod common;

pub mod small_buf {
include!(concat!(env!("OUT_DIR"), "/smallbuf/helloworld.rs"));
}
use small_buf::{
greeter_server::{Greeter, GreeterServer},
HelloReply, HelloRequest,
};

#[derive(Default)]
pub struct MyGreeter {}

#[tonic::async_trait]
impl Greeter for MyGreeter {
async fn say_hello(
&self,
request: Request<HelloRequest>,
) -> Result<Response<HelloReply>, Status> {
println!("Got a request from {:?}", request.remote_addr());

let reply = HelloReply {
message: format!("Hello {}!", request.into_inner().name),
};
Ok(Response::new(reply))
}
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let addr = "[::1]:50051".parse().unwrap();
let greeter = MyGreeter::default();

println!("GreeterServer listening on {}", addr);

Server::builder()
.add_service(GreeterServer::new(greeter))
.serve(addr)
.await?;

Ok(())
}
8 changes: 8 additions & 0 deletions examples/src/json-codec/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ impl<T: serde::Serialize> Encoder for JsonEncoder<T> {
fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
serde_json::to_writer(buf.writer(), &item).map_err(|e| Status::internal(e.to_string()))
}

fn buffer_settings(&self) -> tonic::codec::BufferSettings {
Default::default()
}
}

#[derive(Debug)]
Expand All @@ -48,6 +52,10 @@ impl<U: serde::de::DeserializeOwned> Decoder for JsonDecoder<U> {
serde_json::from_reader(buf.reader()).map_err(|e| Status::internal(e.to_string()))?;
Ok(Some(item))
}

fn buffer_settings(&self) -> tonic::codec::BufferSettings {
Default::default()
}
}

/// A [`Codec`] that implements `application/grpc+json` via the serde library.
Expand Down
8 changes: 4 additions & 4 deletions tonic-build/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ fn generate_unary<T: Service>(
proto_path: &str,
compile_well_known_types: bool,
) -> TokenStream {
let codec_name = syn::parse_str::<syn::Path>(method.codec_path()).unwrap();
let codec_name = syn::parse_str::<syn::Path>(&method.codec_path()).unwrap();
let ident = format_ident!("{}", method.name());
let (request, response) = method.request_response_name(proto_path, compile_well_known_types);
let service_name = format_service_name(service, emit_package);
Expand Down Expand Up @@ -252,7 +252,7 @@ fn generate_server_streaming<T: Service>(
proto_path: &str,
compile_well_known_types: bool,
) -> TokenStream {
let codec_name = syn::parse_str::<syn::Path>(method.codec_path()).unwrap();
let codec_name = syn::parse_str::<syn::Path>(&method.codec_path()).unwrap();
let ident = format_ident!("{}", method.name());
let (request, response) = method.request_response_name(proto_path, compile_well_known_types);
let service_name = format_service_name(service, emit_package);
Expand Down Expand Up @@ -283,7 +283,7 @@ fn generate_client_streaming<T: Service>(
proto_path: &str,
compile_well_known_types: bool,
) -> TokenStream {
let codec_name = syn::parse_str::<syn::Path>(method.codec_path()).unwrap();
let codec_name = syn::parse_str::<syn::Path>(&method.codec_path()).unwrap();
let ident = format_ident!("{}", method.name());
let (request, response) = method.request_response_name(proto_path, compile_well_known_types);
let service_name = format_service_name(service, emit_package);
Expand Down Expand Up @@ -314,7 +314,7 @@ fn generate_streaming<T: Service>(
proto_path: &str,
compile_well_known_types: bool,
) -> TokenStream {
let codec_name = syn::parse_str::<syn::Path>(method.codec_path()).unwrap();
let codec_name = syn::parse_str::<syn::Path>(&method.codec_path()).unwrap();
let ident = format_ident!("{}", method.name());
let (request, response) = method.request_response_name(proto_path, compile_well_known_types);
let service_name = format_service_name(service, emit_package);
Expand Down
69 changes: 69 additions & 0 deletions tonic-build/src/compile_settings.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use std::{
marker::PhantomData,
mem::take,
sync::{Mutex, MutexGuard},
};

#[derive(Debug, Clone)]
pub(crate) struct CompileSettings {
pub(crate) codec_path: String,
}

impl Default for CompileSettings {
fn default() -> Self {
Self {
codec_path: "tonic::codec::ProstCodec".to_string(),
}
}
}

thread_local! {
static COMPILE_SETTINGS: Mutex<Option<CompileSettings>> = Default::default();
}

/// Called before compile, this installs a CompileSettings in the current thread's
/// context, so that live code generation can access the settings.
/// The previous state is restored when you drop the SettingsGuard.
pub(crate) fn set_context(new_settings: CompileSettings) -> SettingsGuard {
COMPILE_SETTINGS.with(|settings| {
let mut guard = settings
.lock()
.expect("threadlocal mutex should always succeed");
let old_settings = guard.clone();
*guard = Some(new_settings);
SettingsGuard {
previous_settings: old_settings,
_pd: PhantomData,
}
})
}

/// Access the current compile settings. This is populated only during
/// code generation compile() or compile_with_config() time.
pub(crate) fn load() -> CompileSettings {
COMPILE_SETTINGS.with(|settings| {
settings
.lock()
.expect("threadlocal mutex should always succeed")
.clone()
.unwrap_or_default()
})
}

type PhantomUnsend = PhantomData<MutexGuard<'static, ()>>;

pub(crate) struct SettingsGuard {
previous_settings: Option<CompileSettings>,
_pd: PhantomUnsend,
}

impl Drop for SettingsGuard {
fn drop(&mut self) {
COMPILE_SETTINGS.with(|settings| {
let mut guard = settings
.lock()
.expect("threadlocal mutex should always succeed");
*guard = take(&mut self.previous_settings);
})
}
}
5 changes: 4 additions & 1 deletion tonic-build/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ pub mod server;
mod code_gen;
pub use code_gen::CodeGenBuilder;

mod compile_settings;
pub(crate) use compile_settings::CompileSettings;

/// Service generation trait.
///
/// This trait can be implemented and consumed
Expand Down Expand Up @@ -137,7 +140,7 @@ pub trait Method {
/// Identifier used to generate type name.
fn identifier(&self) -> &str;
/// Path to the codec.
fn codec_path(&self) -> &str;
fn codec_path(&self) -> String;
/// Method is streamed by client.
fn client_streaming(&self) -> bool;
/// Method is streamed by server.
Expand Down
4 changes: 2 additions & 2 deletions tonic-build/src/manual.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,8 @@ impl crate::Method for Method {
&self.route_name
}

fn codec_path(&self) -> &str {
&self.codec_path
fn codec_path(&self) -> String {
self.codec_path.clone()
}

fn client_streaming(&self) -> bool {
Expand Down
30 changes: 25 additions & 5 deletions tonic-build/src/prost.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::code_gen::CodeGenBuilder;
use crate::{code_gen::CodeGenBuilder, compile_settings, CompileSettings};

use super::Attributes;
use proc_macro2::TokenStream;
Expand Down Expand Up @@ -41,6 +41,7 @@ pub fn configure() -> Builder {
disable_comments: HashSet::default(),
use_arc_self: false,
generate_default_stubs: false,
compile_settings: CompileSettings::default(),
}
}

Expand All @@ -61,8 +62,6 @@ pub fn compile_protos(proto: impl AsRef<Path>) -> io::Result<()> {
Ok(())
}

const PROST_CODEC_PATH: &str = "tonic::codec::ProstCodec";

/// Non-path Rust types allowed for request/response types.
const NON_PATH_TYPE_ALLOWLIST: &[&str] = &["()"];

Expand Down Expand Up @@ -102,8 +101,17 @@ impl crate::Method for Method {
&self.proto_name
}

fn codec_path(&self) -> &str {
PROST_CODEC_PATH
/// For code generation, you can override the codec.
///
/// You should set the codec path to an import path that has a free
/// function like `fn default()`. The default value is tonic::codec::ProstCodec,
/// which returns a default-configured ProstCodec. You may wish to configure
/// the codec, e.g., with a buffer configuration.
///
/// Though ProstCodec implements Default, it is currently only required that
/// the function match the Default trait's function spec.
fn codec_path(&self) -> String {
compile_settings::load().codec_path
}

fn client_streaming(&self) -> bool {
Expand Down Expand Up @@ -252,6 +260,7 @@ pub struct Builder {
pub(crate) disable_comments: HashSet<String>,
pub(crate) use_arc_self: bool,
pub(crate) generate_default_stubs: bool,
pub(crate) compile_settings: CompileSettings,

out_dir: Option<PathBuf>,
}
Expand Down Expand Up @@ -524,6 +533,16 @@ impl Builder {
self
}

/// Override the default codec.
///
/// If set, writes `{codec_path}::default()` in generated code wherever a codec is created.
///
/// This defaults to `"tonic::codec::ProstCodec"`
pub fn codec_path(mut self, codec_path: impl Into<String>) -> Self {
self.compile_settings.codec_path = codec_path.into();
self
}

/// Compile the .proto files and execute code generation.
pub fn compile(
self,
Expand All @@ -541,6 +560,7 @@ impl Builder {
protos: &[impl AsRef<Path>],
includes: &[impl AsRef<Path>],
) -> io::Result<()> {
let _compile_settings_guard = compile_settings::set_context(self.compile_settings.clone());
let out_dir = if let Some(out_dir) = self.out_dir.as_ref() {
out_dir.clone()
} else {
Expand Down
Loading

0 comments on commit 9987023

Please sign in to comment.