Skip to content

Commit

Permalink
test batcher
Browse files Browse the repository at this point in the history
  • Loading branch information
TroyKomodo committed Oct 20, 2024
1 parent 66bac44 commit b44abd8
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 19 deletions.
46 changes: 28 additions & 18 deletions foundations/src/batcher/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use std::sync::atomic::{AtomicU64, AtomicUsize};
use std::sync::Arc;

use tokio::sync::OnceCell;
use tracing::Instrument;

pub mod dataloader;

Expand Down Expand Up @@ -225,14 +224,15 @@ impl Drop for CancelOnDrop {
}

struct BatcherInner<T: BatchOperation> {
semaphore: tokio::sync::Semaphore,
semaphore: Arc<tokio::sync::Semaphore>,
notify: tokio::sync::Notify,
sleep_duration: AtomicU64,
batch_id: AtomicU64,
max_batch_size: AtomicUsize,
operation: T,
name: String,
active_batch: tokio::sync::RwLock<Option<Batch<T>>>,
queued_batches: tokio::sync::mpsc::Sender<Batch<T>>,
}

struct Batch<T: BatchOperation> {
Expand Down Expand Up @@ -288,18 +288,14 @@ impl<E: std::error::Error> From<E> for BatcherError<E> {

impl<T: BatchOperation + 'static + Send + Sync> Batch<T> {
#[tracing::instrument(skip_all, fields(name = %inner.name))]
async fn run(self, inner: Arc<BatcherInner<T>>) {
async fn run(self, inner: Arc<BatcherInner<T>>, ticket: tokio::sync::OwnedSemaphorePermit) {
self.results
.get_or_init(|| async move {
let _ticket = inner
.semaphore
.acquire()
.instrument(tracing::debug_span!("Semaphore"))
.await
.map_err(|_| BatcherError::AcquireSemaphore)?;
Ok(inner.operation.process(self.ops).await.map_err(BatcherError::Batch)?)
inner.operation.process(self.ops).await.map_err(BatcherError::Batch)
})
.await;

drop(ticket);
}
}

Expand All @@ -312,8 +308,8 @@ pub struct BatcherConfig {
}

impl<T: BatchOperation + 'static + Send + Sync> BatcherInner<T> {
fn spawn_batch(self: &Arc<Self>, batch: Batch<T>) {
tokio::spawn(batch.run(self.clone()));
fn spawn_batch(self: &Arc<Self>, batch: Batch<T>, ticket: tokio::sync::OwnedSemaphorePermit) {
tokio::spawn(batch.run(self.clone(), ticket));
}

fn new_batch(&self) -> Batch<T> {
Expand Down Expand Up @@ -342,7 +338,7 @@ impl<T: BatchOperation + 'static + Send + Sync> BatcherInner<T> {
.unwrap_or(true)
{
if let Some(b) = batch.take() {
self.spawn_batch(b);
self.queued_batches.send(b).await.ok();
}

*batch = Some(self.new_batch());
Expand Down Expand Up @@ -374,8 +370,11 @@ impl<T: BatchOperation + 'static + Send + Sync> Batcher<T> {
pub fn new(operation: T) -> Self {
let config = operation.config();

let (tx, mut rx) = tokio::sync::mpsc::channel(64);

let inner = Arc::new(BatcherInner {
semaphore: tokio::sync::Semaphore::new(config.concurrency),
semaphore: Arc::new(tokio::sync::Semaphore::new(config.concurrency)),
queued_batches: tx.clone(),
notify: tokio::sync::Notify::new(),
batch_id: AtomicU64::new(0),
active_batch: tokio::sync::RwLock::new(None),
Expand All @@ -390,6 +389,13 @@ impl<T: BatchOperation + 'static + Send + Sync> Batcher<T> {
_auto_loader_abort: CancelOnDrop(
tokio::task::spawn(async move {
loop {
tokio::select! {
Some(batch) = rx.recv() => {
let ticket = inner.semaphore.clone().acquire_owned().await.unwrap();
inner.spawn_batch(batch, ticket);
},
_ = inner.notify.notified() => {},
}
inner.notify.notified().await;
let Some((id, expires_at)) = inner.active_batch.read().await.as_ref().map(|b| (b.id, b.expires_at))
else {
Expand All @@ -399,11 +405,15 @@ impl<T: BatchOperation + 'static + Send + Sync> Batcher<T> {
if expires_at > tokio::time::Instant::now() {
tokio::time::sleep_until(expires_at).await;
}

let mut batch = inner.active_batch.write().await;
if batch.as_ref().is_some_and(|b| b.id == id) {
inner.spawn_batch(batch.take().unwrap());
}
let batch = if batch.as_ref().is_some_and(|b| b.id == id) {
batch.take().unwrap()
} else {
continue;
};

tx.send(batch).await.ok();
}
})
.abort_handle(),
Expand Down
6 changes: 5 additions & 1 deletion image-processor/src/management/grpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@ impl ManagementServer {
#[tracing::instrument(skip_all)]
pub async fn run_grpc(&self, addr: std::net::SocketAddr) -> Result<(), tonic::transport::Error> {
let server = tonic::transport::Server::builder()
.add_service(scuffle_image_processor_proto::image_processor_server::ImageProcessorServer::new(self.clone()))
.add_service(
scuffle_image_processor_proto::image_processor_server::ImageProcessorServer::new(self.clone())
.max_decoding_message_size(128 * 1024 * 1024)
.max_encoding_message_size(128 * 1024 * 1024)
)
.serve_with_shutdown(addr, scuffle_foundations::context::Context::global().into_done());

tracing::info!("gRPC management server listening on {}", addr);
Expand Down

0 comments on commit b44abd8

Please sign in to comment.