Skip to content

Commit

Permalink
Update model_vqa.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Junnan Li authored Jul 5, 2022
1 parent 6224e78 commit fb38420
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion models/model_vqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def forward(self, image, quesiton, answer=None, alpha=0, k=None, weights=None, t
labels = answer_targets,
return_dict = True,
soft_labels = F.softmax(logits_m,dim=-1),
alpha = alpha,
reduction = 'none',
)
else:
Expand Down Expand Up @@ -210,4 +211,4 @@ def tile(x, dim, n_tile):
repeat_idx[dim] = n_tile
x = x.repeat(*(repeat_idx))
order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
return torch.index_select(x, dim, order_index.to(x.device))
return torch.index_select(x, dim, order_index.to(x.device))

0 comments on commit fb38420

Please sign in to comment.