diff --git a/setup.cfg b/setup.cfg index d4e57b23a..48b42548d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -9,3 +9,4 @@ addopts = --verbose --color=yes testpaths = tests junit_family = xunit2 junit_logging = all +asyncio_mode = strict diff --git a/setup.py b/setup.py index 55df6a34a..15faebd20 100644 --- a/setup.py +++ b/setup.py @@ -91,7 +91,7 @@ def str_from_file(name): "boto3==1.10.32", ] -tests_require = ["ujson", "pytest==6.2.5", "pytest-benchmark==3.2.2", "pytest-asyncio==0.16.0"] +tests_require = ["ujson", "pytest==6.2.5", "pytest-benchmark==3.2.2", "pytest-asyncio==0.18.1"] # These packages are only required when developing Rally develop_require = [ diff --git a/tests/driver/driver_test.py b/tests/driver/driver_test.py index cdf8a0a90..7e27b3427 100644 --- a/tests/driver/driver_test.py +++ b/tests/driver/driver_test.py @@ -22,9 +22,9 @@ import time import unittest.mock as mock from datetime import datetime -from unittest import TestCase import elasticsearch +import pytest from esrally import config, exceptions, metrics, track from esrally.driver import driver, runner, scheduler @@ -58,24 +58,19 @@ def params(self): return self._params -class DriverTests(TestCase): +class TestDriver: class Holder: def __init__(self, all_hosts=None, all_client_options=None): self.all_hosts = all_hosts self.all_client_options = all_client_options self.uses_static_responses = False - def __init__(self, methodName="runTest"): - super().__init__(methodName) - self.cfg = None - self.track = None - class StaticClientFactory: PATCHER = None def __init__(self, *args, **kwargs): - DriverTests.StaticClientFactory.PATCHER = mock.patch("elasticsearch.Elasticsearch") - self.es = DriverTests.StaticClientFactory.PATCHER.start() + TestDriver.StaticClientFactory.PATCHER = mock.patch("elasticsearch.Elasticsearch") + self.es = TestDriver.StaticClientFactory.PATCHER.start() self.es.indices.stats.return_value = {"mocked": True} self.es.cat.master.return_value = {"mocked": True} @@ -84,9 +79,9 @@ def create(self): @classmethod def close(cls): - DriverTests.StaticClientFactory.PATCHER.stop() + TestDriver.StaticClientFactory.PATCHER.stop() - def setUp(self): + def setup_method(self, method): self.cfg = config.Config() self.cfg.add(config.Scope.application, "system", "env.name", "unittest") self.cfg.add(config.Scope.application, "system", "time.start", datetime(year=2017, month=8, day=20, hour=1, minute=0, second=0)) @@ -100,8 +95,8 @@ def setUp(self): self.cfg.add(config.Scope.application, "telemetry", "params", {}) self.cfg.add(config.Scope.application, "mechanic", "car.names", ["default"]) self.cfg.add(config.Scope.application, "mechanic", "skip.rest.api.check", True) - self.cfg.add(config.Scope.application, "client", "hosts", DriverTests.Holder(all_hosts={"default": ["localhost:9200"]})) - self.cfg.add(config.Scope.application, "client", "options", DriverTests.Holder(all_client_options={"default": {}})) + self.cfg.add(config.Scope.application, "client", "hosts", self.Holder(all_hosts={"default": ["localhost:9200"]})) + self.cfg.add(config.Scope.application, "client", "options", self.Holder(all_client_options={"default": {}})) self.cfg.add(config.Scope.application, "driver", "load_driver_hosts", ["localhost"]) self.cfg.add(config.Scope.application, "reporting", "datastore.type", "in-memory") @@ -113,8 +108,8 @@ def setUp(self): another_challenge = track.Challenge("other", default=False) self.track = track.Track(name="unittest", description="unittest track", challenges=[another_challenge, default_challenge]) - def tearDown(self): - DriverTests.StaticClientFactory.close() + def teardown_method(self): + self.StaticClientFactory.close() def create_test_driver_target(self): client = "client_marker" @@ -128,7 +123,7 @@ def test_start_benchmark_and_prepare_track(self, resolve): resolve.side_effect = ["10.5.5.1", "10.5.5.2"] target = self.create_test_driver_target() - d = driver.Driver(target, self.cfg, es_client_factory_class=DriverTests.StaticClientFactory) + d = driver.Driver(target, self.cfg, es_client_factory_class=self.StaticClientFactory) d.prepare_benchmark(t=self.track) target.prepare_track.assert_called_once_with(["10.5.5.1", "10.5.5.2"], self.cfg, self.track) @@ -144,11 +139,11 @@ def test_start_benchmark_and_prepare_track(self, resolve): ) # Did we start all load generators? There is no specific mock assert for this... - self.assertEqual(4, target.start_worker.call_count) + assert target.start_worker.call_count == 4 def test_assign_drivers_round_robin(self): target = self.create_test_driver_target() - d = driver.Driver(target, self.cfg, es_client_factory_class=DriverTests.StaticClientFactory) + d = driver.Driver(target, self.cfg, es_client_factory_class=self.StaticClientFactory) d.prepare_benchmark(t=self.track) @@ -166,34 +161,34 @@ def test_assign_drivers_round_robin(self): ) # Did we start all load generators? There is no specific mock assert for this... - self.assertEqual(4, target.start_worker.call_count) + assert target.start_worker.call_count == 4 def test_client_reaches_join_point_others_still_executing(self): target = self.create_test_driver_target() - d = driver.Driver(target, self.cfg, es_client_factory_class=DriverTests.StaticClientFactory) + d = driver.Driver(target, self.cfg, es_client_factory_class=self.StaticClientFactory) d.prepare_benchmark(t=self.track) d.start_benchmark() - self.assertEqual(0, len(d.workers_completed_current_step)) + assert len(d.workers_completed_current_step) == 0 d.joinpoint_reached( worker_id=0, worker_local_timestamp=10, task_allocations=[driver.ClientAllocation(client_id=0, task=driver.JoinPoint(id=0))] ) - self.assertEqual(1, len(d.workers_completed_current_step)) + assert len(d.workers_completed_current_step) == 1 - self.assertEqual(0, target.on_task_finished.call_count) - self.assertEqual(0, target.drive_at.call_count) + assert target.on_task_finished.call_count == 0 + assert target.drive_at.call_count == 0 def test_client_reaches_join_point_which_completes_parent(self): target = self.create_test_driver_target() - d = driver.Driver(target, self.cfg, es_client_factory_class=DriverTests.StaticClientFactory) + d = driver.Driver(target, self.cfg, es_client_factory_class=self.StaticClientFactory) d.prepare_benchmark(t=self.track) d.start_benchmark() - self.assertEqual(0, len(d.workers_completed_current_step)) + assert len(d.workers_completed_current_step) == 0 d.joinpoint_reached( worker_id=0, @@ -201,10 +196,10 @@ def test_client_reaches_join_point_which_completes_parent(self): task_allocations=[driver.ClientAllocation(client_id=0, task=driver.JoinPoint(id=0, clients_executing_completing_task=[0]))], ) - self.assertEqual(-1, d.current_step) - self.assertEqual(1, len(d.workers_completed_current_step)) + assert d.current_step == -1 + assert len(d.workers_completed_current_step) == 1 # notified all drivers that they should complete the current task ASAP - self.assertEqual(4, target.complete_current_task.call_count) + assert target.complete_current_task.call_count == 4 # awaiting responses of other clients d.joinpoint_reached( @@ -213,16 +208,16 @@ def test_client_reaches_join_point_which_completes_parent(self): task_allocations=[driver.ClientAllocation(client_id=1, task=driver.JoinPoint(id=0, clients_executing_completing_task=[0]))], ) - self.assertEqual(-1, d.current_step) - self.assertEqual(2, len(d.workers_completed_current_step)) + assert d.current_step == -1 + assert len(d.workers_completed_current_step) == 2 d.joinpoint_reached( worker_id=2, worker_local_timestamp=12, task_allocations=[driver.ClientAllocation(client_id=2, task=driver.JoinPoint(id=0, clients_executing_completing_task=[0]))], ) - self.assertEqual(-1, d.current_step) - self.assertEqual(3, len(d.workers_completed_current_step)) + assert d.current_step == -1 + assert len(d.workers_completed_current_step) == 3 d.joinpoint_reached( worker_id=3, @@ -231,20 +226,20 @@ def test_client_reaches_join_point_which_completes_parent(self): ) # by now the previous step should be considered completed and we are at the next one - self.assertEqual(0, d.current_step) - self.assertEqual(0, len(d.workers_completed_current_step)) + assert d.current_step == 0 + assert len(d.workers_completed_current_step) == 0 # this requires at least Python 3.6 # target.on_task_finished.assert_called_once() - self.assertEqual(1, target.on_task_finished.call_count) - self.assertEqual(4, target.drive_at.call_count) + assert target.on_task_finished.call_count == 1 + assert target.drive_at.call_count == 4 def op(name, operation_type): return track.Operation(name, operation_type, param_source="driver-test-param-source") -class SamplePostprocessorTests(TestCase): +class TestSamplePostprocessor: def throughput(self, absolute_time, relative_time, value): return mock.call( name="throughput", @@ -372,7 +367,7 @@ def test_dependent_samples(self, metrics_store): metrics_store.put_value_cluster_level.assert_has_calls(calls) -class WorkerAssignmentTests(TestCase): +class TestWorkerAssignment: def test_single_host_assignment_clients_matches_cores(self): host_configs = [ { @@ -383,20 +378,17 @@ def test_single_host_assignment_clients_matches_cores(self): assignments = driver.calculate_worker_assignments(host_configs, client_count=4) - self.assertEqual( - [ - { - "host": "localhost", - "workers": [ - [0], - [1], - [2], - [3], - ], - } - ], - assignments, - ) + assert assignments == [ + { + "host": "localhost", + "workers": [ + [0], + [1], + [2], + [3], + ], + } + ] def test_single_host_assignment_more_clients_than_cores(self): host_configs = [ @@ -408,20 +400,17 @@ def test_single_host_assignment_more_clients_than_cores(self): assignments = driver.calculate_worker_assignments(host_configs, client_count=6) - self.assertEqual( - [ - { - "host": "localhost", - "workers": [ - [0, 1], - [2, 3], - [4], - [5], - ], - } - ], - assignments, - ) + assert assignments == [ + { + "host": "localhost", + "workers": [ + [0, 1], + [2, 3], + [4], + [5], + ], + } + ] def test_single_host_assignment_less_clients_than_cores(self): host_configs = [ @@ -433,20 +422,17 @@ def test_single_host_assignment_less_clients_than_cores(self): assignments = driver.calculate_worker_assignments(host_configs, client_count=2) - self.assertEqual( - [ - { - "host": "localhost", - "workers": [ - [0], - [1], - [], - [], - ], - } - ], - assignments, - ) + assert assignments == [ + { + "host": "localhost", + "workers": [ + [0], + [1], + [], + [], + ], + } + ] def test_multiple_host_assignment_more_clients_than_cores(self): host_configs = [ @@ -462,29 +448,26 @@ def test_multiple_host_assignment_more_clients_than_cores(self): assignments = driver.calculate_worker_assignments(host_configs, client_count=16) - self.assertEqual( - [ - { - "host": "host-a", - "workers": [ - [0, 1], - [2, 3], - [4, 5], - [6, 7], - ], - }, - { - "host": "host-b", - "workers": [ - [8, 9], - [10, 11], - [12, 13], - [14, 15], - ], - }, - ], - assignments, - ) + assert assignments == [ + { + "host": "host-a", + "workers": [ + [0, 1], + [2, 3], + [4, 5], + [6, 7], + ], + }, + { + "host": "host-b", + "workers": [ + [8, 9], + [10, 11], + [12, 13], + [14, 15], + ], + }, + ] def test_multiple_host_assignment_less_clients_than_cores(self): host_configs = [ @@ -500,29 +483,26 @@ def test_multiple_host_assignment_less_clients_than_cores(self): assignments = driver.calculate_worker_assignments(host_configs, client_count=4) - self.assertEqual( - [ - { - "host": "host-a", - "workers": [ - [0], - [1], - [], - [], - ], - }, - { - "host": "host-b", - "workers": [ - [2], - [3], - [], - [], - ], - }, - ], - assignments, - ) + assert assignments == [ + { + "host": "host-a", + "workers": [ + [0], + [1], + [], + [], + ], + }, + { + "host": "host-b", + "workers": [ + [2], + [3], + [], + [], + ], + }, + ] def test_uneven_assignment_across_hosts(self): host_configs = [ @@ -542,42 +522,39 @@ def test_uneven_assignment_across_hosts(self): assignments = driver.calculate_worker_assignments(host_configs, client_count=17) - self.assertEqual( - [ - { - "host": "host-a", - "workers": [ - [0, 1], - [2, 3], - [4], - [5], - ], - }, - { - "host": "host-b", - "workers": [ - [6, 7], - [8, 9], - [10], - [11], - ], - }, - { - "host": "host-c", - "workers": [ - [12, 13], - [14], - [15], - [16], - ], - }, - ], - assignments, - ) + assert assignments == [ + { + "host": "host-a", + "workers": [ + [0, 1], + [2, 3], + [4], + [5], + ], + }, + { + "host": "host-b", + "workers": [ + [6, 7], + [8, 9], + [10], + [11], + ], + }, + { + "host": "host-c", + "workers": [ + [12, 13], + [14], + [15], + [16], + ], + }, + ] -class AllocatorTests(TestCase): - def setUp(self): +class TestAllocator: + def setup_method(self, method): params.register_param_source_for_name("driver-test-param-source", DriverTestParamSource) def ta(self, task, client_index_in_task, global_client_index=None, total_clients=None): @@ -593,35 +570,35 @@ def test_allocates_one_task(self): allocator = driver.Allocator([task]) - self.assertEqual(1, allocator.clients) - self.assertEqual(3, len(allocator.allocations[0])) - self.assertEqual(2, len(allocator.join_points)) - self.assertEqual([{task}], allocator.tasks_per_joinpoint) + assert allocator.clients == 1 + assert len(allocator.allocations[0]) == 3 + assert len(allocator.join_points) == 2 + assert allocator.tasks_per_joinpoint == [{task}] def test_allocates_two_serial_tasks(self): task = track.Task("index", op("index", track.OperationType.Bulk)) allocator = driver.Allocator([task, task]) - self.assertEqual(1, allocator.clients) + assert allocator.clients == 1 # we have two operations and three join points - self.assertEqual(5, len(allocator.allocations[0])) - self.assertEqual(3, len(allocator.join_points)) - self.assertEqual([{task}, {task}], allocator.tasks_per_joinpoint) + assert len(allocator.allocations[0]) == 5 + assert len(allocator.join_points) == 3 + assert allocator.tasks_per_joinpoint == [{task}, {task}] def test_allocates_two_parallel_tasks(self): task = track.Task("index", op("index", track.OperationType.Bulk)) allocator = driver.Allocator([track.Parallel([task, task])]) - self.assertEqual(2, allocator.clients) - self.assertEqual(3, len(allocator.allocations[0])) - self.assertEqual(3, len(allocator.allocations[1])) - self.assertEqual(2, len(allocator.join_points)) - self.assertEqual([{task}], allocator.tasks_per_joinpoint) + assert allocator.clients == 2 + assert len(allocator.allocations[0]) == 3 + assert len(allocator.allocations[1]) == 3 + assert len(allocator.join_points) == 2 + assert allocator.tasks_per_joinpoint == [{task}] for join_point in allocator.join_points: - self.assertFalse(join_point.preceding_task_completes_parent) - self.assertEqual(0, join_point.num_clients_executing_completing_task) + assert not join_point.preceding_task_completes_parent is True + assert join_point.num_clients_executing_completing_task == 0 def test_a_task_completes_the_parallel_structure(self): taskA = track.Task("index-completing", op("index", track.OperationType.Bulk), completes_parent=True) @@ -629,15 +606,15 @@ def test_a_task_completes_the_parallel_structure(self): allocator = driver.Allocator([track.Parallel([taskA, taskB])]) - self.assertEqual(2, allocator.clients) - self.assertEqual(3, len(allocator.allocations[0])) - self.assertEqual(3, len(allocator.allocations[1])) - self.assertEqual(2, len(allocator.join_points)) - self.assertEqual([{taskA, taskB}], allocator.tasks_per_joinpoint) + assert allocator.clients == 2 + assert len(allocator.allocations[0]) == 3 + assert len(allocator.allocations[1]) == 3 + assert len(allocator.join_points) == 2 + assert allocator.tasks_per_joinpoint == [{taskA, taskB}] final_join_point = allocator.join_points[1] - self.assertTrue(final_join_point.preceding_task_completes_parent) - self.assertEqual(1, final_join_point.num_clients_executing_completing_task) - self.assertEqual([0], final_join_point.clients_executing_completing_task) + assert final_join_point.preceding_task_completes_parent is True + assert final_join_point.num_clients_executing_completing_task == 1 + assert final_join_point.clients_executing_completing_task == [0] def test_any_task_completes_the_parallel_structure(self): taskA = track.Task("index-completing", op("index", track.OperationType.Bulk), any_completes_parent=True) @@ -645,14 +622,14 @@ def test_any_task_completes_the_parallel_structure(self): # Both tasks can complete the parent allocator = driver.Allocator([track.Parallel([taskA, taskB])]) - self.assertEqual(2, allocator.clients) - self.assertEqual(3, len(allocator.allocations[0])) - self.assertEqual(3, len(allocator.allocations[1])) - self.assertEqual(2, len(allocator.join_points)) - self.assertEqual([{taskA, taskB}], allocator.tasks_per_joinpoint) + assert allocator.clients == 2 + assert len(allocator.allocations[0]) == 3 + assert len(allocator.allocations[1]) == 3 + assert len(allocator.join_points) == 2 + assert allocator.tasks_per_joinpoint == [{taskA, taskB}] final_join_point = allocator.join_points[-1] - self.assertEqual(2, len(final_join_point.any_task_completes_parent)) - self.assertEqual([0, 1], final_join_point.any_task_completes_parent) + assert len(final_join_point.any_task_completes_parent) == 2 + assert final_join_point.any_task_completes_parent == [0, 1] def test_allocates_mixed_tasks(self): index = track.Task("index", op("index", track.OperationType.Bulk)) @@ -661,17 +638,17 @@ def test_allocates_mixed_tasks(self): allocator = driver.Allocator([index, track.Parallel([index, stats, stats]), index, index, track.Parallel([search, search, search])]) - self.assertEqual(3, allocator.clients) + assert allocator.clients == 3 # 1 join point, 1 op, 1 jp, 1 (parallel) op, 1 jp, 1 op, 1 jp, 1 op, 1 jp, 1 (parallel) op, 1 jp - self.assertEqual(11, len(allocator.allocations[0])) - self.assertEqual(11, len(allocator.allocations[1])) - self.assertEqual(11, len(allocator.allocations[2])) - self.assertEqual(6, len(allocator.join_points)) - self.assertEqual([{index}, {index, stats}, {index}, {index}, {search}], allocator.tasks_per_joinpoint) + assert len(allocator.allocations[0]) == 11 + assert len(allocator.allocations[1]) == 11 + assert len(allocator.allocations[2]) == 11 + assert len(allocator.join_points) == 6 + assert allocator.tasks_per_joinpoint == [{index}, {index, stats}, {index}, {index}, {search}] for join_point in allocator.join_points: - self.assertFalse(join_point.preceding_task_completes_parent) - self.assertEqual(0, join_point.num_clients_executing_completing_task) + assert not join_point.preceding_task_completes_parent is True + assert join_point.num_clients_executing_completing_task == 0 # TODO (follow-up PR): We should probably forbid this def test_allocates_more_tasks_than_clients(self): @@ -683,44 +660,38 @@ def test_allocates_more_tasks_than_clients(self): allocator = driver.Allocator([track.Parallel(tasks=[index_a, index_b, index_c, index_d, index_e], clients=2)]) - self.assertEqual(2, allocator.clients) + assert allocator.clients == 2 allocations = allocator.allocations # 2 clients - self.assertEqual(2, len(allocations)) + assert len(allocations) == 2 # join_point, index_a, index_c, index_e, join_point - self.assertEqual(5, len(allocations[0])) + assert len(allocations[0]) == 5 # we really have no chance to extract the join point so we just take what is there... - self.assertEqual( - [ - allocations[0][0], - self.ta(index_a, client_index_in_task=0, global_client_index=0, total_clients=2), - self.ta(index_c, client_index_in_task=0, global_client_index=2, total_clients=2), - self.ta(index_e, client_index_in_task=0, global_client_index=4, total_clients=2), - allocations[0][4], - ], - allocations[0], - ) + assert allocations[0] == [ + allocations[0][0], + self.ta(index_a, client_index_in_task=0, global_client_index=0, total_clients=2), + self.ta(index_c, client_index_in_task=0, global_client_index=2, total_clients=2), + self.ta(index_e, client_index_in_task=0, global_client_index=4, total_clients=2), + allocations[0][4], + ] # join_point, index_a, index_c, None, join_point - self.assertEqual(5, len(allocator.allocations[1])) - self.assertEqual( - [ - allocations[1][0], - self.ta(index_b, client_index_in_task=0, global_client_index=1, total_clients=2), - self.ta(index_d, client_index_in_task=0, global_client_index=3, total_clients=2), - None, - allocations[1][4], - ], - allocations[1], - ) + assert len(allocator.allocations[1]) == 5 + assert allocations[1] == [ + allocations[1][0], + self.ta(index_b, client_index_in_task=0, global_client_index=1, total_clients=2), + self.ta(index_d, client_index_in_task=0, global_client_index=3, total_clients=2), + None, + allocations[1][4], + ] - self.assertEqual([{index_a, index_b, index_c, index_d, index_e}], allocator.tasks_per_joinpoint) - self.assertEqual(2, len(allocator.join_points)) + assert allocator.tasks_per_joinpoint == [{index_a, index_b, index_c, index_d, index_e}] + assert len(allocator.join_points) == 2 final_join_point = allocator.join_points[1] - self.assertTrue(final_join_point.preceding_task_completes_parent) - self.assertEqual(1, final_join_point.num_clients_executing_completing_task) - self.assertEqual([1], final_join_point.clients_executing_completing_task) + assert final_join_point.preceding_task_completes_parent is True + assert final_join_point.num_clients_executing_completing_task == 1 + assert final_join_point.clients_executing_completing_task == [1] # TODO (follow-up PR): We should probably forbid this def test_considers_number_of_clients_per_subtask(self): @@ -730,64 +701,55 @@ def test_considers_number_of_clients_per_subtask(self): allocator = driver.Allocator([track.Parallel(tasks=[index_a, index_b, index_c], clients=3)]) - self.assertEqual(3, allocator.clients) + assert allocator.clients == 3 allocations = allocator.allocations # 3 clients - self.assertEqual(3, len(allocations)) + assert len(allocations) == 3 # tasks that client 0 will execute: # join_point, index_a, index_c, join_point - self.assertEqual(4, len(allocations[0])) + assert len(allocations[0]) == 4 # we really have no chance to extract the join point so we just take what is there... - self.assertEqual( - [ - allocations[0][0], - self.ta(index_a, client_index_in_task=0, global_client_index=0, total_clients=3), - self.ta(index_c, client_index_in_task=1, global_client_index=3, total_clients=3), - allocations[0][3], - ], - allocations[0], - ) + assert allocations[0] == [ + allocations[0][0], + self.ta(index_a, client_index_in_task=0, global_client_index=0, total_clients=3), + self.ta(index_c, client_index_in_task=1, global_client_index=3, total_clients=3), + allocations[0][3], + ] # task that client 1 will execute: # join_point, index_b, None, join_point - self.assertEqual(4, len(allocator.allocations[1])) - self.assertEqual( - [ - allocations[1][0], - self.ta(index_b, client_index_in_task=0, global_client_index=1, total_clients=3), - None, - allocations[1][3], - ], - allocations[1], - ) + assert len(allocator.allocations[1]) == 4 + assert allocations[1] == [ + allocations[1][0], + self.ta(index_b, client_index_in_task=0, global_client_index=1, total_clients=3), + None, + allocations[1][3], + ] # tasks that client 2 will execute: - self.assertEqual(4, len(allocator.allocations[2])) - self.assertEqual( - [ - allocations[2][0], - self.ta(index_c, client_index_in_task=0, global_client_index=2, total_clients=3), - None, - allocations[2][3], - ], - allocations[2], - ) + assert len(allocator.allocations[2]) == 4 + assert allocations[2] == [ + allocations[2][0], + self.ta(index_c, client_index_in_task=0, global_client_index=2, total_clients=3), + None, + allocations[2][3], + ] - self.assertEqual([{index_a, index_b, index_c}], allocator.tasks_per_joinpoint) + assert [{index_a, index_b, index_c}] == allocator.tasks_per_joinpoint - self.assertEqual(2, len(allocator.join_points)) + assert len(allocator.join_points) == 2 final_join_point = allocator.join_points[1] - self.assertTrue(final_join_point.preceding_task_completes_parent) + assert final_join_point.preceding_task_completes_parent is True # task index_c has two clients, hence we have to wait for two clients to finish - self.assertEqual(2, final_join_point.num_clients_executing_completing_task) - self.assertEqual([2, 0], final_join_point.clients_executing_completing_task) + assert final_join_point.num_clients_executing_completing_task == 2 + assert final_join_point.clients_executing_completing_task == [2, 0] -class MetricsAggregationTests(TestCase): - def setUp(self): +class TestMetricsAggregation: + def setup_method(self, method): params.register_param_source_for_name("driver-test-param-source", DriverTestParamSource) def test_different_sample_types(self): @@ -800,13 +762,13 @@ def test_different_sample_types(self): aggregated = self.calculate_global_throughput(samples) - self.assertIn(op, aggregated) - self.assertEqual(1, len(aggregated)) + assert op in aggregated + assert len(aggregated) == 1 throughput = aggregated[op] - self.assertEqual(2, len(throughput)) - self.assertEqual((1470838595, 21, metrics.SampleType.Warmup, 3000, "docs/s"), throughput[0]) - self.assertEqual((1470838595.5, 21.5, metrics.SampleType.Normal, 3666.6666666666665, "docs/s"), throughput[1]) + assert len(throughput) == 2 + assert throughput[0] == (1470838595, 21, metrics.SampleType.Warmup, 3000, "docs/s") + assert throughput[1] == (1470838595.5, 21.5, metrics.SampleType.Normal, 3666.6666666666665, "docs/s") def test_single_metrics_aggregation(self): op = track.Operation("index", track.OperationType.Bulk, param_source="driver-test-param-source") @@ -825,18 +787,17 @@ def test_single_metrics_aggregation(self): aggregated = self.calculate_global_throughput(samples) - self.assertIn(op, aggregated) - self.assertEqual(1, len(aggregated)) + assert op in aggregated + assert len(aggregated) == 1 throughput = aggregated[op] - self.assertEqual(6, len(throughput)) - self.assertEqual((38595, 21, metrics.SampleType.Normal, 5000, "docs/s"), throughput[0]) - self.assertEqual((38596, 22, metrics.SampleType.Normal, 5000, "docs/s"), throughput[1]) - self.assertEqual((38597, 23, metrics.SampleType.Normal, 5000, "docs/s"), throughput[2]) - self.assertEqual((38598, 24, metrics.SampleType.Normal, 5000, "docs/s"), throughput[3]) - self.assertEqual((38599, 25, metrics.SampleType.Normal, 6000, "docs/s"), throughput[4]) - self.assertEqual((38600, 26, metrics.SampleType.Normal, 6666.666666666667, "docs/s"), throughput[5]) - # self.assertEqual((1470838600.5, 26.5, metrics.SampleType.Normal, 10000), throughput[6]) + assert len(throughput) == 6 + assert throughput[0] == (38595, 21, metrics.SampleType.Normal, 5000, "docs/s") + assert throughput[1] == (38596, 22, metrics.SampleType.Normal, 5000, "docs/s") + assert throughput[2] == (38597, 23, metrics.SampleType.Normal, 5000, "docs/s") + assert throughput[3] == (38598, 24, metrics.SampleType.Normal, 5000, "docs/s") + assert throughput[4] == (38599, 25, metrics.SampleType.Normal, 6000, "docs/s") + assert throughput[5] == (38600, 26, metrics.SampleType.Normal, 6666.666666666667, "docs/s") def test_use_provided_throughput(self): op = track.Operation("index-recovery", track.OperationType.WaitForRecovery, param_source="driver-test-param-source") @@ -849,20 +810,20 @@ def test_use_provided_throughput(self): aggregated = self.calculate_global_throughput(samples) - self.assertIn(op, aggregated) - self.assertEqual(1, len(aggregated)) + assert op in aggregated + assert len(aggregated) == 1 throughput = aggregated[op] - self.assertEqual(3, len(throughput)) - self.assertEqual((38595, 21, metrics.SampleType.Normal, 8000, "byte/s"), throughput[0]) - self.assertEqual((38596, 22, metrics.SampleType.Normal, 8000, "byte/s"), throughput[1]) - self.assertEqual((38597, 23, metrics.SampleType.Normal, 8000, "byte/s"), throughput[2]) + assert len(throughput) == 3 + assert throughput[0] == (38595, 21, metrics.SampleType.Normal, 8000, "byte/s") + assert throughput[1] == (38596, 22, metrics.SampleType.Normal, 8000, "byte/s") + assert throughput[2] == (38597, 23, metrics.SampleType.Normal, 8000, "byte/s") def calculate_global_throughput(self, samples): return driver.ThroughputCalculator().calculate(samples) -class SchedulerTests(TestCase): +class TestScheduler: class RunnerWithProgress: def __init__(self, complete_after=3): self.completed = False @@ -899,11 +860,11 @@ async def assert_schedule(self, expected_schedule, schedule_handle, infinite_sch async for invocation_time, sample_type, progress_percent, runner, params in schedule_handle(): schedule_handle.before_request(now=idx) exp_invocation_time, exp_sample_type, exp_progress_percent, exp_params = expected_schedule[idx] - self.assertAlmostEqual(exp_invocation_time, invocation_time, msg="Invocation time for sample at index %d does not match" % idx) - self.assertEqual(exp_sample_type, sample_type, "Sample type for sample at index %d does not match" % idx) - self.assertEqual(exp_progress_percent, progress_percent, "Current progress for sample at index %d does not match" % idx) - self.assertIsNotNone(runner, "runner must be defined") - self.assertEqual(exp_params, params, "Parameters do not match") + assert round(abs(exp_invocation_time - invocation_time), 7) == 0 + assert sample_type == exp_sample_type + assert progress_percent == exp_progress_percent + assert runner is not None + assert params == exp_params idx += 1 # for infinite schedules we only check the first few elements if infinite_schedule and idx == len(expected_schedule): @@ -911,17 +872,17 @@ async def assert_schedule(self, expected_schedule, schedule_handle, infinite_sch # simulate that the request is done - we only support throttling based on request count (ops). schedule_handle.after_request(now=idx, weight=1, unit="ops", request_meta_data=None) if not infinite_schedule: - self.assertEqual(len(expected_schedule), idx, msg="Number of elements in the schedules do not match") + assert len(expected_schedule) == idx - def setUp(self): + def setup_method(self, method): self.test_track = track.Track(name="unittest") - self.runner_with_progress = SchedulerTests.RunnerWithProgress() + self.runner_with_progress = self.RunnerWithProgress() params.register_param_source_for_name("driver-test-param-source", DriverTestParamSource) runner.register_default_runners() runner.register_runner("driver-test-runner-with-completion", self.runner_with_progress, async_runner=True) - scheduler.register_scheduler("custom-complex-scheduler", SchedulerTests.CustomComplexScheduler) + scheduler.register_scheduler("custom-complex-scheduler", self.CustomComplexScheduler) - def tearDown(self): + def teardown_method(self, method): runner.remove_runner("driver-test-runner-with-completion") scheduler.remove_scheduler("custom-complex-scheduler") @@ -940,8 +901,8 @@ def test_injects_parameter_source_into_scheduler(self): param_source = track.operation_parameters(self.test_track, task) schedule = driver.schedule_for(task_allocation, param_source) - self.assertIsNotNone(schedule.sched.parameter_source, "Parameter source has not been injected into scheduler") - self.assertEqual(param_source, schedule.sched.parameter_source) + assert schedule.sched.parameter_source is not None, "Parameter source has not been injected into scheduler" + assert schedule.sched.parameter_source == param_source @run_async async def test_search_task_one_client(self): @@ -1223,24 +1184,24 @@ async def test_schedule_for_time_based(self): schedule_handle = driver.schedule_for(task_allocation, param_source) schedule_handle.start() # first client does not wait - self.assertEqual(0.0, schedule_handle.ramp_up_wait_time) + assert schedule_handle.ramp_up_wait_time == 0.0 schedule = schedule_handle() last_progress = -1 async for invocation_time, sample_type, progress_percent, runner, params in schedule: # we're not throughput throttled - self.assertEqual(0, invocation_time) + assert invocation_time == 0 if progress_percent <= 0.5: - self.assertEqual(metrics.SampleType.Warmup, sample_type) + assert metrics.SampleType.Warmup == sample_type else: - self.assertEqual(metrics.SampleType.Normal, sample_type) - self.assertTrue(last_progress < progress_percent) + assert metrics.SampleType.Normal == sample_type + assert last_progress < progress_percent last_progress = progress_percent - self.assertTrue(round(progress_percent, 2) >= 0.0, "progress should be >= 0.0 but was [%f]" % progress_percent) - self.assertTrue(round(progress_percent, 2) <= 1.0, "progress should be <= 1.0 but was [%f]" % progress_percent) - self.assertIsNotNone(runner, "runner must be defined") - self.assertEqual({"body": ["a"], "operation-type": "bulk", "size": 11}, params) + assert round(progress_percent, 2) >= 0.0 + assert round(progress_percent, 2) <= 1.0 + assert runner is not None + assert params == {"body": ["a"], "operation-type": "bulk", "size": 11} @run_async async def test_schedule_for_time_based_with_multiple_clients(self): @@ -1270,27 +1231,27 @@ async def test_schedule_for_time_based_with_multiple_clients(self): schedule_handle = driver.schedule_for(task_allocation, param_source) schedule_handle.start() # client number 4 out of 8 -> 0.1 * (4 / 8) = 0.05 - self.assertEqual(0.05, schedule_handle.ramp_up_wait_time) + assert schedule_handle.ramp_up_wait_time == 0.05 schedule = schedule_handle() last_progress = -1 async for invocation_time, sample_type, progress_percent, runner, params in schedule: # we're not throughput throttled - self.assertEqual(0, invocation_time) + assert invocation_time == 0 if progress_percent <= 0.5: - self.assertEqual(metrics.SampleType.Warmup, sample_type) + assert metrics.SampleType.Warmup == sample_type else: - self.assertEqual(metrics.SampleType.Normal, sample_type) - self.assertTrue(last_progress < progress_percent) + assert metrics.SampleType.Normal == sample_type + assert last_progress < progress_percent last_progress = progress_percent - self.assertTrue(round(progress_percent, 2) >= 0.0, "progress should be >= 0.0 but was [%f]" % progress_percent) - self.assertTrue(round(progress_percent, 2) <= 1.0, "progress should be <= 1.0 but was [%f]" % progress_percent) - self.assertIsNotNone(runner, "runner must be defined") - self.assertEqual({"body": ["a"], "operation-type": "bulk", "size": 11}, params) + assert round(progress_percent, 2) >= 0.0 + assert round(progress_percent, 2) <= 1.0 + assert runner is not None + assert params == {"body": ["a"], "operation-type": "bulk", "size": 11} -class AsyncExecutorTests(TestCase): +class TestAsyncExecutor: class NoopContextManager: def __init__(self, mock): self.mock = mock @@ -1348,18 +1309,14 @@ class RunnerOverridingThroughput: async def __call__(self, es, params): return {"weight": 1, "unit": "ops", "throughput": 1.23} - def __init__(self, methodName): - super().__init__(methodName) - self.runner_with_progress = None - @staticmethod def context_managed(mock): - return AsyncExecutorTests.NoopContextManager(mock) + return TestAsyncExecutor.NoopContextManager(mock) - def setUp(self): + def setup_method(self, method): runner.register_default_runners() - self.runner_with_progress = AsyncExecutorTests.RunnerWithProgress() - self.runner_overriding_throughput = AsyncExecutorTests.RunnerOverridingThroughput() + self.runner_with_progress = self.RunnerWithProgress() + self.runner_overriding_throughput = self.RunnerOverridingThroughput() runner.register_runner("unit-test-recovery", self.runner_with_progress, async_runner=True) runner.register_runner("override-throughput", self.runner_overriding_throughput, async_runner=True) @@ -1367,7 +1324,7 @@ def setUp(self): @run_async async def test_execute_schedule_in_throughput_mode(self, es): task_start = time.perf_counter() - es.new_request_context.return_value = AsyncExecutorTests.StaticRequestTiming(task_start=task_start) + es.new_request_context.return_value = self.StaticRequestTiming(task_start=task_start) es.bulk = mock.AsyncMock(return_value=io.StringIO('{"errors": false, "took": 8}')) @@ -1416,29 +1373,29 @@ async def test_execute_schedule_in_throughput_mode(self, es): samples = sampler.samples - self.assertTrue(len(samples) > 0) - self.assertFalse(complete.is_set(), "Executor should not auto-complete a normal task") + assert len(samples) > 0 + assert not complete.is_set(), "Executor should not auto-complete a normal task" previous_absolute_time = -1.0 previous_relative_time = -1.0 for sample in samples: - self.assertEqual(2, sample.client_id) - self.assertEqual(task, sample.task) - self.assertLess(previous_absolute_time, sample.absolute_time) + assert sample.client_id == 2 + assert sample.task == task + assert previous_absolute_time < sample.absolute_time previous_absolute_time = sample.absolute_time - self.assertLess(previous_relative_time, sample.relative_time) + assert previous_relative_time < sample.relative_time previous_relative_time = sample.relative_time # we don't have any warmup time period - self.assertEqual(metrics.SampleType.Normal, sample.sample_type) + assert metrics.SampleType.Normal == sample.sample_type # latency equals service time in throughput mode - self.assertEqual(sample.latency, sample.service_time) - self.assertEqual(1, sample.total_ops) - self.assertEqual("docs", sample.total_ops_unit) + assert sample.latency == sample.service_time + assert sample.total_ops == 1 + assert sample.total_ops_unit == "docs" @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_execute_schedule_with_progress_determined_by_runner(self, es): task_start = time.perf_counter() - es.new_request_context.return_value = AsyncExecutorTests.StaticRequestTiming(task_start=task_start) + es.new_request_context.return_value = self.StaticRequestTiming(task_start=task_start) params.register_param_source_for_name("driver-test-param-source", DriverTestParamSource) test_track = track.Track(name="unittest", description="unittest track", indices=None, challenges=None) @@ -1480,33 +1437,33 @@ async def test_execute_schedule_with_progress_determined_by_runner(self, es): samples = sampler.samples - self.assertEqual(5, len(samples)) - self.assertTrue(self.runner_with_progress.completed) - self.assertEqual(1.0, self.runner_with_progress.percent_completed) - self.assertFalse(complete.is_set(), "Executor should not auto-complete a normal task") + assert len(samples) == 5 + assert self.runner_with_progress.completed is True + assert self.runner_with_progress.percent_completed == 1.0 + assert not complete.is_set(), "Executor should not auto-complete a normal task" previous_absolute_time = -1.0 previous_relative_time = -1.0 for sample in samples: - self.assertEqual(2, sample.client_id) - self.assertEqual(task, sample.task) - self.assertLess(previous_absolute_time, sample.absolute_time) + assert sample.client_id == 2 + assert sample.task == task + assert previous_absolute_time < sample.absolute_time previous_absolute_time = sample.absolute_time - self.assertLess(previous_relative_time, sample.relative_time) + assert previous_relative_time < sample.relative_time previous_relative_time = sample.relative_time # we don't have any warmup time period - self.assertEqual(metrics.SampleType.Normal, sample.sample_type) + assert metrics.SampleType.Normal == sample.sample_type # throughput is not overridden and will be calculated later - self.assertIsNone(sample.throughput) + assert sample.throughput is None # latency equals service time in throughput mode - self.assertEqual(sample.latency, sample.service_time) - self.assertEqual(1, sample.total_ops) - self.assertEqual("ops", sample.total_ops_unit) + assert sample.latency == sample.service_time + assert sample.total_ops == 1 + assert sample.total_ops_unit == "ops" @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_execute_schedule_runner_overrides_times(self, es): task_start = time.perf_counter() - es.new_request_context.return_value = AsyncExecutorTests.StaticRequestTiming(task_start=task_start) + es.new_request_context.return_value = self.StaticRequestTiming(task_start=task_start) params.register_param_source_for_name("driver-test-param-source", DriverTestParamSource) test_track = track.Track(name="unittest", description="unittest track", indices=None, challenges=None) @@ -1549,19 +1506,19 @@ async def test_execute_schedule_runner_overrides_times(self, es): samples = sampler.samples - self.assertFalse(complete.is_set(), "Executor should not auto-complete a normal task") - self.assertEqual(1, len(samples)) + assert not complete.is_set(), "Executor should not auto-complete a normal task" + assert len(samples) == 1 sample = samples[0] - self.assertEqual(0, sample.client_id) - self.assertEqual(task, sample.task) + assert sample.client_id == 0 + assert sample.task == task # we don't have any warmup samples - self.assertEqual(metrics.SampleType.Normal, sample.sample_type) - self.assertEqual(sample.latency, sample.service_time) - self.assertEqual(1, sample.total_ops) - self.assertEqual("ops", sample.total_ops_unit) - self.assertEqual(1.23, sample.throughput) - self.assertIsNotNone(sample.service_time) - self.assertIsNotNone(sample.time_period) + assert metrics.SampleType.Normal == sample.sample_type + assert sample.latency == sample.service_time + assert sample.total_ops == 1 + assert sample.total_ops_unit == "ops" + assert sample.throughput == 1.23 + assert sample.service_time is not None + assert sample.time_period is not None @mock.patch("elasticsearch.Elasticsearch") @run_async @@ -1625,11 +1582,8 @@ async def perform_request(*args, **kwargs): sample_size = len(samples) lower_bound = bounds[0] upper_bound = bounds[1] - self.assertTrue( - lower_bound <= sample_size <= upper_bound, - msg="Expected sample size to be between %d and %d but was %d" % (lower_bound, upper_bound, sample_size), - ) - self.assertTrue(complete.is_set(), "Executor should auto-complete a task that terminates its parent") + assert lower_bound <= sample_size <= upper_bound + assert complete.is_set(), "Executor should auto-complete a task that terminates its parent" @mock.patch("elasticsearch.Elasticsearch") @run_async @@ -1680,7 +1634,7 @@ async def test_cancel_execute_schedule(self, es): samples = sampler.samples sample_size = len(samples) - self.assertEqual(0, sample_size) + assert sample_size == 0 @mock.patch("elasticsearch.Elasticsearch") @run_async @@ -1706,7 +1660,7 @@ def start(self): pass async def __call__(self): - invocations = [(0, metrics.SampleType.Warmup, 0, AsyncExecutorTests.context_managed(run), None)] + invocations = [(0, metrics.SampleType.Warmup, 0, TestAsyncExecutor.context_managed(run), None)] for invocation in invocations: yield invocation @@ -1733,10 +1687,10 @@ async def __call__(self): on_error="continue", ) - with self.assertRaisesRegex(exceptions.RallyError, r"Cannot run task \[no-op\]: expected unit test exception"): + with pytest.raises(exceptions.RallyError, match=r"Cannot run task \[no-op\]: expected unit test exception"): await execute_schedule() - self.assertEqual(0, es.call_count) + assert es.call_count == 0 @run_async async def test_execute_single_no_return_value(self): @@ -1746,9 +1700,9 @@ async def test_execute_single_no_return_value(self): ops, unit, request_meta_data = await driver.execute_single(self.context_managed(runner), es, params, on_error="continue") - self.assertEqual(1, ops) - self.assertEqual("ops", unit) - self.assertEqual({"success": True}, request_meta_data) + assert ops == 1 + assert unit == "ops" + assert request_meta_data == {"success": True} @run_async async def test_execute_single_tuple(self): @@ -1758,9 +1712,9 @@ async def test_execute_single_tuple(self): ops, unit, request_meta_data = await driver.execute_single(self.context_managed(runner), es, params, on_error="continue") - self.assertEqual(500, ops) - self.assertEqual("MB", unit) - self.assertEqual({"success": True}, request_meta_data) + assert ops == 500 + assert unit == "MB" + assert request_meta_data == {"success": True} @run_async async def test_execute_single_dict(self): @@ -1777,29 +1731,25 @@ async def test_execute_single_dict(self): ops, unit, request_meta_data = await driver.execute_single(self.context_managed(runner), es, params, on_error="continue") - self.assertEqual(50, ops) - self.assertEqual("docs", unit) - self.assertEqual( - { - "some-custom-meta-data": "valid", - "http-status": 200, - "success": True, - }, - request_meta_data, - ) + assert ops == 50 + assert unit == "docs" + assert request_meta_data == { + "some-custom-meta-data": "valid", + "http-status": 200, + "success": True, + } + + @pytest.mark.parametrize("on_error", ["abort", "continue"]) + @pytest.mark.asyncio + async def test_execute_single_with_connection_error_always_aborts(self, on_error): + es = None + params = None + # ES client uses pseudo-status "N/A" in this case... + runner = mock.AsyncMock(side_effect=elasticsearch.ConnectionError("N/A", "no route to host", None)) - @run_async - async def test_execute_single_with_connection_error_always_aborts(self): - for on_error in ["abort", "continue"]: - with self.subTest(): - es = None - params = None - # ES client uses pseudo-status "N/A" in this case... - runner = mock.AsyncMock(side_effect=elasticsearch.ConnectionError("N/A", "no route to host", None)) - - with self.assertRaises(exceptions.RallyAssertionError) as ctx: - await driver.execute_single(self.context_managed(runner), es, params, on_error=on_error) - self.assertEqual("Request returned an error. Error type: transport, Description: no route to host", ctx.exception.args[0]) + with pytest.raises(exceptions.RallyAssertionError) as exc: + await driver.execute_single(self.context_managed(runner), es, params, on_error=on_error) + assert exc.value.args[0] == "Request returned an error. Error type: transport, Description: no route to host" @run_async async def test_execute_single_with_http_400_aborts_when_specified(self): @@ -1807,11 +1757,10 @@ async def test_execute_single_with_http_400_aborts_when_specified(self): params = None runner = mock.AsyncMock(side_effect=elasticsearch.NotFoundError(404, "not found", "the requested document could not be found")) - with self.assertRaises(exceptions.RallyAssertionError) as ctx: + with pytest.raises(exceptions.RallyAssertionError) as exc: await driver.execute_single(self.context_managed(runner), es, params, on_error="abort") - self.assertEqual( - "Request returned an error. Error type: transport, Description: not found (the requested document could not be found)", - ctx.exception.args[0], + assert exc.value.args[0] == ( + "Request returned an error. Error type: transport, Description: not found (the requested document could not be found)" ) @run_async @@ -1822,17 +1771,14 @@ async def test_execute_single_with_http_400(self): ops, unit, request_meta_data = await driver.execute_single(self.context_managed(runner), es, params, on_error="continue") - self.assertEqual(0, ops) - self.assertEqual("ops", unit) - self.assertEqual( - { - "http-status": 404, - "error-type": "transport", - "error-description": "not found (the requested document could not be found)", - "success": False, - }, - request_meta_data, - ) + assert ops == 0 + assert unit == "ops" + assert request_meta_data == { + "http-status": 404, + "error-type": "transport", + "error-description": "not found (the requested document could not be found)", + "success": False, + } @run_async async def test_execute_single_with_http_413(self): @@ -1842,17 +1788,14 @@ async def test_execute_single_with_http_413(self): ops, unit, request_meta_data = await driver.execute_single(self.context_managed(runner), es, params, on_error="continue") - self.assertEqual(0, ops) - self.assertEqual("ops", unit) - self.assertEqual( - { - "http-status": 413, - "error-type": "transport", - "error-description": "", - "success": False, - }, - request_meta_data, - ) + assert ops == 0 + assert unit == "ops" + assert request_meta_data == { + "http-status": 413, + "error-type": "transport", + "error-description": "", + "success": False, + } @run_async async def test_execute_single_with_key_error(self): @@ -1870,15 +1813,14 @@ def __str__(self): params["mode"] = "append" runner = FailingRunner() - with self.assertRaises(exceptions.SystemSetupError) as ctx: + with pytest.raises(exceptions.SystemSetupError) as exc: await driver.execute_single(self.context_managed(runner), es, params, on_error="continue") - self.assertEqual( - "Cannot execute [failing_mock_runner]. Provided parameters are: ['bulk', 'mode']. Error: ['bulk-size missing'].", - ctx.exception.args[0], + assert exc.value.args[0] == ( + "Cannot execute [failing_mock_runner]. Provided parameters are: ['bulk', 'mode']. Error: ['bulk-size missing']." ) -class AsyncProfilerTests(TestCase): +class TestAsyncProfiler: @run_async async def test_profiler_is_a_transparent_wrapper(self): async def f(x): @@ -1890,6 +1832,6 @@ async def f(x): # this should take roughly 1 second and should return something return_value = await profiler(1) end = time.perf_counter() - self.assertEqual(2, return_value) + assert return_value == 2 duration = end - start - self.assertTrue(0.9 <= duration <= 1.2, "Should sleep for roughly 1 second but took [%.2f] seconds." % duration) + assert 0.9 <= duration <= 1.2, "Should sleep for roughly 1 second but took [%.2f] seconds." % duration diff --git a/tests/driver/runner_test.py b/tests/driver/runner_test.py index 502d0b895..24a888b2d 100644 --- a/tests/driver/runner_test.py +++ b/tests/driver/runner_test.py @@ -16,14 +16,16 @@ # under the License. import asyncio +import collections import copy import io import json +import math import random import unittest.mock as mock -from unittest import TestCase import elasticsearch +import pytest from esrally import client, exceptions from esrally.driver import runner @@ -40,8 +42,8 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): return False -class RegisterRunnerTests(TestCase): - def tearDown(self): +class TestRegisterRunner: + def teardown_method(self, method): runner.remove_runner("unit_test") @run_async @@ -51,11 +53,9 @@ async def runner_function(*args): runner.register_runner(operation_type="unit_test", runner=runner_function, async_runner=True) returned_runner = runner.runner_for("unit_test") - self.assertIsInstance(returned_runner, runner.NoCompletion) - self.assertEqual("user-defined runner for [runner_function]", repr(returned_runner)) - self.assertEqual( - ("default_client", "param"), await returned_runner({"default": "default_client", "other": "other_client"}, "param") - ) + assert isinstance(returned_runner, runner.NoCompletion) + assert repr(returned_runner) == "user-defined runner for [runner_function]" + assert await returned_runner({"default": "default_client", "other": "other_client"}, "param") == ("default_client", "param") @run_async async def test_single_cluster_runner_class_with_context_manager_should_be_wrapped_with_context_manager_enabled(self): @@ -69,17 +69,13 @@ def __str__(self): test_runner = UnitTestSingleClusterContextManagerRunner() runner.register_runner(operation_type="unit_test", runner=test_runner, async_runner=True) returned_runner = runner.runner_for("unit_test") - self.assertIsInstance(returned_runner, runner.NoCompletion) - self.assertEqual( - "user-defined context-manager enabled runner for [UnitTestSingleClusterContextManagerRunner]", repr(returned_runner) - ) + assert isinstance(returned_runner, runner.NoCompletion) + assert repr(returned_runner) == "user-defined context-manager enabled runner for [UnitTestSingleClusterContextManagerRunner]" # test that context_manager functionality gets preserved after wrapping async with returned_runner: - self.assertEqual( - ("default_client", "param"), await returned_runner({"default": "default_client", "other": "other_client"}, "param") - ) + assert await returned_runner({"default": "default_client", "other": "other_client"}, "param") == ("default_client", "param") # check that the context manager interface of our inner runner has been respected. - self.assertTrue(test_runner.fp.closed) + assert test_runner.fp.closed @run_async async def test_multi_cluster_runner_class_with_context_manager_should_be_wrapped_with_context_manager_enabled(self): @@ -95,17 +91,15 @@ def __str__(self): test_runner = UnitTestMultiClusterContextManagerRunner() runner.register_runner(operation_type="unit_test", runner=test_runner, async_runner=True) returned_runner = runner.runner_for("unit_test") - self.assertIsInstance(returned_runner, runner.NoCompletion) - self.assertEqual( - "user-defined context-manager enabled runner for [UnitTestMultiClusterContextManagerRunner]", repr(returned_runner) - ) + assert isinstance(returned_runner, runner.NoCompletion) + assert repr(returned_runner) == "user-defined context-manager enabled runner for [UnitTestMultiClusterContextManagerRunner]" # test that context_manager functionality gets preserved after wrapping all_clients = {"default": "default_client", "other": "other_client"} async with returned_runner: - self.assertEqual((all_clients, "param1", "param2"), await returned_runner(all_clients, "param1", "param2")) + assert await returned_runner(all_clients, "param1", "param2") == (all_clients, "param1", "param2") # check that the context manager interface of our inner runner has been respected. - self.assertTrue(test_runner.fp.closed) + assert test_runner.fp.closed @run_async async def test_single_cluster_runner_class_should_be_wrapped(self): @@ -119,11 +113,9 @@ def __str__(self): test_runner = UnitTestSingleClusterRunner() runner.register_runner(operation_type="unit_test", runner=test_runner, async_runner=True) returned_runner = runner.runner_for("unit_test") - self.assertIsInstance(returned_runner, runner.NoCompletion) - self.assertEqual("user-defined runner for [UnitTestSingleClusterRunner]", repr(returned_runner)) - self.assertEqual( - ("default_client", "param"), await returned_runner({"default": "default_client", "other": "other_client"}, "param") - ) + assert isinstance(returned_runner, runner.NoCompletion) + assert repr(returned_runner) == "user-defined runner for [UnitTestSingleClusterRunner]" + assert await returned_runner({"default": "default_client", "other": "other_client"}, "param") == ("default_client", "param") @run_async async def test_multi_cluster_runner_class_should_be_wrapped(self): @@ -139,17 +131,17 @@ def __str__(self): test_runner = UnitTestMultiClusterRunner() runner.register_runner(operation_type="unit_test", runner=test_runner, async_runner=True) returned_runner = runner.runner_for("unit_test") - self.assertIsInstance(returned_runner, runner.NoCompletion) - self.assertEqual("user-defined runner for [UnitTestMultiClusterRunner]", repr(returned_runner)) + assert isinstance(returned_runner, runner.NoCompletion) + assert repr(returned_runner) == "user-defined runner for [UnitTestMultiClusterRunner]" all_clients = {"default": "default_client", "other": "other_client"} - self.assertEqual((all_clients, "some_param"), await returned_runner(all_clients, "some_param")) + assert await returned_runner(all_clients, "some_param") == (all_clients, "some_param") -class AssertingRunnerTests(TestCase): - def setUp(self): +class TestAssertingRunner: + def setup_method(self, method): runner.enable_assertions(True) - def tearDown(self): + def teardown_method(self, method): runner.enable_assertions(False) @run_async @@ -177,7 +169,7 @@ async def test_asserts_equal_succeeds(self): }, ) - self.assertEqual(response, final_response) + assert final_response == response @run_async async def test_asserts_equal_fails(self): @@ -192,8 +184,8 @@ async def test_asserts_equal_fails(self): } delegate = mock.AsyncMock(return_value=response) r = runner.AssertingRunner(delegate) - with self.assertRaisesRegex( - exceptions.RallyTaskAssertionError, r"Expected \[hits.hits.relation\] in \[test-task\] to be == \[eq\] but was \[gte\]." + with pytest.raises( + exceptions.RallyTaskAssertionError, match=r"Expected \[hits.hits.relation\] in \[test-task\] to be == \[eq\] but was \[gte\]." ): async with r: await r( @@ -228,11 +220,11 @@ async def test_skips_asserts_for_non_dicts(self): }, ) # still passes response as is - self.assertEqual(response, final_response) + assert final_response == response def test_predicates(self): r = runner.AssertingRunner(delegate=None) - self.assertEqual(5, len(r.predicates)) + assert len(r.predicates) == 5 predicate_success = { # predicate: (expected, actual) @@ -245,7 +237,7 @@ def test_predicates(self): for predicate, vals in predicate_success.items(): expected, actual = vals - self.assertTrue(r.predicates[predicate](expected, actual), f"Expected [{expected} {predicate} {actual}] to succeed.") + assert r.predicates[predicate](expected, actual), f"Expected [{expected} {predicate} {actual}] to succeed." predicate_fail = { # predicate: (expected, actual) @@ -258,10 +250,10 @@ def test_predicates(self): for predicate, vals in predicate_fail.items(): expected, actual = vals - self.assertFalse(r.predicates[predicate](expected, actual), f"Expected [{expected} {predicate} {actual}] to fail.") + assert not r.predicates[predicate](expected, actual), f"Expected [{expected} {predicate} {actual}] to fail." -class SelectiveJsonParserTests(TestCase): +class TestSelectiveJsonParser: def doc_as_text(self, doc): return io.StringIO(json.dumps(doc)) @@ -290,9 +282,10 @@ def test_parse_all_expected(self): ], ) - self.assertEqual("Hello", parsed.get("title")) - self.assertEqual(2000, parsed.get("meta.date.year")) - self.assertNotIn("meta.date.month", parsed) + assert parsed == { + "title": "Hello", + "meta.date.year": 2000, + } def test_list_length(self): doc = self.doc_as_text( @@ -336,21 +329,20 @@ def test_list_length(self): ["authors", "readers", "supporters"], ) - self.assertEqual("Hello", parsed.get("title")) - self.assertEqual(2000, parsed.get("meta.date.year")) - self.assertNotIn("meta.date.month", parsed) - - # lists - self.assertFalse(parsed.get("authors")) - self.assertFalse(parsed.get("readers")) - self.assertTrue(parsed.get("supporters")) + assert parsed == { + "title": "Hello", + "meta.date.year": 2000, + "authors": False, + "readers": False, + "supporters": True, + } def _build_bulk_body(*lines): return "".join(line + "\n" for line in lines) -class BulkIndexRunnerTests(TestCase): +class TestBulkIndexRunner: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_bulk_index_missing_params(self, es): @@ -373,12 +365,11 @@ async def test_bulk_index_missing_params(self, es): ) } - with self.assertRaises(exceptions.DataError) as ctx: + with pytest.raises(exceptions.DataError) as exc: await bulk(es, bulk_params) - self.assertEqual( + assert exc.value.args[0] == ( "Parameter source for operation 'bulk-index' did not provide the mandatory parameter 'action-metadata-present'. " - "Add it to your parameter source and try again.", - ctx.exception.args[0], + "Add it to your parameter source and try again." ) @mock.patch("elasticsearch.Elasticsearch") @@ -409,13 +400,15 @@ async def test_bulk_index_success_with_timeout(self, es): result = await bulk(es, bulk_params) - self.assertEqual(8, result["took"]) - self.assertIsNone(result["index"]) - self.assertEqual(3, result["weight"]) - self.assertEqual("docs", result["unit"]) - self.assertEqual(True, result["success"]) - self.assertEqual(0, result["error-count"]) - self.assertFalse("error-type" in result) + assert result == { + "took": 8, + "index": None, + "weight": 3, + "unit": "docs", + "success": True, + "success-count": 3, + "error-count": 0, + } es.bulk.assert_awaited_with(body=bulk_params["body"], params={"timeout": "1m"}) @@ -446,13 +439,15 @@ async def test_bulk_index_success_with_metadata(self, es): result = await bulk(es, bulk_params) - self.assertEqual(8, result["took"]) - self.assertIsNone(result["index"]) - self.assertEqual(3, result["weight"]) - self.assertEqual("docs", result["unit"]) - self.assertEqual(True, result["success"]) - self.assertEqual(0, result["error-count"]) - self.assertFalse("error-type" in result) + assert result == { + "took": 8, + "index": None, + "weight": 3, + "unit": "docs", + "success": True, + "success-count": 3, + "error-count": 0, + } es.bulk.assert_awaited_with(body=bulk_params["body"], params={}) @@ -485,12 +480,15 @@ async def test_simple_bulk_with_timeout_and_headers(self, es): result = await bulk(es, bulk_params) - self.assertEqual(8, result["took"]) - self.assertEqual(3, result["weight"]) - self.assertEqual("docs", result["unit"]) - self.assertEqual(True, result["success"]) - self.assertEqual(0, result["error-count"]) - self.assertFalse("error-type" in result) + assert result == { + "took": 8, + "index": "test1", + "weight": 3, + "unit": "docs", + "success": True, + "success-count": 3, + "error-count": 0, + } es.bulk.assert_awaited_with( doc_type="_doc", @@ -527,13 +525,15 @@ async def test_bulk_index_success_without_metadata_with_doc_type(self, es): result = await bulk(es, bulk_params) - self.assertEqual(8, result["took"]) - self.assertEqual("test-index", result["index"]) - self.assertEqual(3, result["weight"]) - self.assertEqual("docs", result["unit"]) - self.assertEqual(True, result["success"]) - self.assertEqual(0, result["error-count"]) - self.assertFalse("error-type" in result) + assert result == { + "took": 8, + "index": "test-index", + "weight": 3, + "unit": "docs", + "success": True, + "success-count": 3, + "error-count": 0, + } es.bulk.assert_awaited_with(body=bulk_params["body"], index="test-index", doc_type="_doc", params={}) @@ -561,13 +561,15 @@ async def test_bulk_index_success_without_metadata_and_without_doc_type(self, es result = await bulk(es, bulk_params) - self.assertEqual(8, result["took"]) - self.assertEqual("test-index", result["index"]) - self.assertEqual(3, result["weight"]) - self.assertEqual("docs", result["unit"]) - self.assertEqual(True, result["success"]) - self.assertEqual(0, result["error-count"]) - self.assertFalse("error-type" in result) + assert result == { + "took": 8, + "index": "test-index", + "weight": 3, + "unit": "docs", + "success": True, + "success-count": 3, + "error-count": 0, + } es.bulk.assert_awaited_with(body=bulk_params["body"], index="test-index", doc_type=None, params={}) @@ -605,13 +607,17 @@ async def test_bulk_index_error(self, es): result = await bulk(es, bulk_params) - self.assertEqual("test", result["index"]) - self.assertEqual(5, result["took"]) - self.assertEqual(3, result["weight"]) - self.assertEqual("docs", result["unit"]) - self.assertEqual(False, result["success"]) - self.assertEqual(2, result["error-count"]) - self.assertEqual("bulk", result["error-type"]) + result.pop("error-description") # TODO not deterministic + assert result == { + "took": 5, + "index": "test", + "weight": 3, + "unit": "docs", + "success": False, + "success-count": 1, + "error-count": 2, + "error-type": "bulk", + } es.bulk.assert_awaited_with(body=bulk_params["body"], params={}) @@ -674,13 +680,17 @@ async def test_bulk_index_error_no_shards(self, es): result = await bulk(es, bulk_params) - self.assertEqual("test", result["index"]) - self.assertEqual(20, result["took"]) - self.assertEqual(3, result["weight"]) - self.assertEqual("docs", result["unit"]) - self.assertEqual(False, result["success"]) - self.assertEqual(3, result["error-count"]) - self.assertEqual("bulk", result["error-type"]) + result.pop("error-description") # TODO not deterministic + assert result == { + "took": 20, + "index": "test", + "weight": 3, + "unit": "docs", + "success": False, + "success-count": 0, + "error-count": 3, + "error-type": "bulk", + } es.bulk.assert_awaited_with(body=bulk_params["body"], params={}) @@ -767,14 +777,18 @@ async def test_mixed_bulk_with_simple_stats(self, es): result = await bulk(es, bulk_params) - self.assertEqual("test", result["index"]) - self.assertEqual(30, result["took"]) - self.assertNotIn("ingest_took", result, "ingest_took is not extracted with simple stats") - self.assertEqual(4, result["weight"]) - self.assertEqual("docs", result["unit"]) - self.assertEqual(False, result["success"]) - self.assertEqual(2, result["error-count"]) - self.assertEqual("bulk", result["error-type"]) + result.pop("error-description") # TODO not deterministic + assert result == { + "took": 30, + "index": "test", + "weight": 4, + "unit": "docs", + "success": False, + "success-count": 2, + "error-count": 2, + "error-type": "bulk", + } + assert "ingest_took" not in result, "ingest_took is not extracted with simple stats" es.bulk.assert_awaited_with(body=bulk_params["body"], params={}) @@ -892,37 +906,35 @@ async def test_mixed_bulk_with_detailed_stats_body_as_string(self, es): result = await bulk(es, bulk_params) - self.assertEqual("test", result["index"]) - self.assertEqual(30, result["took"]) - self.assertEqual(20, result["ingest_took"]) - self.assertEqual(6, result["weight"]) - self.assertEqual("docs", result["unit"]) - self.assertEqual(False, result["success"]) - self.assertEqual(3, result["error-count"]) - self.assertEqual("bulk", result["error-type"]) - self.assertEqual( - { - "index": {"item-count": 4, "created": 2, "noop": 2}, - "update": {"item-count": 2, "updated": 1, "noop": 1}, + result.pop("error-description") # TODO not deterministic + assert result == { + "took": 30, + "ingest_took": 20, + "index": "test", + "weight": 6, + "unit": "docs", + "success": False, + "success-count": 3, + "error-count": 3, + "error-type": "bulk", + "ops": { + "index": collections.Counter({"item-count": 4, "created": 2, "noop": 2}), + "update": collections.Counter({"item-count": 2, "updated": 1, "noop": 1}), }, - result["ops"], - ) - self.assertEqual( - [ + "shards_histogram": [ {"item-count": 3, "shards": {"total": 2, "successful": 1, "failed": 0}}, {"item-count": 2, "shards": {"total": 2, "successful": 0, "failed": 2}}, {"item-count": 1, "shards": {"total": 2, "successful": 1, "failed": 1}}, ], - result["shards_histogram"], - ) - self.assertEqual(582, result["bulk-request-size-bytes"]) - self.assertEqual(234, result["total-document-size-bytes"]) + "bulk-request-size-bytes": 582, + "total-document-size-bytes": 234, + } es.bulk.assert_awaited_with(body=bulk_params["body"], params={}) es.bulk.return_value.pop("ingest_took") result = await bulk(es, bulk_params) - self.assertNotIn("ingest_took", result) + assert "ingest_took" not in result @mock.patch("elasticsearch.Elasticsearch") @run_async @@ -965,36 +977,33 @@ async def test_simple_bulk_with_detailed_stats_body_as_list(self, es): result = await bulk(es, bulk_params) - self.assertEqual("test", result["index"]) - self.assertEqual(30, result["took"]) - self.assertEqual(20, result["ingest_took"]) - self.assertEqual(1, result["weight"]) - self.assertEqual("docs", result["unit"]) - self.assertEqual(True, result["success"]) - self.assertEqual(0, result["error-count"]) - self.assertEqual( - { - "index": {"item-count": 1, "created": 1}, + assert result == { + "took": 30, + "ingest_took": 20, + "index": "test", + "weight": 1, + "unit": "docs", + "success": True, + "success-count": 1, + "error-count": 0, + "ops": { + "index": collections.Counter({"item-count": 1, "created": 1}), }, - result["ops"], - ) - self.assertEqual( - [ + "shards_histogram": [ { "item-count": 1, "shards": {"total": 2, "successful": 1, "failed": 0}, } ], - result["shards_histogram"], - ) - self.assertEqual(93, result["bulk-request-size-bytes"]) - self.assertEqual(39, result["total-document-size-bytes"]) + "bulk-request-size-bytes": 93, + "total-document-size-bytes": 39, + } es.bulk.assert_awaited_with(body=bulk_params["body"], params={}) es.bulk.return_value.pop("ingest_took") result = await bulk(es, bulk_params) - self.assertNotIn("ingest_took", result) + assert "ingest_took" not in result @mock.patch("elasticsearch.Elasticsearch") @run_async @@ -1034,36 +1043,33 @@ async def test_simple_bulk_with_detailed_stats_body_as_bytes(self, es): result = await bulk(es, bulk_params) - self.assertEqual("test", result["index"]) - self.assertEqual(30, result["took"]) - self.assertEqual(20, result["ingest_took"]) - self.assertEqual(1, result["weight"]) - self.assertEqual("docs", result["unit"]) - self.assertEqual(True, result["success"]) - self.assertEqual(0, result["error-count"]) - self.assertEqual( - { - "index": {"item-count": 1, "created": 1}, + assert result == { + "took": 30, + "ingest_took": 20, + "index": "test", + "weight": 1, + "unit": "docs", + "success": True, + "success-count": 1, + "error-count": 0, + "ops": { + "index": collections.Counter({"item-count": 1, "created": 1}), }, - result["ops"], - ) - self.assertEqual( - [ + "shards_histogram": [ { "item-count": 1, "shards": {"total": 1, "successful": 1, "failed": 0}, } ], - result["shards_histogram"], - ) - self.assertEqual(83, result["bulk-request-size-bytes"]) - self.assertEqual(27, result["total-document-size-bytes"]) + "bulk-request-size-bytes": 83, + "total-document-size-bytes": 27, + } es.bulk.assert_awaited_with(body=bulk_params["body"], params={}) es.bulk.return_value.pop("ingest_took") result = await bulk(es, bulk_params) - self.assertNotIn("ingest_took", result) + assert "ingest_took" not in result @mock.patch("elasticsearch.Elasticsearch") @run_async @@ -1103,7 +1109,7 @@ async def test_simple_bulk_with_detailed_stats_body_as_unrecognized_type(self, e "index": "test", } - with self.assertRaisesRegex(exceptions.DataError, "bulk body is not of type bytes, string, or list"): + with pytest.raises(exceptions.DataError, match="bulk body is not of type bytes, string, or list"): await bulk(es, bulk_params) es.bulk.assert_awaited_with(body=bulk_params["body"], params={}) @@ -1151,23 +1157,29 @@ async def test_bulk_index_error_logs_warning_with_detailed_stats_body(self, es): result = await bulk(es, bulk_params) mocked_warning_logger.assert_has_calls([mock.call("Bulk request failed: [%s]", result["error-description"])]) - self.assertEqual("test", result["index"]) - self.assertEqual(5, result["took"]) - self.assertEqual(1, result["weight"]) - self.assertEqual("docs", result["unit"]) - self.assertEqual(False, result["success"]) - self.assertEqual(1, result["error-count"]) - self.assertEqual("bulk", result["error-type"]) - self.assertEqual( - "HTTP status: 429, message: index [test] blocked by: [TOO_MANY_REQUESTS/12/disk usage " - "exceeded flood-stage watermark, index has read-only-allow-delete block];", - result["error-description"], - ) + assert result == { + "took": 5, + "index": "test", + "weight": 1, + "unit": "docs", + "success": False, + "success-count": 0, + "error-count": 1, + "error-type": "bulk", + "error-description": ( + "HTTP status: 429, message: index [test] blocked by: [TOO_MANY_REQUESTS/12/disk usage " + "exceeded flood-stage watermark, index has read-only-allow-delete block];" + ), + "ops": {"create": collections.Counter({"item-count": 1})}, + "shards_histogram": [], + "total-document-size-bytes": 27, + "bulk-request-size-bytes": 80, + } es.bulk.assert_awaited_with(body=bulk_params["body"], params={}) -class ForceMergeRunnerTests(TestCase): +class TestForceMergeRunner: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_force_merge_with_defaults(self, es): @@ -1322,16 +1334,18 @@ async def test_force_merge_with_polling_and_params(self, es): es.indices.forcemerge.assert_awaited_once_with(index="_all", max_num_segments=1, request_timeout=50000) -class IndicesStatsRunnerTests(TestCase): +class TestIndicesStatsRunner: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_indices_stats_without_parameters(self, es): es.indices.stats = mock.AsyncMock(return_value={}) indices_stats = runner.IndicesStats() result = await indices_stats(es, params={}) - self.assertEqual(1, result["weight"]) - self.assertEqual("ops", result["unit"]) - self.assertTrue(result["success"]) + assert result == { + "weight": 1, + "unit": "ops", + "success": True, + } es.indices.stats.assert_awaited_once_with(index="_all", metric="_all") @@ -1348,9 +1362,11 @@ async def test_indices_stats_with_timeout_and_headers(self, es): "opaque-id": "test-id1", }, ) - self.assertEqual(1, result["weight"]) - self.assertEqual("ops", result["unit"]) - self.assertTrue(result["success"]) + assert result == { + "weight": 1, + "unit": "ops", + "success": True, + } es.indices.stats.assert_awaited_once_with( index="_all", metric="_all", headers={"header1": "value1"}, opaque_id="test-id1", request_timeout=3.0 @@ -1377,17 +1393,16 @@ async def test_indices_stats_with_failed_condition(self, es): result = await indices_stats( es, params={"index": "logs-*", "condition": {"path": "_all.total.merges.current", "expected-value": 0}} ) - self.assertEqual(1, result["weight"]) - self.assertEqual("ops", result["unit"]) - self.assertFalse(result["success"]) - self.assertDictEqual( - { + assert result == { + "weight": 1, + "unit": "ops", + "success": False, + "condition": { "path": "_all.total.merges.current", "actual-value": "2", "expected-value": "0", }, - result["condition"], - ) + } es.indices.stats.assert_awaited_once_with(index="logs-*", metric="_all") @@ -1419,17 +1434,16 @@ async def test_indices_stats_with_successful_condition(self, es): }, }, ) - self.assertEqual(1, result["weight"]) - self.assertEqual("ops", result["unit"]) - self.assertTrue(result["success"]) - self.assertDictEqual( - { + assert result == { + "weight": 1, + "unit": "ops", + "success": True, + "condition": { "path": "_all.total.merges.current", "actual-value": "0", "expected-value": "0", }, - result["condition"], - ) + } es.indices.stats.assert_awaited_once_with(index="logs-*", metric="_all") @@ -1451,22 +1465,21 @@ async def test_indices_stats_with_non_existing_path(self, es): }, }, ) - self.assertEqual(1, result["weight"]) - self.assertEqual("ops", result["unit"]) - self.assertFalse(result["success"]) - self.assertDictEqual( - { + assert result == { + "weight": 1, + "unit": "ops", + "success": False, + "condition": { "path": "indices.my_index.total.docs.count", "actual-value": None, "expected-value": "0", }, - result["condition"], - ) + } es.indices.stats.assert_awaited_once_with(index="logs-*", metric="_all") -class QueryRunnerTests(TestCase): +class TestQueryRunner: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_query_match_only_request_body_defined(self, es): @@ -1503,13 +1516,15 @@ async def test_query_match_only_request_body_defined(self, es): async with query_runner: result = await query_runner(es, params) - self.assertEqual(1, result["weight"]) - self.assertEqual("ops", result["unit"]) - self.assertEqual(1, result["hits"]) - self.assertEqual("gte", result["hits_relation"]) - self.assertFalse(result["timed_out"]) - self.assertEqual(5, result["took"]) - self.assertFalse("error-type" in result) + assert result == { + "weight": 1, + "unit": "ops", + "success": True, + "hits": 1, + "hits_relation": "gte", + "timed_out": False, + "took": 5, + } es.transport.perform_request.assert_awaited_once_with( "GET", "/_all/_search", params={"request_cache": "true"}, body=params["body"], headers=None @@ -1548,13 +1563,15 @@ async def test_query_with_timeout_and_headers(self, es): async with query_runner: result = await query_runner(es, params) - self.assertEqual(1, result["weight"]) - self.assertEqual("ops", result["unit"]) - self.assertEqual(1, result["hits"]) - self.assertEqual("gte", result["hits_relation"]) - self.assertFalse(result["timed_out"]) - self.assertEqual(5, result["took"]) - self.assertFalse("error-type" in result) + assert result == { + "weight": 1, + "unit": "ops", + "success": True, + "hits": 1, + "hits_relation": "gte", + "timed_out": False, + "took": 5, + } es.transport.perform_request.assert_awaited_once_with( "GET", @@ -1599,13 +1616,15 @@ async def test_query_match_using_request_params(self, es): async with query_runner: result = await query_runner(es, params) - self.assertEqual(1, result["weight"]) - self.assertEqual("ops", result["unit"]) - self.assertEqual(2, result["hits"]) - self.assertEqual("eq", result["hits_relation"]) - self.assertFalse(result["timed_out"]) - self.assertEqual(62, result["took"]) - self.assertFalse("error-type" in result) + assert result == { + "weight": 1, + "unit": "ops", + "success": True, + "hits": 2, + "hits_relation": "eq", + "timed_out": False, + "took": 62, + } es.transport.perform_request.assert_awaited_once_with( "GET", @@ -1650,13 +1669,16 @@ async def test_query_no_detailed_results(self, es): async with query_runner: result = await query_runner(es, params) - self.assertEqual(1, result["weight"]) - self.assertEqual("ops", result["unit"]) - self.assertNotIn("hits", result) - self.assertNotIn("hits_relation", result) - self.assertNotIn("timed_out", result) - self.assertNotIn("took", result) - self.assertNotIn("error-type", result) + assert result == { + "weight": 1, + "unit": "ops", + "success": True, + } + assert "hits" not in result + assert "hits_relation" not in result + assert "timed_out" not in result + assert "took" not in result + assert "error-type" not in result es.transport.perform_request.assert_awaited_once_with( "GET", @@ -1702,13 +1724,15 @@ async def test_query_hits_total_as_number(self, es): async with query_runner: result = await query_runner(es, params) - self.assertEqual(1, result["weight"]) - self.assertEqual("ops", result["unit"]) - self.assertEqual(2, result["hits"]) - self.assertEqual("eq", result["hits_relation"]) - self.assertFalse(result["timed_out"]) - self.assertEqual(5, result["took"]) - self.assertFalse("error-type" in result) + assert result == { + "weight": 1, + "unit": "ops", + "success": True, + "hits": 2, + "hits_relation": "eq", + "timed_out": False, + "took": 5, + } es.transport.perform_request.assert_awaited_once_with( "GET", @@ -1757,13 +1781,15 @@ async def test_query_match_all(self, es): async with query_runner: result = await query_runner(es, params) - self.assertEqual(1, result["weight"]) - self.assertEqual("ops", result["unit"]) - self.assertEqual(2, result["hits"]) - self.assertEqual("eq", result["hits_relation"]) - self.assertFalse(result["timed_out"]) - self.assertEqual(5, result["took"]) - self.assertFalse("error-type" in result) + assert result == { + "weight": 1, + "unit": "ops", + "success": True, + "hits": 2, + "hits_relation": "eq", + "timed_out": False, + "took": 5, + } es.transport.perform_request.assert_awaited_once_with( "GET", @@ -1811,13 +1837,15 @@ async def test_query_match_all_doc_type_fallback(self, es): async with query_runner: result = await query_runner(es, params) - self.assertEqual(1, result["weight"]) - self.assertEqual("ops", result["unit"]) - self.assertEqual(2, result["hits"]) - self.assertEqual("eq", result["hits_relation"]) - self.assertFalse(result["timed_out"]) - self.assertEqual(5, result["took"]) - self.assertFalse("error-type" in result) + assert result == { + "weight": 1, + "unit": "ops", + "success": True, + "hits": 2, + "hits_relation": "eq", + "timed_out": False, + "took": 5, + } es.transport.perform_request.assert_awaited_once_with( "GET", @@ -1866,14 +1894,16 @@ async def test_scroll_query_only_one_page(self, es): async with query_runner: results = await query_runner(es, params) - self.assertEqual(1, results["weight"]) - self.assertEqual(1, results["pages"]) - self.assertEqual(2, results["hits"]) - self.assertEqual("eq", results["hits_relation"]) - self.assertEqual(4, results["took"]) - self.assertEqual("pages", results["unit"]) - self.assertFalse(results["timed_out"]) - self.assertFalse("error-type" in results) + assert results == { + "weight": 1, + "pages": 1, + "unit": "pages", + "hits": 2, + "hits_relation": "eq", + "timed_out": False, + "took": 4, + } + assert "error-type" not in results es.transport.perform_request.assert_awaited_once_with( "GET", @@ -1927,14 +1957,16 @@ async def test_scroll_query_no_request_cache(self, es): async with query_runner: results = await query_runner(es, params) - self.assertEqual(1, results["weight"]) - self.assertEqual(1, results["pages"]) - self.assertEqual(2, results["hits"]) - self.assertEqual("eq", results["hits_relation"]) - self.assertEqual(4, results["took"]) - self.assertEqual("pages", results["unit"]) - self.assertFalse(results["timed_out"]) - self.assertFalse("error-type" in results) + assert results == { + "weight": 1, + "pages": 1, + "unit": "pages", + "hits": 2, + "hits_relation": "eq", + "timed_out": False, + "took": 4, + } + assert "error-type" not in results es.transport.perform_request.assert_awaited_once_with( "GET", @@ -1982,14 +2014,16 @@ async def test_scroll_query_only_one_page_only_request_body_defined(self, es): async with query_runner: results = await query_runner(es, params) - self.assertEqual(1, results["weight"]) - self.assertEqual(1, results["pages"]) - self.assertEqual(2, results["hits"]) - self.assertEqual("eq", results["hits_relation"]) - self.assertEqual(4, results["took"]) - self.assertEqual("pages", results["unit"]) - self.assertFalse(results["timed_out"]) - self.assertFalse("error-type" in results) + assert results == { + "weight": 1, + "pages": 1, + "unit": "pages", + "hits": 2, + "hits_relation": "eq", + "timed_out": False, + "took": 4, + } + assert "error-type" not in results es.transport.perform_request.assert_awaited_once_with( "GET", @@ -2065,14 +2099,16 @@ async def test_scroll_query_with_explicit_number_of_pages(self, es): async with query_runner: results = await query_runner(es, params) - self.assertEqual(2, results["weight"]) - self.assertEqual(2, results["pages"]) - self.assertEqual(3, results["hits"]) - self.assertEqual("eq", results["hits_relation"]) - self.assertEqual(79, results["took"]) - self.assertEqual("pages", results["unit"]) - self.assertTrue(results["timed_out"]) - self.assertFalse("error-type" in results) + assert results == { + "weight": 2, + "pages": 2, + "unit": "pages", + "hits": 3, + "hits_relation": "eq", + "timed_out": True, + "took": 79, + } + assert "error-type" not in results es.clear_scroll.assert_awaited_once_with(body={"scroll_id": ["some-scroll-id"]}) @@ -2113,13 +2149,16 @@ async def test_scroll_query_cannot_clear_scroll(self, es): async with query_runner: results = await query_runner(es, params) - self.assertEqual(1, results["weight"]) - self.assertEqual(1, results["pages"]) - self.assertEqual(1, results["hits"]) - self.assertEqual("eq", results["hits_relation"]) - self.assertEqual("pages", results["unit"]) - self.assertEqual(53, results["took"]) - self.assertFalse("error-type" in results) + assert results == { + "weight": 1, + "pages": 1, + "unit": "pages", + "hits": 1, + "hits_relation": "eq", + "timed_out": False, + "took": 53, + } + assert "error-type" not in results es.clear_scroll.assert_awaited_once_with(body={"scroll_id": ["some-scroll-id"]}) @@ -2178,14 +2217,16 @@ async def test_scroll_query_request_all_pages(self, es): async with query_runner: results = await query_runner(es, params) - self.assertEqual(2, results["weight"]) - self.assertEqual(2, results["pages"]) - self.assertEqual(4, results["hits"]) - self.assertEqual("gte", results["hits_relation"]) - self.assertEqual(878, results["took"]) - self.assertEqual("pages", results["unit"]) - self.assertFalse(results["timed_out"]) - self.assertFalse("error-type" in results) + assert results == { + "weight": 2, + "pages": 2, + "unit": "pages", + "hits": 4, + "hits_relation": "gte", + "timed_out": False, + "took": 878, + } + assert "error-type" not in results es.clear_scroll.assert_awaited_once_with(body={"scroll_id": ["some-scroll-id"]}) @@ -2235,14 +2276,16 @@ async def test_query_runner_search_with_pages_logs_warning_and_executes(self, es ] ) - self.assertEqual(1, results["weight"]) - self.assertEqual(1, results["pages"]) - self.assertEqual(2, results["hits"]) - self.assertEqual("eq", results["hits_relation"]) - self.assertEqual(4, results["took"]) - self.assertEqual("pages", results["unit"]) - self.assertFalse(results["timed_out"]) - self.assertFalse("error-type" in results) + assert results == { + "weight": 1, + "pages": 1, + "unit": "pages", + "hits": 2, + "hits_relation": "eq", + "timed_out": False, + "took": 4, + } + assert "error-type" not in results @mock.patch("elasticsearch.Elasticsearch") @run_async @@ -2259,15 +2302,12 @@ async def test_query_runner_fails_with_unknown_operation_type(self, es): }, } - with self.assertRaises(exceptions.RallyError) as ctx: + with pytest.raises(exceptions.RallyError) as exc: await query_runner(es, params) - self.assertEqual( - "No runner available for operation-type: [unknown]", - ctx.exception.args[0], - ) + assert exc.value.args[0] == "No runner available for operation-type: [unknown]" -class PutPipelineRunnerTests(TestCase): +class TestPutPipelineRunner: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_create_pipeline(self, es): @@ -2299,14 +2339,14 @@ async def test_param_body_mandatory(self, es): r = runner.PutPipeline() params = {"id": "rename"} - with self.assertRaisesRegex( + with pytest.raises( exceptions.DataError, - "Parameter source for operation 'put-pipeline' did not provide the mandatory parameter 'body'. " + match="Parameter source for operation 'put-pipeline' did not provide the mandatory parameter 'body'. " "Add it to your parameter source and try again.", ): await r(es, params) - self.assertEqual(0, es.ingest.put_pipeline.await_count) + assert es.ingest.put_pipeline.await_count == 0 @mock.patch("elasticsearch.Elasticsearch") @run_async @@ -2316,17 +2356,17 @@ async def test_param_id_mandatory(self, es): r = runner.PutPipeline() params = {"body": {}} - with self.assertRaisesRegex( + with pytest.raises( exceptions.DataError, - "Parameter source for operation 'put-pipeline' did not provide the mandatory parameter 'id'. " + match="Parameter source for operation 'put-pipeline' did not provide the mandatory parameter 'id'. " "Add it to your parameter source and try again.", ): await r(es, params) - self.assertEqual(0, es.ingest.put_pipeline.await_count) + assert es.ingest.put_pipeline.await_count == 0 -class ClusterHealthRunnerTests(TestCase): +class TestClusterHealthRunner: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_waits_for_expected_cluster_status(self, es): @@ -2337,16 +2377,13 @@ async def test_waits_for_expected_cluster_status(self, es): result = await r(es, params) - self.assertDictEqual( - { - "weight": 1, - "unit": "ops", - "success": True, - "cluster-status": "green", - "relocating-shards": 0, - }, - result, - ) + assert result == { + "weight": 1, + "unit": "ops", + "success": True, + "cluster-status": "green", + "relocating-shards": 0, + } es.cluster.health.assert_awaited_once_with(params={"wait_for_status": "green"}) @@ -2360,16 +2397,13 @@ async def test_accepts_better_cluster_status(self, es): result = await r(es, params) - self.assertDictEqual( - { - "weight": 1, - "unit": "ops", - "success": True, - "cluster-status": "green", - "relocating-shards": 0, - }, - result, - ) + assert result == { + "weight": 1, + "unit": "ops", + "success": True, + "cluster-status": "green", + "relocating-shards": 0, + } es.cluster.health.assert_awaited_once_with(params={"wait_for_status": "yellow"}) @@ -2388,16 +2422,13 @@ async def test_cluster_health_with_timeout_and_headers(self, es): result = await cluster_health_runner(es, params) - self.assertDictEqual( - { - "weight": 1, - "unit": "ops", - "success": True, - "cluster-status": "green", - "relocating-shards": 0, - }, - result, - ) + assert result == { + "weight": 1, + "unit": "ops", + "success": True, + "cluster-status": "green", + "relocating-shards": 0, + } es.cluster.health.assert_awaited_once_with( headers={"header1": "value1"}, opaque_id="testid-1", params={"wait_for_status": "yellow"}, request_timeout=3.0 @@ -2419,16 +2450,13 @@ async def test_rejects_relocating_shards(self, es): result = await r(es, params) - self.assertDictEqual( - { - "weight": 1, - "unit": "ops", - "success": False, - "cluster-status": "yellow", - "relocating-shards": 3, - }, - result, - ) + assert result == { + "weight": 1, + "unit": "ops", + "success": False, + "cluster-status": "yellow", + "relocating-shards": 3, + } es.cluster.health.assert_awaited_once_with(index="logs-*", params={"wait_for_status": "red", "wait_for_no_relocating_shards": True}) @@ -2442,21 +2470,18 @@ async def test_rejects_unknown_cluster_status(self, es): result = await r(es, params) - self.assertDictEqual( - { - "weight": 1, - "unit": "ops", - "success": False, - "cluster-status": None, - "relocating-shards": 0, - }, - result, - ) + assert result == { + "weight": 1, + "unit": "ops", + "success": False, + "cluster-status": None, + "relocating-shards": 0, + } es.cluster.health.assert_awaited_once_with(params={"wait_for_status": "green"}) -class CreateIndexRunnerTests(TestCase): +class TestCreateIndexRunner: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_creates_multiple_indices(self, es): @@ -2476,7 +2501,11 @@ async def test_creates_multiple_indices(self, es): result = await r(es, params) - self.assertDictEqual({"weight": 2, "unit": "ops", "success": True}, result) + assert result == { + "weight": 2, + "unit": "ops", + "success": True, + } es.indices.create.assert_has_awaits( [ @@ -2506,7 +2535,11 @@ async def test_create_with_timeout_and_headers(self, es): result = await create_index_runner(es, params) - self.assertDictEqual({"weight": 1, "unit": "ops", "success": True}, result) + assert result == { + "weight": 1, + "unit": "ops", + "success": True, + } es.indices.create.assert_awaited_once_with( index="indexA", @@ -2537,7 +2570,11 @@ async def test_ignore_invalid_params(self, es): result = await r(es, params) - self.assertDictEqual({"weight": 1, "unit": "ops", "success": True}, result) + assert result == { + "weight": 1, + "unit": "ops", + "success": True, + } es.indices.create.assert_awaited_once_with(index="indexA", body={"settings": {}}, params={"wait_for_active_shards": "true"}) @@ -2549,17 +2586,17 @@ async def test_param_indices_mandatory(self, es): r = runner.CreateIndex() params = {} - with self.assertRaisesRegex( + with pytest.raises( exceptions.DataError, - "Parameter source for operation 'create-index' did not provide the mandatory parameter 'indices'. " + match="Parameter source for operation 'create-index' did not provide the mandatory parameter 'indices'. " "Add it to your parameter source and try again.", ): await r(es, params) - self.assertEqual(0, es.indices.create.await_count) + assert es.indices.create.await_count == 0 -class CreateDataStreamRunnerTests(TestCase): +class TestCreateDataStreamRunner: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_creates_multiple_data_streams(self, es): @@ -2579,7 +2616,11 @@ async def test_creates_multiple_data_streams(self, es): result = await r(es, params) - self.assertDictEqual({"weight": 2, "unit": "ops", "success": True}, result) + assert result == { + "weight": 2, + "unit": "ops", + "success": True, + } es.indices.create_data_stream.assert_has_awaits( [ @@ -2596,17 +2637,17 @@ async def test_param_data_streams_mandatory(self, es): r = runner.CreateDataStream() params = {} - with self.assertRaisesRegex( + with pytest.raises( exceptions.DataError, - "Parameter source for operation 'create-data-stream' did not provide the " + match="Parameter source for operation 'create-data-stream' did not provide the " "mandatory parameter 'data-streams'. Add it to your parameter source and try again.", ): await r(es, params) - self.assertEqual(0, es.indices.create_data_stream.await_count) + assert es.indices.create_data_stream.await_count == 0 -class DeleteIndexRunnerTests(TestCase): +class TestDeleteIndexRunner: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_deletes_existing_indices(self, es): @@ -2620,7 +2661,11 @@ async def test_deletes_existing_indices(self, es): result = await r(es, params) - self.assertDictEqual({"weight": 1, "unit": "ops", "success": True}, result) + assert result == { + "weight": 1, + "unit": "ops", + "success": True, + } es.cluster.put_settings.assert_has_awaits( [ @@ -2646,7 +2691,11 @@ async def test_deletes_all_indices(self, es): result = await r(es, params) - self.assertDictEqual({"weight": 2, "unit": "ops", "success": True}, result) + assert result == { + "weight": 2, + "unit": "ops", + "success": True, + } es.cluster.put_settings.assert_has_awaits( [ @@ -2660,10 +2709,10 @@ async def test_deletes_all_indices(self, es): mock.call(index="indexB", params=params["request-params"]), ] ) - self.assertEqual(0, es.indices.exists.call_count) + assert es.indices.exists.call_count == 0 -class DeleteDataStreamRunnerTests(TestCase): +class TestDeleteDataStreamRunner: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_deletes_existing_data_streams(self, es): @@ -2676,7 +2725,11 @@ async def test_deletes_existing_data_streams(self, es): result = await r(es, params) - self.assertDictEqual({"weight": 1, "unit": "ops", "success": True}, result) + assert result == { + "weight": 1, + "unit": "ops", + "success": True, + } es.indices.delete_data_stream.assert_awaited_once_with("data-stream-B", params={}) @@ -2696,7 +2749,11 @@ async def test_deletes_all_data_streams(self, es): result = await r(es, params) - self.assertDictEqual({"weight": 2, "unit": "ops", "success": True}, result) + assert result == { + "weight": 2, + "unit": "ops", + "success": True, + } es.indices.delete_data_stream.assert_has_awaits( [ @@ -2704,10 +2761,10 @@ async def test_deletes_all_data_streams(self, es): mock.call("data-stream-B", ignore=[404], params=params["request-params"]), ] ) - self.assertEqual(0, es.indices.exists.await_count) + assert es.indices.exists.await_count == 0 -class CreateIndexTemplateRunnerTests(TestCase): +class TestCreateIndexTemplateRunner: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_create_index_templates(self, es): @@ -2725,7 +2782,11 @@ async def test_create_index_templates(self, es): result = await r(es, params) - self.assertDictEqual({"weight": 2, "unit": "ops", "success": True}, result) + assert result == { + "weight": 2, + "unit": "ops", + "success": True, + } es.indices.put_template.assert_has_awaits( [ @@ -2742,17 +2803,17 @@ async def test_param_templates_mandatory(self, es): r = runner.CreateIndexTemplate() params = {} - with self.assertRaisesRegex( + with pytest.raises( exceptions.DataError, - "Parameter source for operation 'create-index-template' did not provide the mandatory parameter " + match="Parameter source for operation 'create-index-template' did not provide the mandatory parameter " "'templates'. Add it to your parameter source and try again.", ): await r(es, params) - self.assertEqual(0, es.indices.put_template.await_count) + assert es.indices.put_template.await_count == 0 -class DeleteIndexTemplateRunnerTests(TestCase): +class TestDeleteIndexTemplateRunner: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_deletes_all_index_templates(self, es): @@ -2771,7 +2832,11 @@ async def test_deletes_all_index_templates(self, es): result = await r(es, params) # 2 times delete index template, one time delete matching indices - self.assertDictEqual({"weight": 3, "unit": "ops", "success": True}, result) + assert result == { + "weight": 3, + "unit": "ops", + "success": True, + } es.indices.delete_template.assert_has_awaits( [mock.call(name="templateA", params=params["request-params"]), mock.call(name="templateB", params=params["request-params"])] @@ -2799,11 +2864,15 @@ async def test_deletes_only_existing_index_templates(self, es): result = await r(es, params) # 2 times delete index template, one time delete matching indices - self.assertDictEqual({"weight": 1, "unit": "ops", "success": True}, result) + assert result == { + "weight": 1, + "unit": "ops", + "success": True, + } es.indices.delete_template.assert_awaited_once_with(name="templateB", params=params["request-params"]) # not called because the matching index is empty. - self.assertEqual(0, es.indices.delete.await_count) + assert es.indices.delete.await_count == 0 @mock.patch("elasticsearch.Elasticsearch") @run_async @@ -2812,17 +2881,17 @@ async def test_param_templates_mandatory(self, es): r = runner.DeleteIndexTemplate() params = {} - with self.assertRaisesRegex( + with pytest.raises( exceptions.DataError, - "Parameter source for operation 'delete-index-template' did not provide the mandatory parameter " + match="Parameter source for operation 'delete-index-template' did not provide the mandatory parameter " "'templates'. Add it to your parameter source and try again.", ): await r(es, params) - self.assertEqual(0, es.indices.delete_template.await_count) + assert es.indices.delete_template.await_count == 0 -class CreateComponentTemplateRunnerTests(TestCase): +class TestCreateComponentTemplateRunner: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_create_index_templates(self, es): @@ -2837,7 +2906,11 @@ async def test_create_index_templates(self, es): } result = await r(es, params) - self.assertDictEqual({"weight": 2, "unit": "ops", "success": True}, result) + assert result == { + "weight": 2, + "unit": "ops", + "success": True, + } es.cluster.put_component_template.assert_has_awaits( [ mock.call( @@ -2861,17 +2934,17 @@ async def test_param_templates_mandatory(self, es): r = runner.CreateComponentTemplate() params = {} - with self.assertRaisesRegex( + with pytest.raises( exceptions.DataError, - "Parameter source for operation 'create-component-template' did not provide the mandatory parameter " + match="Parameter source for operation 'create-component-template' did not provide the mandatory parameter " "'templates'. Add it to your parameter source and try again.", ): await r(es, params) - self.assertEqual(0, es.cluster.put_component_template.await_count) + assert es.cluster.put_component_template.await_count == 0 -class DeleteComponentTemplateRunnerTests(TestCase): +class TestDeleteComponentTemplateRunner: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_deletes_all_index_templates(self, es): @@ -2888,7 +2961,11 @@ async def test_deletes_all_index_templates(self, es): "only-if-exists": False, } result = await r(es, params) - self.assertDictEqual({"weight": 2, "unit": "ops", "success": True}, result) + assert result == { + "weight": 2, + "unit": "ops", + "success": True, + } es.cluster.delete_component_template.assert_has_awaits( [ @@ -2919,7 +2996,11 @@ async def _side_effect(http_method, path): } result = await r(es, params) - self.assertDictEqual({"weight": 1, "unit": "ops", "success": True}, result) + assert result == { + "weight": 1, + "unit": "ops", + "success": True, + } es.cluster.delete_component_template.assert_awaited_once_with(name="templateB", params=params["request-params"]) @@ -2930,17 +3011,17 @@ async def test_param_templates_mandatory(self, es): r = runner.DeleteComponentTemplate() params = {} - with self.assertRaisesRegex( + with pytest.raises( exceptions.DataError, - "Parameter source for operation 'delete-component-template' did not provide the mandatory parameter " + match="Parameter source for operation 'delete-component-template' did not provide the mandatory parameter " "'templates'. Add it to your parameter source and try again.", ): await r(es, params) - self.assertEqual(0, es.indices.delete_template.await_count) + assert es.indices.delete_template.await_count == 0 -class CreateComposableTemplateRunnerTests(TestCase): +class TestCreateComposableTemplateRunner: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_create_index_templates(self, es): @@ -2969,14 +3050,11 @@ async def test_create_index_templates(self, es): } result = await r(es, params) - self.assertDictEqual( - { - "weight": 2, - "unit": "ops", - "success": True, - }, - result, - ) + assert result == { + "weight": 2, + "unit": "ops", + "success": True, + } es.indices.put_index_template.assert_has_awaits( [ mock.call( @@ -3016,17 +3094,17 @@ async def test_param_templates_mandatory(self, es): r = runner.CreateComposableTemplate() params = {} - with self.assertRaisesRegex( + with pytest.raises( exceptions.DataError, - "Parameter source for operation 'create-composable-template' did not provide the mandatory parameter " + match="Parameter source for operation 'create-composable-template' did not provide the mandatory parameter " "'templates'. Add it to your parameter source and try again.", ): await r(es, params) - self.assertEqual(0, es.indices.put_index_template.await_count) + assert es.indices.put_index_template.await_count == 0 -class DeleteComposableTemplateRunnerTests(TestCase): +class TestDeleteComposableTemplateRunner: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_deletes_all_index_templates(self, es): @@ -3046,7 +3124,11 @@ async def test_deletes_all_index_templates(self, es): result = await r(es, params) # 2 times delete index template, one time delete matching indices - self.assertDictEqual({"weight": 3, "unit": "ops", "success": True}, result) + assert result == { + "weight": 3, + "unit": "ops", + "success": True, + } es.indices.delete_index_template.assert_has_awaits( [ @@ -3076,11 +3158,15 @@ async def test_deletes_only_existing_index_templates(self, es): result = await r(es, params) # 2 times delete index template, one time delete matching indices - self.assertDictEqual({"weight": 1, "unit": "ops", "success": True}, result) + assert result == { + "weight": 1, + "unit": "ops", + "success": True, + } es.indices.delete_index_template.assert_awaited_once_with(name="templateB", params=params["request-params"]) # not called because the matching index is empty. - self.assertEqual(0, es.indices.delete.call_count) + assert es.indices.delete.call_count == 0 @mock.patch("elasticsearch.Elasticsearch") @run_async @@ -3088,17 +3174,17 @@ async def test_param_templates_mandatory(self, es): r = runner.DeleteComposableTemplate() params = {} - with self.assertRaisesRegex( + with pytest.raises( exceptions.DataError, - "Parameter source for operation 'delete-composable-template' did not provide the mandatory parameter " + match="Parameter source for operation 'delete-composable-template' did not provide the mandatory parameter " "'templates'. Add it to your parameter source and try again.", ): await r(es, params) - self.assertEqual(0, es.indices.delete_index_template.call_count) + assert es.indices.delete_index_template.call_count == 0 -class CreateMlDatafeedTests(TestCase): +class TestCreateMlDatafeed: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_create_ml_datafeed(self, es): @@ -3126,7 +3212,7 @@ async def test_create_ml_datafeed_fallback(self, es): es.transport.perform_request.assert_awaited_once_with("PUT", f"/_xpack/ml/datafeeds/{datafeed_id}", body=body) -class DeleteMlDatafeedTests(TestCase): +class TestDeleteMlDatafeed: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_delete_ml_datafeed(self, es): @@ -3158,7 +3244,7 @@ async def test_delete_ml_datafeed_fallback(self, es): ) -class StartMlDatafeedTests(TestCase): +class TestStartMlDatafeed: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_start_ml_datafeed_with_body(self, es): @@ -3204,7 +3290,7 @@ async def test_start_ml_datafeed_with_params(self, es): ) -class StopMlDatafeedTests(TestCase): +class TestStopMlDatafeed: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_stop_ml_datafeed(self, es): @@ -3244,7 +3330,7 @@ async def test_stop_ml_datafeed_fallback(self, es): ) -class CreateMlJobTests(TestCase): +class TestCreateMlJob: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_create_ml_job(self, es): @@ -3301,7 +3387,7 @@ async def test_create_ml_job_fallback(self, es): es.transport.perform_request.assert_awaited_once_with("PUT", f"/_xpack/ml/anomaly_detectors/{params['job-id']}", body=body) -class DeleteMlJobTests(TestCase): +class TestDeleteMlJob: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_delete_ml_job(self, es): @@ -3332,7 +3418,7 @@ async def test_delete_ml_job_fallback(self, es): ) -class OpenMlJobTests(TestCase): +class TestOpenMlJob: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_open_ml_job(self, es): @@ -3361,7 +3447,7 @@ async def test_open_ml_job_fallback(self, es): es.transport.perform_request.assert_awaited_once_with("POST", f"/_xpack/ml/anomaly_detectors/{params['job-id']}/_open") -class CloseMlJobTests(TestCase): +class TestCloseMlJob: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_close_ml_job(self, es): @@ -3399,7 +3485,7 @@ async def test_close_ml_job_fallback(self, es): ) -class RawRequestRunnerTests(TestCase): +class TestRawRequestRunner: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_raises_missing_slash(self, es): @@ -3409,9 +3495,9 @@ async def test_raises_missing_slash(self, es): params = {"path": "_cat/count"} with mock.patch.object(r.logger, "error") as mocked_error_logger: - with self.assertRaises(exceptions.RallyAssertionError) as ctx: + with pytest.raises(exceptions.RallyAssertionError) as exc: await r(es, params) - self.assertEqual("RawRequest [_cat/count] failed. Path parameter must begin with a '/'.", ctx.exception.args[0]) + assert exc.value.args[0] == "RawRequest [_cat/count] failed. Path parameter must begin with a '/'." mocked_error_logger.assert_has_calls( [mock.call("RawRequest failed. Path parameter: [%s] must begin with a '/'.", params["path"])] ) @@ -3535,24 +3621,24 @@ async def test_raw_with_timeout_and_opaqueid(self, es): ) -class SleepTests(TestCase): +class TestSleep: @mock.patch("elasticsearch.Elasticsearch") # To avoid real sleeps in unit tests @mock.patch("asyncio.sleep") @run_async async def test_missing_parameter(self, sleep, es): r = runner.Sleep() - with self.assertRaisesRegex( + with pytest.raises( exceptions.DataError, - "Parameter source for operation 'sleep' did not provide the mandatory parameter " + match="Parameter source for operation 'sleep' did not provide the mandatory parameter " "'duration'. Add it to your parameter source and try again.", ): await r(es, params={}) - self.assertEqual(0, es.call_count) - self.assertEqual(1, es.on_request_start.call_count) - self.assertEqual(1, es.on_request_end.call_count) - self.assertEqual(0, sleep.call_count) + assert es.call_count == 0 + assert es.on_request_start.call_count == 1 + assert es.on_request_end.call_count == 1 + assert sleep.call_count == 0 @mock.patch("elasticsearch.Elasticsearch") # To avoid real sleeps in unit tests @@ -3562,13 +3648,13 @@ async def test_sleep(self, sleep, es): r = runner.Sleep() await r(es, params={"duration": 4.3}) - self.assertEqual(0, es.call_count) - self.assertEqual(1, es.on_request_start.call_count) - self.assertEqual(1, es.on_request_end.call_count) + assert es.call_count == 0 + assert es.on_request_start.call_count == 1 + assert es.on_request_end.call_count == 1 sleep.assert_called_once_with(4.3) -class DeleteSnapshotRepositoryTests(TestCase): +class TestDeleteSnapshotRepository: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_delete_snapshot_repository(self, es): @@ -3581,7 +3667,7 @@ async def test_delete_snapshot_repository(self, es): es.snapshot.delete_repository.assert_called_once_with(repository="backups") -class CreateSnapshotRepositoryTests(TestCase): +class TestCreateSnapshotRepository: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_create_snapshot_repository(self, es): @@ -3604,7 +3690,7 @@ async def test_create_snapshot_repository(self, es): ) -class CreateSnapshotTests(TestCase): +class TestCreateSnapshot: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_create_snapshot_no_wait(self, es): @@ -3673,7 +3759,7 @@ async def test_create_snapshot_wait_for_completion(self, es): ) -class WaitForSnapshotCreateTests(TestCase): +class TestWaitForSnapshotCreate: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_wait_for_snapshot_create_entire_lifecycle(self, es): @@ -3769,21 +3855,18 @@ async def test_wait_for_snapshot_create_entire_lifecycle(self, es): es.snapshot.status.assert_awaited_with(repository="restore_speed", snapshot="restore_speed_snapshot", ignore_unavailable=True) - self.assertDictEqual( - { - "weight": 243468188055, - "unit": "byte", - "success": True, - "duration": 1113462, - "file_count": 204, - "throughput": 218658731.10622546, - "start_time_millis": 1597317564956, - "stop_time_millis": 1597317564956 + 1113462, - }, - result, - ) + assert result == { + "weight": 243468188055, + "unit": "byte", + "success": True, + "duration": 1113462, + "file_count": 204, + "throughput": 218658731.10622546, + "start_time_millis": 1597317564956, + "stop_time_millis": 1597317564956 + 1113462, + } - self.assertEqual(3, es.snapshot.status.await_count) + assert es.snapshot.status.await_count == 3 @mock.patch("elasticsearch.Elasticsearch") @run_async @@ -3818,19 +3901,16 @@ async def test_wait_for_snapshot_create_immediate_success(self, es): r = runner.WaitForSnapshotCreate() result = await r(es, params) - self.assertDictEqual( - { - "weight": 9399505, - "unit": "byte", - "success": True, - "duration": 200, - "file_count": 70, - "throughput": 46997525.0, - "start_time_millis": 1591776481060, - "stop_time_millis": 1591776481060 + 200, - }, - result, - ) + assert result == { + "weight": 9399505, + "unit": "byte", + "success": True, + "duration": 200, + "file_count": 70, + "throughput": 46997525.0, + "start_time_millis": 1591776481060, + "stop_time_millis": 1591776481060 + 200, + } es.snapshot.status.assert_awaited_once_with(repository="backups", snapshot="snapshot-001", ignore_unavailable=True) @@ -3857,15 +3937,15 @@ async def test_wait_for_snapshot_create_failure(self, es): r = runner.WaitForSnapshotCreate() with mock.patch.object(r.logger, "error") as mocked_error_logger: - with self.assertRaises(exceptions.RallyAssertionError) as ctx: + with pytest.raises(exceptions.RallyAssertionError) as exc: await r(es, params) - self.assertEqual("Snapshot [snapshot-001] failed. Please check logs.", ctx.exception.args[0]) + assert exc.value.args[0] == "Snapshot [snapshot-001] failed. Please check logs." mocked_error_logger.assert_has_calls( [mock.call("Snapshot [%s] failed. Response:\n%s", "snapshot-001", json.dumps(snapshot_status, indent=2))] ) -class RestoreSnapshotTests(TestCase): +class TestRestoreSnapshot: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_restore_snapshot(self, es): @@ -3921,7 +4001,7 @@ async def test_restore_snapshot_with_body(self, es): ) -class IndicesRecoveryTests(TestCase): +class TestIndicesRecovery: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_waits_for_ongoing_indices_recovery(self, es): @@ -4054,20 +4134,22 @@ async def test_waits_for_ongoing_indices_recovery(self, es): result = await r(es, {"completion-recheck-wait-period": 0, "index": "index1"}) # sum of both shards - self.assertEqual(237783878, result["weight"]) - self.assertEqual("byte", result["unit"]) - self.assertTrue(result["success"]) - # bytes recovered within these 5 seconds - self.assertEqual(47556775.6, result["throughput"]) - self.assertEqual(1393244155000, result["start_time_millis"]) - self.assertEqual(1393244160000, result["stop_time_millis"]) + assert result == { + "weight": 237783878, + "unit": "byte", + "success": True, + # bytes recovered within these 5 seconds + "throughput": 47556775.6, + "start_time_millis": 1393244155000, + "stop_time_millis": 1393244160000, + } es.indices.recovery.assert_awaited_with(index="index1") # retries four times - self.assertEqual(4, es.indices.recovery.await_count) + assert es.indices.recovery.await_count == 4 -class ShrinkIndexTests(TestCase): +class TestShrinkIndex: @mock.patch("elasticsearch.Elasticsearch") # To avoid real sleeps in unit tests @mock.patch("asyncio.sleep") @@ -4282,7 +4364,7 @@ async def test_shrink_index_pattern_with_shrink_node(self, sleep, es): ) -class PutSettingsTests(TestCase): +class TestPutSettings: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_put_settings(self, es): @@ -4295,7 +4377,7 @@ async def test_put_settings(self, es): es.cluster.put_settings.assert_awaited_once_with(body={"transient": {"indices.recovery.max_bytes_per_sec": "20mb"}}) -class CreateTransformTests(TestCase): +class TestCreateTransform: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_create_transform(self, es): @@ -4323,7 +4405,7 @@ async def test_create_transform(self, es): ) -class StartTransformTests(TestCase): +class TestStartTransform: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_start_transform(self, es): @@ -4338,7 +4420,7 @@ async def test_start_transform(self, es): es.transform.start_transform.assert_awaited_once_with(transform_id=transform_id, timeout=params["timeout"]) -class WaitForTransformTests(TestCase): +class TestWaitForTransform: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_wait_for_transform(self, es): @@ -4383,15 +4465,20 @@ async def test_wait_for_transform(self, es): ) r = runner.WaitForTransform() - self.assertFalse(r.completed) - self.assertEqual(r.percent_completed, 0.0) + assert not r.completed + assert r.percent_completed == 0.0 result = await r(es, params) - self.assertTrue(r.completed) - self.assertEqual(r.percent_completed, 1.0) - self.assertEqual(2, result["weight"], 2) - self.assertEqual(result["unit"], "docs") + assert r.completed + assert r.percent_completed == 1.0 + assert result.pop("throughput") + assert result == { + "weight": 2, + "unit": "docs", + "success": True, + "transform-id": transform_id, + } es.transform.stop_transform.assert_awaited_once_with( transform_id=transform_id, @@ -4539,28 +4626,33 @@ async def test_wait_for_transform_progress(self, es): ) r = runner.WaitForTransform() - self.assertFalse(r.completed) - self.assertEqual(r.percent_completed, 0.0) + assert not r.completed + assert r.percent_completed == 0.0 total_calls = 0 while not r.completed: result = await r(es, params) total_calls += 1 if total_calls < 4: - self.assertAlmostEqual(r.percent_completed, (total_calls * 10.20) / 100.0) - - self.assertEqual(total_calls, 4) - self.assertTrue(r.completed) - self.assertEqual(r.percent_completed, 1.0) - self.assertEqual(result["weight"], 60000) - self.assertEqual(result["unit"], "docs") + assert round(abs(r.percent_completed - (total_calls * 10.20) / 100.0), 7) == 0 + + assert total_calls == 4 + assert r.completed + assert r.percent_completed == 1.0 + assert result.pop("throughput") + assert result == { + "weight": 60_000, + "unit": "docs", + "success": True, + "transform-id": "a-transform", + } es.transform.stop_transform.assert_awaited_once_with( transform_id=transform_id, force=params["force"], timeout=params["timeout"], wait_for_completion=False, wait_for_checkpoint=True ) -class DeleteTransformTests(TestCase): +class TestDeleteTransform: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_delete_transform(self, es): @@ -4575,7 +4667,7 @@ async def test_delete_transform(self, es): es.transform.delete_transform.assert_awaited_once_with(transform_id=transform_id, force=params["force"], ignore=[404]) -class TransformStatsRunnerTests(TestCase): +class TestTransformStatsRunner: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_transform_stats_with_timeout_and_headers(self, es): @@ -4591,9 +4683,11 @@ async def test_transform_stats_with_timeout_and_headers(self, es): "opaque-id": "test-id1", }, ) - self.assertEqual(1, result["weight"]) - self.assertEqual("ops", result["unit"]) - self.assertTrue(result["success"]) + assert result == { + "weight": 1, + "unit": "ops", + "success": True, + } es.transform.get_transform_stats.assert_awaited_once_with( transform_id=transform_id, @@ -4635,18 +4729,16 @@ async def test_transform_stats_with_failed_condition(self, es): }, }, ) - self.assertEqual(1, result["weight"]) - self.assertEqual("ops", result["unit"]) - self.assertFalse(result["success"]) - self.assertDictEqual( - { + assert result == { + "weight": 1, + "unit": "ops", + "success": False, + "condition": { "path": "checkpointing.operations_behind", "actual-value": "10000", "expected-value": None, }, - result["condition"], - ) - + } es.transform.get_transform_stats.assert_awaited_once_with(transform_id=transform_id) @mock.patch("elasticsearch.Elasticsearch") @@ -4681,17 +4773,16 @@ async def test_transform_stats_with_successful_condition(self, es): }, }, ) - self.assertEqual(1, result["weight"]) - self.assertEqual("ops", result["unit"]) - self.assertTrue(result["success"]) - self.assertDictEqual( - { + assert result == { + "weight": 1, + "unit": "ops", + "success": True, + "condition": { "path": "checkpointing.operations_behind", "actual-value": None, "expected-value": None, }, - result["condition"], - ) + } es.transform.get_transform_stats.assert_awaited_once_with(transform_id=transform_id) @@ -4727,22 +4818,21 @@ async def test_transform_stats_with_non_existing_path(self, es): }, }, ) - self.assertEqual(1, result["weight"]) - self.assertEqual("ops", result["unit"]) - self.assertFalse(result["success"]) - self.assertDictEqual( - { + assert result == { + "weight": 1, + "unit": "ops", + "success": False, + "condition": { "path": "checkpointing.last.checkpoint", "actual-value": None, "expected-value": "42", }, - result["condition"], - ) + } es.transform.get_transform_stats.assert_awaited_once_with(transform_id=transform_id) -class CreateIlmPolicyRunner(TestCase): +class TestCreateIlmPolicyRunner: params = { "policy-name": "my-ilm-policy", @@ -4760,9 +4850,12 @@ async def test_create_ilm_policy_with_request_params(self, es): es.ilm.put_lifecycle = mock.AsyncMock(return_value={}) create_ilm_policy = runner.CreateIlmPolicy() result = await create_ilm_policy(es, params=self.params) - self.assertEqual(1, result["weight"]) - self.assertEqual("ops", result["unit"]) - self.assertTrue(result["success"]) + + assert result == { + "weight": 1, + "unit": "ops", + "success": True, + } es.ilm.put_lifecycle.assert_awaited_once_with( policy=self.params["policy-name"], body=self.params["body"], params=self.params["request-params"] @@ -4776,14 +4869,16 @@ async def test_create_ilm_policy_without_request_params(self, es): params = copy.deepcopy(self.params) del params["request-params"] result = await create_ilm_policy(es, params=params) - self.assertEqual(1, result["weight"]) - self.assertEqual("ops", result["unit"]) - self.assertTrue(result["success"]) + assert result == { + "weight": 1, + "unit": "ops", + "success": True, + } es.ilm.put_lifecycle.assert_awaited_once_with(policy=params["policy-name"], body=params["body"], params={}) -class DeleteIlmPolicyRunner(TestCase): +class TestDeleteIlmPolicyRunner: params = {"policy-name": "my-ilm-policy", "request-params": {"master_timeout": "30s", "timeout": "30s"}} @@ -4793,9 +4888,11 @@ async def test_delete_ilm_policy_with_request_params(self, es): es.ilm.delete_lifecycle = mock.AsyncMock(return_value={}) delete_ilm_policy = runner.DeleteIlmPolicy() result = await delete_ilm_policy(es, params=self.params) - self.assertEqual(1, result["weight"]) - self.assertEqual("ops", result["unit"]) - self.assertTrue(result["success"]) + assert result == { + "weight": 1, + "unit": "ops", + "success": True, + } es.ilm.delete_lifecycle.assert_awaited_once_with(policy=self.params["policy-name"], params=self.params["request-params"]) @@ -4807,14 +4904,16 @@ async def test_delete_ilm_policy_without_request_params(self, es): params = copy.deepcopy(self.params) del params["request-params"] result = await delete_ilm_policy(es, params=params) - self.assertEqual(1, result["weight"]) - self.assertEqual("ops", result["unit"]) - self.assertTrue(result["success"]) + assert result == { + "weight": 1, + "unit": "ops", + "success": True, + } es.ilm.delete_lifecycle.assert_awaited_once_with(policy=params["policy-name"], params={}) -class SubmitAsyncSearchTests(TestCase): +class TestSubmitAsyncSearch: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_submit_async_search(self, es): @@ -4833,12 +4932,12 @@ async def test_submit_async_search(self, es): async with runner.CompositeContext(): await r(es, params) # search id is registered in context - self.assertEqual("12345", runner.CompositeContext.get("search-1")) + assert runner.CompositeContext.get("search-1") == "12345" es.async_search.submit.assert_awaited_once_with(body={"query": {"match_all": {}}}, index="_all", params={}) -class GetAsyncSearchTests(TestCase): +class TestGetAsyncSearch: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_get_async_search(self, es): @@ -4863,27 +4962,24 @@ async def test_get_async_search(self, es): async with runner.CompositeContext(): runner.CompositeContext.put("search-1", "12345") response = await r(es, params) - self.assertDictEqual( - response, - { - "weight": 1, - "unit": "ops", - "success": True, - "stats": { - "search-1": { - "hits": 1520, - "hits_relation": "eq", - "timed_out": False, - "took": 1122, - }, + assert response == { + "weight": 1, + "unit": "ops", + "success": True, + "stats": { + "search-1": { + "hits": 1520, + "hits_relation": "eq", + "timed_out": False, + "took": 1122, }, }, - ) + } es.async_search.get.assert_awaited_once_with(id="12345", params={}) -class DeleteAsyncSearchTests(TestCase): +class TestDeleteAsyncSearch: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_delete_async_search(self, es): @@ -4905,7 +5001,7 @@ async def test_delete_async_search(self, es): ) -class OpenPointInTimeTests(TestCase): +class TestOpenPointInTime: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_creates_point_in_time(self, es): @@ -4917,7 +5013,7 @@ async def test_creates_point_in_time(self, es): r = runner.OpenPointInTime() async with runner.CompositeContext(): await r(es, params) - self.assertEqual(pit_id, runner.CompositeContext.get("open-pit-test")) + assert runner.CompositeContext.get("open-pit-test") == pit_id @mock.patch("elasticsearch.Elasticsearch") @run_async @@ -4928,13 +5024,13 @@ async def test_can_only_be_run_in_composite(self, es): es.open_point_in_time = mock.AsyncMock(return_value={"id": pit_id}) r = runner.OpenPointInTime() - with self.assertRaises(exceptions.RallyAssertionError) as ctx: + with pytest.raises(exceptions.RallyAssertionError) as exc: await r(es, params) - self.assertEqual("This operation is only allowed inside a composite operation.", ctx.exception.args[0]) + assert exc.value.args[0] == "This operation is only allowed inside a composite operation." -class ClosePointInTimeTests(TestCase): +class TestClosePointInTime: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_closes_point_in_time(self, es): @@ -4952,7 +5048,7 @@ async def test_closes_point_in_time(self, es): es.close_point_in_time.assert_awaited_once_with(body={"id": "0123456789abcdef"}, params={}, headers=None) -class QueryWithSearchAfterScrollTests(TestCase): +class TestQueryWithSearchAfterScroll: @mock.patch("elasticsearch.Elasticsearch") @run_async async def test_search_after_with_pit(self, es): @@ -5013,7 +5109,7 @@ async def test_search_after_with_pit(self, es): runner.CompositeContext.put(pit_op, pit_id) await r(es, params) # make sure pit_id is updated afterward - self.assertEqual("fedcba9876543211", runner.CompositeContext.get(pit_op)) + assert runner.CompositeContext.get(pit_op) == "fedcba9876543211" es.transport.perform_request.assert_has_awaits( [ @@ -5156,7 +5252,7 @@ async def test_search_after_without_pit(self, es): ) -class SearchAfterExtractorTests(TestCase): +class TestSearchAfterExtractor: response_text = """ { "pit_id": "fedcba9876543210", @@ -5185,16 +5281,16 @@ def test_extract_all_properties(self): props, last_sort = target(response=self.response, get_point_in_time=True, hits_total=None) expected_props = {"hits.total.relation": "eq", "hits.total.value": 2, "pit_id": "fedcba9876543210", "timed_out": False, "took": 10} expected_sort_value = [1609780186, "2"] - self.assertEqual(expected_props, props) - self.assertEqual(expected_sort_value, last_sort) + assert props == expected_props + assert last_sort == expected_sort_value def test_extract_ignore_point_in_time(self): target = runner.SearchAfterExtractor() props, last_sort = target(response=self.response, get_point_in_time=False, hits_total=None) expected_props = {"hits.total.relation": "eq", "hits.total.value": 2, "timed_out": False, "took": 10} expected_sort_value = [1609780186, "2"] - self.assertEqual(expected_props, props) - self.assertEqual(expected_sort_value, last_sort) + assert props == expected_props + assert last_sort == expected_sort_value def test_extract_uses_provided_hits_total(self): target = runner.SearchAfterExtractor() @@ -5202,17 +5298,17 @@ def test_extract_uses_provided_hits_total(self): props, last_sort = target(response=self.response, get_point_in_time=False, hits_total=10) expected_props = {"hits.total.relation": "eq", "hits.total.value": 10, "timed_out": False, "took": 10} expected_sort_value = [1609780186, "2"] - self.assertEqual(expected_props, props) - self.assertEqual(expected_sort_value, last_sort) + assert props == expected_props + assert last_sort == expected_sort_value def test_extract_missing_required_point_in_time(self): response_copy = json.loads(self.response_text) del response_copy["pit_id"] response_copy_bytesio = io.BytesIO(json.dumps(response_copy).encode()) target = runner.SearchAfterExtractor() - with self.assertRaises(exceptions.RallyAssertionError) as ctx: + with pytest.raises(exceptions.RallyAssertionError) as exc: target(response=response_copy_bytesio, get_point_in_time=True, hits_total=None) - self.assertEqual("Paginated query failure: pit_id was expected but not found in the response.", ctx.exception.args[0]) + assert exc.value.args[0] == "Paginated query failure: pit_id was expected but not found in the response." def test_extract_missing_ignored_point_in_time(self): response_copy = json.loads(self.response_text) @@ -5222,49 +5318,49 @@ def test_extract_missing_ignored_point_in_time(self): props, last_sort = target(response=response_copy_bytesio, get_point_in_time=False, hits_total=None) expected_props = {"hits.total.relation": "eq", "hits.total.value": 2, "timed_out": False, "took": 10} expected_sort_value = [1609780186, "2"] - self.assertEqual(expected_props, props) - self.assertEqual(expected_sort_value, last_sort) + assert props == expected_props + assert last_sort == expected_sort_value -class CompositeContextTests(TestCase): +class TestCompositeContext: def test_cannot_be_used_outside_of_composite(self): - with self.assertRaises(exceptions.RallyAssertionError) as ctx: + with pytest.raises(exceptions.RallyAssertionError) as exc: runner.CompositeContext.put("test", 1) - self.assertEqual("This operation is only allowed inside a composite operation.", ctx.exception.args[0]) + assert exc.value.args[0] == "This operation is only allowed inside a composite operation." @run_async async def test_put_get_and_remove(self): async with runner.CompositeContext(): runner.CompositeContext.put("test", 1) runner.CompositeContext.put("don't clear this key", 1) - self.assertEqual(runner.CompositeContext.get("test"), 1) + assert runner.CompositeContext.get("test") == 1 runner.CompositeContext.remove("test") # context is cleared properly async with runner.CompositeContext(): - with self.assertRaises(KeyError) as ctx: + with pytest.raises(KeyError) as exc: runner.CompositeContext.get("don't clear this key") - self.assertEqual("Unknown property [don't clear this key]. Currently recognized properties are [].", ctx.exception.args[0]) + assert exc.value.args[0] == "Unknown property [don't clear this key]. Currently recognized properties are []." @run_async async def test_fails_to_read_unknown_key(self): async with runner.CompositeContext(): - with self.assertRaises(KeyError) as ctx: + with pytest.raises(KeyError) as exc: runner.CompositeContext.put("test", 1) runner.CompositeContext.get("unknown") - self.assertEqual("Unknown property [unknown]. Currently recognized properties are [test].", ctx.exception.args[0]) + assert exc.value.args[0] == "Unknown property [unknown]. Currently recognized properties are [test]." @run_async async def test_fails_to_remove_unknown_key(self): async with runner.CompositeContext(): - with self.assertRaises(KeyError) as ctx: + with pytest.raises(KeyError) as exc: runner.CompositeContext.put("test", 1) runner.CompositeContext.remove("unknown") - self.assertEqual("Unknown property [unknown]. Currently recognized properties are [test].", ctx.exception.args[0]) + assert exc.value.args[0] == "Unknown property [unknown]. Currently recognized properties are [test]." -class CompositeTests(TestCase): +class TestComposite: class CounterRunner: def __init__(self): self.max_value = 0 @@ -5292,15 +5388,15 @@ async def __call__(self, es, params): # wait for a short moment to ensure overlap await asyncio.sleep(0.1) - def setUp(self): + def setup_method(self, method): runner.register_default_runners() - self.counter_runner = CompositeTests.CounterRunner() - self.call_recorder_runner = CompositeTests.CallRecorderRunner() + self.counter_runner = self.CounterRunner() + self.call_recorder_runner = self.CallRecorderRunner() runner.register_runner("counter", self.counter_runner, async_runner=True) runner.register_runner("call-recorder", self.call_recorder_runner, async_runner=True) runner.enable_assertions(True) - def tearDown(self): + def teardown_method(self, method): runner.enable_assertions(False) runner.remove_runner("counter") runner.remove_runner("call-recorder") @@ -5427,7 +5523,7 @@ async def test_propagates_violated_assertions(self, es): } r = runner.Composite() - with self.assertRaisesRegex(exceptions.RallyTaskAssertionError, r"Expected \[hits\] to be > \[0\] but was \[0\]."): + with pytest.raises(exceptions.RallyTaskAssertionError, match=r"Expected \[hits\] to be > \[0\] but was \[0\]."): await r(es, params) es.transport.perform_request.assert_has_awaits( @@ -5504,20 +5600,17 @@ async def test_executes_tasks_in_specified_order(self, es): r.supported_op_types = ["call-recorder"] await r(es, params) - self.assertEqual( - [ - "initial-call", - # stream-a and stream-b are concurrent - "stream-a", - "stream-b", - "call-after-stream-ab", - # stream-c and stream-d are concurrent - "stream-c", - "stream-d", - "call-after-stream-cd", - ], - self.call_recorder_runner.calls, - ) + assert self.call_recorder_runner.calls == [ + "initial-call", + # stream-a and stream-b are concurrent + "stream-a", + "stream-b", + "call-after-stream-ab", + # stream-c and stream-d are concurrent + "stream-c", + "stream-d", + "call-after-stream-cd", + ] @run_async async def test_adds_request_timings(self): @@ -5556,27 +5649,27 @@ async def test_adds_request_timings(self): r = runner.Composite() response = await r(es, params) - self.assertEqual(1, response["weight"]) - self.assertEqual("ops", response["unit"]) + assert response["weight"] == 1 + assert response["unit"] == "ops" timings = response["dependent_timing"] - self.assertEqual(3, len(timings)) + assert len(timings) == 3 - self.assertEqual("initial-call", timings[0]["operation"]) - self.assertAlmostEqual(0.1, timings[0]["service_time"], delta=0.05) + assert timings[0]["operation"] == "initial-call" + assert math.isclose(timings[0]["service_time"], 0.1, rel_tol=0.05) - self.assertEqual("stream-a", timings[1]["operation"]) - self.assertAlmostEqual(0.2, timings[1]["service_time"], delta=0.05) + assert timings[1]["operation"] == "stream-a" + assert math.isclose(timings[1]["service_time"], 0.2, rel_tol=0.05) - self.assertEqual("stream-b", timings[2]["operation"]) - self.assertAlmostEqual(0.1, timings[2]["service_time"], delta=0.05) + assert timings[2]["operation"] == "stream-b" + assert math.isclose(timings[2]["service_time"], 0.1, rel_tol=0.05) # common properties for timing in timings: - self.assertEqual("sleep", timing["operation-type"]) - self.assertIn("absolute_time", timing) - self.assertIn("request_start", timing) - self.assertIn("request_end", timing) - self.assertGreater(timing["request_end"], timing["request_start"]) + assert timing["operation-type"] == "sleep" + assert "absolute_time" in timing + assert "request_start" in timing + assert "request_end" in timing + assert timing["request_end"] > timing["request_start"] @mock.patch("elasticsearch.Elasticsearch") @run_async @@ -5613,7 +5706,7 @@ async def test_limits_connections(self, es): await r(es, params) # composite runner should limit to two concurrent connections - self.assertEqual(2, self.counter_runner.max_value) + assert self.counter_runner.max_value == 2 @mock.patch("elasticsearch.Elasticsearch") @run_async @@ -5625,10 +5718,10 @@ async def test_rejects_invalid_stream(self, es): } r = runner.Composite() - with self.assertRaises(exceptions.RallyAssertionError) as ctx: + with pytest.raises(exceptions.RallyAssertionError) as exc: await r(es, params) - self.assertEqual("Requests structure must contain [stream] or [operation-type].", ctx.exception.args[0]) + assert exc.value.args[0] == "Requests structure must contain [stream] or [operation-type]." @mock.patch("elasticsearch.Elasticsearch") @run_async @@ -5636,17 +5729,16 @@ async def test_rejects_unsupported_operations(self, es): params = {"requests": [{"stream": [{"operation-type": "bulk"}]}]} r = runner.Composite() - with self.assertRaises(exceptions.RallyAssertionError) as ctx: + with pytest.raises(exceptions.RallyAssertionError) as exc: await r(es, params) - self.assertEqual( + assert exc.value.args[0] == ( "Unsupported operation-type [bulk]. Use one of [open-point-in-time, close-point-in-time, " - "search, paginated-search, raw-request, sleep, submit-async-search, get-async-search, delete-async-search].", - ctx.exception.args[0], + "search, paginated-search, raw-request, sleep, submit-async-search, get-async-search, delete-async-search]." ) -class RequestTimingTests(TestCase): +class TestRequestTiming: class StaticRequestTiming: def __init__(self, task_start): self.task_start = task_start @@ -5672,7 +5764,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): @run_async async def test_merges_timing_info(self, es): multi_cluster_client = {"default": es} - es.new_request_context.return_value = RequestTimingTests.StaticRequestTiming(task_start=2) + es.new_request_context.return_value = self.StaticRequestTiming(task_start=2) delegate = mock.AsyncMock(return_value={"weight": 5, "unit": "ops", "success": True}) params = {"name": "unit-test-operation", "operation-type": "test-op"} @@ -5680,17 +5772,19 @@ async def test_merges_timing_info(self, es): response = await timer(multi_cluster_client, params) - self.assertEqual(5, response["weight"]) - self.assertEqual("ops", response["unit"]) - self.assertTrue(response["success"]) - self.assertIn("dependent_timing", response) - timing = response["dependent_timing"] - self.assertEqual("unit-test-operation", timing["operation"]) - self.assertEqual("test-op", timing["operation-type"]) - self.assertIsNotNone(timing["absolute_time"]) - self.assertEqual(7, timing["request_start"]) - self.assertEqual(7.1, timing["request_end"]) - self.assertAlmostEqual(0.1, timing["service_time"]) + assert math.isclose(response["dependent_timing"].pop("service_time"), 0.1) + assert response["dependent_timing"].pop("absolute_time") is not None + assert response == { + "weight": 5, + "unit": "ops", + "success": True, + "dependent_timing": { + "operation": "unit-test-operation", + "operation-type": "test-op", + "request_start": 7, + "request_end": 7.1, + }, + } delegate.assert_called_once_with(multi_cluster_client, params) @@ -5698,7 +5792,7 @@ async def test_merges_timing_info(self, es): @run_async async def test_creates_new_timing_info(self, es): multi_cluster_client = {"default": es} - es.new_request_context.return_value = RequestTimingTests.StaticRequestTiming(task_start=2) + es.new_request_context.return_value = self.StaticRequestTiming(task_start=2) # a simple runner without a return value delegate = mock.AsyncMock() @@ -5707,24 +5801,25 @@ async def test_creates_new_timing_info(self, es): response = await timer(multi_cluster_client, params) - # defaults added by the timing runner - self.assertEqual(1, response["weight"]) - self.assertEqual("ops", response["unit"]) - self.assertTrue(response["success"]) - - self.assertIn("dependent_timing", response) - timing = response["dependent_timing"] - self.assertEqual("unit-test-operation", timing["operation"]) - self.assertEqual("test-op", timing["operation-type"]) - self.assertIsNotNone(timing["absolute_time"]) - self.assertEqual(7, timing["request_start"]) - self.assertEqual(7.1, timing["request_end"]) - self.assertAlmostEqual(0.1, timing["service_time"]) + assert math.isclose(response["dependent_timing"].pop("service_time"), 0.1) + assert response["dependent_timing"].pop("absolute_time") is not None + assert response == { + # defaults added by the timing runner + "weight": 1, + "unit": "ops", + "success": True, + "dependent_timing": { + "operation": "unit-test-operation", + "operation-type": "test-op", + "request_start": 7, + "request_end": 7.1, + }, + } delegate.assert_called_once_with(multi_cluster_client, params) -class RetryTests(TestCase): +class TestRetry: @run_async async def test_is_transparent_on_success_when_no_retries(self): delegate = mock.AsyncMock() @@ -5747,7 +5842,7 @@ async def test_is_transparent_on_exception_when_no_retries(self): } retrier = runner.Retry(delegate) - with self.assertRaises(elasticsearch.ConnectionError): + with pytest.raises(elasticsearch.ConnectionError): await retrier(es, params) delegate.assert_called_once_with(es, params) @@ -5765,7 +5860,7 @@ async def test_is_transparent_on_application_error_when_no_retries(self): result = await retrier(es, params) - self.assertEqual(original_return_value, result) + assert result == original_return_value delegate.assert_called_once_with(es, params) @run_async @@ -5793,7 +5888,7 @@ async def test_retries_on_timeout_if_wanted_and_raises_if_no_recovery(self): params = {"retries": 3, "retry-wait-period": 0.01, "retry-on-timeout": True, "retry-on-error": True} retrier = runner.Retry(delegate) - with self.assertRaises(elasticsearch.ConnectionError): + with pytest.raises(elasticsearch.ConnectionError): await retrier(es, params) delegate.assert_has_calls( @@ -5819,7 +5914,7 @@ async def test_retries_on_timeout_if_wanted_and_returns_first_call(self): retrier = runner.Retry(delegate) result = await retrier(es, params) - self.assertEqual(failed_return_value, result) + assert result == failed_return_value delegate.assert_has_calls( [ @@ -5857,7 +5952,7 @@ async def test_retries_mixed_timeout_and_application_errors(self): retrier = runner.Retry(delegate) result = await retrier(es, params) - self.assertEqual(success_return_value, result) + assert result == success_return_value delegate.assert_has_calls( [ @@ -5883,7 +5978,7 @@ async def test_does_not_retry_on_timeout_if_not_wanted(self): params = {"retries": 3, "retry-wait-period": 0.01, "retry-on-timeout": False, "retry-on-error": True} retrier = runner.Retry(delegate) - with self.assertRaises(elasticsearch.ConnectionTimeout): + with pytest.raises(elasticsearch.ConnectionTimeout): await retrier(es, params) delegate.assert_called_once_with(es, params) @@ -5900,7 +5995,7 @@ async def test_retries_on_application_error_if_wanted(self): result = await retrier(es, params) - self.assertEqual(success_return_value, result) + assert result == success_return_value delegate.assert_has_calls( [ @@ -5921,7 +6016,7 @@ async def test_does_not_retry_on_application_error_if_not_wanted(self): result = await retrier(es, params) - self.assertEqual(failed_return_value, result) + assert result == failed_return_value delegate.assert_called_once_with(es, params) @@ -5934,7 +6029,7 @@ async def test_assumes_success_if_runner_returns_non_dict(self): result = await retrier(es, params) - self.assertEqual((1, "ops"), result) + assert result == (1, "ops") delegate.assert_called_once_with(es, params) @@ -5956,19 +6051,19 @@ async def test_retries_until_success(self): result = await retrier(es, params) - self.assertEqual(success_return_value, result) + assert result == success_return_value delegate.assert_has_calls([mock.call(es, params) for _ in range(failure_count + 1)]) -class RemovePrefixTests(TestCase): +class TestRemovePrefix: def test_remove_matching_prefix(self): suffix = runner.remove_prefix("index-20201117", "index") - self.assertEqual(suffix, "-20201117") + assert suffix == "-20201117" def test_prefix_doesnt_exit(self): index_name = "index-20201117" suffix = runner.remove_prefix(index_name, "unrelatedprefix") - self.assertEqual(suffix, index_name) + assert index_name == suffix diff --git a/tests/driver/scheduler_test.py b/tests/driver/scheduler_test.py index f005fff8e..1b975f703 100644 --- a/tests/driver/scheduler_test.py +++ b/tests/driver/scheduler_test.py @@ -17,59 +17,48 @@ # pylint: disable=protected-access import random -from unittest import TestCase + +import pytest from esrally import exceptions from esrally.driver import scheduler from esrally.track import track -class SchedulerTestCase(TestCase): +def assert_throughput(sched, expected_average_throughput, msg="", relative_delta=0.05): ITERATIONS = 10000 + expected_average_rate = 1 / expected_average_throughput + sum = 0 + for _ in range(0, ITERATIONS): + tn = sched.next(0) + # schedule must not go backwards in time + assert tn >= 0, msg + sum += tn + actual_average_rate = sum / ITERATIONS - def assertThroughputEquals(self, sched, expected_average_throughput, msg="", relative_delta=0.05): - expected_average_rate = 1 / expected_average_throughput - sum = 0 - for _ in range(0, SchedulerTestCase.ITERATIONS): - tn = sched.next(0) - # schedule must not go backwards in time - self.assertGreaterEqual(tn, 0, msg) - sum += tn - actual_average_rate = sum / SchedulerTestCase.ITERATIONS - - expected_lower_bound = (1.0 - relative_delta) * expected_average_rate - expected_upper_bound = (1.0 + relative_delta) * expected_average_rate - - self.assertGreaterEqual( - actual_average_rate, - expected_lower_bound, - f"{msg}: expected target rate to be >= [{expected_lower_bound}] but was [{actual_average_rate}].", - ) - self.assertLessEqual( - actual_average_rate, - expected_upper_bound, - f"{msg}: expected target rate to be <= [{expected_upper_bound}] but was [{actual_average_rate}].", - ) + expected_lower_bound = (1.0 - relative_delta) * expected_average_rate + expected_upper_bound = (1.0 + relative_delta) * expected_average_rate + assert expected_lower_bound <= actual_average_rate <= expected_upper_bound -class DeterministicSchedulerTests(SchedulerTestCase): +class TestDeterministicScheduler: def test_schedule_matches_expected_target_throughput(self): target_throughput = random.randint(10, 1000) # this scheduler does not make use of the task, thus we won't specify it here s = scheduler.DeterministicScheduler(task=None, target_throughput=target_throughput) - self.assertThroughputEquals(s, target_throughput, f"target throughput=[{target_throughput}] ops/s") + assert_throughput(s, target_throughput, f"target throughput=[{target_throughput}] ops/s") -class PoissonSchedulerTests(SchedulerTestCase): +class TestPoissonScheduler: def test_schedule_matches_expected_target_throughput(self): target_throughput = random.randint(10, 1000) # this scheduler does not make use of the task, thus we won't specify it here s = scheduler.PoissonScheduler(task=None, target_throughput=target_throughput) - self.assertThroughputEquals(s, target_throughput, f"target throughput=[{target_throughput}] ops/s") + assert_throughput(s, target_throughput, f"target throughput=[{target_throughput}] ops/s") -class UnitAwareSchedulerTests(TestCase): +class TestUnitAwareScheduler: def test_scheduler_rejects_differing_throughput_units(self): task = track.Task( name="bulk-index", @@ -79,11 +68,10 @@ def test_scheduler_rejects_differing_throughput_units(self): ) s = scheduler.UnitAwareScheduler(task=task, scheduler_class=scheduler.DeterministicScheduler) - with self.assertRaises(exceptions.RallyAssertionError) as ex: + with pytest.raises(exceptions.RallyAssertionError) as exc: s.after_request(now=None, weight=1000, unit="docs", request_meta_data=None) - self.assertEqual( - "Target throughput for [bulk-index] is specified in [MB/s] but the task throughput is measured in [docs/s].", - ex.exception.args[0], + assert exc.value.args[0] == ( + "Target throughput for [bulk-index] is specified in [MB/s] but the task throughput is measured in [docs/s]." ) def test_scheduler_adapts_to_changed_weights(self): @@ -98,18 +86,18 @@ def test_scheduler_adapts_to_changed_weights(self): # first request is unthrottled # suppress pylint false positive # pylint: disable=not-callable - self.assertEqual(0, s.next(0)) + assert s.next(0) == 0 # we'll start with bulks of 1.000 docs, which corresponds to 5 requests per second for all clients s.after_request(now=None, weight=1000, unit="docs", request_meta_data=None) # suppress pylint false positive # pylint: disable=not-callable - self.assertEqual(1 / 5 * task.clients, s.next(0)) + assert s.next(0) == 1 / 5 * task.clients # bulk size changes to 10.000 docs, which means one request every two seconds for all clients s.after_request(now=None, weight=10000, unit="docs", request_meta_data=None) # suppress pylint false positive # pylint: disable=not-callable - self.assertEqual(2 * task.clients, s.next(0)) + assert s.next(0) == 2 * task.clients def test_scheduler_accepts_differing_units_pages_and_ops(self): task = track.Task( @@ -126,13 +114,13 @@ def test_scheduler_accepts_differing_units_pages_and_ops(self): # first request is unthrottled # suppress pylint false positive # pylint: disable=not-callable - self.assertEqual(0, s.next(0)) + assert s.next(0) == 0 # no exception despite differing units ... s.after_request(now=None, weight=20, unit="pages", request_meta_data=None) # ... and it is still throttled in ops/s # suppress pylint false positive # pylint: disable=not-callable - self.assertEqual(0.1 * task.clients, s.next(0)) + assert s.next(0) == 0.1 * task.clients def test_scheduler_does_not_change_throughput_for_empty_requests(self): task = track.Task( @@ -150,23 +138,23 @@ def test_scheduler_does_not_change_throughput_for_empty_requests(self): s.before_request(now=0) # suppress pylint false positive # pylint: disable=not-callable - self.assertEqual(0, s.next(0)) + assert s.next(0) == 0 # ... but it also produced an error (zero ops) s.after_request(now=1, weight=0, unit="ops", request_meta_data=None) # next request is still unthrottled s.before_request(now=1) # suppress pylint false positive # pylint: disable=not-callable - self.assertEqual(0, s.next(0)) + assert s.next(0) == 0 s.after_request(now=2, weight=1, unit="ops", request_meta_data=None) # now we throttle s.before_request(now=2) # suppress pylint false positive # pylint: disable=not-callable - self.assertEqual(0.1 * task.clients, s.next(0)) + assert s.next(0) == 0.1 * task.clients -class SchedulerCategorizationTests(TestCase): +class TestSchedulerCategorization: class LegacyScheduler: # pylint: disable=unused-variable def __init__(self, params): @@ -178,21 +166,21 @@ def __init__(self, params, my_default_param=True): pass def test_detects_legacy_scheduler(self): - self.assertTrue(scheduler.is_legacy_scheduler(SchedulerCategorizationTests.LegacyScheduler)) - self.assertTrue(scheduler.is_legacy_scheduler(SchedulerCategorizationTests.LegacySchedulerWithAdditionalArgs)) + assert scheduler.is_legacy_scheduler(self.LegacyScheduler) + assert scheduler.is_legacy_scheduler(self.LegacySchedulerWithAdditionalArgs) def test_a_regular_scheduler_is_not_a_legacy_scheduler(self): - self.assertFalse(scheduler.is_legacy_scheduler(scheduler.DeterministicScheduler)) - self.assertFalse(scheduler.is_legacy_scheduler(scheduler.UnitAwareScheduler)) + assert not scheduler.is_legacy_scheduler(scheduler.DeterministicScheduler) + assert not scheduler.is_legacy_scheduler(scheduler.UnitAwareScheduler) def test_is_simple_scheduler(self): - self.assertTrue(scheduler.is_simple_scheduler(scheduler.PoissonScheduler)) + assert scheduler.is_simple_scheduler(scheduler.PoissonScheduler) def test_is_not_simple_scheduler(self): - self.assertFalse(scheduler.is_simple_scheduler(scheduler.UnitAwareScheduler)) + assert not scheduler.is_simple_scheduler(scheduler.UnitAwareScheduler) -class SchedulerThrottlingTests(TestCase): +class TestSchedulerThrottling: def task(self, schedule=None, target_throughput=None, target_interval=None): op = track.Operation("bulk-index", track.OperationType.Bulk.to_hyphenated_string()) params = {} @@ -203,22 +191,22 @@ def task(self, schedule=None, target_throughput=None, target_interval=None): return track.Task("test", op, schedule=schedule, params=params) def test_throttled_by_target_throughput(self): - self.assertFalse(scheduler.run_unthrottled(self.task(target_throughput=4, schedule="deterministic"))) + assert not scheduler.run_unthrottled(self.task(target_throughput=4, schedule="deterministic")) def test_throttled_by_target_interval(self): - self.assertFalse(scheduler.run_unthrottled(self.task(target_interval=2))) + assert not scheduler.run_unthrottled(self.task(target_interval=2)) def test_throttled_by_custom_schedule(self): - self.assertFalse(scheduler.run_unthrottled(self.task(schedule="my-custom-schedule"))) + assert not scheduler.run_unthrottled(self.task(schedule="my-custom-schedule")) def test_unthrottled_by_target_throughput(self): - self.assertTrue(scheduler.run_unthrottled(self.task(target_throughput=None))) + assert scheduler.run_unthrottled(self.task(target_throughput=None)) def test_unthrottled_by_target_interval(self): - self.assertTrue(scheduler.run_unthrottled(self.task(target_interval=0, schedule="poisson"))) + assert scheduler.run_unthrottled(self.task(target_interval=0, schedule="poisson")) -class LegacyWrappingSchedulerTests(TestCase): +class TestLegacyWrappingScheduler: class SimpleLegacyScheduler: # pylint: disable=unused-variable def __init__(self, params): @@ -227,10 +215,10 @@ def __init__(self, params): def next(self, current): return current - def setUp(self): - scheduler.register_scheduler("simple", LegacyWrappingSchedulerTests.SimpleLegacyScheduler) + def setup_method(self, method): + scheduler.register_scheduler("simple", self.SimpleLegacyScheduler) - def tearDown(self): + def teardown_method(self, method): scheduler.remove_scheduler("simple") def test_legacy_scheduler(self): @@ -243,5 +231,5 @@ def test_legacy_scheduler(self): s = scheduler.scheduler_for(task) - self.assertEqual(0, s.next(0)) - self.assertEqual(0, s.next(0)) + assert s.next(0) == 0 + assert s.next(0) == 0