From 3c3d27c2032fe32dec8aa64308d2c2d6446345e7 Mon Sep 17 00:00:00 2001 From: smallpoxscattered <1989838596@qq.com> Date: Sun, 21 Apr 2024 15:06:56 +0000 Subject: [PATCH 1/4] [SCI Export&Infer No.24] biharmonic2d --- docs/zh/examples/biharmonic2d.md | 12 +++ examples/biharmonic2d/biharmonic2d.py | 102 ++++++++++++++++++- examples/biharmonic2d/conf/biharmonic2d.yaml | 20 ++++ 3 files changed, 132 insertions(+), 2 deletions(-) diff --git a/docs/zh/examples/biharmonic2d.md b/docs/zh/examples/biharmonic2d.md index 816dbf5bc..c54aa322e 100644 --- a/docs/zh/examples/biharmonic2d.md +++ b/docs/zh/examples/biharmonic2d.md @@ -14,6 +14,18 @@ python biharmonic2d.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/biharmonic2d/biharmonic2d_pretrained.pdparams ``` +=== "模型导出命令" + + ``` sh + python biharmonic2d.py mode=export + ``` + +=== "模型推理命令" + + ``` sh + python biharmonic2d.py mode=infer + ``` + | 预训练模型 | 指标 | |:--| :--| | [biharmonic2d_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/biharmonic2d/biharmonic2d_pretrained.pdparams) | l2_error: 0.02774 | diff --git a/examples/biharmonic2d/biharmonic2d.py b/examples/biharmonic2d/biharmonic2d.py index ad4f08256..4c1539b41 100644 --- a/examples/biharmonic2d/biharmonic2d.py +++ b/examples/biharmonic2d/biharmonic2d.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os from os import path as osp import hydra @@ -31,6 +32,8 @@ def plotting(figname, output_dir, data, griddata_points, griddata_xi, boundary): + if not osp.exists(output_dir): + os.makedirs(output_dir) plt.clf() fig = plt.figure(figname, figsize=(15, 12)) gs = gridspec.GridSpec(2, 3) @@ -39,7 +42,9 @@ def plotting(figname, output_dir, data, griddata_points, griddata_xi, boundary): for i, key in enumerate(data): plot_data = griddata( griddata_points, - data[key].numpy().flatten(), + data[key].flatten() + if isinstance(data[key], np.ndarray) + else data[key].numpy().flatten(), griddata_xi, method="cubic", ) @@ -350,14 +355,107 @@ def compute_outs(w, x, y): ) +def export(cfg: DictConfig): + from paddle import nn + from paddle.static import InputSpec + + # set models + disp_net = ppsci.arch.MLP(**cfg.MODEL) + + # load pretrained model + solver = ppsci.solver.Solver( + model=disp_net, pretrained_model_path=cfg.INFER.pretrained_model_path + ) + + class Wrapped_Model(nn.Layer): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, x): + model_out = self.model(x) + outs = self.compute_outs(model_out["u"], x["x"], x["y"]) + return outs + + def compute_outs(self, w, x, y): + D = cfg.E * (cfg.HEIGHT**3) / (12.0 * (1.0 - cfg.NU**2)) + w_x2 = hessian(w, x) + w_y2 = hessian(w, y) + w_x_y = jacobian(jacobian(w, x), y) + M_x = -(w_x2 + cfg.NU * w_y2) * D + M_y = -(cfg.NU * w_x2 + w_y2) * D + M_xy = (1 - cfg.NU) * w_x_y * D + Q_x = -jacobian((w_x2 + w_y2), x) * D + Q_y = -jacobian((w_x2 + w_y2), y) * D + return {"Mx": M_x, "Mxy": M_xy, "My": M_y, "Qx": Q_x, "Qy": Q_y, "w": w} + + solver.model = Wrapped_Model(solver.model) + + # export models + input_spec = [ + {key: InputSpec([None, 1], "float32", name=key) for key in disp_net.input_keys}, + ] + solver.export(input_spec, cfg.INFER.export_path) + + +def inference(cfg: DictConfig): + from deploy.python_infer import pinn_predictor + + # set model predictor + predictor = pinn_predictor.PINNPredictor(cfg) + + # generate samples + num_x = 201 + num_y = 301 + x_grad, y_grad = np.meshgrid( + np.linspace( + start=0, stop=cfg.LENGTH, num=num_x, endpoint=True, dtype=np.float32 + ), + np.linspace( + start=0, stop=cfg.WIDTH, num=num_y, endpoint=True, dtype=np.float32 + ), + ) + x_faltten = x_grad.reshape(-1, 1) + y_faltten = y_grad.reshape(-1, 1) + + output_dict = predictor.predict( + {"x": x_faltten, "y": y_faltten}, cfg.INFER.batch_size + ) + + # mapping data to cfg.INFER.output_keys + output_dict = { + store_key: output_dict[infer_key] + for store_key, infer_key in zip(cfg.INFER.output_keys, output_dict.keys()) + } + + # plotting + griddata_points = np.concatenate([x_faltten, y_faltten], axis=-1) + griddata_xi = (x_grad, y_grad) + boundary = [0, cfg.LENGTH, 0, cfg.WIDTH] + plotting( + "eval_Mx_Mxy_My_Qx_Qy_w", + "./biharmonic2d_pred", + output_dict, + griddata_points, + griddata_xi, + boundary, + ) + + @hydra.main(version_base=None, config_path="./conf", config_name="biharmonic2d.yaml") def main(cfg: DictConfig): if cfg.mode == "train": train(cfg) elif cfg.mode == "eval": evaluate(cfg) + elif cfg.mode == "export": + export(cfg) + elif cfg.mode == "infer": + inference(cfg) else: - raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'") + raise ValueError( + f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'" + ) if __name__ == "__main__": diff --git a/examples/biharmonic2d/conf/biharmonic2d.yaml b/examples/biharmonic2d/conf/biharmonic2d.yaml index 67bd20f92..cb9e2fce6 100644 --- a/examples/biharmonic2d/conf/biharmonic2d.yaml +++ b/examples/biharmonic2d/conf/biharmonic2d.yaml @@ -11,6 +11,8 @@ hydra: - TRAIN.checkpoint_path - TRAIN.pretrained_model_path - EVAL.pretrained_model_path + - INFER.pretrained_model_path + - INFER.export_path - mode - output_dir - log_freq @@ -72,3 +74,21 @@ EVAL: eval_with_no_grad: true batch_size: sup_validator: 128 + +INFER: + pretrained_model_path: https://paddle-org.bj.bcebos.com/paddlescience/models/biharmonic2d/biharmonic2d_pretrained.pdparams + export_path: ./inference/biharmonic2d + pdmodel_path: ${INFER.export_path}.pdmodel + pdpiparams_path: ${INFER.export_path}.pdiparams + output_keys: ["Mx", "Mxy", "My", "Qx", "Qy", "w"] + device: gpu + engine: native + precision: fp32 + onnx_path: ${INFER.export_path}.onnx + ir_optim: true + min_subgraph_size: 10 + gpu_mem: 4000 + gpu_id: 0 + max_batch_size: 128 + num_cpu_threads: 4 + batch_size: 128 From 4bf0622b1085a81889965d50673ee3907a18684c Mon Sep 17 00:00:00 2001 From: smallpoxscattered <1989838596@qq.com> Date: Mon, 22 Apr 2024 04:42:00 +0000 Subject: [PATCH 2/4] P[PSCI Export&Infer No.724] biharmonic2d fix --- examples/biharmonic2d/biharmonic2d.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/examples/biharmonic2d/biharmonic2d.py b/examples/biharmonic2d/biharmonic2d.py index 4c1539b41..ec599f6ce 100644 --- a/examples/biharmonic2d/biharmonic2d.py +++ b/examples/biharmonic2d/biharmonic2d.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os from os import path as osp import hydra @@ -32,8 +31,6 @@ def plotting(figname, output_dir, data, griddata_points, griddata_xi, boundary): - if not osp.exists(output_dir): - os.makedirs(output_dir) plt.clf() fig = plt.figure(figname, figsize=(15, 12)) gs = gridspec.GridSpec(2, 3) @@ -42,9 +39,7 @@ def plotting(figname, output_dir, data, griddata_points, griddata_xi, boundary): for i, key in enumerate(data): plot_data = griddata( griddata_points, - data[key].flatten() - if isinstance(data[key], np.ndarray) - else data[key].numpy().flatten(), + data[key].flatten(), griddata_xi, method="cubic", ) @@ -348,7 +343,7 @@ def compute_outs(w, x, y): plotting( "eval_Mx_Mxy_My_Qx_Qy_w", cfg.output_dir, - outs, + {k: v.numpy() for k, v in outs.items()}, griddata_points, griddata_xi, boundary, @@ -434,7 +429,7 @@ def inference(cfg: DictConfig): boundary = [0, cfg.LENGTH, 0, cfg.WIDTH] plotting( "eval_Mx_Mxy_My_Qx_Qy_w", - "./biharmonic2d_pred", + cfg.output_dir, output_dict, griddata_points, griddata_xi, From 2ba677afe3eb31b3ca242a1753b483572f3150f9 Mon Sep 17 00:00:00 2001 From: smallpoxscattered <1989838596@qq.com> Date: Tue, 23 Apr 2024 01:30:28 +0000 Subject: [PATCH 3/4] add export&infer nsfnet4 --- docs/zh/examples/nsfnet4.md | 19 +++ examples/nsfnet/VP_NSFNet4.py | 166 ++++++++++++++++++++++++++- examples/nsfnet/conf/VP_NSFNet4.yaml | 19 +++ 3 files changed, 203 insertions(+), 1 deletion(-) diff --git a/docs/zh/examples/nsfnet4.md b/docs/zh/examples/nsfnet4.md index 7ddd435cd..c93f9e8ad 100644 --- a/docs/zh/examples/nsfnet4.md +++ b/docs/zh/examples/nsfnet4.md @@ -29,6 +29,25 @@ ``` +=== "模型导出命令" + + ``` sh + python VP_NSFNet4.py mode=export + ``` + +=== "模型推理命令" + + ``` sh + # VP_NSFNet4 + # linux + wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/NSFNet/NSF4_data.zip -P ./data/ + unzip ./data/NSF4_data.zip + # windows + # curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/NSFNet/NSF4_data.zip --output ./data/NSF4_data.zip + # unzip ./data/NSF4_data.zip + python VP_NSFNet4.py mode=infer + ``` + ## 1. 背景简介 最近几年, 深度学习在很多领域取得了非凡的成就, 尤其是计算机视觉和自然语言处理方面, 而受启发于深度学习的快速发展, 基于深度学习强大的函数逼近能力, 神经网络在科学计算领域也取得了成功, 现阶段的研究主要分为两大类, 一类是将物理信息以及物理限制加入损失函数来对神经网络进行训练, 其代表有 PINN 以及 Deep Ritz Net, 另一类是通过数据驱动的深度神经网络算子, 其代表有 FNO 以及 DeepONet。这些方法都在科学实践中获得了广泛应用, 比如天气预测, 量子化学, 生物工程, 以及计算流体等领域。而为充分探索PINN对流体方程的求解能力, 本次复现[论文](https://arxiv.org/abs/2003.06496)作者设计了NSFNets, 并且先后使用具有解析解或数值解的二维、三维纳韦斯托克方程以及使用DNS方法进行高精度求解的数据集作为参考, 进行正问题求解训练。论文实验表明PINN对不可压纳韦斯托克方程具有优秀的数值求解能力, 本项目主要目标是使用PaddleScience复现论文所实现的高精度求解纳韦斯托克方程的代码。 diff --git a/examples/nsfnet/VP_NSFNet4.py b/examples/nsfnet/VP_NSFNet4.py index 8ecedce7c..6e034d213 100644 --- a/examples/nsfnet/VP_NSFNet4.py +++ b/examples/nsfnet/VP_NSFNet4.py @@ -447,15 +447,179 @@ def evaluate(cfg: DictConfig): plt.savefig(osp.join(cfg.output_dir, "x=0 plane")) +def export(cfg: DictConfig): + from paddle.static import InputSpec + + # set models + model = ppsci.arch.MLP(**cfg.MODEL) + + # load pretrained model + solver = ppsci.solver.Solver( + model=model, pretrained_model_path=cfg.INFER.pretrained_model_path + ) + + # export models + input_spec = [ + {key: InputSpec([None, 1], "float32", name=key) for key in model.input_keys}, + ] + solver.export(input_spec, cfg.INFER.export_path) + + +def inference(cfg: DictConfig): + from deploy.python_infer import pinn_predictor + + # set model predictor + predictor = pinn_predictor.PINNPredictor(cfg) + + # infer Data + test_x = np.load(osp.join(cfg.data_dir, "test43_l.npy")).astype(np.float32) + test_v = np.load(osp.join(cfg.data_dir, "test43_vp.npy")).astype(np.float32) + t = np.array([0.0065, 4 * 0.0065, 7 * 0.0065, 10 * 0.0065, 13 * 0.0065]).astype( + np.float32 + ) + t_star = np.tile(t.reshape(5, 1), (1, 3000)).reshape(-1, 1) + x_star = np.tile(test_x[:, 0:1], (5, 1)).reshape(-1, 1) + y_star = np.tile(test_x[:, 1:2], (5, 1)).reshape(-1, 1) + z_star = np.tile(test_x[:, 2:3], (5, 1)).reshape(-1, 1) + u_star = test_v[:, 0:1] + v_star = test_v[:, 1:2] + w_star = test_v[:, 2:3] + p_star = test_v[:, 3:4] + + pred = predictor.predict( + { + "x": x_star, + "y": y_star, + "z": z_star, + "t": t_star, + }, + cfg.INFER.batch_size, + ) + + pred = { + store_key: pred[infer_key] + for store_key, infer_key in zip(cfg.INFER.output_keys, pred.keys()) + } + + u_pred = pred["u"].reshape((5, -1)) + v_pred = pred["v"].reshape((5, -1)) + w_pred = pred["w"].reshape((5, -1)) + p_pred = pred["p"].reshape((5, -1)) + u_star = u_star.reshape((5, -1)) + v_star = v_star.reshape((5, -1)) + w_star = w_star.reshape((5, -1)) + p_star = p_star.reshape((5, -1)) + + # NS equation can figure out pressure drop, need background pressure p_star.mean() + p_pred = p_pred - p_pred.mean() + p_star.mean() + + u_error = np.linalg.norm(u_pred - u_star, axis=1) / np.linalg.norm(u_star, axis=1) + v_error = np.linalg.norm(v_pred - v_star, axis=1) / np.linalg.norm(v_star, axis=1) + w_error = np.linalg.norm(w_pred - w_star, axis=1) / np.linalg.norm(w_star, axis=1) + p_error = np.linalg.norm(p_pred - p_star, axis=1) / np.linalg.norm(w_star, axis=1) + t = np.array([0.0065, 4 * 0.0065, 7 * 0.0065, 10 * 0.0065, 13 * 0.0065]) + plt.plot(t, np.array(u_error)) + plt.plot(t, np.array(v_error)) + plt.plot(t, np.array(w_error)) + plt.plot(t, np.array(p_error)) + plt.legend(["u_error", "v_error", "w_error", "p_error"]) + plt.xlabel("t") + plt.ylabel("Relative l2 Error") + plt.title("Relative l2 Error, on test dataset") + plt.savefig(osp.join(cfg.output_dir, "error.jpg")) + + grid_x, grid_y = np.mgrid[ + x_star.min() : x_star.max() : 100j, y_star.min() : y_star.max() : 100j + ].astype(np.float32) + x_plot = grid_x.reshape(-1, 1) + y_plot = grid_y.reshape(-1, 1) + z_plot = (z_star.min() * np.ones(y_plot.shape)).astype(np.float32) + t_plot = ((t[-1]) * np.ones(x_plot.shape)).astype(np.float32) + sol = predictor.predict( + {"x": x_plot, "y": y_plot, "z": z_plot, "t": t_plot}, cfg.INFER.batch_size + ) + sol = { + store_key: sol[infer_key] + for store_key, infer_key in zip(cfg.INFER.output_keys, sol.keys()) + } + fig, ax = plt.subplots(1, 4, figsize=(16, 4)) + cmap = plt.cm.get_cmap("jet") + + ax[0].contourf(grid_x, grid_y, sol["u"].reshape(grid_x.shape), levels=50, cmap=cmap) + ax[0].set_title("u prediction") + ax[1].contourf(grid_x, grid_y, sol["v"].reshape(grid_x.shape), levels=50, cmap=cmap) + ax[1].set_title("v prediction") + ax[2].contourf(grid_x, grid_y, sol["w"].reshape(grid_x.shape), levels=50, cmap=cmap) + ax[2].set_title("w prediction") + ax[3].contourf(grid_x, grid_y, sol["p"].reshape(grid_x.shape), levels=50, cmap=cmap) + ax[3].set_title("p prediction") + norm = matplotlib.colors.Normalize( + vmin=sol["u"].min(), vmax=sol["u"].max() + ) # set maximum and minimum + im = plt.cm.ScalarMappable(norm=norm, cmap=cmap) + ax13 = fig.add_axes([0.125, 0.0, 0.175, 0.02]) + plt.colorbar(im, cax=ax13, orientation="horizontal") + ax13 = fig.add_axes([0.325, 0.0, 0.175, 0.02]) + plt.colorbar(im, cax=ax13, orientation="horizontal") + ax13 = fig.add_axes([0.525, 0.0, 0.175, 0.02]) + plt.colorbar(im, cax=ax13, orientation="horizontal") + ax13 = fig.add_axes([0.725, 0.0, 0.175, 0.02]) + plt.colorbar(im, cax=ax13, orientation="horizontal") + plt.savefig(osp.join(cfg.output_dir, "z=0 plane")) + + grid_y, grid_z = np.mgrid[ + y_star.min() : y_star.max() : 100j, z_star.min() : z_star.max() : 100j + ].astype(np.float32) + z_plot = grid_z.reshape(-1, 1) + y_plot = grid_y.reshape(-1, 1) + x_plot = (x_star.min() * np.ones(y_plot.shape)).astype(np.float32) + t_plot = ((t[-1]) * np.ones(x_plot.shape)).astype(np.float32) + sol = predictor.predict( + {"x": x_plot, "y": y_plot, "z": z_plot, "t": t_plot}, cfg.INFER.batch_size + ) + sol = { + store_key: sol[infer_key] + for store_key, infer_key in zip(cfg.INFER.output_keys, sol.keys()) + } + fig, ax = plt.subplots(1, 4, figsize=(16, 4)) + cmap = plt.cm.get_cmap("jet") + + ax[0].contourf(grid_y, grid_z, sol["u"].reshape(grid_x.shape), levels=50, cmap=cmap) + ax[0].set_title("u prediction") + ax[1].contourf(grid_y, grid_z, sol["v"].reshape(grid_x.shape), levels=50, cmap=cmap) + ax[1].set_title("v prediction") + ax[2].contourf(grid_y, grid_z, sol["w"].reshape(grid_x.shape), levels=50, cmap=cmap) + ax[2].set_title("w prediction") + ax[3].contourf(grid_y, grid_z, sol["p"].reshape(grid_x.shape), levels=50, cmap=cmap) + ax[3].set_title("p prediction") + norm = matplotlib.colors.Normalize( + vmin=sol["u"].min(), vmax=sol["u"].max() + ) # set maximum and minimum + im = plt.cm.ScalarMappable(norm=norm, cmap=cmap) + ax13 = fig.add_axes([0.125, 0.0, 0.175, 0.02]) + plt.colorbar(im, cax=ax13, orientation="horizontal") + ax13 = fig.add_axes([0.325, 0.0, 0.175, 0.02]) + plt.colorbar(im, cax=ax13, orientation="horizontal") + ax13 = fig.add_axes([0.525, 0.0, 0.175, 0.02]) + plt.colorbar(im, cax=ax13, orientation="horizontal") + ax13 = fig.add_axes([0.725, 0.0, 0.175, 0.02]) + plt.colorbar(im, cax=ax13, orientation="horizontal") + plt.savefig(osp.join(cfg.output_dir, "x=0 plane")) + + @hydra.main(version_base=None, config_path="./conf", config_name="VP_NSFNet4.yaml") def main(cfg: DictConfig): if cfg.mode == "train": train(cfg) elif cfg.mode == "eval": evaluate(cfg) + elif cfg.mode == "export": + export(cfg) + elif cfg.mode == "infer": + inference(cfg) else: raise ValueError( - osp.join("cfg.mode should in ['train', 'eval'], but got", cfg.mode) + f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'" ) diff --git a/examples/nsfnet/conf/VP_NSFNet4.yaml b/examples/nsfnet/conf/VP_NSFNet4.yaml index 5d09c6ae7..258ace18d 100644 --- a/examples/nsfnet/conf/VP_NSFNet4.yaml +++ b/examples/nsfnet/conf/VP_NSFNet4.yaml @@ -23,6 +23,7 @@ hydra: seed: 1234 output_dir: ${hydra:run.dir} data_dir: ./data/ +log_freq: 20 MODEL: input_keys: ["x", "y","z","t"] output_keys: ["u", "v", "w","p"] @@ -52,3 +53,21 @@ EVAL: pretrained_model_path: null eval_with_no_grad: true + +INFER: + pretrained_model_path: https://paddle-org.bj.bcebos.com/paddlescience/models/nsfnet/nsfnet4.pdparams + export_path: ./inference/VP_NSFNet4 + pdmodel_path: ${INFER.export_path}.pdmodel + pdpiparams_path: ${INFER.export_path}.pdiparams + output_keys: ['p', 'u', 'v', 'w'] + device: gpu + engine: native + precision: fp32 + onnx_path: ${INFER.export_path}.onnx + ir_optim: true + min_subgraph_size: 10 + gpu_mem: 4000 + gpu_id: 0 + max_batch_size: 64 + num_cpu_threads: 4 + batch_size: 16 From 50d06fb10c112f3b32ed35686ad08e61f688b4b1 Mon Sep 17 00:00:00 2001 From: smallpoxscattered <1989838596@qq.com> Date: Tue, 23 Apr 2024 08:11:46 +0000 Subject: [PATCH 4/4] add export&infer nsfnet4 --- examples/nsfnet/VP_NSFNet4.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/nsfnet/VP_NSFNet4.py b/examples/nsfnet/VP_NSFNet4.py index 6e034d213..f60f76607 100644 --- a/examples/nsfnet/VP_NSFNet4.py +++ b/examples/nsfnet/VP_NSFNet4.py @@ -389,7 +389,7 @@ def evaluate(cfg: DictConfig): t_plot = paddle.to_tensor((t[-1]) * np.ones(x_plot.shape), paddle.float32) sol = model({"x": x_plot, "y": y_plot, "z": z_plot, "t": t_plot}) fig, ax = plt.subplots(1, 4, figsize=(16, 4)) - cmap = plt.cm.get_cmap("jet") + cmap = matplotlib.colormaps.get_cmap("jet") ax[0].contourf(grid_x, grid_y, sol["u"].reshape(grid_x.shape), levels=50, cmap=cmap) ax[0].set_title("u prediction") @@ -422,7 +422,7 @@ def evaluate(cfg: DictConfig): t_plot = paddle.to_tensor((t[-1]) * np.ones(x_plot.shape), paddle.float32) sol = model({"x": x_plot, "y": y_plot, "z": z_plot, "t": t_plot}) fig, ax = plt.subplots(1, 4, figsize=(16, 4)) - cmap = plt.cm.get_cmap("jet") + cmap = matplotlib.colormaps.get_cmap("jet") ax[0].contourf(grid_y, grid_z, sol["u"].reshape(grid_x.shape), levels=50, cmap=cmap) ax[0].set_title("u prediction") @@ -543,7 +543,7 @@ def inference(cfg: DictConfig): for store_key, infer_key in zip(cfg.INFER.output_keys, sol.keys()) } fig, ax = plt.subplots(1, 4, figsize=(16, 4)) - cmap = plt.cm.get_cmap("jet") + cmap = matplotlib.colormaps.get_cmap("jet") ax[0].contourf(grid_x, grid_y, sol["u"].reshape(grid_x.shape), levels=50, cmap=cmap) ax[0].set_title("u prediction") @@ -582,7 +582,7 @@ def inference(cfg: DictConfig): for store_key, infer_key in zip(cfg.INFER.output_keys, sol.keys()) } fig, ax = plt.subplots(1, 4, figsize=(16, 4)) - cmap = plt.cm.get_cmap("jet") + cmap = matplotlib.colormaps.get_cmap("jet") ax[0].contourf(grid_y, grid_z, sol["u"].reshape(grid_x.shape), levels=50, cmap=cmap) ax[0].set_title("u prediction")