Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The memory overflow #10

Open
Yongquan-He opened this issue Oct 3, 2021 · 5 comments
Open

The memory overflow #10

Yongquan-He opened this issue Oct 3, 2021 · 5 comments

Comments

@Yongquan-He
Copy link

I use your code as a part in our experiment.
But there is a problem about the test_graph.
Because the code uses all train triplets to build graph during valid and test, so when I test the model, 64G memory is not enough.
So I wish you to give me some suggestions.
Thank you very much!

@ShellingFord221
Copy link

ShellingFord221 commented Oct 10, 2021

Hi, I have also encountered the problem of out of memory. It seems that when there are too many triplets, massive CPU memory is needed:
image

In my experiments, information of the graph is as follows:
image

It seems that the only way to avoid OOM is to reduce the number of triplets, then there will be less edge_type indexed in w:
image

@zmce2018
Copy link

zmce2018 commented Mar 1, 2022

I use your code as a part in our experiment. But there is a problem about the test_graph. Because the code uses all train triplets to build graph during valid and test, so when I test the model, 64G memory is not enough. So I wish you to give me some suggestions. Thank you very much!

Hi He, Could you please tell me the way you fix this overflow issue? Many thanks in advance.

@workdesk96
Copy link

workdesk96 commented Jul 26, 2023

One solution seems to be to modify valid() to include batching for evaluation phase (& modify calc_mrr accordingly to return hits also):

def valid(valid_triplets, model, test_graph, all_triplets, batch_size=1024):
    with torch.no_grad():
        model.eval()
        mrr = 0
        hits = {1: 0, 3: 0, 10: 0}
        for i in range(0, len(valid_triplets), batch_size):
            batch_valid_triplets = valid_triplets[i:i+batch_size]
            entity_embedding = model(test_graph.entity, test_graph.edge_index, test_graph.edge_type, test_graph.edge_norm)
            mrr_b, hits_bdict = calc_mrr(entity_embedding, model.relation_embedding, batch_valid_triplets, all_triplets, hits=[1, 3, 10])
            mrr+=mrr_b
            hits[1]+=hits_bdict[1]
            hits[3]+=hits_bdict[3]
            hits[10]+=hits_bdict[10]
        mrr /= (len(valid_triplets) // batch_size)
        hits[1] /= (len(valid_triplets) // batch_size)
        hits[3] /= (len(valid_triplets) // batch_size)
        hits[10] /= (len(valid_triplets) // batch_size)
        print(f'MRR: {mrr}, Hits@1: {hits[1]}, Hits@3: {hits[3]}, Hits@10: {hits[10]}')
    return mrr

@workdesk96
Copy link

The above however does not seem to work for FB15k-237. Could the source of the issue be this line: https://github.com/JinheonBaek/RGCN/blob/818bf70b00d5cd178a7496a748e4f18da3bcde82/main.py#L25C41-L25C47

@workdesk96
Copy link

workdesk96 commented Jul 29, 2023

In case it helps, here is the memory profiling for the message function during training & during validation.

During Training:

Line #    Mem usage    Increment  Occurrences   Line Contents                                                                                                                                  =============================================================
   188   1904.4 MiB   1904.4 MiB           1       @profile
   189                                             def message(self, x_j, edge_index_j, edge_type, edge_norm):
   190                                                 """
   191                                                 """
   192
   193                                                 # Call the function that might be causing the memory overflow
   194   1904.4 MiB      0.0 MiB           1           w = torch.matmul(self.att, self.basis.view(self.num_bases, -1))
   195
   196                                                 # If no node features are given, we implement a simple embedding
   197                                                 # loopkup based on the target node index and its edge type.                                                                                198   1904.4 MiB      0.0 MiB           1           if x_j is None:
   199                                                     w = w.view(-1, self.out_channels)
   200                                                     index = edge_type * self.in_channels + edge_index_j
   201                                                     out = torch.index_select(w, 0, index)
   202                                                 else:
   203   1904.4 MiB      0.0 MiB           1               w = w.view(self.num_rel, self.in_chan, self.out_chan)
   204   3047.9 MiB   1143.5 MiB           1               w = torch.index_select(w, 0, edge_type)
   205   3047.9 MiB      0.0 MiB           1               out = torch.bmm(x_j.unsqueeze(1), w).squeeze(-2)
   206
   207   3047.9 MiB      0.0 MiB           1           if edge_norm is not None:
   208   3047.9 MiB      0.0 MiB           1               out = out * edge_norm.view(-1, 1)                                                                                                      209
   210   3047.9 MiB      0.0 MiB           1           return out

During Validation:

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
   188    844.2 MiB    844.2 MiB           1       @profile
   189                                             def message(self, x_j, edge_index_j, edge_type, edge_norm):
   190                                                 """
   191                                                 """
   192
   193                                                 # Call the function that might be causing the memory overflow
   194    844.2 MiB      0.0 MiB           1           w = torch.matmul(self.att, self.basis.view(self.num_bases, -1))
   195
   196                                                 # If no node features are given, we implement a simple embedding
   197                                                 # loopkup based on the target node index and its edge type.                                                                                198    844.2 MiB      0.0 MiB           1           if x_j is None:
   199                                                     w = w.view(-1, self.out_channels)
   200                                                     index = edge_type * self.in_channels + edge_index_j
   201                                                     out = torch.index_select(w, 0, index)                                                                                                  202                                                 else:
   203    844.2 MiB      0.0 MiB           1               w = w.view(self.num_rel, self.in_chan, self.out_chan)
   204  11635.1 MiB  10790.8 MiB           1               w = torch.index_select(w, 0, edge_type)                                                                                                205  11743.1 MiB    108.0 MiB           1               out = torch.bmm(x_j.unsqueeze(1), w).squeeze(-2)
   206
   207  11743.1 MiB      0.0 MiB           1           if edge_norm is not None:
   208  11743.2 MiB      0.2 MiB           1               out = out * edge_norm.view(-1, 1)                                                                                                      209                                                                                                                                                                                            
   210  11743.2 MiB      0.0 MiB           1           return out

It appears that the memory overflow happens specifically during validation because the size of edge_type is large during validation compared to training.
During Training:

Size of edge_type 30000

During Validation:

Size of edge_type 282884

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants