diff --git a/MARBLE/main.py b/MARBLE/main.py index 7adf9313..d612f673 100644 --- a/MARBLE/main.py +++ b/MARBLE/main.py @@ -231,6 +231,7 @@ def forward(self, data, n_id, adjs=None): # restrict to current batch x = x[n_id] + mask = mask[n_id] if data.kernels[0].size(0) == n * d: n_id = utils.expand_index(n_id, d) else: @@ -270,7 +271,7 @@ def forward(self, data, n_id, adjs=None): emb = self.enc(out) - return emb, mask[n_id][: size[1]] + return emb, mask[: size[1]] def evaluate(self, data): """Forward pass @ evaluation (no minibatches)""" diff --git a/examples/ex_vector_field_flat_surface.py b/examples/ex_vector_field_flat_surface.py index 99025c64..e785ea63 100644 --- a/examples/ex_vector_field_flat_surface.py +++ b/examples/ex_vector_field_flat_surface.py @@ -42,11 +42,11 @@ def main(): # train model params = { - "epochs": 50, # optimisation epochs + "epochs": 100, # optimisation epochs "order": 1, # first-order derivatives are enough because the vector field have at most first-order features - "hidden_channels": 16, # 16 is enough in this simple example + "hidden_channels": 32, # 16 is enough in this simple example "out_channels": 3, # 3 is enough in this simple example - "inner_product_features": True, # try changing this to False and see how the embeddings change + "inner_product_features": False, # try changing this to False and see how the embeddings change } model = net(data, params=params) model.run_training(data) @@ -62,7 +62,7 @@ def main(): plotting.fields(data, titles=titles, col=2) # plt.savefig('../results/fields.svg') plotting.embedding(data, data.y.numpy(), titles=titles, clusters_visible=True) - # plt.savefig('../results/embedding.svg') + plt.savefig('../results/embedding.svg') plotting.histograms(data, titles=titles) # plt.savefig('../results/histogram.svg') plotting.neighbourhoods(data) diff --git a/examples/log/events.out.tfevents.1688827441.SV-87M-007 b/examples/log/events.out.tfevents.1688827441.SV-87M-007 new file mode 100644 index 00000000..a6859620 Binary files /dev/null and b/examples/log/events.out.tfevents.1688827441.SV-87M-007 differ diff --git a/examples/log/events.out.tfevents.1688827497.SV-87M-007 b/examples/log/events.out.tfevents.1688827497.SV-87M-007 new file mode 100644 index 00000000..81f01ac9 Binary files /dev/null and b/examples/log/events.out.tfevents.1688827497.SV-87M-007 differ diff --git a/examples/log/events.out.tfevents.1688827544.SV-87M-007 b/examples/log/events.out.tfevents.1688827544.SV-87M-007 new file mode 100644 index 00000000..ec2476bc Binary files /dev/null and b/examples/log/events.out.tfevents.1688827544.SV-87M-007 differ diff --git a/examples/log/events.out.tfevents.1688931165.SV-87M-007 b/examples/log/events.out.tfevents.1688931165.SV-87M-007 new file mode 100644 index 00000000..af73eb10 Binary files /dev/null and b/examples/log/events.out.tfevents.1688931165.SV-87M-007 differ diff --git a/examples/log/events.out.tfevents.1688931288.SV-87M-007 b/examples/log/events.out.tfevents.1688931288.SV-87M-007 new file mode 100644 index 00000000..c697f0fe Binary files /dev/null and b/examples/log/events.out.tfevents.1688931288.SV-87M-007 differ diff --git a/examples/log/events.out.tfevents.1688931533.SV-87M-007 b/examples/log/events.out.tfevents.1688931533.SV-87M-007 new file mode 100644 index 00000000..4697b526 Binary files /dev/null and b/examples/log/events.out.tfevents.1688931533.SV-87M-007 differ diff --git a/examples/log/events.out.tfevents.1688931570.SV-87M-007 b/examples/log/events.out.tfevents.1688931570.SV-87M-007 new file mode 100644 index 00000000..292f1fd6 Binary files /dev/null and b/examples/log/events.out.tfevents.1688931570.SV-87M-007 differ diff --git a/examples/log/events.out.tfevents.1688931715.SV-87M-007 b/examples/log/events.out.tfevents.1688931715.SV-87M-007 new file mode 100644 index 00000000..73d64250 Binary files /dev/null and b/examples/log/events.out.tfevents.1688931715.SV-87M-007 differ diff --git a/examples/log/events.out.tfevents.1688931841.SV-87M-007 b/examples/log/events.out.tfevents.1688931841.SV-87M-007 new file mode 100644 index 00000000..d3a845a0 Binary files /dev/null and b/examples/log/events.out.tfevents.1688931841.SV-87M-007 differ diff --git a/examples/log/events.out.tfevents.1688932181.SV-87M-007 b/examples/log/events.out.tfevents.1688932181.SV-87M-007 new file mode 100644 index 00000000..59f950b1 Binary files /dev/null and b/examples/log/events.out.tfevents.1688932181.SV-87M-007 differ diff --git a/examples/log/events.out.tfevents.1697705047.SV-87M-007 b/examples/log/events.out.tfevents.1697705047.SV-87M-007 new file mode 100644 index 00000000..e69de29b diff --git a/examples/log/events.out.tfevents.1697706454.SV-87M-007 b/examples/log/events.out.tfevents.1697706454.SV-87M-007 new file mode 100644 index 00000000..e69de29b diff --git a/examples/log/events.out.tfevents.1697706489.SV-87M-007 b/examples/log/events.out.tfevents.1697706489.SV-87M-007 new file mode 100644 index 00000000..e69de29b diff --git a/examples/log/events.out.tfevents.1697706571.SV-87M-007 b/examples/log/events.out.tfevents.1697706571.SV-87M-007 new file mode 100644 index 00000000..bfd662a8 Binary files /dev/null and b/examples/log/events.out.tfevents.1697706571.SV-87M-007 differ