Skip to content

PredictiveScienceLab/jaxpi

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

47 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

JAX-PI

This repository is a comprehensive implementation of physics-informed neural networks (PINNs), seamlessly integrating several advanced network architectures, training algorithms from these papers

This repository also releases an extensive range of benchmarking examples, showcasing the effectiveness and robustness of our implementation. Our implementation supports both single and multi-GPU training, while evaluation is currently limited to single-GPU setups.

Installation

Ensure that you have Python 3.8 or later installed on your system. Our code is GPU-only. We highly recommend using the most recent versions of JAX and JAX-lib, along with compatible CUDA and cuDNN versions. The code has been tested and confirmed to work with the following versions:

  • JAX 0.4.16
  • CUDA 12.1
  • cuDNN 8.9

Install JAX-PI with the following commands:

git clone https://github.com/PredictiveIntelligenceLab/JAX-PI
cd JAX-PI
pip install .

Quickstart

We use Weights & Biases to log and monitor training metrics. Please ensure you have Weights & Biases installed and properly set up with your account before proceeding. You can follow the installation guide provided here.

To illustrate how to use our code, we will use the advection equation as an example. First, navigate to the advection directory within the examples folder:

cd JAX-PI/examples/advection

To train the model, run the following command:

python3 main.py 

Our code automatically supports multi-GPU execution. You can specify the GPUs you want to use with the CUDA_VISIBLE_DEVICES environment variable. For example, to use the first two GPUs (0 and 1), use the following command:

CUDA_VISIBLE_DEVICES=0,1 python3 main.py

Note on Memory Usage: Different models and examples may require varying amounts of GPU memory. If you encounter an out-of-memory error, you can decrease the batch size using the --config.batch_size_per_device option.

To evaluate the model's performance, you can switch to evaluation mode with the following command:

python main.py --config.mode=eval

Examples

In the following table, we present a comparison of various benchmarks. Each row contains information about the specific benchmark, its relative $L^2$ error, and links to the corresponding model checkpoints and Weights & Biases logs.

Benchmark Relative $L^2$ Error Checkpoint Weights & Biases
Allen-Cahn equation $5.37 \times 10^{-5}$ allen_cahn allen_cahn
Advection equation $6.88 \times 10^{-4}$ adv adv
Stokes flow $8.04 \times 10^{-5}$ stokes stokes
Kuramoto–Sivashinsky equation $1.61 \times 10^{-1}$ ks ks
Lid-driven cavity flow $1.58 \times 10^{-1}$ ldc ldc
Navier–Stokes flow in tori $3.53 \times 10^{-1}$ ns_tori ns_tori
Navier–Stokes flow around a cylinder - ns_cylinder ns_cylinder

Decaying Navier-Stokes flow in tori

ns_tori

Vortex shedding

ns_cylinder

ns_cylinder

ns_cylinder

Citation

@article{wang2023expert,
  title={An Expert's Guide to Training Physics-informed Neural Networks},
  author={Wang, Sifan and Sankaran, Shyam and Wang, Hanwen and Perdikaris, Paris},
  journal={arXiv preprint arXiv:2308.08468},
  year={2023}
}

About

Forked version of jaxpi to be used in ME 697.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%