From e3602cbc6338864e4313f794d8e3e76dc02affc7 Mon Sep 17 00:00:00 2001 From: Apostolos Ntelopoulos Date: Mon, 20 May 2024 23:13:12 +0200 Subject: [PATCH] fixed runtime device error --- eval_retrieval_video.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/eval_retrieval_video.py b/eval_retrieval_video.py index 07ebab7f..3425ab27 100644 --- a/eval_retrieval_video.py +++ b/eval_retrieval_video.py @@ -74,8 +74,8 @@ def evaluation(model, data_loader, tokenizer, device, config): video_feats.append(video_feat.cpu()) video_embeds.append(video_embed) - video_feats = torch.cat(video_feats,dim=0) - video_embeds = torch.cat(video_embeds,dim=0) + video_feats = torch.cat(video_feats,dim=0).to(device) + video_embeds = torch.cat(video_embeds,dim=0).to(device) sims_matrix = video_embeds @ text_embeds.t() score_matrix_v2t = torch.full((len(texts),len(texts)),-100.0).to(device) @@ -110,6 +110,8 @@ def evaluation(model, data_loader, tokenizer, device, config): for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)): topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0) + topk_idx = topk_idx.to(device,non_blocking=True) + topk_sim = topk_sim.to(device,non_blocking=True) encoder_output = video_feats[topk_idx].to(device,non_blocking=True) encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device,non_blocking=True) output = model.text_encoder(text_ids[start+i].repeat(config['k_test'],1),