Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

测试时为啥候选物品的embedding 不需要过Model模块的fc呢,也即MLP模块,而是直接使用物品的modalembedding或者说id embedding #7

Open
Jackie-gj opened this issue Dec 28, 2023 · 2 comments

Comments

@Jackie-gj
Copy link

    self.fc = MLP_Layers(word_embedding_dim=num_fc_ftr,
                         item_embedding_dim=args.embedding_dim,
                         layers=[args.embedding_dim] * (args.dnn_layer + 1),
                         drop_rate=args.drop_rate)这段代码在训练是对输入的embedding进行了转换,然后再与候选的正负样本计算相似度以及BCE损失, 在模型预测时,为啥是直接使用item_embeddings而不需要经过上面得MLP_Layers呢?
item_embeddings = item_embeddings.to(local_rank)
with torch.no_grad():
    eval_all_user = []
    item_rank = torch.Tensor(np.arange(item_num) + 1).to(local_rank)
    for data in eval_dl:
        user_ids, input_embs, log_mask, labels = data
        user_ids, input_embs, log_mask, labels = \
            user_ids.to(local_rank), input_embs.to(local_rank),\
            log_mask.to(local_rank), labels.to(local_rank).detach()
        prec_emb = model.module.user_encoder(input_embs, log_mask, local_rank)[:, -1].detach()
        scores = torch.matmul(prec_emb, item_embeddings.t()).squeeze(dim=-1).detach()
@zyuanat
Copy link
Contributor

zyuanat commented Feb 2, 2024

感谢提问,抱歉这么晚回复
所提到的item embedding其实是已经通过model模块的fc层的,在metics.py文件的get_item_embeddings()函数中的 “item_emb = model.module.fc(item_emb)”这一行:

def get_item_embeddings(model, item_image_embs, test_batch_size, args, local_rank):
    model.eval()
    item_dataset = Build_Id_Eval_Dataset(data=item_image_embs)
    item_dataloader = DataLoader(item_dataset, batch_size=test_batch_size,
                                 num_workers=args.num_workers, pin_memory=True)
    item_embeddings = []
    with torch.no_grad():
        for input_ids in item_dataloader:
            item_emb = input_ids.to(local_rank)
            item_emb = model.module.fc(item_emb)
            item_embeddings.extend(item_emb)
    return torch.stack(tensors=item_embeddings, dim=0).to(torch.device("cpu")).detach()

@Jackie-gj
Copy link
Author

Jackie-gj commented Feb 23, 2024 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants