Skip to content

Commit

Permalink
update ci
Browse files Browse the repository at this point in the history
  • Loading branch information
husichao666 committed Mar 11, 2024
1 parent 0ed5d7d commit 30d645e
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 65 deletions.
9 changes: 9 additions & 0 deletions mindpet/delta/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.nn.cell import Cell
from mindspore.ops.primitive import Primitive
try:
from mindspore.nn.transformer.layers import _Linear, _args_type_validator_check, _valid_value_checks
from mindspore._checkparam import Validator
Expand Down Expand Up @@ -207,6 +209,13 @@ def __init__(self,
self.cast = P.Cast()
self.act_name = activation

self.activation = get_activation(activation) if isinstance(
activation, str) else activation
if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
raise TypeError(f"For '{self.cls_name}', the 'activation' must be str or Cell or Primitive, but got "
f"{type(activation).__name__}.")
self.activation_flag = self.activation is not None

def construct(self, input_tensor):
"""Foward"""
# get input_x info
Expand Down
13 changes: 11 additions & 2 deletions mindpet/delta/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
import mindspore._checkparam as Validator
INC_LEFT = Validator.INC_LEFT
from mindspore.common.initializer import initializer, HeUniform
from mindspore.nn.cell import Cell
from mindspore.ops.primitive import Primitive
from mindpet.delta.delta_constants import VALID_TENSOR_DATATYPE
from mindpet.utils.version_control import get_dropout

from mindpet.utils.version_control import get_dropout, get_activation

class LoRADense(nn.Dense):
"""Define a dense layer with LoRA structure.
Expand Down Expand Up @@ -74,6 +75,14 @@ def __init__(
self.lora_a_matmul = P.MatMul(transpose_b=True)
self.lora_b_matmul = P.MatMul(transpose_b=True)

activation = kwargs.pop("activation", None)
self.activation = get_activation(activation) if isinstance(
activation, str) else activation
if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
raise TypeError(f"For '{self.cls_name}', the 'activation' must be str or Cell or Primitive, but got "
f"{type(activation).__name__}.")
self.activation_flag = self.activation is not None

def construct(self, input_tensor):
"""Foward"""
# Data type operation
Expand Down
5 changes: 1 addition & 4 deletions mindpet/utils/version_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@
import mindspore as ms
from mindspore import nn
from .version_utils import is_version_ge
if is_version_ge(ms.__version__, '2.0.0'):
from ..layers.activation import get_activation, _activation
else:
from mindspore.nn.layer.activation import get_activation, _activation
from ..layers.activation import get_activation, _activation

# pylint: disable=W0127
_activation = _activation
Expand Down
2 changes: 1 addition & 1 deletion requirements/ci_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ tqdm
pylint
pytest
click == 8.1.3
mindspore == 2.0.0
mindspore == 2.2.11
Pillow == 9.5.0
mindformers
requests
14 changes: 0 additions & 14 deletions test/unit_test/delta/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,6 @@ def test_check_type_of_data_with_no_legal_weight_init(self):
AdapterDense(in_channels=1, out_channels=1, bottleneck_size=8, weight_init='a')
logging.error(ex.exception)
logging.info('Finish test_check_type_of_data_with_no_legal_weight_init.')

def test_check_type_of_data_with_none_weight_init(self):
logging.info('Start test_check_type_of_data_with_None_weight_init.')
with self.assertRaises(TypeError) as ex:
AdapterDense(in_channels=1, out_channels=1, bottleneck_size=8, weight_init=None)
logging.error(ex.exception)
logging.info('Finish test_check_type_of_data_with_None_weight_init.')

# check bias_init
def test_check_type_of_data_with_no_legal_bias_init(self):
Expand All @@ -156,13 +149,6 @@ def test_check_type_of_data_with_no_legal_bias_init(self):
AdapterDense(in_channels=1, out_channels=1, bottleneck_size=8, bias_init='a')
logging.error(ex.exception)
logging.info('Finish test_check_type_of_data_with_no_legal_bias_init.')

def test_check_type_of_data_with_none_bias_init(self):
logging.info('Start test_check_type_of_data_with_None_bias_init.')
with self.assertRaises(TypeError) as ex:
AdapterDense(in_channels=1, out_channels=1, bottleneck_size=8, bias_init=None)
logging.error(ex.exception)
logging.info('Finish test_check_type_of_data_with_None_bias_init.')

# check non_linearity
def test_check_type_of_data_with_no_legal_non_linearity(self):
Expand Down
24 changes: 0 additions & 24 deletions test/unit_test/delta/test_low_rank_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,18 +155,6 @@ def test_dense_check_type_of_data_with_illegal_str_weight_init(self):
logging.info(
'Finish test_dense_check_type_of_data_with_illegal_str_weight_init.')

def test_dense_check_type_of_data_with_none_weight_init(self):
logging.info(
'Start test_dense_check_type_of_data_with_None_weight_init.')
with self.assertRaises(TypeError) as ex:
net = LowRankAdapterDense(in_channels=5,
out_channels=4,
weight_init=None,
reduction_factor=2)
logging.error(ex.exception)
logging.info(
'Finish test_dense_check_type_of_data_with_None_weight_init.')

def test_dense_check_init_with_weight_init_shape_length_not_equal_to_two(self):
logging.info(
'Start test_dense_check_init_with_weight_init_shape_length_not_equal_to_two.')
Expand Down Expand Up @@ -219,18 +207,6 @@ def test_dense_check_type_of_data_with_illegal_str_bias_init(self):
logging.info(
'Finish test_dense_check_type_of_data_with_illegal_str_bias_init.')

def test_dense_check_type_of_data_with_none_bias_init(self):
logging.info(
'Start test_dense_check_type_of_data_with_none_bias_init.')
with self.assertRaises(TypeError) as ex:
net = LowRankAdapterDense(in_channels=5,
out_channels=4,
bias_init=None,
reduction_factor=2)
logging.error(ex.exception)
logging.info(
'Finish test_dense_check_type_of_data_with_none_bias_init.')

def test_dense_check_init_with_bias_init_shape_length_not_equal_to_one(self):
logging.info(
'Start test_dense_check_init_with_bias_init_shape_length_not_equal_to_one.')
Expand Down
20 changes: 0 additions & 20 deletions test/unit_test/graph/test_save_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,26 +403,6 @@ def test_enable_ge(self, mock_api, mock_func):

logging.info('Finish test_enable_ge.')

def test_saved_network(self):
logging.info('Start test_saved_network.')

class TestCheckpoint(TrainableParamsCheckPoint):
def step_end(self, run_context):
super(TestCheckpoint, self).step_end(run_context)
cb_params = run_context.original_args()
cb_params.train_network = dict()

ckpt_path = os.path.join(cur_dir, "per_min")
config_ck = CheckpointConfig(save_checkpoint_steps=1, keep_checkpoint_max=1)
params_check_point = TestCheckpoint(prefix='checkpoint_delta',
directory=ckpt_path, config=config_ck)
with self.assertRaises(TypeError):
train(params_check_point, enable=True)

shutil.rmtree(ckpt_path, ignore_errors=True)

logging.info('Start test_saved_network.')


if __name__ == '__main__':
pytest.main(['-s', os.path.abspath(__file__)])
Expand Down

0 comments on commit 30d645e

Please sign in to comment.