Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fail to train on gen2 dataset #159

Open
Guackkk opened this issue Jul 3, 2023 · 1 comment
Open

Fail to train on gen2 dataset #159

Guackkk opened this issue Jul 3, 2023 · 1 comment
Labels
question ❓ Further information is requested reproducibility 🔬 Question about how to reproduce something

Comments

@Guackkk
Copy link

Guackkk commented Jul 3, 2023

Hi !
I wanted to reproduce the results in espaloma and followed the training method suggested here in https://espaloma.wangyq.net/experiments/qm_fitting.html.
But the error occur when I wanted to calculate the training set performance :

Traceback ( most recent call last): File "train_gen2.py", line 86, in <module> u = torch.cat(u, dim=0) RuntimeError: Sizes of tensors must match except in dimension8. Expected size 42 but got size 21 for tensor number 1 in the list

I used the same code for training set performance and validation set performance as the one in https://espaloma.wangyq.net/experiments/qm_fitting.html :
`with torch.no_grad():
for idx_epoch in range(10000):
espaloma_model.load_state_dict(
torch.load("%s.th" % idx_epoch)
)

    # training set performance
    u = []
    u_ref = []
    for g in ds_tr:
        if torch.cuda.is_available():
            g.heterograph = g.heterograph.to("cuda:0")
        espaloma_model(g.heterograph)
        u.append(g.nodes['g'].data['u'])
        u_ref.append(g.nodes['g'])
    u = torch.cat(u, dim=0)
    u_ref = torch.cat(u_ref, dim=0)
    loss_tr.append(inspect_metric(u, u_ref))


    # validation set performance
    u = []
    u_ref = []
    for g in ds_vl:
        if torch.cuda.is_available():
            g.heterograph = g.heterograph.to("cuda:0")
        espaloma_model(g.heterograph)
        u.append(g.nodes['g'].data['u'])
        u_ref.append(g.nodes['g'])
    u = torch.cat(u, dim=0)
    u_ref = torch.cat(u_ref, dim=0)
    loss_vl.append(inspect_metric(u, u_ref))`

Thanks for any help given.

@mikemhenry mikemhenry added question ❓ Further information is requested reproducibility 🔬 Question about how to reproduce something labels Jul 21, 2023
@mikemhenry
Copy link
Contributor

We have made some API changes lately, I will defer to @yuanqing-wang and @kntkb to take a look at this question. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question ❓ Further information is requested reproducibility 🔬 Question about how to reproduce something
Projects
None yet
Development

No branches or pull requests

2 participants