From a33692e070284121934173032c0cfcad5d1a3781 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Tue, 13 Feb 2024 08:11:31 +0000 Subject: [PATCH] Put `ancestors` on same device as `next_token_logits` (#651) Fixes https://github.com/outlines-dev/outlines/issues/649 --------- Co-authored-by: Andrew Lapp --- outlines/samplers.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/outlines/samplers.py b/outlines/samplers.py index 393b16af2..71b1592af 100644 --- a/outlines/samplers.py +++ b/outlines/samplers.py @@ -66,7 +66,9 @@ def __call__( logprobs = torch.nn.functional.log_softmax(next_token_logits, dim=-1) next_token_ids = torch.argmax(logprobs, dim=-1, keepdim=True) - ancestors = torch.arange(next_token_logits.shape[0]) + ancestors = torch.arange( + next_token_logits.shape[0], device=next_token_logits.device + ) weights = sequence_weights + torch.gather(logprobs, 1, next_token_ids).squeeze() return next_token_ids, ancestors, weights @@ -144,7 +146,9 @@ def __call__( next_token_ids = torch.multinomial(probs, num_samples=1, generator=rng) logprobs = torch.nn.functional.log_softmax(altered_next_token_logits, dim=-1) - ancestors = torch.arange(altered_next_token_logits.shape[0]) + ancestors = torch.arange( + altered_next_token_logits.shape[0], device=next_token_logits.device + ) weights = sequence_weights + torch.gather(logprobs, 1, next_token_ids).squeeze() return next_token_ids, ancestors, weights @@ -292,7 +296,7 @@ def __call__( # Re-shape the weights, next_token_ids and ancestors to (n_batch * n_samples, 1) first_batch_idx = torch.arange( - 0, batch_size * self.samples, self.samples + 0, batch_size * self.samples, self.samples, device=next_token_logits.device ).unsqueeze(1) ancestors = ancestors + first_batch_idx