From 8293daf18ad024d0f39f1955c5c8edc348acab48 Mon Sep 17 00:00:00 2001 From: daiyuxin0511 <455472400@qq.com> Date: Mon, 17 Jun 2024 16:35:52 +0800 Subject: [PATCH] update jit level --- examples/conformer/predict.py | 9 ++++++--- examples/conformer/train.py | 12 +++++++----- requirements-dev.txt | 2 +- requirements.txt | 2 +- 4 files changed, 15 insertions(+), 10 deletions(-) diff --git a/examples/conformer/predict.py b/examples/conformer/predict.py index c429136..03e373b 100644 --- a/examples/conformer/predict.py +++ b/examples/conformer/predict.py @@ -5,10 +5,10 @@ import os +import mindspore import numpy as np from asr_model import creadte_asr_model from dataset import create_asr_predict_dataset, load_language_dict -from mindspore import context from mindspore.train.model import Model from mindspore.train.serialization import load_checkpoint, load_param_into_net @@ -44,8 +44,11 @@ def main(): os.makedirs(decode_dir, exist_ok=True) result_file = open(os.path.join(decode_dir, "result.txt"), "w") - context.set_context( - mode=context.GRAPH_MODE, device_target="Ascend", device_id=get_device_id() + mindspore.set_context( + mode=0, + device_target="Ascend", + device_id=get_device_id(), + jit_config={"jit_level": "O2"}, ) # load test data diff --git a/examples/conformer/train.py b/examples/conformer/train.py index 60cf24f..d944b6f 100644 --- a/examples/conformer/train.py +++ b/examples/conformer/train.py @@ -5,9 +5,10 @@ import os +import mindspore from asr_model import creadte_asr_model, create_asr_eval_net from dataset import create_dataset -from mindspore import ParameterTuple, context, set_seed +from mindspore import ParameterTuple, set_seed from mindspore.communication.management import init from mindspore.context import ParallelMode from mindspore.nn.optim import Adam @@ -57,12 +58,13 @@ def train(): model_dir = os.path.join(exp_dir, "model") graph_dir = os.path.join(exp_dir, "graph") summary_dir = os.path.join(exp_dir, "summary") - context.set_context( - mode=context.GRAPH_MODE, + mindspore.set_context( + mode=0, device_target="Ascend", device_id=get_device_id(), save_graphs=config.save_graphs, save_graphs_path=graph_dir, + jit_config={"jit_level": "O2"}, ) device_num = get_device_num() @@ -70,8 +72,8 @@ def train(): # configurations for distributed training if config.is_distributed: init() - context.reset_auto_parallel_context() - context.set_auto_parallel_context( + mindspore.reset_auto_parallel_context() + mindspore.set_auto_parallel_context( parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, device_num=device_num, diff --git a/requirements-dev.txt b/requirements-dev.txt index cc6cae0..6eb672e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,5 @@ mindspore==2.0.0 -numpy>=1.17.0 +numpy>=1.17.0, <2 scipy>=1.6.0 pyyaml>=5.3 tqdm diff --git a/requirements.txt b/requirements.txt index 5296538..121944f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ mindspore==2.0.0 -numpy>=1.17.0 +numpy>=1.17.0, <2 scipy>=1.6.0 pyyaml>=5.3 tqdm