diff --git a/examples/model_demos/faenet_pyg.py b/examples/model_demos/faenet_pyg.py index 603a2475..85bb6c77 100644 --- a/examples/model_demos/faenet_pyg.py +++ b/examples/model_demos/faenet_pyg.py @@ -4,12 +4,10 @@ from matsciml.datasets.transforms import ( FrameAveraging, - GraphToGraphTransform, PointCloudToGraphTransform, - UnitCellCalculator, ) from matsciml.lightning.data_utils import MatSciMLDataModule -from matsciml.models.base import ScalarRegressionTask +from matsciml.models.base import ForceRegressionTask from matsciml.models.pyg import FAENet """ @@ -17,23 +15,23 @@ in combination with a PyG implementation of FAENet. """ -# construct IS2RE relaxed energy regression with PyG implementation of FAENet -task = ScalarRegressionTask( +# construct S2EF force regression task with PyG implementation of FAENet +task = ForceRegressionTask( encoder_class=FAENet, encoder_kwargs={ - "average_frame_embeddings": True, + "average_frame_embeddings": False, # set to false for use with FA transform "pred_as_dict": False, "hidden_dim": 128, - "out_dim": 64, + "out_dim": 128, "tag_hidden_channels": 0, }, - output_kwargs={"lazy": False, "input_dim": 64, "hidden_dim": 64}, - task_keys=["energy_relaxed"], + output_kwargs={"lazy": False, "input_dim": 128, "hidden_dim": 128}, + task_keys=["force"], ) # ### matsciml devset for OCP are serialized with DGL - this transform goes between the two frameworks dm = MatSciMLDataModule.from_devset( - "IS2REDataset", + "S2EFDataset", dset_kwargs={ "transforms": [ PointCloudToGraphTransform( @@ -49,42 +47,4 @@ # run a quick training loop trainer = pl.Trainer(fast_dev_run=10) -trainer.fit(task, datamodule=dm) - - -######################################################################################## -######################################################################################## - - -# construct Materials Project band gap regression with PyG implementation of FAENet -task = ScalarRegressionTask( - encoder_class=FAENet, - encoder_kwargs={ - "pred_as_dict": False, - "hidden_dim": 128, - "out_dim": 64, - "tag_hidden_channels": 0, - "input_dim": 128, - }, - output_kwargs={"lazy": False, "input_dim": 64, "hidden_dim": 64}, - task_keys=["band_gap"], -) - -dm = MatSciMLDataModule.from_devset( - "MaterialsProjectDataset", - dset_kwargs={ - "transforms": [ - UnitCellCalculator(), - PointCloudToGraphTransform( - "pyg", - cutoff_dist=20.0, - node_keys=["pos", "atomic_numbers"], - ), - FrameAveraging(frame_averaging="3D", fa_method="stochastic"), - ], - }, -) - -# run a quick training loop -trainer = pl.Trainer(fast_dev_run=10) -trainer.fit(task, datamodule=dm) +trainer.fit(task, datamodule=dm) \ No newline at end of file diff --git a/matsciml/models/base.py b/matsciml/models/base.py index 05095f9d..994103ab 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -1282,7 +1282,6 @@ def _compute_losses( losses[key] = self.loss_func(predictions[key], target_val) * ( coefficient / predictions[key].numel() ) - total_loss: torch.Tensor = sum(losses.values()) return {"loss": total_loss, "log": losses} @@ -1536,15 +1535,19 @@ def forward( pos: torch.Tensor = graph.ndata.get("pos") # for frame averaging fa_rot = graph.ndata.get("fa_rot", None) + fa_pos = graph.ndata.get("fa_pos", None) else: # otherwise assume it's PyG pos: torch.Tensor = graph.pos + # for frame averaging fa_rot = getattr(graph, "fa_rot", None) + fa_pos = getattr(graph, "fa_pos", None) else: # assume point cloud otherwise pos: torch.Tensor = batch.get("pos") # no frame averaging architecture yet for point clouds fa_rot = None + fa_pos = None if pos is None: raise ValueError( "No atomic positions were found in batch - neither as standalone tensor nor graph.", @@ -1557,11 +1560,16 @@ def forward( raise ValueError( f"'pos' data is required for force calculation, but isn't a tensor or a list of tensors: {type(pos)}.", ) + if isinstance(fa_pos, torch.Tensor): + fa_pos.requires_grad_(True) + elif isinstance(fa_pos, list): + [f_p.requires_grad_(True) for f_p in fa_pos] if "embeddings" in batch: embeddings = batch.get("embeddings") else: embeddings = self.encoder(batch) - outputs = self.process_embedding(embeddings, pos, fa_rot) + natoms = batch.get("natoms", None) + outputs = self.process_embedding(embeddings, pos, fa_rot, fa_pos, natoms) return outputs def process_embedding( @@ -1569,41 +1577,56 @@ def process_embedding( embeddings: Embeddings, pos: torch.Tensor, fa_rot: None | torch.Tensor = None, + fa_pos: None | torch.Tensor = None, + natoms: None | torch.Tensor = None, ) -> dict[str, torch.Tensor]: outputs = {} - energy = self.output_heads["energy"](embeddings.system_embedding) - # now use autograd for force calculation - force = ( - -1 - * torch.autograd.grad( - energy, - pos, - grad_outputs=torch.ones_like(energy), - create_graph=True, - )[0] - ) + + def energy_and_force( + pos: torch.Tensor, system_embedding: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + energy = self.output_heads["energy"](system_embedding) + # now use autograd for force calculation + force = ( + -1 + * torch.autograd.grad( + energy, + pos, + grad_outputs=torch.ones_like(energy), + create_graph=True, + )[0] + ) + return energy, force + + if fa_pos is None: + energy, force = energy_and_force(pos, embeddings.system_embedding) + else: + energy = [] + force = [] + for idx, pos in enumerate(fa_pos): + frame_embedding = embeddings.system_embedding[:, idx, :] + frame_energy, frame_force = energy_and_force(pos, frame_embedding) + force.append(frame_force) + energy.append(frame_energy) + # check to see if we are frame averaging - if isinstance(fa_rot, torch.Tensor): - natoms = pos.size(0) + if fa_rot is not None: all_forces = [] # loop over each frame prediction, and transform to guarantee # equivariance of frame averaging method - for frame_idx, frame_rot in fa_rot: + natoms = natoms.squeeze(-1).to(int) + for frame_idx, frame_rot in enumerate(fa_rot): repeat_rot = torch.repeat_interleave( frame_rot, natoms, dim=0, ).to(self.device) rotated_forces = ( - force[:, frame_idx, :] - .view(-1, 1, 3) - .bmm( - repeat_rot.transpose(1, 2), - ) + force[frame_idx].view(-1, 1, 3).bmm(repeat_rot.transpose(1, 2)) ) - all_forces.append(rotated_forces.view(natoms, 3)) + all_forces.append(rotated_forces) # combine all the force data into a single tensor - force = torch.stack(all_forces, dim=1) + force = torch.cat(all_forces, dim=1) # reduce outputs to what are expected shapes outputs["force"] = reduce( force, diff --git a/matsciml/models/pyg/faenet/faenet.py b/matsciml/models/pyg/faenet/faenet.py index b02572b5..d893f89e 100644 --- a/matsciml/models/pyg/faenet/faenet.py +++ b/matsciml/models/pyg/faenet/faenet.py @@ -2,6 +2,7 @@ FAENet: Frame Averaging Equivariant graph neural Network Simple, scalable and expressive model for property prediction on 3D atomic systems. """ + from __future__ import annotations from copy import deepcopy @@ -354,17 +355,8 @@ def first_forward( Returns: (dict): predicted energy, forces and final atomic hidden states """ - if self.training: - mode = "train" - else: - mode = "inference" - preproc = True + preproc = False data = graph - - # energy gradient w.r.t. positions will be computed - if mode == "train" or self.regress_forces == "from_energy": - data.pos.requires_grad_(True) - # produce final embeddings after going through model embeddings = self.energy_forward(data, preproc) return embeddings