diff --git a/crates/llama-cpp-bindings/src/lib.rs b/crates/llama-cpp-bindings/src/lib.rs index 207fd3d6204b..d1bc877d3387 100644 --- a/crates/llama-cpp-bindings/src/lib.rs +++ b/crates/llama-cpp-bindings/src/lib.rs @@ -114,22 +114,22 @@ impl AsyncTextInferenceEngine { ) -> BoxStream { let stop_condition = self.stop_condition_factory.create(prompt, options.language); - let (tx, mut rx) = channel::(4); - { - let mut engine = self.engine.lock().await; - - let mut request_id = self.next_request_id.lock().await; + let request_id = self.alloc_request_id().await; + let mut rx = { + let (tx, rx) = channel::(4); self.requests .lock() .await - .insert(*request_id, InferenceRequest { tx, stop_condition }); + .insert(request_id, InferenceRequest { tx, stop_condition }); + rx + }; + + { + let mut engine = self.engine.lock().await; engine .as_mut() .unwrap() - .add_request(*request_id, prompt, options.max_input_length); - - // 2048 should be large enough to avoid collision. - *request_id = (*request_id + 1) % 2048; + .add_request(request_id, prompt, options.max_input_length); } let s = stream! { @@ -147,6 +147,16 @@ impl AsyncTextInferenceEngine { Box::pin(s) } + + async fn alloc_request_id(&self) -> u32 { + let mut request_id = self.next_request_id.lock().await; + let ret: u32 = *request_id; + + // 2048 should be large enough to avoid collision. + *request_id = (*request_id + 1) % 2048; + + ret + } } #[derive(Builder, Debug)]