Skip to content

Commit

Permalink
add skip_prune_program arg for Solver.export (PaddlePaddle#835)
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate authored Apr 7, 2024
1 parent 9b62a10 commit f0eaa6a
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions ppsci/solver/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import os
import sys
from os import path as osp
from typing import TYPE_CHECKING
from typing import Callable
from typing import Dict
from typing import List
Expand All @@ -41,7 +42,6 @@
from paddle import optimizer as optim
from paddle.distributed import fleet
from paddle.framework import core
from paddle.static import InputSpec
from typing_extensions import Literal

import ppsci
Expand All @@ -51,6 +51,9 @@
from ppsci.utils import misc
from ppsci.utils import save_load

if TYPE_CHECKING:
from paddle.static import InputSpec


class Solver:
"""Class for solver.
Expand Down Expand Up @@ -729,7 +732,11 @@ def predict(

@misc.run_on_eval_mode
def export(
self, input_spec: List[InputSpec], export_path: str, with_onnx: bool = False
self,
input_spec: List["InputSpec"],
export_path: str,
with_onnx: bool = False,
skip_prune_program: bool = False,
):
"""
Convert model to static graph model and export to files.
Expand All @@ -739,7 +746,9 @@ def export(
of the model input.
export_path (str): The path prefix to save model.
with_onnx (bool, optional): Whether to export model into onnx after
paddle inference models are exported.
paddle inference models are exported. Defaults to False.
skip_prune_program (bool, optional): Whether prune program, pruning program
may cause unexpectable result, e.g. llm-inference. Defaults to False.
"""
jit.enable_to_static(True)

Expand All @@ -760,7 +769,7 @@ def export(
if len(osp.dirname(export_path)):
os.makedirs(osp.dirname(export_path), exist_ok=True)
try:
jit.save(static_model, export_path)
jit.save(static_model, export_path, skip_prune_program=skip_prune_program)
except Exception as e:
raise e
logger.message(
Expand Down

0 comments on commit f0eaa6a

Please sign in to comment.