Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PPSCI Export&Infer No.11-12】 #883

Merged
merged 30 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
573fd22
fix doc bugs
wufei2 Apr 8, 2024
dfa2162
fix codestyle bugs
wufei2 Apr 9, 2024
417ac44
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleScien…
wufei2 Apr 9, 2024
fc60eb8
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleScien…
wufei2 Apr 10, 2024
eb99d69
【PPSCI Export&Infer No.15-16】
wufei2 May 2, 2024
f00b73c
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleScien…
wufei2 May 2, 2024
2986349
fix codestyle bug for PPSCI Export&Infer No.15-16】
wufei2 May 2, 2024
cf0d347
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleScien…
wufei2 May 2, 2024
cfab600
fix codestyle bugs for 【PPSCI Export&Infer No.15-16】
wufei2 May 2, 2024
cf0c33f
fix codestyle bugs for 【PPSCI Export&Infer No.15-16】
wufei2 May 2, 2024
7f8bfb2
fix codestyle bugs for 【PPSCI Export&Infer No.15-16】
wufei2 May 2, 2024
ddee33e
fix bugs for 【PPSCI Export&Infer No.15-16】
wufei2 May 6, 2024
0e45563
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleScien…
wufei2 May 6, 2024
dfb3649
fix codestyle bugs
wufei2 May 6, 2024
ae4b627
【PPSCI Export&Infer No.11-12】
wufei2 May 8, 2024
7188afb
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleScien…
wufei2 May 8, 2024
20ba5f7
change predictor
wufei2 May 9, 2024
b20be86
fix bugs in change predictor
wufei2 May 9, 2024
800cf77
cancel extra doc commit
wufei2 May 9, 2024
9af1f85
fix codestyle bugs
wufei2 May 9, 2024
b60953e
Update examples/cylinder/2d_unsteady/cylinder2d_unsteady_Re100.py
HydrogenSulfate May 9, 2024
e1f57bf
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleScien…
wufei2 May 10, 2024
1b62517
Update examples/cylinder/2d_unsteady/transformer_physx/train_transfor…
HydrogenSulfate May 10, 2024
f6f3e76
Update examples/cylinder/2d_unsteady/cylinder2d_unsteady_Re100.py
HydrogenSulfate May 10, 2024
6368b7b
Update examples/cylinder/2d_unsteady/transformer_physx/train_transfor…
wufei2 May 10, 2024
ff28ae0
Merge branch 'my-cool-stuff' of https://github.com/wufei2/PaddleScien…
wufei2 May 10, 2024
2f75180
Update examples/cylinder/2d_unsteady/transformer_physx/train_transfor…
wufei2 May 10, 2024
d058198
cancel extra changes
wufei2 May 10, 2024
0ea2af6
cancel extra changes
wufei2 May 10, 2024
d0a5d37
update examples/cylinder/2d_unsteady/transformer_physx/conf/transform…
wufei2 May 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docs/zh/examples/cylinder2d_unsteady.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,24 @@
python cylinder2d_unsteady_Re100.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/cylinder2d_unsteady_Re100/cylinder2d_unsteady_Re100_pretrained.pdparams
```

=== "模型导出命令"

``` sh
python cylinder2d_unsteady_Re100.py mode=export
```

=== "模型推理命令"

``` sh
# linux
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/cylinder2d_unsteady_Re100/cylinder2d_unsteady_Re100_dataset.tar
# windows
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/cylinder2d_unsteady_Re100/cylinder2d_unsteady_Re100_dataset.tar --output cylinder2d_unsteady_Re100_dataset.tar
# unzip it
tar -xvf cylinder2d_unsteady_Re100_dataset.tar
python cylinder2d_unsteady_Re100.py mode=infer
```

| 预训练模型 | 指标 |
|:--| :--|
| [cylinder2d_unsteady_Re100_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/cylinder2d_unsteady_Re100/cylinder2d_unsteady_Re100_pretrained.pdparams) | loss(Residual): 0.00398<br>MSE.continuity(Residual): 0.00126<br>MSE.momentum_x(Residual): 0.00151<br>MSE.momentum_y(Residual): 0.00120 |
Expand Down
17 changes: 17 additions & 0 deletions docs/zh/examples/cylinder2d_unsteady_transformer_physx.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,23 @@
python train_enn.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/cylinder/cylinder_pretrained.pdparams
python train_transformer.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/cylinder/cylinder_transformer_pretrained.pdparams EMBEDDING_MODEL_PATH=https://paddle-org.bj.bcebos.com/paddlescience/models/cylinder/cylinder_pretrained.pdparams
```
=== "模型导出命令"

