Skip to content

Commit

Permalink
adjust to new params format with optional param in decorator (#82)
Browse files Browse the repository at this point in the history
Signed-off-by: Walter Martin <[email protected]>
  • Loading branch information
wamartin-aml authored Aug 28, 2023
1 parent 9c62fb5 commit d4a95a7
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 41 deletions.
6 changes: 4 additions & 2 deletions inference_schema/schema_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ._constants import INPUT_SCHEMA_ATTR, OUTPUT_SCHEMA_ATTR


def input_schema(param_name, param_type, convert_to_provided_type=True):
def input_schema(param_name, param_type, convert_to_provided_type=True, optional=False):
"""
Decorator to define an input schema model for a function parameter
The input schema is a representation of what type the function expects
Expand Down Expand Up @@ -46,14 +46,16 @@ def decorator_input(user_run, instance, args, kwargs):
if convert_to_provided_type:
args = list(args)

if param_name not in kwargs.keys():
if param_name not in kwargs.keys() and not optional:
decorators = _get_decorators(user_run)
arg_names = inspect.getfullargspec(decorators[-1]).args
if param_name not in arg_names:
raise Exception('Error, provided param_name "{}" '
'is not in the decorated function.'.format(param_name))
param_position = arg_names.index(param_name)
args[param_position] = _deserialize_input_argument(args[param_position], param_type, param_name)
elif param_name not in kwargs.keys() and optional:
pass
else:
kwargs[param_name] = _deserialize_input_argument(kwargs[param_name], param_type, param_name)

Expand Down
19 changes: 8 additions & 11 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,17 +186,14 @@ def pandas_url_func(param):

@pytest.fixture(scope="session")
def decorated_pandas_func_parameters(pandas_sample_input_for_params, sample_param_dict):
@input_schema('input_data', StandardPythonParameterType({
'split_df': PandasParameterType(pandas_sample_input_for_params, orient='split'),
'parameters': StandardPythonParameterType(sample_param_dict)
}))
def pandas_params_func(input_data):
assert type(input_data) is dict
assert type(input_data["split_df"]) is pd.DataFrame
if 'parameters' in input_data:
assert type(input_data["parameters"]) is dict
beams = input_data['parameters']['num_beams'] if 'parameters' in input_data else 0
return input_data["split_df"]["sentence1"], beams
@input_schema('input_data', PandasParameterType(pandas_sample_input_for_params, orient='split'))
@input_schema('params', StandardPythonParameterType(sample_param_dict), optional=True)
def pandas_params_func(input_data, params=None):
assert type(input_data) is pd.DataFrame
if params is not None:
assert type(params) is dict
beams = params['num_beams'] if params is not None else 0
return input_data["sentence1"], beams

return pandas_params_func

Expand Down
52 changes: 24 additions & 28 deletions tests/test_pandas_parameter_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,38 +79,34 @@ def test_pandas_categorical_handling(self, decorated_pandas_categorical_func):
assert categorical == result

def test_pandas_params_handling(self, decorated_pandas_func_parameters):
pandas_input_data = {"input_data": {
"split_df": {
"columns": [
"sentence1"
],
"data": [
["this is a string starting with"]
],
"index": [0]
},
"parameters": {
"num_beams": 2,
"max_length": 512
}
}}
result = decorated_pandas_func_parameters(**pandas_input_data)
pandas_input_data = {
"columns": [
"sentence1"
],
"data": [
["this is a string starting with"]
],
"index": [0]
}
parameters = {
"num_beams": 2,
"max_length": 512
}
result = decorated_pandas_func_parameters(pandas_input_data, params=parameters)
assert result[0][0] == "this is a string starting with"
assert result[1] == 2

def test_pandas_params_handling_without_params(self, decorated_pandas_func_parameters):
pandas_input_data = {"input_data": {
"split_df": {
"columns": [
"sentence1"
],
"data": [
["this is a string starting with"]
],
"index": [0]
}
}}
result = decorated_pandas_func_parameters(**pandas_input_data)
pandas_input_data = {
"columns": [
"sentence1"
],
"data": [
["this is a string starting with"]
],
"index": [0]
}
result = decorated_pandas_func_parameters(pandas_input_data)
assert result[0][0] == "this is a string starting with"
assert result[1] == 0

Expand Down

0 comments on commit d4a95a7

Please sign in to comment.