Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KeyError: 'id' is reported when I run the pretrain.py #55

Open
GrayChan04 opened this issue Feb 28, 2022 · 12 comments
Open

KeyError: 'id' is reported when I run the pretrain.py #55

GrayChan04 opened this issue Feb 28, 2022 · 12 comments

Comments

@GrayChan04
Copy link

excuse me,does anyone know how to deal?
batched_graph.ndata['h'] = ent_embeds[batched_graph.ndata['id']].view(-1, ent_embeds.shape[1]) File "/home/xxxx/anaconda3/envs/renet/lib/python3.6/site-packages/dgl/view.py", line 60, in __getitem__ return self._graph.get_n_repr(self._nodes)[key] KeyError: 'id'
it seems like something wrong with code
batched_graph.ndata['h'] = ent_embeds[batched_graph.ndata['id']].view(-1, ent_embeds.shape[1])
but I don‘t know how to deal with the problem

@GrayChan04
Copy link
Author

I change GDELT to ICEWS14, however still the same problem...
according to readme.md, I configuration the environment, but i still CANNOT run the code.
do I miss something?

$ python3 pretrain.py -d ICEWS14 --gpu 0 --dropout 0.5 --n-hidden 200 --lr 1e-3 --max-epochs 20 --batch-size 1024
Using backend: pytorch
Namespace(batch_size=1024, dataset='ICEWS14', dropout=0.5, gpu=0, grad_norm=1.0, lr=0.001, max_epochs=20, maxpool=1, model=3, n_hidden=200, num_k=10, rnn_layers=1, seq_len=10)
start training...
/home/xxxx/RE-Net-cp/Aggregator.py:32: UserWarning: This overload of nonzero is deprecated:
nonzero(Tensor input, *, Tensor out)
Consider using one of the following signatures instead:
nonzero(Tensor input, *, bool as_tuple) (Triggered internally at /pytorch/torch/csrc/utils/python_arg_parser.cpp:766.)
num_non_zero = len(torch.nonzero(t_list))
/home/xxxx/RE-Net-cp/utils.py:290: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.
return torch.mean(torch.sum(- soft_targets * logsoftmax(pred), 1))
Traceback (most recent call last):
File "pretrain.py", line 139, in
train(args)
File "pretrain.py", line 92, in train
model.global_emb = model.get_global_emb(train_times_origin, graph_dict)
File "/home/xxxx/RE-Net-cp/global_model.py", line 67, in get_global_emb
emb, _, _ = self.predict(t, graph_dict)
File "/home/xxxx/RE-Net-cp/global_model.py", line 88, in predict
rnn_inp = self.aggregator.predict(t, self.ent_embeds, graph_dict, reverse=reverse)
File "/home/xxxx/RE-Net-cp/Aggregator.py", line 97, in predict
batched_graph.ndata['h'] = ent_embeds[batched_graph.ndata['id']].view(-1, ent_embeds.shape[1])
File "/home/xxxx/anaconda3/envs/renet/lib/python3.6/site-packages/dgl/view.py", line 60, in getitem
return self._graph.get_n_repr(self._nodes)[key]
KeyError: 'id'

@Zhai-Huichen
Copy link

batched_graph.ndata['h'] = ent_embeds[batched_graph.ndata['id']].view(-1, ent_embeds.shape[1])
应该是batched_graph在cpu上,用gpu上的东西给它赋值出错了,在出错那一行之前把batched_graph转到gpu里

@GrayChan04
Copy link
Author

已经转到gpu了 但是还是报同样的错

@Zhai-Huichen
Copy link

你的get_history_graph.py里,方法get_data_with_t有改过吗?
x = data[np.where(data[3] == tim)].copy()改成x = data[np.where(data[:, 3] == tim)].copy()
我用的ICEWS18跑的,14好像有问题

已经转到gpu了 但是还是报同样的错

@Zhai-Huichen
Copy link

你的get_history_graph.py里,方法get_data_with_t有改过吗?
x = data[np.where(data[3] == tim)].copy()改成x = data[np.where(data[:, 3] == tim)].copy()
我用的ICEWS18跑的,14好像有问题

@GrayChan04
Copy link
Author

你的get_history_graph.py里,方法get_data_with_t有改过吗? x = data[np.where(data[3] == tim)].copy()改成x = data[np.where(data[:, 3] == tim)].copy() 我用的ICEWS18跑的,14好像有问题

我没有改 这里的代码我觉得没有问题

@WWWindrunner
Copy link

你的get_history_graph.py里,方法get_data_with_t有改过吗? x = data[np.where(data[3] == tim)].copy()改成x = data[np.where(data[:, 3] == tim)].copy() 我用的ICEWS18跑的,14好像有问题

我没有改 这里的代码我觉得没有问题

这个方法你打印过中间结果吗,我之前遇到['id']报错,就是因为这里执行不对,导致batched_graph为空。
我改为了
triples = [[quad[0], quad[1], quad[2]] for quad in data if quad[3] == tim]
就ok了

@MrLiuCC
Copy link

MrLiuCC commented Mar 31, 2022

你的get_history_graph.py里,方法get_data_with_t有改过吗? x = data[np.where(data[3] == tim)].copy()改成x = data[np.where(data[:, 3] == tim)].copy() 我用的ICEWS18跑的,14好像有问题

ICEWS14里面没有valid.txt

@MrLiuCC
Copy link

MrLiuCC commented Mar 31, 2022

图片
请问这部分是这样子改的吗?

@WWWindrunner
Copy link

WWWindrunner commented Apr 5, 2022 via email

@s821220
Copy link

s821220 commented Jun 20, 2022

请问最后解决了吗??解决的方案是什么

@Young0222
Copy link

Young0222 commented Jul 20, 2022

我有解决这个问题,请参考:
batched_graph.ndata['h'] = ent_embeds[batched_graph.ndata['id']].view(-1, ent_embeds.shape[1]).to('cpu')
batched_graph = batched_graph.to('cuda:0')
另外,删除move_dgl_to_cuda(batched_graph)这一行

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants