Skip to content

Commit

Permalink
undo batcher changes
Browse files Browse the repository at this point in the history
  • Loading branch information
TroyKomodo committed Oct 20, 2024
1 parent 49aadf4 commit 35fbf2e
Showing 1 changed file with 20 additions and 46 deletions.
66 changes: 20 additions & 46 deletions foundations/src/batcher/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::sync::atomic::{AtomicU64, AtomicUsize};
use std::sync::Arc;

use tokio::sync::OnceCell;
use tracing::Instrument;

pub mod dataloader;

Expand Down Expand Up @@ -224,15 +225,14 @@ impl Drop for CancelOnDrop {
}

struct BatcherInner<T: BatchOperation> {
semaphore: Arc<tokio::sync::Semaphore>,
semaphore: tokio::sync::Semaphore,
notify: tokio::sync::Notify,
sleep_duration: AtomicU64,
batch_id: AtomicU64,
max_batch_size: AtomicUsize,
operation: T,
name: String,
active_batch: tokio::sync::RwLock<Option<Batch<T>>>,
queued_batches: tokio::sync::mpsc::Sender<Batch<T>>,
}

struct Batch<T: BatchOperation> {
Expand Down Expand Up @@ -288,14 +288,18 @@ impl<E: std::error::Error> From<E> for BatcherError<E> {

impl<T: BatchOperation + 'static + Send + Sync> Batch<T> {
#[tracing::instrument(skip_all, fields(name = %inner.name))]
async fn run(self, inner: Arc<BatcherInner<T>>, ticket: tokio::sync::OwnedSemaphorePermit) {
async fn run(self, inner: Arc<BatcherInner<T>>) {
self.results
.get_or_init(|| async move {
inner.operation.process(self.ops).await.map_err(BatcherError::Batch)
let _ticket = inner
.semaphore
.acquire()
.instrument(tracing::debug_span!("Semaphore"))
.await
.map_err(|_| BatcherError::AcquireSemaphore)?;
Ok(inner.operation.process(self.ops).await.map_err(BatcherError::Batch)?)
})
.await;

drop(ticket);
}
}

Expand All @@ -308,8 +312,8 @@ pub struct BatcherConfig {
}

impl<T: BatchOperation + 'static + Send + Sync> BatcherInner<T> {
fn spawn_batch(self: &Arc<Self>, batch: Batch<T>, ticket: tokio::sync::OwnedSemaphorePermit) {
tokio::spawn(batch.run(self.clone(), ticket));
fn spawn_batch(self: &Arc<Self>, batch: Batch<T>) {
tokio::spawn(batch.run(self.clone()));
}

fn new_batch(&self) -> Batch<T> {
Expand All @@ -330,7 +334,6 @@ impl<T: BatchOperation + 'static + Send + Sync> BatcherInner<T> {
let mut waiters = vec![];
let mut batch = self.active_batch.write().await;
let max_documents = self.max_batch_size.load(std::sync::atomic::Ordering::Relaxed);
let mut batches = vec![];

for document in T::Mode::filter_item_iter(documents) {
if batch
Expand All @@ -339,7 +342,7 @@ impl<T: BatchOperation + 'static + Send + Sync> BatcherInner<T> {
.unwrap_or(true)
{
if let Some(b) = batch.take() {
batches.push(b);
self.spawn_batch(b);
}

*batch = Some(self.new_batch());
Expand All @@ -363,10 +366,6 @@ impl<T: BatchOperation + 'static + Send + Sync> BatcherInner<T> {
T::Mode::input_add(&mut b.ops, tracker, document);
}

for batch in batches {
self.queued_batches.send(batch).await.ok();
}

waiters
}
}
Expand All @@ -375,11 +374,8 @@ impl<T: BatchOperation + 'static + Send + Sync> Batcher<T> {
pub fn new(operation: T) -> Self {
let config = operation.config();

let (tx, mut rx) = tokio::sync::mpsc::channel(64);

let inner = Arc::new(BatcherInner {
semaphore: Arc::new(tokio::sync::Semaphore::new(config.concurrency)),
queued_batches: tx.clone(),
semaphore: tokio::sync::Semaphore::new(config.concurrency),
notify: tokio::sync::Notify::new(),
batch_id: AtomicU64::new(0),
active_batch: tokio::sync::RwLock::new(None),
Expand All @@ -393,43 +389,21 @@ impl<T: BatchOperation + 'static + Send + Sync> Batcher<T> {
inner: inner.clone(),
_auto_loader_abort: CancelOnDrop(
tokio::task::spawn(async move {
let mut next_wakeup = None;
loop {
tokio::select! {
Some(batch) = rx.recv() => {
let ticket = inner.semaphore.clone().acquire_owned().await.unwrap();
inner.spawn_batch(batch, ticket);
},
_ = async {
if let Some(expires_at) = next_wakeup {
tokio::time::sleep_until(expires_at).await;
} else {
inner.notify.notified().await;
}
} => {},
_ = inner.notify.notified() => {},
}

inner.notify.notified().await;
let Some((id, expires_at)) = inner.active_batch.read().await.as_ref().map(|b| (b.id, b.expires_at))
else {
continue;
};

if expires_at > tokio::time::Instant::now() {
next_wakeup = Some(expires_at);
continue;
} else {
next_wakeup = None;
tokio::time::sleep_until(expires_at).await;
}

let mut batch = inner.active_batch.write().await;
let batch = if batch.as_ref().is_some_and(|b| b.id == id) {
batch.take().unwrap()
} else {
continue;
};

tx.send(batch).await.ok();
let mut batch = inner.active_batch.write().await;
if batch.as_ref().is_some_and(|b| b.id == id) {
inner.spawn_batch(batch.take().unwrap());
}
}
})
.abort_handle(),
Expand Down

0 comments on commit 35fbf2e

Please sign in to comment.