Skip to content

Commit

Permalink
update jit level
Browse files Browse the repository at this point in the history
  • Loading branch information
daiyuxin0511 committed Jun 17, 2024
1 parent 70b44d5 commit e71d064
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
6 changes: 3 additions & 3 deletions examples/conformer/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
"""

import os
import mindspore

import numpy as np
from asr_model import creadte_asr_model
from dataset import create_asr_predict_dataset, load_language_dict
from mindspore import context
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net

Expand Down Expand Up @@ -44,8 +44,8 @@ def main():
os.makedirs(decode_dir, exist_ok=True)
result_file = open(os.path.join(decode_dir, "result.txt"), "w")

context.set_context(
mode=context.GRAPH_MODE, device_target="Ascend", device_id=get_device_id()
mindspore.set_context(
mode=0, device_target="Ascend", device_id=get_device_id(), jit_config={"jit_level": "O2"}
)

# load test data
Expand Down
12 changes: 7 additions & 5 deletions examples/conformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
"""

import os
import mindspore

from asr_model import creadte_asr_model, create_asr_eval_net
from dataset import create_dataset
from mindspore import ParameterTuple, context, set_seed
from mindspore import ParameterTuple, set_seed
from mindspore.communication.management import init
from mindspore.context import ParallelMode
from mindspore.nn.optim import Adam
Expand Down Expand Up @@ -57,21 +58,22 @@ def train():
model_dir = os.path.join(exp_dir, "model")
graph_dir = os.path.join(exp_dir, "graph")
summary_dir = os.path.join(exp_dir, "summary")
context.set_context(
mode=context.GRAPH_MODE,
mindspore.set_context(
mode=0,
device_target="Ascend",
device_id=get_device_id(),
save_graphs=config.save_graphs,
save_graphs_path=graph_dir,
jit_config={"jit_level": "O2"}
)

device_num = get_device_num()
rank = get_rank_id()
# configurations for distributed training
if config.is_distributed:
init()
context.reset_auto_parallel_context()
context.set_auto_parallel_context(
mindspore.reset_auto_parallel_context()
mindspore.set_auto_parallel_context(
parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True,
device_num=device_num,
Expand Down

0 comments on commit e71d064

Please sign in to comment.