diff --git a/osbenchmark/benchmark.py b/osbenchmark/benchmark.py index e70d4568c..ffa7170ce 100644 --- a/osbenchmark/benchmark.py +++ b/osbenchmark/benchmark.py @@ -566,6 +566,22 @@ def add_workload_source(subparser): f"(default: {metrics.GlobalStatsCalculator.DEFAULT_THROUGHPUT_PERCENTILES}).", default=metrics.GlobalStatsCalculator.DEFAULT_THROUGHPUT_PERCENTILES ) + test_execution_parser.add_argument( + "--randomization-enabled", + help="Runs the given workload with query randomization enabled (default: false).", + default=False, + action="store_true") + test_execution_parser.add_argument( + "--randomization-repeat-frequency", + "-rf", + help=f"The repeat_frequency for query randomization. Ignored if randomization is off" + f"(default: {workload.loader.QueryRandomizerWorkloadProcessor.DEFAULT_RF}).", + default=workload.loader.QueryRandomizerWorkloadProcessor.DEFAULT_RF) + test_execution_parser.add_argument( + "--randomization-n", + help=f"The number of standard values to generate for each field for query randomization." + f"Ignored if randomization is off (default: {workload.loader.QueryRandomizerWorkloadProcessor.DEFAULT_N}).", + default=workload.loader.QueryRandomizerWorkloadProcessor.DEFAULT_N) ############################################################################### # @@ -876,6 +892,9 @@ def dispatch_sub_command(arg_parser, args, cfg): cfg.add(config.Scope.applicationOverride, "workload", "test.mode.enabled", args.test_mode) cfg.add(config.Scope.applicationOverride, "workload", "latency.percentiles", args.latency_percentiles) cfg.add(config.Scope.applicationOverride, "workload", "throughput.percentiles", args.throughput_percentiles) + cfg.add(config.Scope.applicationOverride, "workload", "randomization.enabled", args.randomization_enabled) + cfg.add(config.Scope.applicationOverride, "workload", "randomization.repeat_frequency", args.randomization_repeat_frequency) + cfg.add(config.Scope.applicationOverride, "workload", "randomization.n", args.randomization_n) configure_workload_params(arg_parser, args, cfg) configure_connection_params(arg_parser, args, cfg) configure_telemetry_params(args, cfg) diff --git a/osbenchmark/workload/loader.py b/osbenchmark/workload/loader.py index d9a99bb1f..4efc4822f 100644 --- a/osbenchmark/workload/loader.py +++ b/osbenchmark/workload/loader.py @@ -25,6 +25,7 @@ import json import logging import os +import random import re import sys import tempfile @@ -48,7 +49,7 @@ class WorkloadSyntaxError(exceptions.InvalidSyntax): class WorkloadProcessor: - def on_after_load_workload(self, workload): + def on_after_load_workload(self, input_workload, **kwargs): """ This method is called by Benchmark after a workload has been loaded. Implementations are expected to modify the provided workload object in place. @@ -74,7 +75,7 @@ def on_prepare_workload(self, workload, data_root_dir): class WorkloadProcessorRegistry: def __init__(self, cfg): - self.required_processors = [TaskFilterWorkloadProcessor(cfg), TestModeWorkloadProcessor(cfg)] + self.required_processors = [TaskFilterWorkloadProcessor(cfg), TestModeWorkloadProcessor(cfg), QueryRandomizerWorkloadProcessor(cfg)] self.workload_processors = [] self.offline = cfg.opts("system", "offline.mode") self.test_mode = cfg.opts("workload", "test.mode.enabled", mandatory=False, default_value=False) @@ -824,11 +825,11 @@ def _filter_out_match(self, task): return self.exclude return not self.exclude - def on_after_load_workload(self, workload): + def on_after_load_workload(self, input_workload, **kwargs): if not self.filters: - return workload + return input_workload - for test_procedure in workload.test_procedures: + for test_procedure in input_workload.test_procedures: # don't modify the schedule while iterating over it tasks_to_remove = [] for task in test_procedure.schedule: @@ -847,7 +848,7 @@ def on_after_load_workload(self, workload): self.logger.info("Removing task [%s] from test_procedure [%s] due to task filter.", task, test_procedure) test_procedure.remove_task(task) - return workload + return input_workload class TestModeWorkloadProcessor(WorkloadProcessor): @@ -855,11 +856,11 @@ def __init__(self, cfg): self.test_mode_enabled = cfg.opts("workload", "test.mode.enabled", mandatory=False, default_value=False) self.logger = logging.getLogger(__name__) - def on_after_load_workload(self, workload): + def on_after_load_workload(self, input_workload, **kwargs): if not self.test_mode_enabled: - return workload - self.logger.info("Preparing workload [%s] for test mode.", str(workload)) - for corpus in workload.corpora: + return input_workload + self.logger.info("Preparing workload [%s] for test mode.", str(input_workload)) + for corpus in input_workload.corpora: if self.logger.isEnabledFor(logging.DEBUG): self.logger.debug("Reducing corpus size to 1000 documents for [%s]", corpus.name) for document_set in corpus.documents: @@ -884,7 +885,7 @@ def on_after_load_workload(self, workload): document_set.compressed_size_in_bytes = None document_set.uncompressed_size_in_bytes = None - for test_procedure in workload.test_procedures: + for test_procedure in input_workload.test_procedures: for task in test_procedure.schedule: # we need iterate over leaf tasks and await iterating over possible intermediate 'parallel' elements for leaf_task in task: @@ -918,8 +919,189 @@ def on_after_load_workload(self, workload): leaf_task.params.pop("target-interval", None) leaf_task.params["target-throughput"] = f"{sys.maxsize} {original_throughput.unit}" - return workload + return input_workload +class QueryRandomizerWorkloadProcessor(WorkloadProcessor): + DEFAULT_RF = 0.3 + DEFAULT_N = 5000 + def __init__(self, cfg): + self.randomization_enabled = cfg.opts("workload", "randomization.enabled", mandatory=False, default_value=False) + self.rf = float(cfg.opts("workload", "randomization.repeat_frequency", mandatory=False, default_value=self.DEFAULT_RF)) + self.logger = logging.getLogger(__name__) + self.N = int(cfg.opts("workload", "randomization.n", mandatory=False, default_value=self.DEFAULT_N)) + self.zipf_alpha = 1 + self.H_list = self.precompute_H(self.N, self.zipf_alpha) + + # Helper functions for computing Zipf distribution + def H(self, i, H_list): + # compute the harmonic number H_n,m = sum over i from 1 to n of (1 / i^m) + return H_list[i-1] + + def precompute_H(self, n, m): + H_list = [1] + for j in range(2, n+1): + H_list.append(H_list[-1] + 1 / (j ** m)) + return H_list + + def zipf_cdf_inverse(self, u, H_list): + # To map a uniformly distributed u from [0, 1] to some probability distribution we plug it into its inverse CDF. + # as the zipf cdf is discontinuous there is no real inverse but we can use this solution: + # https://math.stackexchange.com/questions/53671/how-to-calculate-the-inverse-cdf-for-the-zipf-distribution + # Precompute all values H_i,alpha for a fixed alpha and pass in as H_list + if (u < 0 or u >= 1): + raise exceptions.ExecutorError( + "Input u must have 0 <= u < 1. This error shouldn't appear, please raise an issue if it does") + n = len(H_list) + candidate_return = 1 + denominator = self.H(n, H_list) + numerator = 0 + while candidate_return < n: + numerator = self.H(candidate_return, H_list) + if u < numerator / denominator: + return candidate_return + candidate_return += 1 + return n + + def get_dict_from_previous_path(self, root, current_path): + curr = root + for value in current_path: + curr = curr[value] + return curr + + def extract_fields_helper(self, root, current_path): + # Recursively called to find the location of ranges in an OpenSearch range query. + # Return the field and the current path if we're currently scanning the field name in a range query, otherwise return an empty list. + fields = [] # pairs of (field, path_to_field) + curr = self.get_dict_from_previous_path(root, current_path) + if isinstance(curr, dict) and curr != {}: + if len(current_path) > 0 and current_path[-1] == "range": + for key in curr.keys(): + if isinstance(curr, dict): + if ("gte" in curr[key] or "gt" in curr[key]) and ("lte" in curr[key] or "lt" in curr[key]): + fields.append((key, current_path)) + return fields + else: + for key in curr.keys(): + fields += self.extract_fields_helper(root, current_path + [key]) + return fields + elif isinstance(curr, list) and curr != []: + for i in range(len(curr)): + fields += self.extract_fields_helper(root, current_path + [i]) + return fields + else: + # leaf node + return [] + + def extract_fields_and_paths(self, params): + # Search for fields used in range queries, and the paths to those fields + # Return pairs of (field, path_to_field) + # TODO: Maybe only do this the first time, and assume for a given task, the same query structure is used. + # We could achieve this by passing in the task name to get_randomized_values as a kwarg? + try: + root = params["body"]["query"] + except KeyError: + raise exceptions.SystemSetupError( + f"Cannot extract range query fields from these params: {params}\n, missing params[\"body\"][\"query\"]\n" + f"Make sure the operation in operations/default.json is well-formed") + fields_and_paths = self.extract_fields_helper(root, []) + return fields_and_paths + + def set_range(self, params, fields_and_paths, new_values): + assert len(fields_and_paths) == len(new_values) + for field_and_path, new_value in zip(fields_and_paths, new_values): + field = field_and_path[0] + path = field_and_path[1] + range_section = self.get_dict_from_previous_path(params["body"]["query"], path)[field] + # get the section of the query corresponding to the field name + for greater_than in ["gte", "gt"]: + if greater_than in range_section: + range_section[greater_than] = new_value["gte"] + for less_than in ["lte", "lt"]: + if less_than in range_section: + range_section[less_than] = new_value["lte"] + if "format" in new_values: + range_section["format"] = new_values["format"] + return params + + def get_repeated_value_index(self): + # minus 1 for mapping [1, N] to [0, N-1] of list indices + return self.zipf_cdf_inverse(random.random(), self.H_list) - 1 + + def get_randomized_values(self, input_workload, input_params, + get_standard_value=params.get_standard_value, + get_standard_value_source=params.get_standard_value_source, # Made these configurable for simpler unit tests + **kwargs): + + # The queries as listed in operations/default.json don't have the index param, + # unlike the custom ones you would specify in workload.py, so we have to add them ourselves + if not "index" in input_params: + input_params["index"] = params.get_target(input_workload, input_params) + + fields_and_paths = self.extract_fields_and_paths(input_params) + + if random.random() < self.rf: + # Draw a potentially repeated value from the saved standard values + index = self.get_repeated_value_index() + new_values = [get_standard_value(kwargs["op_name"], field_and_path[0], index) for field_and_path in fields_and_paths] + # Use the same index for all fields in one query, otherwise the probability of repeats in a multi-field query would be very low + input_params = self.set_range(input_params, fields_and_paths, new_values) + else: + # Generate a new random value, from the standard value source function. This will be new (a cache miss) + new_values = [get_standard_value_source(kwargs["op_name"], field_and_path[0])() for field_and_path in fields_and_paths] + input_params = self.set_range(input_params, fields_and_paths, new_values) + return input_params + + def create_param_source_lambda(self, op_name, get_standard_value, get_standard_value_source): + return lambda w, p, **kwargs: self.get_randomized_values(w, p, + get_standard_value=get_standard_value, + get_standard_value_source=get_standard_value_source, + op_name=op_name, **kwargs) + + def on_after_load_workload(self, input_workload, **kwargs): + if not self.randomization_enabled: + self.logger.info("Query randomization is disabled.") + return input_workload + self.logger.info("Query randomization is enabled, with repeat frequency = %d, n = %d",self.rf, self.N) + + # By default, use params for standard values and generate new standard values the first time an op/field is seen. + # In unit tests, we should be able to supply our own sources independent of params. + # This is done in kwargs because pylint didn't like having specific keyword args that weren't in the parent method. + generate_new_standard_values = False + if "get_standard_value" not in kwargs: + kwargs["get_standard_value"] = params.get_standard_value + generate_new_standard_values = True + if "get_standard_value_source" not in kwargs: + kwargs["get_standard_value_source"] = params.get_standard_value_source + generate_new_standard_values = True + + default_test_procedure = None + for test_procedure in input_workload.test_procedures: + if test_procedure.default: + default_test_procedure = test_procedure + break + + for task in default_test_procedure.schedule: + for leaf_task in task: + try: + op_type = workload.OperationType.from_hyphenated_string(leaf_task.operation.type) + except KeyError: + op_type = None + self.logger.info( + "Found operation %s in default schedule with type %s, which couldn't be converted to a known OperationType", + leaf_task.operation.name, leaf_task.operation.type) + if op_type == workload.OperationType.Search: + op_name = leaf_task.operation.name + param_source_name = op_name + "-randomized" + params.register_param_source_for_name( + param_source_name, + self.create_param_source_lambda(op_name, get_standard_value=kwargs["get_standard_value"], + get_standard_value_source=kwargs["get_standard_value_source"])) + leaf_task.operation.param_source = param_source_name + # Generate the right number of standard values for this field, if not already present + for field_and_path in self.extract_fields_and_paths(leaf_task.operation.params): + if generate_new_standard_values: + params.generate_standard_values_if_absent(op_name, field_and_path[0], self.N) + return input_workload class CompleteWorkloadParams: def __init__(self, user_specified_workload_params=None): @@ -1097,6 +1279,10 @@ def register_workload_processor(self, workload_processor): if self.workload_processor_registry: self.workload_processor_registry(workload_processor) + def register_standard_value_source(self, op_name, field_name, standard_value_source): + # Define a value source for parameters for a given operation name and field name, for use in randomization + params.register_standard_value_source(op_name, field_name, standard_value_source) + @property def meta_data(self): return { diff --git a/osbenchmark/workload/params.py b/osbenchmark/workload/params.py index c0daef588..dba4ffc2d 100644 --- a/osbenchmark/workload/params.py +++ b/osbenchmark/workload/params.py @@ -46,6 +46,8 @@ __PARAM_SOURCES_BY_OP = {} __PARAM_SOURCES_BY_NAME = {} +__STANDARD_VALUE_SOURCES = {} +__STANDARD_VALUES = {} def param_source_for_operation(op_type, workload, params, task_name): try: @@ -63,6 +65,14 @@ def param_source_for_name(name, workload, params): else: return param_source(workload, params) +def get_standard_value_source(op_name, field_name): + try: + return __STANDARD_VALUE_SOURCES[op_name][field_name] + except KeyError: + raise exceptions.SystemSetupError( + "Could not find standard value source for operation {}, field {}! Make sure this is registered in workload.py" + .format(op_name, field_name)) + def ensure_valid_param_source(param_source): if not inspect.isfunction(param_source) and not inspect.isclass(param_source): @@ -78,6 +88,37 @@ def register_param_source_for_name(name, param_source_class): ensure_valid_param_source(param_source_class) __PARAM_SOURCES_BY_NAME[name] = param_source_class +def register_standard_value_source(op_name, field_name, standard_value_source): + if op_name in __STANDARD_VALUE_SOURCES: + __STANDARD_VALUE_SOURCES[op_name][field_name] = standard_value_source + # We have to allow re-registration for the same op/field, since plugins are loaded many times when a workload is run + else: + __STANDARD_VALUE_SOURCES[op_name] = {field_name:standard_value_source} + +def generate_standard_values_if_absent(op_name, field_name, n): + if not op_name in __STANDARD_VALUES: + __STANDARD_VALUES[op_name] = {} + if not field_name in __STANDARD_VALUES[op_name]: + __STANDARD_VALUES[op_name][field_name] = [] + try: + standard_value_source = __STANDARD_VALUE_SOURCES[op_name][field_name] + except KeyError: + raise exceptions.SystemSetupError( + "Cannot generate standard values for operation {}, field {}. Standard value source is missing" + .format(op_name, field_name)) + for _i in range(n): + __STANDARD_VALUES[op_name][field_name].append(standard_value_source()) + +def get_standard_value(op_name, field_name, i): + try: + return __STANDARD_VALUES[op_name][field_name][i] + except KeyError: + raise exceptions.SystemSetupError("No standard values generated for operation {}, field {}".format(op_name, field_name)) + except IndexError: + raise exceptions.SystemSetupError( + "Standard value index {} out of range for operation {}, field name {} ({} values total)" + .format(i, op_name, field_name, len(__STANDARD_VALUES[op_name][field_name]))) + # only intended for tests def _unregister_param_source_for_name(name): @@ -85,6 +126,10 @@ def _unregister_param_source_for_name(name): # something is fishy with the test and we'd rather know early. __PARAM_SOURCES_BY_NAME.pop(name) +# only intended for tests +def _clear_standard_values(): + __STANDARD_VALUES = {} + __STANDARD_VALUE_SOURCES = {} # Default class ParamSource: diff --git a/tests/workload/loader_test.py b/tests/workload/loader_test.py index f3c7554a3..6696814bd 100644 --- a/tests/workload/loader_test.py +++ b/tests/workload/loader_test.py @@ -1682,6 +1682,305 @@ def test_unmatched_include_runs_nothing(self): schedule = filtered.test_procedures[0].schedule self.assertEqual(expected_schedule, schedule) +class WorkloadRandomizationTests(TestCase): + + # Helper class used to set up queries with mock standard values for testing + # We want >1 op to ensure logic for giving different ops their own lambdas is working properly + class StandardValueHelper: + def __init__(self): + self.op_name_1 = "op-name-1" + self.op_name_2 = "op-name-2" + self.field_name_1 = "dummy_field_1" + self.field_name_2 = "dummy_field_2" + self.index_name = "dummy_index" + + # Make the saved standard values different from the functions generating the new values, + # to be able to distinguish when we generate a new value vs draw an "existing" one. + # in actual usage, these would come from the same function with some randomness in it + self.saved_values = { + self.op_name_1:{ + self.field_name_1:{"lte":40, "gte":30}, + self.field_name_2:{"lte":"06/06/2016", "gte":"05/05/2016", "format":"dd/MM/yyyy"} + }, + self.op_name_2:{ + self.field_name_1:{"lte":11, "gte":10} + } + } + + # Used to generate new values, in the source function + self.new_values = { + self.op_name_1:{ + self.field_name_1:{"lte":41, "gte":31}, + self.field_name_2:{"lte":"04/04/2016", "gte":"03/03/2016", "format":"dd/MM/yyyy"} + }, + self.op_name_2:{ + self.field_name_1:{"lte":15, "gte":14}, + } + } + + self.op_1_query = { + "name": self.op_name_1, + "operation-type": "search", + "body": { + "size": 0, + "query": { + "bool": { + "filter": { + "range": { + self.field_name_1: { + "lt": 50, + "gte": 0 + } + }, + "must": [ + { + "range": { + self.field_name_2: { + "gte": "01/01/2015", + "lte": "21/01/2015", + "format": "dd/MM/yyyy" + } + } + } + ] + } + } + } + } + } + + self.op_2_query = { + "name": self.op_name_2, + "operation-type": "search", + "body": { + "size": 0, + "query": { + "range": { + self.field_name_1: { + "lt": 50, + "gte": 0 + } + } + } + } + } + + def get_simple_workload(self): + # Modified from test_filters_tasks + workload_specification = { + "description": "description for unit test", + "indices": [{"name": self.index_name, "auto-managed": False}], + "operations": [ + { + "name": "create-index", + "operation-type": "create-index" + }, + self.op_1_query, + self.op_2_query + ], + "test_procedures": [ + { + "name": "default-test_procedure", + "schedule": [ + { + "operation": "create-index" + }, + { + "name": "dummy-task-name-1", + "operation": self.op_name_1, + }, + { + "name": "dummy-task-name-2", + "operation": self.op_name_2, + }, + ] + } + ] + } + reader = loader.WorkloadSpecificationReader() + full_workload = reader("unittest", workload_specification, "/mappings") + return full_workload + + def get_standard_value_source(self, op_name, field_name): + # Passed to the processor, to be able to find the standard value sources for all ops/fields. + # The actual source functions for the op/field pairs, which in a real application + # would be defined in the workload's workload.py and involve some randomization + return lambda: self.new_values[op_name][field_name] + + def get_standard_value(self, op_name, field_name, index): + # Passed to the processor, to be able to retrive the saved standard values for all ops/fields. + return self.saved_values[op_name][field_name] + + def test_range_finding_function(self): + cfg = config.Config() + processor = loader.QueryRandomizerWorkloadProcessor(cfg) + single_range_query = { + "name": "distance_amount_agg", + "operation-type": "search", + "body": { + "size": 0, + "query": { + "bool": { + "filter": { + "range": { + "trip_distance": { + "lt": 50, + "gte": 0 + } + } + } + } + } + } + } + single_range_query_result = processor.extract_fields_and_paths(single_range_query) + single_range_query_expected = [("trip_distance", ["bool", "filter", "range"])] + self.assertEqual(single_range_query_result, single_range_query_expected) + + multiple_nested_range_query = { + "name": "date_histogram_agg", + "operation-type": "search", + "body": { + "size": 0, + "query": { + "range": { + "dropoff_datetime": { + "gte": "01/01/2015", + "lte": "21/01/2015", + "format": "dd/MM/yyyy" + } + }, + "bool": { + "filter": { + "range": { + "dummy_field": { + "lte": 50, + "gt": 0 + } + } + }, + "must": [ + { + "range": { + "dummy_field_2": { + "gte": "1998-05-01T00:00:00Z", + "lt": "1998-05-02T00:00:00Z" + } + } + }, + { + "match": { + "status": "400" + } + }, + { + "range": { + "dummy_field_3": { + "gt": 10, + "lt": 11 + } + } + } + ] + } + } + } + } + multiple_nested_range_query_result = processor.extract_fields_and_paths(multiple_nested_range_query) + print("Multi result: ", multiple_nested_range_query_result) + multiple_nested_range_query_expected = [ + ("dropoff_datetime", ["range"]), + ("dummy_field", ["bool", "filter", "range"]), + ("dummy_field_2", ["bool", "must", 0, "range"]), + ("dummy_field_3", ["bool", "must", 2, "range"]) + ] + self.assertEqual(multiple_nested_range_query_result, multiple_nested_range_query_expected) + + with self.assertRaises(exceptions.SystemSetupError) as ctx: + params = {"body":{"contents":["not_a_valid_query"]}} + _ = processor.extract_fields_and_paths(params) + self.assertEqual( + f"Cannot extract range query fields from these params: {params}\n, missing params[\"body\"][\"query\"]\n" + f"Make sure the operation in operations/default.json is well-formed", + ctx.exception.args[0]) + + def test_get_randomized_values(self): + helper = self.StandardValueHelper() + + for rf, expected_values_dict in zip([1.0, 0.0], [helper.saved_values, helper.new_values]): + # first test where we always draw a saved value, not a new random one + # next test where we always draw a new random value. We've made them distinct, to be able to tell which codepath is taken + cfg = config.Config() + cfg.add(config.Scope.application, "workload", "randomization.repeat_frequency", rf) + processor = loader.QueryRandomizerWorkloadProcessor(cfg) + self.assertAlmostEqual(processor.rf, rf) + + # Test resulting params for operation 1 + workload = helper.get_simple_workload() + modified_params = processor.get_randomized_values(workload, helper.op_1_query, op_name=helper.op_name_1, + get_standard_value=helper.get_standard_value, + get_standard_value_source=helper.get_standard_value_source) + modified_range_1 = modified_params["body"]["query"]["bool"]["filter"]["range"][helper.field_name_1] + modified_range_2 = modified_params["body"]["query"]["bool"]["filter"]["must"][0]["range"][helper.field_name_2] + self.assertEqual(modified_range_1["lt"], expected_values_dict[helper.op_name_1][helper.field_name_1]["lte"]) + # Note it should keep whichever of lt/lte it found in the original query + self.assertEqual(modified_range_1["gte"], expected_values_dict[helper.op_name_1][helper.field_name_1]["gte"]) + + self.assertEqual(modified_range_2["lte"], expected_values_dict[helper.op_name_1][helper.field_name_2]["lte"]) + self.assertEqual(modified_range_2["gte"], expected_values_dict[helper.op_name_1][helper.field_name_2]["gte"]) + self.assertEqual(modified_range_2["format"], expected_values_dict[helper.op_name_1][helper.field_name_2]["format"]) + + self.assertEqual(modified_params["index"], helper.index_name) + + # Test resulting params for operation 2 + workload = helper.get_simple_workload() + modified_params = processor.get_randomized_values(workload, helper.op_2_query, op_name=helper.op_name_2, + get_standard_value=helper.get_standard_value, + get_standard_value_source=helper.get_standard_value_source) + modified_range_1 = modified_params["body"]["query"]["range"][helper.field_name_1] + + self.assertEqual(modified_range_1["lt"], expected_values_dict[helper.op_name_2][helper.field_name_1]["lte"]) + self.assertEqual(modified_range_1["gte"], expected_values_dict[helper.op_name_2][helper.field_name_1]["gte"]) + self.assertEqual(modified_params["index"], helper.index_name) + + + def test_on_after_load_workload(self): + cfg = config.Config() + processor = loader.QueryRandomizerWorkloadProcessor(cfg) + # Do nothing with default config as randomization.enabled is false + helper = self.StandardValueHelper() + input_workload = helper.get_simple_workload() + self.assertEqual( + repr(input_workload), + repr(processor.on_after_load_workload(input_workload, get_standard_value=helper.get_standard_value, + get_standard_value_source=helper.get_standard_value_source))) + # It seems that comparing the workloads directly will incorrectly call them equal, even if they have differences, + # so compare their string representations instead + + cfg = config.Config() + cfg.add(config.Scope.application, "workload", "randomization.enabled", True) + processor = loader.QueryRandomizerWorkloadProcessor(cfg) + self.assertEqual(processor.randomization_enabled, True) + self.assertEqual(processor.N, loader.QueryRandomizerWorkloadProcessor.DEFAULT_N) + self.assertEqual(type(processor.N), int) + self.assertEqual(processor.rf, loader.QueryRandomizerWorkloadProcessor.DEFAULT_RF) + self.assertEqual(type(processor.rf), float) + input_workload = helper.get_simple_workload() + self.assertNotEqual( + repr(input_workload), + repr(processor.on_after_load_workload(input_workload, get_standard_value=helper.get_standard_value, + get_standard_value_source=helper.get_standard_value_source))) + for test_procedure in input_workload.test_procedures: + for task in test_procedure.schedule: + for leaf_task in task: + try: + op_type = workload.OperationType.from_hyphenated_string(leaf_task.operation.type) + except KeyError: + op_type = None + if op_type == workload.OperationType.Search: + self.assertIsNotNone(leaf_task.operation.param_source) + + # pylint: disable=too-many-public-methods class WorkloadSpecificationReaderTests(TestCase): @@ -3558,6 +3857,7 @@ def test_default_workload_processors(self): expected_defaults = [ loader.TaskFilterWorkloadProcessor, loader.TestModeWorkloadProcessor, + loader.QueryRandomizerWorkloadProcessor, loader.DefaultWorkloadPreparator ] actual_defaults = [proc.__class__ for proc in tpr.processors] @@ -3573,6 +3873,7 @@ def test_override_default_preparator(self): expected_processors = [ loader.TaskFilterWorkloadProcessor, loader.TestModeWorkloadProcessor, + loader.QueryRandomizerWorkloadProcessor, MyMockWorkloadProcessor ] actual_processors = [proc.__class__ for proc in tpr.processors] @@ -3589,6 +3890,7 @@ def test_allow_to_specify_default_preparator(self): expected_processors = [ loader.TaskFilterWorkloadProcessor, loader.TestModeWorkloadProcessor, + loader.QueryRandomizerWorkloadProcessor, MyMockWorkloadProcessor, loader.DefaultWorkloadPreparator ] diff --git a/tests/workload/params_test.py b/tests/workload/params_test.py index 5188a9645..73a740783 100644 --- a/tests/workload/params_test.py +++ b/tests/workload/params_test.py @@ -1402,6 +1402,62 @@ def test_cannot_register_an_instance_as_param_source(self): "Parameter source \\[test param source\\] must be either a function or a class\\."): params.register_param_source_for_name(source_name, ParamsRegistrationTests.ParamSourceClass()) +class StandardValueSourceRegistrationTests(TestCase): + def get_mock_standard_value_source(self, gte, lte): + return lambda : {"gte":gte, "lte":lte} + + def test_register_standard_value_source(self): + # Test the sequence: register standard value source -> generate saved standard values + # -> retrieve those values or generate new values from source + op_name = "op-1" + field_name_1 = "field-1" + field_name_2 = "field-2" + n = 100 + + gte_field_1 = 0 + lte_field_1 = 1 + gte_field_2 = 2 + lte_field_2 = 3 + + params._clear_standard_values() + + params.register_standard_value_source(op_name, field_name_1, self.get_mock_standard_value_source(gte_field_1, lte_field_1)) + + self.assertEqual(params.get_standard_value_source(op_name, field_name_1)(), {"gte":gte_field_1, "lte":lte_field_1}) + + with self.assertRaises(exceptions.SystemSetupError) as ctx: + _ = params.get_standard_value_source(op_name, field_name_2) + self.assertEqual( + "Could not find standard value source for operation {}, field {}! Make sure this is registered in workload.py" + .format(op_name, field_name_2), ctx.exception.args[0]) + + with self.assertRaises(exceptions.SystemSetupError) as ctx: + _ = params.get_standard_value(op_name, field_name_1, 0) + self.assertEqual("No standard values generated for operation {}, field {}".format(op_name, field_name_1), ctx.exception.args[0]) + + params.generate_standard_values_if_absent(op_name, field_name_1, n) + self.assertEqual(params.get_standard_value(op_name, field_name_1, 0), {"gte":gte_field_1, "lte":lte_field_1}) + + # check that running generate_standard_values_if_absent on the same inputs does nothing + # we can do this by telling it to generate 2*n, but it won't because values are already present + params.generate_standard_values_if_absent(op_name, field_name_1, 2*n) + with self.assertRaises(exceptions.SystemSetupError) as ctx: + _ = params.get_standard_value(op_name, field_name_1, n + 1) + self.assertEqual( + "Standard value index {} out of range for operation {}, field name {} ({} values total)" + .format(n+1, op_name, field_name_1, n), ctx.exception.args[0]) + + with self.assertRaises(exceptions.SystemSetupError) as ctx: + params.generate_standard_values_if_absent(op_name, field_name_2, n) + self.assertEqual( + "Cannot generate standard values for operation {}, field {}. Standard value source is missing" + .format(op_name, field_name_2), ctx.exception.args[0]) + + params.register_standard_value_source(op_name, field_name_2, self.get_mock_standard_value_source(gte_field_2, lte_field_2)) + self.assertEqual(params.get_standard_value_source(op_name, field_name_2)(), {"gte":gte_field_2, "lte":lte_field_2}) + self.assertEqual(params.get_standard_value_source(op_name, field_name_1)(), {"gte":gte_field_1, "lte":lte_field_1}) + + params._clear_standard_values() class SleepParamSourceTests(TestCase): def test_missing_duration_parameter(self):