diff --git a/src/retrievals/models/rerank.py b/src/retrievals/models/rerank.py index fdb9c3e2..c94ba602 100644 --- a/src/retrievals/models/rerank.py +++ b/src/retrievals/models/rerank.py @@ -118,7 +118,7 @@ def forward( logger.warning('loss_fn is not setup, use BCEWithLogitsLoss') self.loss_fn = nn.BCEWithLogitsLoss(reduction='mean') - loss = self.loss_fn(logits, labels) + loss = self.loss_fn(logits, labels.float()) if return_dict: outputs_dict['loss'] = loss return outputs_dict