``` sh
python train_transformer.py mode=export EMBEDDING_MODEL_PATH=https://paddle-org.bj.bcebos.com/paddlescience/models/cylinder/cylinder_pretrained.pdparams
HydrogenSulfate marked this conversation as resolved.
Show resolved Hide resolved
```

=== "模型推理命令"

``` sh
# linux
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/transformer_physx/cylinder_training.hdf5 -P ./datasets/
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/transformer_physx/cylinder_valid.hdf5 -P ./datasets/
# windows
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/transformer_physx/cylinder_training.hdf5 --output ./datasets/cylinder_training.hdf5
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/transformer_physx/cylinder_valid.hdf5 --output ./datasets/cylinder_valid.hdf5
python train_transformer.py mode=infer
```

| 模型 | MSE |
| :-- | :-- |
Expand Down
18 changes: 18 additions & 0 deletions examples/cylinder/2d_unsteady/conf/cylinder2d_unsteady_Re100.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,21 @@ TRAIN:
EVAL:
batch_size: 10240
pretrained_model_path: null

# inference settings
INFER:
pretrained_model_path: https://paddle-org.bj.bcebos.com/paddlescience/models/cylinder2d_unsteady_Re100/cylinder2d_unsteady_Re100_pretrained.pdparams
export_path: ./inference/cylinder2d_unsteady_Re100
pdmodel_path: ${INFER.export_path}.pdmodel
pdiparams_path: ${INFER.export_path}.pdiparams
onnx_path: ${INFER.export_path}.onnx
device: gpu
engine: native
precision: fp32
ir_optim: true
min_subgraph_size: 5
gpu_mem: 2000
gpu_id: 0
max_batch_size: 10240
num_cpu_threads: 10
batch_size: 10240
61 changes: 60 additions & 1 deletion examples/cylinder/2d_unsteady/cylinder2d_unsteady_Re100.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,59 @@ def evaluate(cfg: DictConfig):
solver.visualize()


def export(cfg: DictConfig):
# set model
model = ppsci.arch.MLP(**cfg.MODEL)

# initialize solver
solver = ppsci.solver.Solver(
model,
pretrained_model_path=cfg.INFER.pretrained_model_path,
)
# export model
from paddle.static import InputSpec

input_spec = [
{key: InputSpec([None, 1], "float32", name=key) for key in model.input_keys},
]
solver.export(input_spec, cfg.INFER.export_path)


def inference(cfg: DictConfig):
from deploy.python_infer import pinn_predictor

predictor = pinn_predictor.PINNPredictor(cfg)
HydrogenSulfate marked this conversation as resolved.
Show resolved Hide resolved
# set time-geometry
geom = {
"time_rect_eval": ppsci.geometry.PointCloud(
reader.load_csv_file(
cfg.DOMAIN_EVAL_PATH,
("t", "x", "y"),
),
("t", "x", "y"),
),
}
NPOINT_EVAL = (
cfg.NPOINT_PDE + cfg.NPOINT_INLET_CYLINDER + cfg.NPOINT_OUTLET
) * cfg.NUM_TIMESTAMPS
input_dict = geom["time_rect_eval"].sample_interior(NPOINT_EVAL, evenly=True)
output_dict = predictor.predict(input_dict, cfg.INFER.batch_size)

# mapping data to cfg.INFER.output_keys
output_dict = {
store_key: output_dict[infer_key]
for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys())
}

ppsci.visualize.save_vtu_from_dict(
"./cylinder2d_unsteady_Re100_pred.vtu",
{**input_dict, **output_dict},
input_dict.keys(),
cfg.MODEL.output_keys,
cfg.NUM_TIMESTAMPS,
)


