Skip to content

Commit

Permalink
bug fix with masking
Browse files Browse the repository at this point in the history
  • Loading branch information
agosztolai committed Oct 19, 2023
1 parent 61f3c92 commit 175dff4
Show file tree
Hide file tree
Showing 16 changed files with 6 additions and 5 deletions.
3 changes: 2 additions & 1 deletion MARBLE/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def forward(self, data, n_id, adjs=None):

# restrict to current batch
x = x[n_id]
mask = mask[n_id]
if data.kernels[0].size(0) == n * d:
n_id = utils.expand_index(n_id, d)
else:
Expand Down Expand Up @@ -270,7 +271,7 @@ def forward(self, data, n_id, adjs=None):

emb = self.enc(out)

return emb, mask[n_id][: size[1]]
return emb, mask[: size[1]]

def evaluate(self, data):
"""Forward pass @ evaluation (no minibatches)"""
Expand Down
8 changes: 4 additions & 4 deletions examples/ex_vector_field_flat_surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ def main():

# train model
params = {
"epochs": 50, # optimisation epochs
"epochs": 100, # optimisation epochs
"order": 1, # first-order derivatives are enough because the vector field have at most first-order features
"hidden_channels": 16, # 16 is enough in this simple example
"hidden_channels": 32, # 16 is enough in this simple example
"out_channels": 3, # 3 is enough in this simple example
"inner_product_features": True, # try changing this to False and see how the embeddings change
"inner_product_features": False, # try changing this to False and see how the embeddings change
}
model = net(data, params=params)
model.run_training(data)
Expand All @@ -62,7 +62,7 @@ def main():
plotting.fields(data, titles=titles, col=2)
# plt.savefig('../results/fields.svg')
plotting.embedding(data, data.y.numpy(), titles=titles, clusters_visible=True)
# plt.savefig('../results/embedding.svg')
plt.savefig('../results/embedding.svg')
plotting.histograms(data, titles=titles)
# plt.savefig('../results/histogram.svg')
plotting.neighbourhoods(data)
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Empty file.
Empty file.
Empty file.
Binary file not shown.

0 comments on commit 175dff4

Please sign in to comment.