Graph Neural Networks (GNNs) provide a powerful and flexible framework for simulating complex physical systems, including fluid dynamics, by representing particles as nodes and their interactions as edges. This approach reduces reliance on computationally expensive physical models, such as those used in traditional computational fluid dynamics (CFD), which often require solving complex partial differential equations in real time. Our implementation is based on DeepMind's Learning to Simulate paper. To validate our approach, we use the WaterDrop dataset, originally generated with the Material Point Method (MPM). While our GNN models produce fluid behavior that is visually plausible and captures key dynamics such as splashes and waves, the resulting simulations are less dynamic and somewhat “heavier” than the ground truth. Despite this limitation, the learned model remains more computationally efficient than traditional simulators, making it a promising tool for real-time applications such as gaming and virtual reality.
Graph Neural Networks (GNNs) are a relatively new deep learning methodology. Proposed in 2009 by Scarselli et al. they introduced a new way to structure deep learning problems. Specifically, they allow data to be less structured than known methods. Non-graph neural networks typically expect data to be in laid out in a sequential format, text or image based inputs for example. With this new graph representation, learning things like physics simulators becomes a feasible problem. In this report we look into using GNNs to simulate fluids, also classically known as computational fluid dynamics (CFD).
Standard physics simulators, specifically CFDs, use complex math which is computationally heavy when compared with a learned model. Approaches like smoothed particle hydrodynamics (SPH), and the material point method (MPM) all require solving complex differential equations or similar formulas which directly model fluid [9][10]. This mathematical complexity means that it is difficult to run these simulations in real time with a large number of particles.
GNNs utilize a graph structure to represent particles or other objects as nodes and edges to represent the interactions between them. These edge based interactions are then learned through message passing, which is a function of combining node features and edge features in some way. After the message passing occurs, the features are passed first to an encoder which projects the features into a higher dimensional space in order to learn the underlying interactions better. After the features are encoded, they go through a processing phase where the actual interactions are learned. Finally, the processor is decoded through a layer of a node MLP bringing predictions back down in dimensionality.
Using GNNs, any simulator which outputs position vectors is able to be learned. While the learned predictions may not be as physically accurate, there are many applications that don't require quantitatively true results, such as movies and video games. The time reduction from using a learned simulation over a computed one is a worthwhile tradeoff when it comes to real time applications.
Graph Neural Networks have many applications outside of just physics simulations including: natural language processing [7], recommender systems [3], and even drug discovery [6]. We chose to base our work off of the Google DeepMind paper Learning to Simulate [1] but there are many other papers which explore phyiscs simulations with graph neural networks, including: [2], [11]. We chose to implement the DeepMind paper specifically, because their method generalizes well and isn't restricted to a specific problem type. It can be extrapolated to work with many different kinds of substances and can be extended to three dimensions. As written in the paper "we found our single GNS model performed well across dozens of experiments and was generally robust to hyperparameter choices" [1].
We used the dataset generated by Google DeepMind in order to ensure consistent results, but our method can be run with any dataset consisting of a set of simulations which are made up of position vectors (positions
), and a list of particle types (particle_types
) and then saved as an .npz format. Specifically, we use this format, which is based off of the Geoelements [13] implementation of a GNN for physical simulations. We chose this encoding type because it was easier to work with than the .tfrecord format that the original paper uses.
simulation_data = {
"simulation_0": (
positions,
particle_types,
)
"simulation_1": (
...
),
...
"simulation_n": (
...
)
}
The specific dataset we used is called "WaterDrop" [12] and was generated using MPM, it consists of up to 1000 randomly generated particles, across 1000 timesteps for 1000 simulations, each with differing initial conditions. We made a custom data loader to convert the original dataset into our new .npz format.
def convert_to_npz(file_path, output_path):
context_features = {
"key": tf.io.FixedLenFeature([], tf.int64, default_value=0),
"particle_type": tf.io.VarLenFeature(tf.string),
}
sequence_features = {
"position": tf.io.VarLenFeature(tf.string),
}
simulation_count = 0
simulation_data = {}
raw_dataset = tf.data.TFRecordDataset(file_path)
for raw_record in tqdm(raw_dataset):
timesteps = []
try:
context, parsed_features = tf.io.parse_single_sequence_example(
raw_record,
context_features=context_features,
sequence_features=sequence_features,
)
if "particle_type" not in context or not context["particle_type"].values:
continue
particle_types = context["particle_type"]
particle_types = np.frombuffer(
particle_types.values[0].numpy(), dtype=np.int64
)
for feature, value in parsed_features.items():
if not value.values:
continue
for i in range(len(value.values)):
positions = np.frombuffer(value.values[i].numpy(), dtype=np.float32)
if len(positions) % 2 != 0:
print(f"Skipping malformed position data in feature {feature}")
continue
rows = len(positions) // 2
positions = positions.reshape(rows, 2)
timesteps.append(positions)
simulation_data[f"simulation_{simulation_count}"] = (
timesteps,
particle_types,
)
simulation_count += 1
except Exception as e:
print(f"Error processing record: {e}")
np.savez(output_path, simulation_data=simulation_data)
The simulation of the waterdrop dataset looks like this:
The model architecture is based on the Learning to Simulate paper [1], a specialized Graph Neural Network (GNN) designed to model interactions within a complex system of interconnected particles. The architecture is strategically structured into three components: the Encoder, Interaction Network, and Decoder, each serving a critical role in processing and understanding particle dynamics.
The encoder is responsible for transforming raw input features into a high-dimensional latent space, facilitating more complex and meaningful representations for subsequent processing. The input features are categorized into Node features and Edge features:
-
Node Features - These include information specific to each particle, such as particle type and the velocity sequence. The
Encode_NodeMLP
module processes these features through a series of LinearLayer's with LeakyReLU activations and ending with a Layer Normalization, resulting in high-dimensional node embeddings. -
Edge Features - These capture the relationships between particles, including displacement and normalized distances. The
Encode_EdgeMLP
module similarly processes edge features through a Multi-Layer Perceptron (MLP), producing high-dimensional edge embeddings.
class Encode_NodeMLP(torch.nn.Module):
def __init__(self, window_size=5):
super().__init__()
input_size = 16 + 2 * window_size
# Two-hidden-layers
self.layers = nn.Sequential(
nn.Linear(input_size, 128),
nn.LeakyReLU(0.01),
nn.Linear(128, 128),
nn.LeakyReLU(0.01),
nn.Linear(128, 128),
nn.LeakyReLU(0.01),
nn.Linear(128, 128),
nn.LayerNorm(128),
)
def forward(self, x):
return self.layers(x)
class Encode_EdgeMLP(torch.nn.Module):
def __init__(self):
super().__init__()
# Two-hidden-layers
self.layers = nn.Sequential(
nn.Linear(3, 128), # 2D + relative distance
nn.LeakyReLU(0.01),
nn.Linear(128, 128),
nn.LeakyReLU(0.01),
nn.Linear(128, 128),
nn.LeakyReLU(0.01),
nn.Linear(128, 128),
nn.LayerNorm(128),
)
def forward(self, x):
return self.layers(x)
The Interaction Network serves as the backbone of the model, leveraging Message Passing to learn and model the interactions between particles. It consists of multiple stacked submodules, each representing a layer of message passing. The primary components and processes within the Interaction Network include:
-
Message Passing Mechanism: Implemented via the
InteractionNetwork
class, which extends PyTorch Geometric's MessagePassing class. In each message-passing layer- Message Function: For each edge, messages are computed by concatenating the features of the source and target nodes with the edge features. These concatenated features are then passed through the
Processor_EdgeMLP
to generate edge-specific messages. - Aggregation Function: Messages from all incoming edges to a node are aggregated (summed) using the
torch_scatter.scatter
function. - Update Function: The aggregated messages are concatenated with the node's existing features and processed through the
Processor_NodeMLP
to update the node's embeddings.
- Message Function: For each edge, messages are computed by concatenating the features of the source and target nodes with the edge features. These concatenated features are then passed through the
-
Stacking Layers: Multiple Interaction Network layers (as defined by
gnn_layers
) are stacked to allow the model to capture higher-order interactions and dependencies between particles over several iterations of message passing.
class InteractionNetwork(pyg.nn.MessagePassing):
def __init__(self):
super().__init__()
self.NodeMLP = Processor_NodeMLP()
self.EdgeMLP = Processor_EdgeMLP()
def forward(self, x, edge_index, edge_feature):
edge_out, aggr = self.propagate(edge_index, x=(x, x), edge_feature=edge_feature)
node_out = self.NodeMLP(torch.cat((x, aggr), dim=-1))
edge_out = edge_feature + edge_out
node_out = x + node_out
return node_out, edge_out
def message(self, x_i, x_j, edge_feature):
x = torch.cat((x_i, x_j, edge_feature), dim=-1)
x = self.EdgeMLP(x)
return x
def aggregate(self, inputs, index, dim_size=None):
out = torch_scatter.scatter(
inputs, index, dim=self.node_dim, dim_size=dim_size, reduce="sum"
)
return (inputs, out)
The decoder translates the high-dimensional node embeddings produced by the Interaction Network back into the target space, which in this application consists of the x and y acceleration values for each particle. This is achieved through the Decoder_MLP
module, which comprises a series of linear layers with LeakyReLU activations, and a final layer that outputs the 2D acceleration values.
Modeling particle interactions effectively requires structuring the input data as a graph, where each node corresponds to an individual particle and edges encode pairwise relationships based on spatial proximity. The following subsections detail the processes involved in preparing this graph representation and the associated input features.
- Function:
generate_noise
- Purpose: Enhances model robustness by adding Gaussian noise to particle velocities, counteracting the accumulation of prediction errors during simulation rollouts.
- Process:
- Noise Creation: Gaussian noise, scaled by the noise standard deviation, is generated and accumulated over time to form a noisy velocity trajectory.
- Position Adjustment: Particle positions are recomputed to remain consistent with the newly perturbed velocity sequence, ensuring that the noisy velocities accurately reflect the particle trajectories.
- Function:
to_graph
- Purpose: Transforms raw trajectory data into a graph format suitable for processing by the GNN.
Process:- Velocity Calculation: Derives velocity sequences from input positions and applies the generated noise.
- Determining Connectivity: Uses a radius-based approach (
pyg.nn.radius_graph
) to determine which particles are connected based on their proximity - Feature Normalization: Normalizes the particle velocities and computes normalized distances to the dataset boundaries, ensuring the model remains invariant to global position and velocity scales.
- Edge Feature Computation: Calculates the displacement vectors and normalized distances between connected particles, encapsulating relational information essential to model particle interactions.
- Target Acceleration Extraction: Determines ground truth accelerations from velocity differences. These values are normalized before serving as training targets, maintaining consistency with the model's output space.
def to_graph(
particle_type: torch.Tensor,
position_seq: torch.Tensor,
target_position: torch.Tensor,
metadata: dict,
noise_std: float = 0.0,
) -> pyg.data.Data:
"""Preprocess a trajectory and construct the graph."""
velocity_seq = position_seq[1:] - position_seq[:-1]
position_seq, velocity_seq = generate_noise(position_seq, velocity_seq, noise_std)
recent_position = position_seq[-1]
edge_index, edge_attr = compute_edges(recent_position, metadata)
normal_velocity_seq = compute_normalized_velocity(velocity_seq, metadata, noise_std)
distance_to_boundary = compute_distance_to_boundary(recent_position, metadata)
# Compute acceleration (if target_position is provided)
acceleration = compute_acceleration(
recent_position, velocity_seq, target_position, metadata, noise_std
)
final_node_features = torch.cat(
(
normal_velocity_seq.reshape(normal_velocity_seq.size(1), -1),
distance_to_boundary,
),
dim=-1,
)
graph = pyg.data.Data(
x=particle_type,
edge_index=edge_index,
edge_attr=edge_attr,
y=acceleration,
pos=final_node_features,
)
return graph
In developing the function for constructing graph representations of the particle systems, several key design decisions were made that diverged from the implementation presented in a reference Medium article. These differences were driven by the aim to enhance model performance and code readability. Below, we outline the primary distinctions between our implementation and the reference implementation, along with the rationale behind our choices.
- Reference: They used
torch.clip
to normalize distance_to_boundary to the range [-1, 1] - Our Implementation: We apply
torch.tanh()
to map distance_to_boundary smoothly to [−1, 1] - Rationale: tanh provides a differentiable and smooth normalization function, facilitating better gradient flow and information preservation. We found that this implementation choice leads to improved qualitative results, enabling the model to learn more nuanced boundary interactions.
- Reference: Their implementation utilizes
torch.gather
to compute displacement vectors between connected particles, which can be computationally intensive, and adds unnecessary complexity. - Our Implementation: Computes displacement by directly indexing with edge_index, simplifying the process and enhancing computational efficiency.
- Rationale: Direct indexing reduces computational overhead and simplifies the code, making it more scalable and efficient.
- Reference: Due to the stucture of their dataloader, the authors had to switch the column order of their position sequences before processing them, which introduced unnecessary complexity throughout the processing pipeline. The purpose of this permutation was to make time the first dimension in the vector.
- Our Implementation: In our implementation, we were able to preprocess the dataset such that we did not have to change the column order.
- Rationale: We decided to not add this unecessary complexity in keeping track of which order the vector was currently in, for better readability and comprehension.
We found that our GNN model is capable of effectively learning fluid simulations, generating results which resemble the ground truth and thus appears realistic to the human eye. However, while our results capture the general fluid movement and behaviour, it deviates from the ground truth and does not achieve a perfect match. The particles in our simulations appear more weighed down and less mobile than the ground truth, overall featuring less movement. We also found that our model performs best after about 80 000 steps, failing to improve with any subsequent training; this is likely due to overfitting.
Our initial training produced strange results which did not resemble reality, with particles floating around in a seemingly zero-gravity environment and going through the container boundaries. In diagnosing the issue, we realized that we had not been normalizing the velocity for the node features; fixing this produced results which began resemble the ground truth.
We completed one epoch after normalizing the velocity and observed better results, with particles falling to the floor and beginning to spill to either side. However, once impacting the floor, the particles began rapidly accelerating in all directions, more closely resembling a simulation of gas than liquid. Our next step was to add noise to the velocities during the graph construction which had a further positive impact.
(We also saw some overfitting behaviour with this added normalization.)
The addition of noise made an immediate impact, already producing results which showed some resemblance to the ground truth after only 3000 steps. Despite this, the model was not able to progress much further, and was still subject to the rapid multidirectional acceleration seen with the previous one.
Our next step was to perform a Weights & Biases sweep to find the best hyperparameters for our model. Our new hyperparameters addressed the scattering effect seen in the previous simulations, and the particles now clumped together more tightly, bearing stronger resemblance to the ground truth.
For the first time, our simulation featured waves from the water hitting the walls like the ground truth has. However, our waves were much less pronounced, and only on one side at at time depending on the number of training steps. We let this model train the longest, reaching over 620000 steps, however the results did not improve much past 90000 steps.
As discussed in methodology (more specific reference for final version), we then decided to change the way we normalized the node feature distance_to_boundary
to use the continuous function tanh instead of clipping. While this produced similar quantitative results in terms of loss, this change yielded much better qualitative results, retaining the improved clumping from previous iterations but achieving more realistic particle 'flow'. This is our final model.
Our final model demonstrates a significant improvement in the qualitative realism of the fluid simulation compared to our earlier models. However, quantitatively, we found that our optimal hyperparameters were enough to make the models converge to an evaluation loss of 0.42
, both with and without the new distance_to_boundary
normalization. This loss could be interpreted as the difference in energy between our simulation and the ground truth, indicating a less dynamic system. This makes sense when observing our results as the particles in our simulation move slower and less dynamically.
Despite virtually identical quantitative results across these two training approaches, the continuous normalization method significantly improves the qualitative outcome, more closely resembling the ground truth. We assume that the improved performance of tanh
normalization compared to clipping is due to its differentiability. While the clipping function is discrete, tanh
is smooth and continuous, which likely aids gradient descent by providing smoother gradients and preserving more information about the relative distances to the boundary. Clipping discards values beyond a fixed range while tanh maps them into a continuous range, allowing the model to better learn small-scale interactions between particles. Overall, our model shows promising results in achieving both physical and visual plausibility in fluid simulations.
Exploring graph neural networks for physical simulations was a great learning experience. We were pleasantly surprised with the results we achieved as they were somewhat similar to the validation dataset. We are eager to further explore the robustness of GNNs by applying them to other novel problems. Our main takeaways from this project were: hyperparameters make a substantial difference when it comes to training, qualitative data is just as, if not more, important than quantitative data for this type of model, and careful data preprocessing can make implementation smoother. For future work, we hope to replicate the other types of substances in the original paper, while also improving our model.
[1] A. Sanchez-Gonzalez, J. Godwin, T. Pfaff, R. Ying, J. Leskovec, and P. Battaglia, “Learning to Simulate Complex Physics with Graph Networks,” in Proceedings of the 37th International Conference on Machine Learning, PMLR, Nov. 2020, pp. 8459–8468. Accessed: Dec. 04, 2024. [Online]. Available: https://proceedings.mlr.press/v119/sanchez-gonzalez20a.html
[2] Z. Jin, B. Zheng, C. Kim, and G. X. Gu, “Leveraging graph neural networks and neural operator techniques for high-fidelity mesh-based physics simulations,” APL Machine Learning, vol. 1, no. 4, p. 046109, Nov. 2023, doi: 10.1063/5.0167014.
[3] S. Wu, F. Sun, W. Zhang, X. Xie, and B. Cui, “Graph Neural Networks in Recommender Systems: A Survey,” ACM Comput. Surv., vol. 55, no. 5, pp. 1–37, May 2023, doi: 10.1145/3535101.
[4] “The Graph Neural Network Model | IEEE Journals & Magazine | IEEE Xplore.” Accessed: Dec. 04, 2024. [Online]. Available: https://ieeexplore-ieee-org.ezproxy.library.uvic.ca/abstract/document/4700287
[5] Z. Wu, S. Pan, F. Chen, G. Long, C. Zhang, and P. S. Yu, “A Comprehensive Survey on Graph Neural Networks,” IEEE Transactions on Neural Networks and Learning Systems, vol. 32, no. 1, pp. 4–24, Jan. 2021, doi: 10.1109/TNNLS.2020.2978386.
[6] P. Bongini, M. Bianchini, and F. Scarselli, “Molecular generative Graph Neural Networks for Drug Discovery,” Neurocomputing, vol. 450, pp. 242–252, Aug. 2021, doi: 10.1016/j.neucom.2021.04.039.
[7] L. Wu et al., “Graph Neural Networks for Natural Language Processing: A Survey,” MAL, vol. 16, no. 2, pp. 119–328, Jan. 2023, doi: 10.1561/2200000096.
[8] F. Scarselli, M. Gori, A. C. Tsoi, M. Hagenbuchner, and G. Monfardini, “The Graph Neural Network Model,” IEEE Transactions on Neural Networks, vol. 20, no. 1, pp. 61–80, Jan. 2009, doi: 10.1109/TNN.2008.2005605.
[9] M. Müller, D. Charypar, and M. Gross, “Particle-Based Fluid Simulation for Interactive Applications”.
[10] M. L. Hosain and R. B. Fdhila, “Literature Review of Accelerated CFD Simulation Methods towards Online Application,” Energy Procedia, vol. 75, pp. 3307–3314, Aug. 2015, doi: 10.1016/j.egypro.2015.07.714.
[11] L. Wu, P. Cui, J. Pei, L. Zhao, and X. Guo, “Graph Neural Networks: Foundation, Frontiers and Applications,” in Proceedings of the 28th ACM SIGKDD Conference on Knowledge Discovery and Data Mining, in KDD ’22. New York, NY, USA: Association for Computing Machinery, Aug. 2022, pp. 4840–4841. doi: 10.1145/3534678.3542609.
[12] “Learning to simulate.” Accessed: Dec. 04, 2024. [Online]. Available: https://sites.google.com/view/learning-to-simulate
[13] “Graph Network Simulator.” Accessed: Dec. 04, 2024. [Online]. Available: https://www.geoelements.org/gns/#/