Skip to content

Commit

Permalink
refactor: making forward pass only set embeddings to batch, not reusi…
Browse files Browse the repository at this point in the history
…ng them

This addresses the issue of computational graphs breaking
  • Loading branch information
laserkelvin committed Jul 1, 2024
1 parent 515e4b8 commit 7f726b9
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions matsciml/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,11 +842,8 @@ def forward(
self,
batch: dict[str, torch.Tensor | dgl.DGLGraph | dict[str, torch.Tensor]],
) -> dict[str, torch.Tensor]:
if "embeddings" in batch:
embeddings = batch.get("embeddings")
else:
embeddings = self.encoder(batch)
batch["embeddings"] = embeddings
embeddings = self.encoder(batch)
batch["embeddings"] = embeddings
outputs = self.process_embedding(embeddings)
return outputs

Expand Down

0 comments on commit 7f726b9

Please sign in to comment.