Skip to content

Commit

Permalink
【PPSCI Export&Infer No.23】viv (PaddlePaddle#832)
Browse files Browse the repository at this point in the history
* eadd export and inference for viv

* add doc

* fix viv export&infer

* Rewriting function

* fix viv export&infer
  • Loading branch information
smallpoxscattered authored Apr 7, 2024
1 parent f0eaa6a commit d7e4991
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 2 deletions.
12 changes: 12 additions & 0 deletions docs/zh/examples/viv.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,18 @@
python viv.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/viv/viv_pretrained.pdparams
```

=== "模型导出命令"

``` sh
python viv.py mode=export
```

=== "模型推理命令"

``` sh
python viv.py mode=infer
```

| 预训练模型 | 指标 |
|:--| :--|
| [viv_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/aneurysm/viv_pretrained.pdparams)<br>[viv_pretrained.pdeqn](https://paddle-org.bj.bcebos.com/paddlescience/models/aneurysm/viv_pretrained.pdeqn) | 'eta': 1.1416150300647132e-06<br>'f': 4.635014192899689e-06 |
Expand Down
22 changes: 22 additions & 0 deletions examples/fsi/conf/viv.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ hydra:
- TRAIN.checkpoint_path
- TRAIN.pretrained_model_path
- EVAL.pretrained_model_path
- INFER.pretrained_model_path
- INFER.export_path
- mode
- output_dir
- log_freq
Expand Down Expand Up @@ -60,3 +62,23 @@ TRAIN:
EVAL:
pretrained_model_path: null
batch_size: 32

# inference settings
INFER:
pretrained_model_path: "https://paddle-org.bj.bcebos.com/paddlescience/models/viv/viv_pretrained.pdparams"
export_path: ./inference/viv
pdmodel_path: ${INFER.export_path}.pdmodel
pdpiparams_path: ${INFER.export_path}.pdiparams
input_keys: ${MODEL.input_keys}
output_keys: ["eta", "f"]
device: gpu
engine: native
precision: fp32
onnx_path: ${INFER.export_path}.onnx
ir_optim: true
min_subgraph_size: 10
gpu_mem: 4000
gpu_id: 0
max_batch_size: 64
num_cpu_threads: 4
batch_size: 16
76 changes: 75 additions & 1 deletion examples/fsi/viv.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,14 +200,88 @@ def evaluate(cfg: DictConfig):
solver.visualize()


def export(cfg: DictConfig):
from paddle import nn
from paddle.static import InputSpec

# set model
model = ppsci.arch.MLP(**cfg.MODEL)
# initialize equation
equation = {"VIV": ppsci.equation.Vibration(2, -4, 0)}
# initialize solver
solver = ppsci.solver.Solver(
model,
equation=equation,
pretrained_model_path=cfg.INFER.pretrained_model_path,
)
# Convert equation to func
f_func = ppsci.lambdify(
solver.equation["VIV"].equations["f"],
solver.model,
list(solver.equation["VIV"].learnable_parameters),
)

class Wrapped_Model(nn.Layer):
def __init__(self, model, func):
super().__init__()
self.model = model
self.func = func

def forward(self, x):
model_out = self.model(x)
func_out = self.func(x)
return {**model_out, "f": func_out}

solver.model = Wrapped_Model(model, f_func)
# export models
input_spec = [
{key: InputSpec([None, 1], "float32", name=key) for key in model.input_keys},
]
solver.export(input_spec, cfg.INFER.export_path, skip_prune_program=True)


def inference(cfg: DictConfig):
from deploy.python_infer import pinn_predictor

# set model predictor
predictor = pinn_predictor.PINNPredictor(cfg)

infer_mat = ppsci.utils.reader.load_mat_file(
cfg.VIV_DATA_PATH,
("t_f", "eta_gt", "f_gt"),
alias_dict={"eta_gt": "eta", "f_gt": "f"},
)

input_dict = {key: infer_mat[key] for key in cfg.INFER.input_keys}

output_dict = predictor.predict(input_dict, cfg.INFER.batch_size)

# mapping data to cfg.INFER.output_keys
output_dict = {
store_key: output_dict[infer_key]
for store_key, infer_key in zip(cfg.INFER.output_keys, output_dict.keys())
}
infer_mat.update(output_dict)

ppsci.visualize.plot.save_plot_from_1d_dict(
"./viv_pred", infer_mat, ("t_f",), ("eta", "eta_gt", "f", "f_gt")
)


@hydra.main(version_base=None, config_path="./conf", config_name="viv.yaml")
def main(cfg: DictConfig):
if cfg.mode == "train":
train(cfg)
elif cfg.mode == "eval":
evaluate(cfg)
elif cfg.mode == "export":
export(cfg)
elif cfg.mode == "infer":
inference(cfg)
else:
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
raise ValueError(
f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion ppsci/utils/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ class ComposedNode(nn.Layer):
def __init__(self, callable_nodes: List[Node]):
super().__init__()
assert len(callable_nodes)
self.callable_nodes = callable_nodes
self.callable_nodes = nn.LayerList(callable_nodes)

def forward(self, data_dict: DATA_DICT) -> paddle.Tensor:
# call all callable_nodes in order
Expand Down

0 comments on commit d7e4991

Please sign in to comment.