From 30d645e40fbfa4b085f3f4a746410fd5b65065fd Mon Sep 17 00:00:00 2001 From: husichao Date: Mon, 11 Mar 2024 16:39:28 +0800 Subject: [PATCH] update ci --- mindpet/delta/adapter.py | 9 +++++++ mindpet/delta/lora.py | 13 ++++++++-- mindpet/utils/version_control.py | 5 +--- requirements/ci_requirements.txt | 2 +- test/unit_test/delta/test_adapter.py | 14 ----------- test/unit_test/delta/test_low_rank_adapter.py | 24 ------------------- test/unit_test/graph/test_save_ckpt.py | 20 ---------------- 7 files changed, 22 insertions(+), 65 deletions(-) diff --git a/mindpet/delta/adapter.py b/mindpet/delta/adapter.py index bfc21d4..9eacc26 100755 --- a/mindpet/delta/adapter.py +++ b/mindpet/delta/adapter.py @@ -8,6 +8,8 @@ import mindspore.common.dtype as mstype from mindspore.ops import operations as P from mindspore.ops import functional as F +from mindspore.nn.cell import Cell +from mindspore.ops.primitive import Primitive try: from mindspore.nn.transformer.layers import _Linear, _args_type_validator_check, _valid_value_checks from mindspore._checkparam import Validator @@ -207,6 +209,13 @@ def __init__(self, self.cast = P.Cast() self.act_name = activation + self.activation = get_activation(activation) if isinstance( + activation, str) else activation + if activation is not None and not isinstance(self.activation, (Cell, Primitive)): + raise TypeError(f"For '{self.cls_name}', the 'activation' must be str or Cell or Primitive, but got " + f"{type(activation).__name__}.") + self.activation_flag = self.activation is not None + def construct(self, input_tensor): """Foward""" # get input_x info diff --git a/mindpet/delta/lora.py b/mindpet/delta/lora.py index 2dd6ddc..ecb6e36 100755 --- a/mindpet/delta/lora.py +++ b/mindpet/delta/lora.py @@ -17,9 +17,10 @@ import mindspore._checkparam as Validator INC_LEFT = Validator.INC_LEFT from mindspore.common.initializer import initializer, HeUniform +from mindspore.nn.cell import Cell +from mindspore.ops.primitive import Primitive from mindpet.delta.delta_constants import VALID_TENSOR_DATATYPE -from mindpet.utils.version_control import get_dropout - +from mindpet.utils.version_control import get_dropout, get_activation class LoRADense(nn.Dense): """Define a dense layer with LoRA structure. @@ -74,6 +75,14 @@ def __init__( self.lora_a_matmul = P.MatMul(transpose_b=True) self.lora_b_matmul = P.MatMul(transpose_b=True) + activation = kwargs.pop("activation", None) + self.activation = get_activation(activation) if isinstance( + activation, str) else activation + if activation is not None and not isinstance(self.activation, (Cell, Primitive)): + raise TypeError(f"For '{self.cls_name}', the 'activation' must be str or Cell or Primitive, but got " + f"{type(activation).__name__}.") + self.activation_flag = self.activation is not None + def construct(self, input_tensor): """Foward""" # Data type operation diff --git a/mindpet/utils/version_control.py b/mindpet/utils/version_control.py index d336694..e031a8e 100755 --- a/mindpet/utils/version_control.py +++ b/mindpet/utils/version_control.py @@ -5,10 +5,7 @@ import mindspore as ms from mindspore import nn from .version_utils import is_version_ge -if is_version_ge(ms.__version__, '2.0.0'): - from ..layers.activation import get_activation, _activation -else: - from mindspore.nn.layer.activation import get_activation, _activation +from ..layers.activation import get_activation, _activation # pylint: disable=W0127 _activation = _activation diff --git a/requirements/ci_requirements.txt b/requirements/ci_requirements.txt index d14f0ac..d03d3b8 100755 --- a/requirements/ci_requirements.txt +++ b/requirements/ci_requirements.txt @@ -3,7 +3,7 @@ tqdm pylint pytest click == 8.1.3 -mindspore == 2.0.0 +mindspore == 2.2.11 Pillow == 9.5.0 mindformers requests \ No newline at end of file diff --git a/test/unit_test/delta/test_adapter.py b/test/unit_test/delta/test_adapter.py index 701c378..3dd323a 100755 --- a/test/unit_test/delta/test_adapter.py +++ b/test/unit_test/delta/test_adapter.py @@ -141,13 +141,6 @@ def test_check_type_of_data_with_no_legal_weight_init(self): AdapterDense(in_channels=1, out_channels=1, bottleneck_size=8, weight_init='a') logging.error(ex.exception) logging.info('Finish test_check_type_of_data_with_no_legal_weight_init.') - - def test_check_type_of_data_with_none_weight_init(self): - logging.info('Start test_check_type_of_data_with_None_weight_init.') - with self.assertRaises(TypeError) as ex: - AdapterDense(in_channels=1, out_channels=1, bottleneck_size=8, weight_init=None) - logging.error(ex.exception) - logging.info('Finish test_check_type_of_data_with_None_weight_init.') # check bias_init def test_check_type_of_data_with_no_legal_bias_init(self): @@ -156,13 +149,6 @@ def test_check_type_of_data_with_no_legal_bias_init(self): AdapterDense(in_channels=1, out_channels=1, bottleneck_size=8, bias_init='a') logging.error(ex.exception) logging.info('Finish test_check_type_of_data_with_no_legal_bias_init.') - - def test_check_type_of_data_with_none_bias_init(self): - logging.info('Start test_check_type_of_data_with_None_bias_init.') - with self.assertRaises(TypeError) as ex: - AdapterDense(in_channels=1, out_channels=1, bottleneck_size=8, bias_init=None) - logging.error(ex.exception) - logging.info('Finish test_check_type_of_data_with_None_bias_init.') # check non_linearity def test_check_type_of_data_with_no_legal_non_linearity(self): diff --git a/test/unit_test/delta/test_low_rank_adapter.py b/test/unit_test/delta/test_low_rank_adapter.py index 9fec36d..9eec0f6 100755 --- a/test/unit_test/delta/test_low_rank_adapter.py +++ b/test/unit_test/delta/test_low_rank_adapter.py @@ -155,18 +155,6 @@ def test_dense_check_type_of_data_with_illegal_str_weight_init(self): logging.info( 'Finish test_dense_check_type_of_data_with_illegal_str_weight_init.') - def test_dense_check_type_of_data_with_none_weight_init(self): - logging.info( - 'Start test_dense_check_type_of_data_with_None_weight_init.') - with self.assertRaises(TypeError) as ex: - net = LowRankAdapterDense(in_channels=5, - out_channels=4, - weight_init=None, - reduction_factor=2) - logging.error(ex.exception) - logging.info( - 'Finish test_dense_check_type_of_data_with_None_weight_init.') - def test_dense_check_init_with_weight_init_shape_length_not_equal_to_two(self): logging.info( 'Start test_dense_check_init_with_weight_init_shape_length_not_equal_to_two.') @@ -219,18 +207,6 @@ def test_dense_check_type_of_data_with_illegal_str_bias_init(self): logging.info( 'Finish test_dense_check_type_of_data_with_illegal_str_bias_init.') - def test_dense_check_type_of_data_with_none_bias_init(self): - logging.info( - 'Start test_dense_check_type_of_data_with_none_bias_init.') - with self.assertRaises(TypeError) as ex: - net = LowRankAdapterDense(in_channels=5, - out_channels=4, - bias_init=None, - reduction_factor=2) - logging.error(ex.exception) - logging.info( - 'Finish test_dense_check_type_of_data_with_none_bias_init.') - def test_dense_check_init_with_bias_init_shape_length_not_equal_to_one(self): logging.info( 'Start test_dense_check_init_with_bias_init_shape_length_not_equal_to_one.') diff --git a/test/unit_test/graph/test_save_ckpt.py b/test/unit_test/graph/test_save_ckpt.py index ec2302b..f04ea9d 100755 --- a/test/unit_test/graph/test_save_ckpt.py +++ b/test/unit_test/graph/test_save_ckpt.py @@ -403,26 +403,6 @@ def test_enable_ge(self, mock_api, mock_func): logging.info('Finish test_enable_ge.') - def test_saved_network(self): - logging.info('Start test_saved_network.') - - class TestCheckpoint(TrainableParamsCheckPoint): - def step_end(self, run_context): - super(TestCheckpoint, self).step_end(run_context) - cb_params = run_context.original_args() - cb_params.train_network = dict() - - ckpt_path = os.path.join(cur_dir, "per_min") - config_ck = CheckpointConfig(save_checkpoint_steps=1, keep_checkpoint_max=1) - params_check_point = TestCheckpoint(prefix='checkpoint_delta', - directory=ckpt_path, config=config_ck) - with self.assertRaises(TypeError): - train(params_check_point, enable=True) - - shutil.rmtree(ckpt_path, ignore_errors=True) - - logging.info('Start test_saved_network.') - if __name__ == '__main__': pytest.main(['-s', os.path.abspath(__file__)])