Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PPSCI Export&Infer No.22】VP_NSFNet4 #864

Merged
merged 5 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions docs/zh/examples/nsfnet4.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,25 @@

```

=== "模型导出命令"

``` sh
python VP_NSFNet4.py mode=export
```

=== "模型推理命令"

``` sh
# VP_NSFNet4
# linux
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/NSFNet/NSF4_data.zip -P ./data/
unzip ./data/NSF4_data.zip
# windows
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/NSFNet/NSF4_data.zip --output ./data/NSF4_data.zip
# unzip ./data/NSF4_data.zip
python VP_NSFNet4.py mode=infer
```

## 1. 背景简介

最近几年, 深度学习在很多领域取得了非凡的成就, 尤其是计算机视觉和自然语言处理方面, 而受启发于深度学习的快速发展, 基于深度学习强大的函数逼近能力, 神经网络在科学计算领域也取得了成功, 现阶段的研究主要分为两大类, 一类是将物理信息以及物理限制加入损失函数来对神经网络进行训练, 其代表有 PINN 以及 Deep Ritz Net, 另一类是通过数据驱动的深度神经网络算子, 其代表有 FNO 以及 DeepONet。这些方法都在科学实践中获得了广泛应用, 比如天气预测, 量子化学, 生物工程, 以及计算流体等领域。而为充分探索PINN对流体方程的求解能力, 本次复现[论文](https://arxiv.org/abs/2003.06496)作者设计了NSFNets, 并且先后使用具有解析解或数值解的二维、三维纳韦斯托克方程以及使用DNS方法进行高精度求解的数据集作为参考, 进行正问题求解训练。论文实验表明PINN对不可压纳韦斯托克方程具有优秀的数值求解能力, 本项目主要目标是使用PaddleScience复现论文所实现的高精度求解纳韦斯托克方程的代码。
Expand Down
170 changes: 167 additions & 3 deletions examples/nsfnet/VP_NSFNet4.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def evaluate(cfg: DictConfig):
t_plot = paddle.to_tensor((t[-1]) * np.ones(x_plot.shape), paddle.float32)
sol = model({"x": x_plot, "y": y_plot, "z": z_plot, "t": t_plot})
fig, ax = plt.subplots(1, 4, figsize=(16, 4))
cmap = plt.cm.get_cmap("jet")
cmap = matplotlib.colormaps.get_cmap("jet")

ax[0].contourf(grid_x, grid_y, sol["u"].reshape(grid_x.shape), levels=50, cmap=cmap)
ax[0].set_title("u prediction")
Expand Down Expand Up @@ -422,7 +422,167 @@ def evaluate(cfg: DictConfig):
t_plot = paddle.to_tensor((t[-1]) * np.ones(x_plot.shape), paddle.float32)
sol = model({"x": x_plot, "y": y_plot, "z": z_plot, "t": t_plot})
fig, ax = plt.subplots(1, 4, figsize=(16, 4))
cmap = plt.cm.get_cmap("jet")
cmap = matplotlib.colormaps.get_cmap("jet")

ax[0].contourf(grid_y, grid_z, sol["u"].reshape(grid_x.shape), levels=50, cmap=cmap)
ax[0].set_title("u prediction")
ax[1].contourf(grid_y, grid_z, sol["v"].reshape(grid_x.shape), levels=50, cmap=cmap)
ax[1].set_title("v prediction")
ax[2].contourf(grid_y, grid_z, sol["w"].reshape(grid_x.shape), levels=50, cmap=cmap)
ax[2].set_title("w prediction")
ax[3].contourf(grid_y, grid_z, sol["p"].reshape(grid_x.shape), levels=50, cmap=cmap)
ax[3].set_title("p prediction")
norm = matplotlib.colors.Normalize(
vmin=sol["u"].min(), vmax=sol["u"].max()
) # set maximum and minimum
im = plt.cm.ScalarMappable(norm=norm, cmap=cmap)
ax13 = fig.add_axes([0.125, 0.0, 0.175, 0.02])
plt.colorbar(im, cax=ax13, orientation="horizontal")
ax13 = fig.add_axes([0.325, 0.0, 0.175, 0.02])
plt.colorbar(im, cax=ax13, orientation="horizontal")
ax13 = fig.add_axes([0.525, 0.0, 0.175, 0.02])
plt.colorbar(im, cax=ax13, orientation="horizontal")
ax13 = fig.add_axes([0.725, 0.0, 0.175, 0.02])
plt.colorbar(im, cax=ax13, orientation="horizontal")
plt.savefig(osp.join(cfg.output_dir, "x=0 plane"))


def export(cfg: DictConfig):
from paddle.static import InputSpec

# set models
model = ppsci.arch.MLP(**cfg.MODEL)

# load pretrained model
solver = ppsci.solver.Solver(
model=model, pretrained_model_path=cfg.INFER.pretrained_model_path
)

# export models
input_spec = [
{key: InputSpec([None, 1], "float32", name=key) for key in model.input_keys},
]
solver.export(input_spec, cfg.INFER.export_path)


def inference(cfg: DictConfig):
from deploy.python_infer import pinn_predictor

# set model predictor
predictor = pinn_predictor.PINNPredictor(cfg)

# infer Data
test_x = np.load(osp.join(cfg.data_dir, "test43_l.npy")).astype(np.float32)
test_v = np.load(osp.join(cfg.data_dir, "test43_vp.npy")).astype(np.float32)
t = np.array([0.0065, 4 * 0.0065, 7 * 0.0065, 10 * 0.0065, 13 * 0.0065]).astype(
np.float32
)
t_star = np.tile(t.reshape(5, 1), (1, 3000)).reshape(-1, 1)
x_star = np.tile(test_x[:, 0:1], (5, 1)).reshape(-1, 1)
y_star = np.tile(test_x[:, 1:2], (5, 1)).reshape(-1, 1)
z_star = np.tile(test_x[:, 2:3], (5, 1)).reshape(-1, 1)
u_star = test_v[:, 0:1]
v_star = test_v[:, 1:2]
w_star = test_v[:, 2:3]
p_star = test_v[:, 3:4]

pred = predictor.predict(
{
"x": x_star,
"y": y_star,
"z": z_star,
"t": t_star,
},
cfg.INFER.batch_size,
)

pred = {
store_key: pred[infer_key]
for store_key, infer_key in zip(cfg.INFER.output_keys, pred.keys())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

所有INFER.output_keys替换成MODEL.output_keys

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没 我不知道为什么 输出被打乱了 要这么对应 输出才是和预测一样的

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没 我不知道为什么 输出被打乱了 要这么对应 输出才是和预测一样的

应该是一个已知的问题,可能是jit.save、jit.load存在的BUG导致导出模型的输出顺序跟动态图不一致。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那我是不是得删掉

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那我是不是得删掉

先这样吧,等修复之后我再改

}

u_pred = pred["u"].reshape((5, -1))
v_pred = pred["v"].reshape((5, -1))
w_pred = pred["w"].reshape((5, -1))
p_pred = pred["p"].reshape((5, -1))
u_star = u_star.reshape((5, -1))
v_star = v_star.reshape((5, -1))
w_star = w_star.reshape((5, -1))
p_star = p_star.reshape((5, -1))

# NS equation can figure out pressure drop, need background pressure p_star.mean()
p_pred = p_pred - p_pred.mean() + p_star.mean()

u_error = np.linalg.norm(u_pred - u_star, axis=1) / np.linalg.norm(u_star, axis=1)
v_error = np.linalg.norm(v_pred - v_star, axis=1) / np.linalg.norm(v_star, axis=1)
w_error = np.linalg.norm(w_pred - w_star, axis=1) / np.linalg.norm(w_star, axis=1)
p_error = np.linalg.norm(p_pred - p_star, axis=1) / np.linalg.norm(w_star, axis=1)
t = np.array([0.0065, 4 * 0.0065, 7 * 0.0065, 10 * 0.0065, 13 * 0.0065])
plt.plot(t, np.array(u_error))
plt.plot(t, np.array(v_error))
plt.plot(t, np.array(w_error))
plt.plot(t, np.array(p_error))
plt.legend(["u_error", "v_error", "w_error", "p_error"])
plt.xlabel("t")
plt.ylabel("Relative l2 Error")
plt.title("Relative l2 Error, on test dataset")
plt.savefig(osp.join(cfg.output_dir, "error.jpg"))

grid_x, grid_y = np.mgrid[
x_star.min() : x_star.max() : 100j, y_star.min() : y_star.max() : 100j
].astype(np.float32)
x_plot = grid_x.reshape(-1, 1)
y_plot = grid_y.reshape(-1, 1)
z_plot = (z_star.min() * np.ones(y_plot.shape)).astype(np.float32)
t_plot = ((t[-1]) * np.ones(x_plot.shape)).astype(np.float32)
sol = predictor.predict(
{"x": x_plot, "y": y_plot, "z": z_plot, "t": t_plot}, cfg.INFER.batch_size
)
sol = {
store_key: sol[infer_key]
for store_key, infer_key in zip(cfg.INFER.output_keys, sol.keys())
}
fig, ax = plt.subplots(1, 4, figsize=(16, 4))
cmap = matplotlib.colormaps.get_cmap("jet")

ax[0].contourf(grid_x, grid_y, sol["u"].reshape(grid_x.shape), levels=50, cmap=cmap)
ax[0].set_title("u prediction")
ax[1].contourf(grid_x, grid_y, sol["v"].reshape(grid_x.shape), levels=50, cmap=cmap)
ax[1].set_title("v prediction")
ax[2].contourf(grid_x, grid_y, sol["w"].reshape(grid_x.shape), levels=50, cmap=cmap)
ax[2].set_title("w prediction")
ax[3].contourf(grid_x, grid_y, sol["p"].reshape(grid_x.shape), levels=50, cmap=cmap)
ax[3].set_title("p prediction")
norm = matplotlib.colors.Normalize(
vmin=sol["u"].min(), vmax=sol["u"].max()
) # set maximum and minimum
im = plt.cm.ScalarMappable(norm=norm, cmap=cmap)
ax13 = fig.add_axes([0.125, 0.0, 0.175, 0.02])
plt.colorbar(im, cax=ax13, orientation="horizontal")
ax13 = fig.add_axes([0.325, 0.0, 0.175, 0.02])
plt.colorbar(im, cax=ax13, orientation="horizontal")
ax13 = fig.add_axes([0.525, 0.0, 0.175, 0.02])
plt.colorbar(im, cax=ax13, orientation="horizontal")
ax13 = fig.add_axes([0.725, 0.0, 0.175, 0.02])
plt.colorbar(im, cax=ax13, orientation="horizontal")
plt.savefig(osp.join(cfg.output_dir, "z=0 plane"))

grid_y, grid_z = np.mgrid[
y_star.min() : y_star.max() : 100j, z_star.min() : z_star.max() : 100j
].astype(np.float32)
z_plot = grid_z.reshape(-1, 1)
y_plot = grid_y.reshape(-1, 1)
x_plot = (x_star.min() * np.ones(y_plot.shape)).astype(np.float32)
t_plot = ((t[-1]) * np.ones(x_plot.shape)).astype(np.float32)
sol = predictor.predict(
{"x": x_plot, "y": y_plot, "z": z_plot, "t": t_plot}, cfg.INFER.batch_size
)
sol = {
store_key: sol[infer_key]
for store_key, infer_key in zip(cfg.INFER.output_keys, sol.keys())
}
fig, ax = plt.subplots(1, 4, figsize=(16, 4))
cmap = matplotlib.colormaps.get_cmap("jet")

ax[0].contourf(grid_y, grid_z, sol["u"].reshape(grid_x.shape), levels=50, cmap=cmap)
ax[0].set_title("u prediction")
Expand Down Expand Up @@ -453,9 +613,13 @@ def main(cfg: DictConfig):
train(cfg)
elif cfg.mode == "eval":
evaluate(cfg)
elif cfg.mode == "export":
export(cfg)
elif cfg.mode == "infer":
inference(cfg)
else:
raise ValueError(
osp.join("cfg.mode should in ['train', 'eval'], but got", cfg.mode)
f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
)


Expand Down
19 changes: 19 additions & 0 deletions examples/nsfnet/conf/VP_NSFNet4.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ hydra:
seed: 1234
output_dir: ${hydra:run.dir}
data_dir: ./data/
log_freq: 20
MODEL:
input_keys: ["x", "y","z","t"]
output_keys: ["u", "v", "w","p"]
Expand Down Expand Up @@ -52,3 +53,21 @@ EVAL:
pretrained_model_path: null
eval_with_no_grad: true


INFER:
pretrained_model_path: https://paddle-org.bj.bcebos.com/paddlescience/models/nsfnet/nsfnet4.pdparams
export_path: ./inference/VP_NSFNet4
pdmodel_path: ${INFER.export_path}.pdmodel
pdpiparams_path: ${INFER.export_path}.pdiparams
output_keys: ['p', 'u', 'v', 'w']
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除,跟MODEL.output_keys重复了

device: gpu
engine: native
precision: fp32
onnx_path: ${INFER.export_path}.onnx
ir_optim: true
min_subgraph_size: 10
gpu_mem: 4000
gpu_id: 0
max_batch_size: 64
num_cpu_threads: 4
batch_size: 16