@hydra.main(
version_base=None,
config_path="./conf",
Expand All @@ -321,8 +374,14 @@ def main(cfg: DictConfig):
train(cfg)
elif cfg.mode == "eval":
evaluate(cfg)
elif cfg.mode == "export":
export(cfg)
elif cfg.mode == "infer":
inference(cfg)
else:
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
raise ValueError(
f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ TRAIN_BLOCK_SIZE: 16
VALID_BLOCK_SIZE: 256
TRAIN_FILE_PATH: ./datasets/cylinder_training.hdf5
VALID_FILE_PATH: ./datasets/cylinder_valid.hdf5
log_freq: 20

# set working condition
EMBEDDING_MODEL_PATH: ./outputs_cylinder2d_unsteady_transformer_physx_enn/checkpoints/latest
Expand All @@ -37,7 +38,7 @@ VIS_DATA_NUMS: 1
MODEL:
input_keys: ["embeds"]
output_keys: ["pred_embeds"]
num_layers: 4
num_layers: 6
num_ctx: 16
embed_size: 128
num_heads: 4
Expand All @@ -63,3 +64,20 @@ TRAIN:
EVAL:
batch_size: 16
pretrained_model_path: null

INFER:
pretrained_model_path: https://paddle-org.bj.bcebos.com/paddlescience/models/cylinder/cylinder_transformer_pretrained.pdparams
export_path: ./inference/cylinder_transformer
pdmodel_path: ${INFER.export_path}.pdmodel
pdiparams_path: ${INFER.export_path}.pdiparams
device: gpu
engine: native
precision: fp32
onnx_path: ${INFER.export_path}.onnx
ir_optim: false
min_subgraph_size: 10
gpu_mem: 4000
gpu_id: 0
max_batch_size: 64
num_cpu_threads: 4
batch_size: 16
Original file line number Diff line number Diff line change
Expand Up @@ -262,14 +262,91 @@ def evaluate(cfg: DictConfig):
solver.visualize()


def export(cfg: DictConfig):
# set model
embedding_model = build_embedding_model(cfg.EMBEDDING_MODEL_PATH)
model_cfg = {
**cfg.MODEL,
"embedding_model": embedding_model,
"input_keys": ["states"],
"output_keys": ["pred_states"],
}
model = ppsci.arch.PhysformerGPT2(**model_cfg)

# initialize solver
solver = ppsci.solver.Solver(
model,
pretrained_model_path=cfg.INFER.pretrained_model_path,
)
# export model
from paddle.static import InputSpec

input_spec = [
{
key: InputSpec([None, 16, 128], "float32", name=key)
for key in model.input_keys
},
]

solver.export(input_spec, cfg.INFER.export_path)


def inference(cfg: DictConfig):
from deploy import python_infer

predictor = python_infer.GeneralPredictor(cfg)

dataset_cfg = {
"name": "CylinderDataset",
"file_path": cfg.VALID_FILE_PATH,
"input_keys": cfg.MODEL.input_keys,
"label_keys": cfg.MODEL.output_keys,
"block_size": cfg.VALID_BLOCK_SIZE,
"stride": 1024,
}

dataset = ppsci.data.dataset.build_dataset(dataset_cfg)

input_dict = {
"states": dataset.data[: cfg.VIS_DATA_NUMS, :-1],
}

output_dict = predictor.predict(input_dict)

# mapping data to cfg.INFER.output_keys
output_keys = ["pred_states"]
output_dict = {
store_key: output_dict[infer_key]
for store_key, infer_key in zip(output_keys, output_dict.keys())
}

input_dict = {
"states": dataset.data[: cfg.VIS_DATA_NUMS, 1:],
}

data_dict = {**input_dict, **output_dict}
for i in range(cfg.VIS_DATA_NUMS):
ppsci.visualize.save_plot_from_3d_dict(
f"./cylinder_transformer_pred_{i}",
{key: value[i] for key, value in data_dict.items()},
("states", "pred_states"),
)


@hydra.main(version_base=None, config_path="./conf", config_name="transformer.yaml")
def main(cfg: DictConfig):
if cfg.mode == "train":
train(cfg)
elif cfg.mode == "eval":
evaluate(cfg)
elif cfg.mode == "export":
export(cfg)
elif cfg.mode == "infer":
inference(cfg)
else:
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
raise ValueError(
f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
)


if __name__ == "__main__":
Expand Down
6 changes: 5 additions & 1 deletion ppsci/arch/physx_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from paddle.nn.initializer import Normal

from ppsci.arch import base
from ppsci.arch.embedding_koopman import CylinderEmbedding

zeros_ = Constant(value=0.0)
ones_ = Constant(value=1.0)
Expand Down Expand Up @@ -387,7 +388,10 @@ def forward(self, x):
x = self._input_transform(x)
x_tensor = self.concat_to_tensor(x, self.input_keys, axis=-1)
if self.embedding_model is not None:
x_tensor = self.embedding_model.encoder(x_tensor)
if isinstance(self.embedding_model, CylinderEmbedding):
x_tensor = self.embedding_model.encoder(x_tensor, x["visc"])
else:
x_tensor = self.embedding_model.encoder(x_tensor)

if self.training:
y = self.forward_tensor(x_tensor)
Expand Down