diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 0559961..d31ac64 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -3,6 +3,12 @@ Version Next * Port test runner to pytest. * Fix compatibility issue with Werkezeug 3 related to deprecated ``request.charset``. +Version 0.40 + + * Port test runner to pytest. + * Fix compatibility issue with Werkezeug 3 related to deprecated ``request.charset``. + * Drop service doubles and mocks, and ``acceptable doubles``. Use the OpenAPI spec with tools such as Connexion instead. + Version 0.39 * Assorted fixes and improvements in generated OpenAPI specs diff --git a/README.rst b/README.rst index f5ea5ce..72f8c6e 100644 --- a/README.rst +++ b/README.rst @@ -20,8 +20,6 @@ Design Goals: - Make it easy to generate API documentation. -- Tools for generating testing doubles from the API metadata. - Usage ----- diff --git a/acceptable/__main__.py b/acceptable/__main__.py index f3f410b..07555c0 100644 --- a/acceptable/__main__.py +++ b/acceptable/__main__.py @@ -141,23 +141,6 @@ def parse_args(raw_args=None, parser_cls=None, stdin=None, stdout=None): lint_parser.set_defaults(func=lint_cmd) - doubles_parser = subparser.add_parser("doubles", help="Generate test doubles") - doubles_parser.add_argument( - "metadata", - nargs="?", - type=argparse.FileType("r"), - default=stdin, - help="metadata file path, uses stdin if omitted", - ) - doubles_parser.add_argument( - "-n", - "--new-style", - action="store_true", - default=False, - help="Generate new style ServiceFactory mocks", - ) - doubles_parser.set_defaults(func=doubles_cmd) - version_parser = subparser.add_parser( "api-version", help="Get the current API version from JSON meta, and " @@ -362,18 +345,6 @@ def lint_cmd(cli_args, stream=sys.stdout): return 1 if has_errors else 0 -def doubles_cmd(cli_args, stream=sys.stdout): - metadata = json.load(cli_args.metadata) - if cli_args.new_style: - from . import generate_mocks - - generate_mocks.generate_service_factory(metadata, stream=stream) - else: - from . import generate_doubles - - generate_doubles.generate_service_mock_doubles(metadata, stream=stream) - - def version_cmd(cli_args, stream=sys.stdout): metadata = load_metadata(cli_args.metadata) json_version = metadata["$version"] diff --git a/acceptable/_build_doubles.py b/acceptable/_build_doubles.py deleted file mode 100644 index 903fab4..0000000 --- a/acceptable/_build_doubles.py +++ /dev/null @@ -1,350 +0,0 @@ -# Copyright 2017-2018 Canonical Ltd. This software is licensed under the -# GNU Lesser General Public License version 3 (see the file LICENSE). - -"""Build Service Doubles: - -This module contains the entry point used to extract schemas from python source -files. The `main` function is installed as a console_script, and has several -modes of operation: - - - The 'scan_file' command allows a user to scan a random python source file - and inspect what service doubles would be extracted from it. This is useful - for ensuring that service_doubles can be extracted from a python source file - before committing it. - - - The 'build' command takes a config file containing service names and - locations, and builds a set of service_doubles based on that config. - -In both cases, the service doubles are built by doing an AST parse of the -python source file in question and extracting calls to acceptable functions. -""" -import argparse -import ast -import collections -import json -import logging -import os.path -import subprocess -import sys -import tempfile -import textwrap - -try: - FileNotFoundError -except NameError: - PermissionError = FileNotFoundError = IOError - - -def main(): - args = parse_args() - args.func(args) - - -def parse_args(arg_list=None, parser_class=None): - parser = parser_class() if parser_class else argparse.ArgumentParser() - subparser = parser.add_subparsers(dest="cmd") - subparser.required = True - scan_file_parser = subparser.add_parser( - "scan-file", help="Scan a file, print extracted service doubles." - ) - scan_file_parser.add_argument("file", type=str) - scan_file_parser.set_defaults(func=scan_file) - - build_parser = subparser.add_parser("build", help="build service doubles.") - build_parser.add_argument("config_file", type=str) - build_parser.set_defaults(func=build_service_doubles) - - return parser.parse_args(arg_list) - - -def scan_file(args): - service_schemas = extract_schemas_from_file(args.file) - print(render_service_double("UNKNOWN", service_schemas, "scan-file %s" % args.file)) - - -def build_service_doubles(args): - with tempfile.TemporaryDirectory() as workdir: - service_config = read_service_config_file(args.config_file) - target_root = os.path.dirname(args.config_file) - for service_name in service_config["services"]: - service = service_config["services"][service_name] - source_url = service["git_source"] - branch = service.get("git_branch") - service_dir = fetch_service_source( - workdir, service_name, source_url, branch - ) - service_schemas = [] - for scan_path in service["scan_paths"]: - abs_path = os.path.join(service_dir, scan_path) - service_schemas.extend(extract_schemas_from_file(abs_path)) - rendered = render_service_double( - service_name, service_schemas, "build %s" % args.config_file - ) - write_service_double_file(target_root, service_name, rendered) - print( - "Rendered schemas file for %s service: %d schemas" - % (service, len(service_schemas)) - ) - - -def read_service_config_file(config_path): - with open(config_path, "r") as config_file: - return json.load(config_file) - - -def fetch_service_source(workdir, service_name, source_url, branch=None): - print("Cloning source for %s service." % service_name) - target_dir = os.path.join(workdir, service_name) - cmd = ["git", "clone"] - if branch is not None: - cmd.extend(["-b", branch]) - cmd.extend([source_url, target_dir]) - subprocess.check_call(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - return target_dir - - -# ViewSchema contains all the information for a flask view... -ViewSchema = collections.namedtuple( - "ViewSchema", - [ - "view_name", # The name of the view function. - "version", # The version the view was introduced at. - "input_schema", # The schema for requests to the service. - "output_schema", # The schema for responses from the service - "methods", # The methods this view supports. - "url", # The URL this view is mounted at. - "doc", # The documentation for this url - ], -) - - -def extract_schemas_from_file(source_path): - """Extract schemas from 'source_path'. - - :returns: a list of ViewSchema objects on success, None if no schemas - could be extracted. - """ - logging.info("Extracting schemas from %s", source_path) - try: - with open(source_path, "r") as source_file: - source = source_file.read() - except (FileNotFoundError, PermissionError) as e: - logging.error("Cannot extract schemas: %s", e.strerror) - else: - try: - schemas = extract_schemas_from_source(source, source_path) - except SyntaxError as e: - logging.error("Cannot extract schemas: %s", str(e)) - else: - logging.info( - "Extracted %d %s", - len(schemas), - "schema" if len(schemas) == 1 else "schemas", - ) - return schemas - - -def _get_simple_assignments(tree): - """Get simple assignments from node tree.""" - result = {} - for node in ast.walk(tree): - if isinstance(node, ast.Assign): - for target in node.targets: - if isinstance(target, ast.Name): - result[target.id] = node.value - return result - - -class _SimpleNamesResolver(ast.NodeTransformer): - def __init__(self, names_values): - super().__init__() - self.names_values = names_values - - def visit_Name(self, node): - if node.id in self.names_values: - node = self.names_values[node.id] - return node - - -def extract_schemas_from_source(source, filename=""): - """Extract schemas from 'source'. - - The 'source' parameter must be a string, and should be valid python - source. - - If 'source' is not valid python source, a SyntaxError will be raised. - - :returns: a list of ViewSchema objects. - """ - # Track which acceptable services have been configured. - acceptable_services = set() - # Track which acceptable views have been configured: - acceptable_views = {} - schemas_found = [] - ast_tree = ast.parse(source, filename) - simple_names = _get_simple_assignments(ast_tree) - - assigns = [n for n in ast_tree.body if isinstance(n, ast.Assign)] - call_assigns = [n for n in assigns if isinstance(n.value, ast.Call)] - - # We need to extract the AcceptableService-related views. We parse the - # assignations twice: The first time to extract the AcceptableService - # instances, the second to extract the views created on those services. - for assign in call_assigns: - if isinstance(assign.value.func, ast.Attribute): - continue - if assign.value.func.id == "AcceptableService": - for target in assign.targets: - acceptable_services.add(target.id) - - for assign in call_assigns: - # only consider calls which are attribute accesses, AND - # calls where the object being accessed is in acceptable_services, AND - # calls where the attribute being accessed is the 'api' method. - if ( - isinstance(assign.value.func, ast.Attribute) - and assign.value.func.value.id in acceptable_services - and assign.value.func.attr == "api" - ): - # this is a view. We need to extract the url and methods specified. - # they may be specified positionally or via a keyword. - url = None - name = None - # methods has a default value: - methods = ["GET"] - - # This is a view - the URL is the first positional argument: - args = assign.value.args - if len(args) >= 1: - url = ast.literal_eval(args[0]) - if len(args) >= 2: - name = ast.literal_eval(args[1]) - kwargs = assign.value.keywords - for kwarg in kwargs: - if kwarg.arg == "url": - url = ast.literal_eval(kwarg.value) - if kwarg.arg == "methods": - methods = ast.literal_eval(kwarg.value) - if kwarg.arg == "view_name": - name = ast.literal_eval(kwarg.value) - if url and name: - for target in assign.targets: - acceptable_views[target.id] = { - "url": url, - "name": name, - "methods": methods, - } - - # iterate over all functions, attempting to find the views. - functions = [n for n in ast_tree.body if isinstance(n, ast.FunctionDef)] - for function in functions: - input_schema = None - output_schema = None - doc = ast.get_docstring(function) - api_options_list = [] - for decorator in function.decorator_list: - if not isinstance(decorator, ast.Call): - continue - if isinstance(decorator.func, ast.Attribute): - decorator_name = decorator.func.value.id - # extract version this view was introduced at, which can be - # specified as an arg or a kwarg: - version = None - for kwarg in decorator.keywords: - if kwarg.arg == "introduced_at": - version = ast.literal_eval(kwarg.value) - break - if len(decorator.args) == 1: - version = ast.literal_eval(decorator.args[0]) - - if decorator_name in acceptable_views: - api_options = acceptable_views[decorator_name] - api_options["version"] = version - api_options_list.append(api_options) - else: - decorator_name = decorator.func.id - if decorator_name == "validate_body": - _SimpleNamesResolver(simple_names).visit(decorator.args[0]) - input_schema = ast.literal_eval(decorator.args[0]) - if decorator_name == "validate_output": - _SimpleNamesResolver(simple_names).visit(decorator.args[0]) - output_schema = ast.literal_eval(decorator.args[0]) - for api_options in api_options_list: - schema = ViewSchema( - view_name=api_options["name"], - version=api_options["version"], - input_schema=input_schema, - output_schema=output_schema, - methods=api_options["methods"], - url=api_options["url"], - doc=doc, - ) - schemas_found.append(schema) - return schemas_found - - -def render_value(value): - """Render a value, ensuring that any nested dicts are sorted by key.""" - if isinstance(value, list): - return "[" + ", ".join(render_value(v) for v in value) + "]" - elif isinstance(value, dict): - return ( - "{" - + ", ".join( - "{k!r}: {v}".format(k=k, v=render_value(v)) - for k, v in sorted(value.items()) - ) - + "}" - ) - else: - return repr(value) - - -def render_service_double(service_name, schemas, regenerate_args): - header = textwrap.dedent( - """\ - # This file is AUTO GENERATED. Do not edit this file directly. Instead, - # re-generate it by running '{progname} {regenerate_args}'. - - from acceptable._doubles import service_mock - """.format( - progname=os.path.basename(sys.argv[0]), regenerate_args=regenerate_args - ) - ) - - rendered_schemas = [] - for schema in schemas: - double_name = "%s_%s" % (schema.view_name, schema.version.replace(".", "_")) - rendered_schema = textwrap.dedent( - """\ - {double_name} = service_mock( - service={service_name!r}, - methods={schema.methods!r}, - url={schema.url!r}, - input_schema={input_schema}, - output_schema={output_schema}, - ) - """ - ).format( - double_name=double_name, - schema=schema, - service_name=service_name, - input_schema=render_value(schema.input_schema), - output_schema=render_value(schema.output_schema), - ) - - rendered_schemas.append(rendered_schema) - - rendered_file = "{header}\n\n{schemas}\n".format( - header=header, schemas="\n\n".join(rendered_schemas) - ) - return rendered_file - - -def write_service_double_file(target_root, service_name, rendered): - """Render syntactically valid python service double code.""" - target_path = os.path.join( - target_root, "snapstore_schemas", "service_doubles", "%s.py" % service_name - ) - with open(target_path, "w") as target_file: - target_file.write(rendered) diff --git a/acceptable/_doubles.py b/acceptable/_doubles.py deleted file mode 100644 index 18df368..0000000 --- a/acceptable/_doubles.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright 2017 Canonical Ltd. This software is licensed under the -# GNU Lesser General Public License version 3 (see the file LICENSE). - -"""Service Double implementation. - -The ServiceMock class in this file is used at test-run-time to mock out a call -to a remote service API view. -""" -import functools -import json -from urllib.parse import urljoin - -import responses -from fixtures import Fixture - -from acceptable._validation import validate -from acceptable.mocks import responses_manager - - -def service_mock(service, methods, url, input_schema, output_schema): - return functools.partial( - ServiceMock, service, methods, url, input_schema, output_schema - ) - - -SERVICE_LOCATIONS = {} - - -def set_service_locations(service_locations): - global SERVICE_LOCATIONS - SERVICE_LOCATIONS = service_locations - - -def get_service_locations(): - global SERVICE_LOCATIONS - return SERVICE_LOCATIONS - - -class ServiceMock(Fixture): - # Kept for backwards compatibility - _requests_mock = responses.mock - - def __init__( - self, - service, - methods, - url, - input_schema, - output_schema, - output, - output_status=200, - output_headers=None, - ): - super().__init__() - self._service = service - self._methods = methods - self._url = url - self._input_schema = input_schema - self._output_schema = output_schema - self._output = output - self._output_status = output_status - self._output_headers = output_headers.copy() if output_headers else {} - self._output_headers.setdefault("Content-Type", "application/json") - - def _setUp(self): - if self._output_schema and self._output_status < 300: - error_list = validate(self._output, self._output_schema) - if error_list: - msg = ( - "While setting up a service mock for the '{s._service}' " - "service's '{s._url}' endpoint, the specified output " - "does not match the service's endpoint output schema.\n\n" - "The errors are:\n{errors}\n\n" - ).format(s=self, errors="\n".join(error_list)) - raise AssertionError(msg) - - config = get_service_locations() - service_location = config.get(self._service) - if service_location is None: - raise AssertionError( - "A service mock for the '%s' service was requested, but the " - "mock has not been configured with a location for that " - "service. Ensure set_service_locations has been " - "called before the mock is required, and that the locations " - "dictionary contains a key for the '%s' service." - % (self._service, self._service) - ) - - full_url = urljoin(service_location, self._url) - - def _callback(request): - if self._input_schema: - payload = json.loads(request.body.decode()) - error_list = validate(payload, self._input_schema) - if error_list: - # TODO: raise AssertionError here, since this is in a test? - return ( - 400, - {"Content-Type": "application/json"}, - json.dumps(error_list), - ) - - return (self._output_status, self._output_headers, json.dumps(self._output)) - - responses_manager.attach() - self.addCleanup(responses_manager.detach) - for method in self._methods: - responses.mock.add_callback(method, full_url, _callback) - - @property - def calls(self): - return responses.mock.calls diff --git a/acceptable/generate_doubles.py b/acceptable/generate_doubles.py deleted file mode 100644 index 876aacd..0000000 --- a/acceptable/generate_doubles.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright 2019 Canonical Ltd. This software is licensed under the -# GNU Lesser General Public License version 3 (see the file LICENSE). -import textwrap -from sys import stdout - - -def write_double(stream, api_name, api_data): - double_name = "%s_%s" % (api_name, "1_0") - double_text = textwrap.dedent( - """\ - - {double_name} = service_mock( - service={service!r}, - methods={methods!r}, - url={url!r}, - input_schema={request_schema!r}, - output_schema={response_schema!r}, - ) - - """ - ).format(double_name=double_name, **api_data) - stream.write(double_text) - - -HEADER_TEXT = """\ -# This file is AUTO GENERATED. Do not edit this file directly. Instead, -# re-generate it. - -from acceptable._doubles import service_mock - -""" - - -def write_header(stream): - stream.write(HEADER_TEXT) - - -def generate_service_mock_doubles(metadata, stream=stdout): - write_header(stream) - for module_name, module_data in metadata.items(): - if module_name.startswith("$"): - continue - for api_name, api_data in module_data["apis"].items(): - write_double(stream, api_name, api_data) diff --git a/acceptable/generate_mocks.py b/acceptable/generate_mocks.py deleted file mode 100644 index 0aa2694..0000000 --- a/acceptable/generate_mocks.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2019 Canonical Ltd. This software is licensed under the -# GNU Lesser General Public License version 3 (see the file LICENSE). -from collections import defaultdict -from sys import stdout - - -def write_service_begin(stream, service_name): - begin = ( - "\n" - "{service_name} = ServiceFactory(\n" - " {service_name!r},\n" - " [\n" - ).format(service_name=service_name) - stream.write(begin) - - -def write_service_end(stream): - stream.write(" ]\n)\n\n") - - -def write_endpoint_spec(stream, api_data): - endpoint = ( - " EndpointSpec(\n" - " {api_name!r},\n" - " {url!r},\n" - " {methods!r},\n" - " {request_schema!r},\n" - " {response_schema!r}\n" - " ),\n" - ).format(**api_data) - stream.write(endpoint) - - -HEADER_TEXT = """\ -# This file is AUTO GENERATED. Do not edit this file directly. Instead, -# re-generate it. - -from acceptable.mocks import ServiceFactory, EndpointSpec - -""" - - -def write_header(stream): - stream.write(HEADER_TEXT) - - -def generate_service_factory(metadata, stream=stdout): - write_header(stream) - services = defaultdict(list) - for module_data in metadata.values(): - if not isinstance(module_data, dict): - continue - for api_data in module_data["apis"].values(): - services[api_data["service"]].append(api_data) - for service_name, apis in services.items(): - write_service_begin(stream, service_name) - for api in apis: - write_endpoint_spec(stream, api) - write_service_end(stream) diff --git a/acceptable/mocks.py b/acceptable/mocks.py deleted file mode 100644 index 13509e7..0000000 --- a/acceptable/mocks.py +++ /dev/null @@ -1,469 +0,0 @@ -# Copyright 2019 Canonical Ltd. This software is licensed under the -# GNU Lesser General Public License version 3 (see the file LICENSE). -import re -from collections import namedtuple -from json import dumps as json_dumps -from json import loads as json_loads -from urllib.parse import urljoin - -import responses -from requests.utils import CaseInsensitiveDict - -from acceptable._validation import validate - -from .responses import responses_manager - - -class Attrs(object): - """A utility class allowing the creation of namespaces from a dict. - Also provides an iterator over the items of the original dict. - - This is used by both Service and ServiceMock to create their - endpoints attributes. - - e.g: - a = Attrs(dict(b=1, c=2)) - assert a.b == 1 - assert a.c == 2 - assert dir(a) == ['b', 'c'] - """ - - def __init__(self, attrs): - # I think python name mangling is ok here to help avoid collisions - # between instance attributes and names in attrs - self.__attrs = dict(attrs) - - def __dir__(self): - return list(self.__attrs) - - def __getattr__(self, name): - try: - return self.__attrs[name] - except KeyError: - raise AttributeError(name) - - def __iter__(self): - return iter(self.__attrs.items()) - - -Call = namedtuple("Call", "request response error".split()) - - -class CallRecorder(object): - def __init__(self): - self._calls = [] - - def record(self, mock, request, response, error): - self._calls.append((mock, Call(request, response, error))) - - def get_calls(self): - return [c for m, c in self._calls] - - def get_calls_for(self, mock): - return [c for m, c in self._calls if m == mock] - - def get_calls_for_matching(self, mock, pattern): - if not hasattr(pattern, "search"): - pattern = re.compile(pattern) - return [c for c in self.get_calls_for(mock) if pattern.search(c.request.url)] - - def get_calls_matching(self, pattern): - if not hasattr(pattern, "search"): - pattern = re.compile(pattern) - return [c for m, c in self._calls if pattern.search(c.request.url)] - - -EndpointSpec = namedtuple( - "EndpointSpec", ["name", "location", "methods", "request_schema", "response_schema"] -) - - -VALIDATION_ERROR_TEXT = """ -{}: -{!r} - -did not match schema: -{!r} - -for service {!r} endpoint {!r} on url {!r} errors where: -{!r} -""" - - -class EndpointMock(object): - """Provides methods to check calls made to this endpoint mock""" - - def __init__( - self, - call_recorder, - service_name, - name, - methods, - url, - request_schema, - response_schema, - response_callback, - ): - self._call_recorder = call_recorder - self._service_name = service_name - self._name = name - self._methods = methods - self._url = url - self._request_schema = request_schema - self._response_schema = response_schema - self._response_callback = response_callback - - @property - def service_name(self): - return self._service_name - - @property - def name(self): - return self._name - - @property - def url(self): - return self._url - - @property - def methods(self): - return list(self._methods) - - @property - def call_recorder(self): - return self._call_recorder - - def _validate(self, data_source_name, body, schema): - if schema is not None: - if body is None: - error_list = ["Missing body"] - else: - if isinstance(body, bytes): - body = body.decode("utf-8") - try: - data = json_loads(body) - except ValueError as e: - error_list = ["JSON decoding error: {}".format(e)] - else: - error_list = validate(data, schema) - if error_list: - raise AssertionError( - VALIDATION_ERROR_TEXT.format( - data_source_name, - body, - schema, - self._service_name, - self._name, - self._url, - error_list, - ) - ) - - def _validate_request(self, request): - self._validate("request data", request.body, self._request_schema) - - def _validate_response(self, response_body): - self._validate("response data", response_body, self._response_schema) - - def _record_response( - self, request, response_status, response_headers, response_body - ): - # Shenanigans to get a response object like responses would - # record in calls list - def tmp_callback(request): - return response_status, response_headers, response_body - - callback_response = responses.CallbackResponse( - request.method, self._url, tmp_callback - ) - response = callback_response.get_response(request) - self._call_recorder.record(self, request, response, None) - - def _callback(self, request): - try: - self._validate_request(request) - response_status, response_headers, response_body = self._response_callback( - request - ) - self._validate_response(response_body) - except Exception as exc: - self._call_recorder.record(self, request, None, exc) - raise exc - else: - self._record_response( - request, response_status, response_headers, response_body - ) - return response_status, response_headers, response_body - - def get_calls(self): - return self._call_recorder.get_calls_for(self) - - def get_last_call(self): - return self.get_calls()[-1] - - def get_calls_matching(self, pattern): - return self._call_recorder.get_calls_for_matching(self, pattern) - - def get_call_count(self): - return len(self.get_calls()) - - def was_called(self): - return self.get_call_count() > 0 - - -class EndpointMockContextManager(object): - def __init__( - self, - methods, - call_recorder, - service_name, - name, - url, - request_schema, - response_schema, - response_callback, - ): - self._methods = methods - self._mock = EndpointMock( - call_recorder, - service_name, - name, - methods, - url, - request_schema, - response_schema, - response_callback, - ) - - def _start(self): - responses_manager.attach_callback( - self._methods, self._mock._url, self._mock._callback - ) - - def _stop(self): - responses_manager.detach_callback( - self._methods, self._mock._url, self._mock._callback - ) - - def __enter__(self): - self._start() - return self._mock - - def __exit__(self, *args): - self._stop() - - -def response_callback_factory(status=200, headers=None, body=None, json=None): - if headers is None: - headers = CaseInsensitiveDict() - else: - headers = CaseInsensitiveDict(headers) - if json is not None: - assert body is None - body = json_dumps(json).encode("utf-8") - if "Content-Type" not in headers: - headers["Content-Type"] = "application/json" - - def response_callback(request): - return status, headers, body - - return response_callback - - -ok_no_content_response_callback = response_callback_factory() - - -class Endpoint(object): - """Configurable endpoint. - - Callable to create a context manager which activates and returns a mock - for this endpoint. - """ - - def __init__(self, base_url, service_name, endpoint_spec, response_callback=None): - if isinstance(endpoint_spec.location, str): - self._url = urljoin(base_url, endpoint_spec.location) - if self._url.find("<") > 0 and self._url.find("<") > 0: - # we know that there are variable references in the url, so - # let's change the url in a regexp - self._url = ( - self._url.replace("<", "(?P<") - .replace(">", r">\S+)") - .replace("int:", "") - .replace("path:", "") - ) - self._url = re.compile(self._url) - else: - self._url = endpoint_spec.location - self._service_name = service_name - self._name = endpoint_spec.name - self._methods = list(endpoint_spec.methods) - self._request_schema = endpoint_spec.request_schema - self._response_schema = endpoint_spec.response_schema - self._response_callback = response_callback - - @property - def service_name(self): - return self._service_name - - @property - def name(self): - return self._name - - @property - def url(self): - return self._url - - @property - def methods(self): - return list(self._methods) - - def disable_request_validation(self): - self._request_schema = None - - def disable_response_validation(self): - self._response_schema = None - - def disable_validation(self): - self.disable_request_validation() - self.disable_response_validation() - - def set_request_schema(self, schema): - self._request_schema = schema - - def set_response_schema(self, schema): - self._response_schema = schema - - def set_response_callback(self, callback): - self._response_callback = callback - - def set_response(self, status=200, headers=None, body=None, json=None): - self._response_callback = response_callback_factory(status, headers, body, json) - - def __call__(self, response_callback=None, call_recorder=None): - if call_recorder is None: - call_recorder = CallRecorder() - if response_callback is None: - response_callback = self._response_callback - if response_callback is None: - response_callback = ok_no_content_response_callback - return EndpointMockContextManager( - self._methods, - call_recorder, - self._service_name, - self._name, - self._url, - self._request_schema, - self._response_schema, - response_callback, - ) - - -class ServiceMock(object): - """Provides access to the endpoint mocks for this service and some functions - to get calls made to the services endpoints. - """ - - def __init__(self, call_recorder, endpoints): - self._call_recorder = call_recorder - mocks = {} - self._endpoint_context_managers = [] - for name, endpoint in endpoints.items(): - ecm = endpoint(call_recorder=call_recorder) - mocks[name] = ecm._mock - self._endpoint_context_managers.append(ecm) - self.endpoints = Attrs(mocks) - - def get_calls(self): - return self._call_recorder.get_calls() - - def get_calls_matching(self, pattern): - return self._call_recorder.get_calls_matching(pattern) - - def get_call_count(self): - return len(self.get_calls()) - - def was_called(self): - return self.get_call_count() > 0 - - def _start(self): - for ecm in self._endpoint_context_managers: - ecm._start() - - def _stop(self): - for ecm in self._endpoint_context_managers: - ecm._stop() - - -class ServiceMockContextManager(object): - def __init__(self, call_recorder, endpoints): - self._mock = ServiceMock(call_recorder, endpoints) - - def __enter__(self): - self._mock._start() - return self._mock - - def __exit__(self, *args): - self._mock._stop() - - -class Service(object): - """Has configurable endpoints (.endpoints.*). - - Callable to create a context manager which will mock all the endpoints on - the service. - - Endpoints can also be individually called to return a context manager - which just mocks that endpoint. - """ - - def __init__(self, base_url, name, endpoint_specs): - self._base_url = base_url - self._name = name - endpoints = {} - for endpoint_spec in endpoint_specs: - endpoints[endpoint_spec.name] = Endpoint( - self._base_url, self._name, endpoint_spec - ) - self.endpoints = Attrs(endpoints) - - @property - def name(self): - return self._name - - @property - def base_url(self): - return self._base_url - - def __call__(self, call_recorder=None): - if call_recorder is None: - call_recorder = CallRecorder() - return ServiceMockContextManager(call_recorder, dict(self.endpoints)) - - -class ServiceFactory(object): - """Callable to create Service instances. - - You can create multiple instances of a Service and configure each - independently. - """ - - def __init__(self, name, endpoint_specs): - self._name = name - self._endpoint_specs = endpoint_specs - - @property - def name(self): - return self._name - - def __call__(self, base_url): - return Service(base_url, self.name, self._endpoint_specs) - - -__ALL__ = [ - "responses_mock_context", - "response_callback_factory", - "ServiceFactory", - "EndpointSpec", - "Endpoint", -] diff --git a/acceptable/tests/test_build_doubles.py b/acceptable/tests/test_build_doubles.py deleted file mode 100644 index cead9d0..0000000 --- a/acceptable/tests/test_build_doubles.py +++ /dev/null @@ -1,584 +0,0 @@ -# Copyright 2017-2018 Canonical Ltd. This software is licensed under the -# GNU Lesser General Public License version 3 (see the file LICENSE). -import argparse -import ast -import os.path -import sys -from textwrap import dedent - -import fixtures -from testtools import TestCase -from testtools.matchers import Contains - -from acceptable import _build_doubles - - -class BuildDoubleTestCase(TestCase): - def setUp(self): - super().setUp() - - -class ExtractSchemasFromSourceTests(BuildDoubleTestCase): - def test_invalid_source(self): - self.assertRaises( - SyntaxError, - _build_doubles.extract_schemas_from_source, - "This is not valid python source!", - ) - - def test_returns_empty_list_on_empty_source(self): - self.assertEqual([], _build_doubles.extract_schemas_from_source("")) - - def test_ignores_undecorated_functions(self): - observed = _build_doubles.extract_schemas_from_source( - dedent( - """ - def my_view(): - pass - """ - ) - ) - self.assertEqual([], observed) - - def test_can_extract_acceptable_view(self): - [schema] = _build_doubles.extract_schemas_from_source( - dedent( - ''' - - service = AcceptableService('vendor') - - root_api = service.api('/', 'root') - - @root_api.view(introduced_at='1.0') - def my_view(): - """Documentation.""" - ''' - ) - ) - - self.assertEqual("root", schema.view_name) - self.assertEqual("/", schema.url) - self.assertEqual("1.0", schema.version) - self.assertEqual(["GET"], schema.methods) - self.assertEqual("Documentation.", schema.doc) - self.assertEqual(None, schema.input_schema) - self.assertEqual(None, schema.output_schema) - - def test_can_extract_acceptable_view_no_docstring(self): - [schema] = _build_doubles.extract_schemas_from_source( - dedent( - """ - - service = AcceptableService('vendor') - - root_api = service.api('/', 'root') - - @root_api.view(introduced_at='1.0') - def my_view(): - pass - """ - ) - ) - - self.assertEqual("root", schema.view_name) - self.assertEqual("/", schema.url) - self.assertEqual("1.0", schema.version) - self.assertEqual(["GET"], schema.methods) - self.assertEqual(None, schema.doc) - self.assertEqual(None, schema.input_schema) - self.assertEqual(None, schema.output_schema) - - def test_can_extract_acceptable_view_multiline_docstring(self): - [schema] = _build_doubles.extract_schemas_from_source( - dedent( - ''' - - service = AcceptableService('vendor') - - root_api = service.api('/', 'root') - - @root_api.view(introduced_at='1.0') - def my_view(): - """Documentation. - - More Documentation. - """ - ''' - ) - ) - - self.assertEqual("root", schema.view_name) - self.assertEqual("/", schema.url) - self.assertEqual("1.0", schema.version) - self.assertEqual(["GET"], schema.methods) - self.assertEqual("Documentation.\n\nMore Documentation.", schema.doc) - self.assertEqual(None, schema.input_schema) - self.assertEqual(None, schema.output_schema) - - def test_can_extract_schema_with_input_schema(self): - [schema] = _build_doubles.extract_schemas_from_source( - dedent( - ''' - - service = AcceptableService('vendor') - - root_api = service.api('/', 'root') - - @root_api.view(introduced_at='1.0') - @validate_body({'type': 'object'}) - def my_view(): - """Documentation.""" - ''' - ) - ) - - self.assertEqual("root", schema.view_name) - self.assertEqual("/", schema.url) - self.assertEqual("1.0", schema.version) - self.assertEqual(["GET"], schema.methods) - self.assertEqual({"type": "object"}, schema.input_schema) - self.assertEqual("Documentation.", schema.doc) - self.assertEqual(None, schema.output_schema) - - def test_can_extract_schema_with_output_schema(self): - [schema] = _build_doubles.extract_schemas_from_source( - dedent( - ''' - - service = AcceptableService('vendor') - - root_api = service.api('/', 'root') - - @root_api.view(introduced_at='1.0') - @validate_output({'type': 'object'}) - def my_view(): - """Documentation.""" - ''' - ) - ) - - self.assertEqual("root", schema.view_name) - self.assertEqual("/", schema.url) - self.assertEqual("1.0", schema.version) - self.assertEqual(["GET"], schema.methods) - self.assertEqual("Documentation.", schema.doc) - self.assertEqual(None, schema.input_schema) - self.assertEqual({"type": "object"}, schema.output_schema) - - def test_can_extract_schema_with_methods(self): - [schema] = _build_doubles.extract_schemas_from_source( - dedent( - ''' - - service = AcceptableService('vendor') - - root_api = service.api('/', 'root', methods=['POST', 'PUT']) - - @root_api.view(introduced_at='1.0') - def my_view(): - """Documentation.""" - ''' - ) - ) - - self.assertEqual("root", schema.view_name) - self.assertEqual("/", schema.url) - self.assertEqual("1.0", schema.version) - self.assertEqual(["POST", "PUT"], schema.methods) - self.assertEqual("Documentation.", schema.doc) - self.assertEqual(None, schema.input_schema) - self.assertEqual(None, schema.output_schema) - - def test_url_can_be_specified_with_kwarg(self): - [schema] = _build_doubles.extract_schemas_from_source( - dedent( - ''' - - service = AcceptableService('vendor') - - root_api = service.api(url='/foo', view_name='root') - - @root_api.view(introduced_at='1.0') - def my_view(): - """Documentation.""" - ''' - ) - ) - - self.assertEqual("root", schema.view_name) - self.assertEqual("/foo", schema.url) - self.assertEqual("1.0", schema.version) - self.assertEqual(["GET"], schema.methods) - self.assertEqual("Documentation.", schema.doc) - self.assertEqual(None, schema.input_schema) - self.assertEqual(None, schema.output_schema) - - def test_can_extract_version_with_kwarg(self): - [schema] = _build_doubles.extract_schemas_from_source( - dedent( - ''' - - service = AcceptableService('vendor') - - root_api = service.api('/foo', 'root') - - @root_api.view(introduced_at='1.1') - def my_view(): - """Documentation.""" - ''' - ) - ) - - self.assertEqual("root", schema.view_name) - self.assertEqual("/foo", schema.url) - self.assertEqual("1.1", schema.version) - self.assertEqual(["GET"], schema.methods) - self.assertEqual("Documentation.", schema.doc) - self.assertEqual(None, schema.input_schema) - self.assertEqual(None, schema.output_schema) - - def test_can_extract_multiple_versioned_schemas(self): - [schema1, schema2] = _build_doubles.extract_schemas_from_source( - dedent( - ''' - - service = AcceptableService('vendor') - - root_api = service.api('/foo', 'root') - - @root_api.view(introduced_at='1.1') - def my_view(): - """Documentation.""" - - - @root_api.view(introduced_at='1.2') - def my_view(): - """Documentation.""" - ''' - ) - ) - - self.assertEqual("1.1", schema1.version) - - def test_can_extract_multiple_names_for_one_view(self): - # This is helpful when in the process of renaming a view. - [schema1, schema2] = _build_doubles.extract_schemas_from_source( - dedent( - ''' - - service = AcceptableService('vendor') - - old_api = service.api('/old', 'old') - new_api = service.api('/new', 'new') - - @old_api.view(introduced_at='1.0') - @new_api.view(introduced_at='1.0') - @validate_body({'type': 'object'}) - @validate_output({'type': 'array'}) - def my_view(): - """Documentation.""" - ''' - ) - ) - - self.assertEqual("old", schema1.view_name) - self.assertEqual("/old", schema1.url) - self.assertEqual("1.0", schema1.version) - self.assertEqual(["GET"], schema1.methods) - self.assertEqual("Documentation.", schema1.doc) - self.assertEqual({"type": "object"}, schema1.input_schema) - self.assertEqual({"type": "array"}, schema1.output_schema) - self.assertEqual("new", schema2.view_name) - self.assertEqual("/new", schema2.url) - self.assertEqual("1.0", schema2.version) - self.assertEqual(["GET"], schema2.methods) - self.assertEqual("Documentation.", schema2.doc) - self.assertEqual({"type": "object"}, schema2.input_schema) - self.assertEqual({"type": "array"}, schema2.output_schema) - - def test_can_specify_version_as_arg(self): - [schema] = _build_doubles.extract_schemas_from_source( - dedent( - ''' - - service = AcceptableService('vendor') - - root_api = service.api('/foo', 'root') - - @root_api.view('1.5') - def my_view(): - """Documentation.""" - ''' - ) - ) - - self.assertEqual("1.5", schema.version) - - def test_handles_other_assignments(self): - self.assertEqual([], _build_doubles.extract_schemas_from_source("foo = {}")) - - def test_can_extract_schema_with_input_name(self): - [schema] = _build_doubles.extract_schemas_from_source( - dedent( - ''' - - FOOBAR = 'object' - - service = AcceptableService('vendor') - - root_api = service.api('/', 'root') - - @root_api.view(introduced_at='1.0') - @validate_body({'type': FOOBAR}) - def my_view(): - """Documentation.""" - ''' - ) - ) - - self.assertEqual("root", schema.view_name) - self.assertEqual("/", schema.url) - self.assertEqual("1.0", schema.version) - self.assertEqual(["GET"], schema.methods) - self.assertEqual({"type": "object"}, schema.input_schema) - self.assertEqual("Documentation.", schema.doc) - self.assertEqual(None, schema.output_schema) - - def test_can_extract_schema_with_output_name(self): - [schema] = _build_doubles.extract_schemas_from_source( - dedent( - ''' - - FOOBAR = 'object' - - service = AcceptableService('vendor') - - root_api = service.api('/', 'root') - - @root_api.view(introduced_at='1.0') - @validate_output({'type': FOOBAR}) - def my_view(): - """Documentation.""" - ''' - ) - ) - - self.assertEqual("root", schema.view_name) - self.assertEqual("/", schema.url) - self.assertEqual("1.0", schema.version) - self.assertEqual(["GET"], schema.methods) - self.assertEqual("Documentation.", schema.doc) - self.assertEqual(None, schema.input_schema) - self.assertEqual({"type": "object"}, schema.output_schema) - - -class ExtractSchemasFromFileTests(BuildDoubleTestCase): - def test_logs_on_missing_file(self): - workdir = self.useFixture(fixtures.TempDir()) - fake_logger = self.useFixture(fixtures.FakeLogger()) - - bad_path = os.path.join(workdir.path, "path_does_not_exist") - result = _build_doubles.extract_schemas_from_file(bad_path) - - self.assertIsNone(result) - self.assertThat( - fake_logger.output, Contains("Extracting schemas from %s" % bad_path) - ) - self.assertThat( - fake_logger.output, - Contains("Cannot extract schemas: No such file or directory"), - ) - - def test_logs_on_no_permissions(self): - workdir = self.useFixture(fixtures.TempDir()) - fake_logger = self.useFixture(fixtures.FakeLogger()) - - bad_path = os.path.join(workdir.path, "path_not_readable") - with open(bad_path, "w") as f: - f.write("# You can't read me") - os.chmod(bad_path, 0) - result = _build_doubles.extract_schemas_from_file(bad_path) - - self.assertIsNone(result) - self.assertThat( - fake_logger.output, Contains("Extracting schemas from %s" % bad_path) - ) - self.assertThat( - fake_logger.output, Contains("Cannot extract schemas: Permission denied") - ) - - def test_logs_on_syntax_error(self): - workdir = self.useFixture(fixtures.TempDir()) - fake_logger = self.useFixture(fixtures.FakeLogger()) - - bad_path = os.path.join(workdir.path, "foo.py") - with open(bad_path, "w") as f: - f.write("not valid pyton") - - result = _build_doubles.extract_schemas_from_file(bad_path) - - self.assertIsNone(result) - self.assertThat( - fake_logger.output, Contains("Extracting schemas from %s" % bad_path) - ) - self.assertThat( - fake_logger.output, - Contains("Cannot extract schemas: invalid syntax (foo.py, line 1)"), - ) - - def test_logs_on_schema_extraction(self): - workdir = self.useFixture(fixtures.TempDir()) - fake_logger = self.useFixture(fixtures.FakeLogger()) - - good_path = os.path.join(workdir.path, "my.py") - with open(good_path, "w") as f: - f.write( - dedent( - ''' - service = AcceptableService('vendor') - - root_api = service.api('/', 'root') - - @root_api.view(introduced_at='1.0') - def my_view(): - """Documentation.""" - ''' - ) - ) - [schema] = _build_doubles.extract_schemas_from_file(good_path) - - self.assertEqual("root", schema.view_name) - self.assertEqual("1.0", schema.version) - self.assertEqual(["GET"], schema.methods) - self.assertEqual("Documentation.", schema.doc) - self.assertEqual(None, schema.input_schema) - self.assertEqual(None, schema.output_schema) - - self.assertThat( - fake_logger.output, Contains("Extracting schemas from %s" % good_path) - ) - self.assertThat(fake_logger.output, Contains("Extracted 1 schema")) - - -# To support testing, we need a version of ArgumentParser that doesn't call -# sys.exit on error, but rather throws an exception, so we can catch that in -# our tests: -class SaneArgumentParser(argparse.ArgumentParser): - def error(self, message): - raise RuntimeError(message) - - -class ParseArgsTests(BuildDoubleTestCase): - def test_error_with_no_args(self): - self.assertRaises( - RuntimeError, _build_doubles.parse_args, [], SaneArgumentParser - ) - - def test_scan_file_requires_file(self): - self.assertRaises( - RuntimeError, _build_doubles.parse_args, ["scan-file"], SaneArgumentParser - ) - - def test_can_scan_file(self): - args = _build_doubles.parse_args(["scan-file", "some-path"]) - self.assertEqual("some-path", args.file) - self.assertEqual(_build_doubles.scan_file, args.func) - - def test_build_requires_file(self): - self.assertRaises( - RuntimeError, _build_doubles.parse_args, ["build"], SaneArgumentParser - ) - - def test_can_build(self): - args = _build_doubles.parse_args(["build", "config-file"]) - self.assertEqual("config-file", args.config_file) - self.assertEqual(_build_doubles.build_service_doubles, args.func) - - -class RenderValueTests(BuildDoubleTestCase): - def test_plain(self): - self.assertEqual("'foo'", _build_doubles.render_value("foo")) - - def test_list(self): - value = [{"type": "object", "properties": {}}, {"type": "string"}] - rendered = "[{'properties': {}, 'type': 'object'}, {'type': 'string'}]" - self.assertEqual(rendered, _build_doubles.render_value(value)) - - def test_dict(self): - value = { - "type": "object", - "properties": {"foo": {"type": "string"}, "bar": {"type": "integer"}}, - "required": ["foo"], - } - rendered = ( - "{" - "'properties': " - "{'bar': {'type': 'integer'}, 'foo': {'type': 'string'}}, " - "'required': ['foo'], " - "'type': 'object'" - "}" - ) - self.assertEqual(rendered, _build_doubles.render_value(value)) - - -class RenderServiceDoubleTests(BuildDoubleTestCase): - def assertIsValidPython(self, source): - try: - ast.parse(source) - except SyntaxError as e: - self.fail(str(e)) - - def test_renders_for_empty_schema_list(self): - source = _build_doubles.render_service_double("foo", [], "build config-file") - self.assertIsValidPython(source) - - def test_renders_for_single_schema(self): - schema = _build_doubles.ViewSchema( - view_name="some_view", - version="1.3", - input_schema=None, - output_schema=None, - methods=["GET"], - url="/foo", - doc=None, - ) - source = _build_doubles.render_service_double( - "foo", [schema], "build config-file" - ) - self.assertIsValidPython(source) - - def test_autogenerated_message(self): - source = _build_doubles.render_service_double("foo", [], "build config-file") - self.assertIn( - "re-generate it by running '%s build config-file'" - % (os.path.basename(sys.argv[0])), - source, - ) - - def test_input_and_output_schemas_are_sorted(self): - schema = _build_doubles.ViewSchema( - view_name="some_view", - version="1.3", - input_schema={"type": "object", "properties": {"item": {"type": "string"}}}, - output_schema={ - "type": "object", - "properties": {"item": {"type": "string"}}, - }, - methods=["GET"], - url="/foo", - doc=None, - ) - source = _build_doubles.render_service_double( - "foo", [schema], "build config-file" - ) - self.assertIsValidPython(source) - self.assertIn( - "input_schema={'properties': {'item': {'type': 'string'}}, " - "'type': 'object'}", - source, - ) - self.assertIn( - "output_schema={'properties': {'item': {'type': 'string'}}, " - "'type': 'object'}", - source, - ) diff --git a/acceptable/tests/test_doubles.py b/acceptable/tests/test_doubles.py deleted file mode 100644 index f6f7c3d..0000000 --- a/acceptable/tests/test_doubles.py +++ /dev/null @@ -1,294 +0,0 @@ -# Copyright 2017 Canonical Ltd. This software is licensed under the -# GNU Lesser General Public License version 3 (see the file LICENSE). -import json - -import requests -from testtools import TestCase -from testtools.matchers import Contains - -from acceptable._doubles import ServiceMock, service_mock, set_service_locations - - -class ServiceMockTests(TestCase): - def setUp(self): - super().setUp() - # service locations are cached between tests. This should eventually - # be fixed, but until then it's easier to set them to an empty dict at - # the start of every test: - set_service_locations({}) - - def test_raises_on_incompatible_output_value(self): - double = ServiceMock( - service="foo", - methods=["POST"], - url="/", - input_schema=None, - output_schema={"type": "object"}, - output=[], - ) - # Note that we can't use 'double.setUp' as the method here, since the - # fixture catches any exceptions raised by _setUp and re-raises a - # different exception instance. - e = self.assertRaises(AssertionError, double._setUp) - self.assertThat( - str(e), - Contains( - "While setting up a service mock for the 'foo' service's '/' " - "endpoint, the specified output does not match the service's " - "endpoint output schema." - ), - ) - - def test_raises_when_service_location_has_not_been_set(self): - double = ServiceMock( - service="foo", - methods=["POST"], - url="/", - input_schema=None, - output_schema={"type": "array"}, - output=[], - ) - e = self.assertRaises(AssertionError, double._setUp) - self.assertEqual( - "A service mock for the 'foo' service was requested, but the mock " - "has not been configured with a location for that service. " - "Ensure set_service_locations has been called before the mock is " - "required, and that the locations dictionary contains a key for " - "the 'foo' service.", - str(e), - ) - - def test_can_construct_double_with_output_schema(self): - double = ServiceMock( - service="foo", - methods=["POST"], - url="/", - input_schema=None, - output_schema={"type": "array"}, - output=[], - ) - set_service_locations(dict(foo="http://localhost:1234/")) - - self.useFixture(double) - - resp = requests.post("http://localhost:1234/") - - self.assertEqual(200, resp.status_code) - self.assertEqual([], resp.json()) - - def test_can_construct_double_with_input_schema_and_invalid_payload(self): - double = ServiceMock( - service="foo", - methods=["POST"], - url="/", - input_schema={"type": "object"}, - output_schema={"type": "array"}, - output=[], - ) - set_service_locations(dict(foo="http://localhost:1234/")) - - self.useFixture(double) - - resp = requests.post("http://localhost:1234/", json=[]) - - self.assertEqual(400, resp.status_code) - self.assertEqual(["[] is not of type 'object' at /"], resp.json()) - - def test_can_construct_double_with_input_schema_and_valid_payload(self): - double = ServiceMock( - service="foo", - methods=["POST"], - url="/", - input_schema={"type": "object"}, - output_schema={"type": "array"}, - output=[], - ) - set_service_locations(dict(foo="http://localhost:1234/")) - - self.useFixture(double) - - resp = requests.post("http://localhost:1234/", json={}) - - self.assertEqual(200, resp.status_code) - self.assertEqual([], resp.json()) - - def test_can_construct_double_with_error_and_different_output_schema(self): - error = {"error_list": {"code": "test"}} - double = ServiceMock( - service="foo", - methods=["POST"], - url="/", - input_schema=None, - output_schema={"type": "object"}, - output_status=400, - output=error, - ) - set_service_locations(dict(foo="http://localhost:1234/")) - - self.useFixture(double) - - resp = requests.post("http://localhost:1234/") - - self.assertEqual(400, resp.status_code) - self.assertEqual(error, resp.json()) - - def test_can_construct_double_with_custom_headers(self): - custom = {"Cool-Header": "What a wonderful life"} - double = ServiceMock( - service="foo", - methods=["POST"], - url="/", - input_schema=None, - output_schema={"type": "object"}, - output_headers=custom, - output={"ok": True}, - ) - set_service_locations(dict(foo="http://localhost:1234/")) - - self.useFixture(double) - - resp = requests.post("http://localhost:1234/") - - self.assertEqual(200, resp.status_code) - self.assertEqual({"ok": True}, resp.json()) - custom["Content-Type"] = "application/json" - self.assertEqual(custom, resp.headers) - - def test_can_construct_double_given_content_type_respected(self): - custom = {"Content-Type": "not-json"} - double = ServiceMock( - service="foo", - methods=["POST"], - url="/", - input_schema=None, - output_schema={"type": "object"}, - output_headers=custom, - output={"ok": True}, - ) - set_service_locations(dict(foo="http://localhost:1234/")) - - self.useFixture(double) - - resp = requests.post("http://localhost:1234/") - - self.assertEqual(200, resp.status_code) - self.assertEqual({"ok": True}, resp.json()) - self.assertEqual(custom, resp.headers) - - def test_mock_records_calls(self): - double = ServiceMock( - service="foo", - methods=["POST"], - url="/", - input_schema={"type": "object"}, - output_schema={"type": "array"}, - output=[], - ) - set_service_locations(dict(foo="http://localhost:1234/")) - - self.useFixture(double) - - requests.post("http://localhost:1234/", json={"call": 1}) - requests.post("http://localhost:1234/", json={"call": 2}) - - call1, call2 = double.calls - self.assertEqual(json.loads(call1.request.body.decode()), {"call": 1}) - self.assertEqual(json.loads(call2.request.body.decode()), {"call": 2}) - - def test_mock_regards_url(self): - double = ServiceMock( - service="foo", - methods=["POST"], - url="/foo", - input_schema={"type": "object"}, - output_schema={"type": "array"}, - output=[], - ) - set_service_locations(dict(foo="http://localhost:1234/")) - - self.useFixture(double) - - self.assertRaises( - requests.exceptions.ConnectionError, - requests.post, - "http://localhost:1234/bar", - json={}, - ) - - def test_mock_regards_method(self): - double = ServiceMock( - service="foo", - methods=["GET"], - url="/foo", - input_schema={"type": "object"}, - output_schema={"type": "array"}, - output=[], - ) - set_service_locations(dict(foo="http://localhost:1234/")) - - self.useFixture(double) - - self.assertRaises( - requests.exceptions.ConnectionError, - requests.post, - "http://localhost:1234/bar", - json={}, - ) - - def test_mock_works_with_multiple_methods(self): - double = ServiceMock( - service="foo", - methods=["GET", "POST", "PATCH"], - url="/foo", - input_schema={"type": "object"}, - output_schema={"type": "array"}, - output=[], - ) - set_service_locations(dict(foo="http://localhost:1234/")) - - self.useFixture(double) - - self.assertEqual( - 200, requests.post("http://localhost:1234/foo", json={}).status_code - ) - self.assertEqual( - 200, requests.get("http://localhost:1234/foo", json={}).status_code - ) - self.assertEqual( - 200, requests.patch("http://localhost:1234/foo", json={}).status_code - ) - - def test_mock_output_status(self): - double = ServiceMock( - service="foo", - methods=["POST"], - url="/foo", - input_schema={"type": "object"}, - output_schema={"type": "array"}, - output=[], - output_status=201, - ) - set_service_locations(dict(foo="http://localhost:1234/")) - - self.useFixture(double) - - self.assertEqual( - 201, requests.post("http://localhost:1234/foo", json={}).status_code - ) - - def test_service_mock(self): - double_factory = service_mock( - service="foo", - methods=["GET"], - url="/foo", - input_schema={"type": "object"}, - output_schema={"type": "array"}, - ) - double = double_factory([]) - - self.assertEqual("foo", double._service) - self.assertEqual(["GET"], double._methods) - self.assertEqual("/foo", double._url) - self.assertEqual({"type": "object"}, double._input_schema) - self.assertEqual({"type": "array"}, double._output_schema) - self.assertEqual([], double._output) diff --git a/acceptable/tests/test_generate_doubles.py b/acceptable/tests/test_generate_doubles.py deleted file mode 100644 index b90eb4e..0000000 --- a/acceptable/tests/test_generate_doubles.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright 2019 Canonical Ltd. This software is licensed under the -# GNU Lesser General Public License version 3 (see the file LICENSE). -import json -from io import StringIO - -import testtools - -from acceptable import generate_doubles - - -class GenerateDoublesTests(testtools.TestCase): - def test_generate_service_mock_doubles_from_example(self): - stream = StringIO() - with open("examples/current_api.json") as f: - metadata = json.load(f) - generate_doubles.generate_service_mock_doubles(metadata, stream=stream) - self.assertIn("foo_1_0 = service_mock", stream.getvalue()) - # check generated code is valid - exec(stream.getvalue(), {}, {}) diff --git a/acceptable/tests/test_mocks.py b/acceptable/tests/test_mocks.py deleted file mode 100644 index 16368e4..0000000 --- a/acceptable/tests/test_mocks.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright 2019 Canonical Ltd. This software is licensed under the -# GNU Lesser General Public License version 3 (see the file LICENSE). -import requests -import testtools -from testtools import ExpectedException -from testtools.assertions import assert_that -from testtools.matchers import Equals, HasLength - -from acceptable.mocks import ( - CallRecorder, - Endpoint, - EndpointMock, - EndpointSpec, - ServiceFactory, -) -from acceptable.responses import responses_mock_context - - -class EventMockTests(testtools.TestCase): - def test_successful_event_mock(self): - call_recorder = CallRecorder() - response_tuple = (200, {}, b"999\n") - mock = EndpointMock( - call_recorder, - "service", - "api", - ["GET"], - "http://example.com", - request_schema={"type": "string"}, - response_schema={"type": "number"}, - response_callback=lambda req: response_tuple, - ) - request = requests.Request("GET", "http://example.com", json="hello").prepare() - assert_that(mock._callback(request), Equals(response_tuple)) - - def test_validation_failure_event_mock(self): - call_recorder = CallRecorder() - mock = EndpointMock( - call_recorder, - "service", - "api", - ["GET"], - "http://example.com", - request_schema={"type": "number"}, - response_schema=None, - response_callback=lambda req: (200, {}, b""), - ) - request = requests.Request("GET", "http://example.com", json="hello").prepare() - with ExpectedException(AssertionError): - mock._callback(request) - - -class ServiceTests(testtools.TestCase): - def make_test_service(self): - endpoints = [ - EndpointSpec( - "test_endpoint", - "test-endpoint", - ["GET", "POST"], - {"type": "number"}, - {"type": "number"}, - ), - EndpointSpec("no_validation", "no-validation", ["GET"], None, None), - ] - service_factory = ServiceFactory("test-service", endpoints) - service = service_factory("http://example.com") - service.endpoints.test_endpoint.set_response(json=999) - return service - - def test_simple_endpoint_cm(self): - service = self.make_test_service() - with service.endpoints.test_endpoint() as mock: - requests.get("http://example.com/test-endpoint", json=888) - assert_that(mock.get_calls(), HasLength(1)) - - def test_service_cm(self): - service = self.make_test_service() - with service() as mock: - requests.get("http://example.com/test-endpoint", json=888) - assert_that(mock.endpoints.test_endpoint.get_calls(), HasLength(1)) - - def test_validation_failure(self): - service = self.make_test_service() - with service() as mock: - with ExpectedException(AssertionError): - requests.get("http://example.com/test-endpoint", json="string") - assert_that(mock.endpoints.test_endpoint.get_calls(), HasLength(1)) - - def test_responses_manager_resets_responses_mock(self): - service = self.make_test_service() - with service(): - with responses_mock_context() as responses_mock: - responses_mock.add("GET", "http://example.com/responses-test", b"test") - requests.get("http://example.com/test-endpoint", json=888) - requests.get("http://example.com/responses-test") - assert_that( - responses_mock.calls, - HasLength(2), - "RequestMock call count inside 2", - ) - assert_that( - responses_mock.calls, HasLength(2), "RequestMock call count inside 1" - ) - assert_that( - responses_mock.calls, HasLength(0), "RequestMock call count outside" - ) - - def test_calls_matching(self): - service = self.make_test_service() - with service() as service_mock: - requests.get("http://example.com/test-endpoint", json=888) - requests.get("http://example.com/no-validation") - assert_that(service_mock.get_calls_matching("no-validation$"), HasLength(1)) - - def test_endpoint_missing_body(self): - ep = Endpoint( - "http://example.com", - "test", - EndpointSpec("endpoint", "endpoint", ["GET"], True, None), - ) - with ep(): - with ExpectedException(AssertionError): - requests.get("http://example.com/endpoint") - - def test_endpoint_empty_body_json_decoding_error(self): - ep = Endpoint( - "http://example.com", - "test", - EndpointSpec("endpoint", "endpoint", ["GET"], True, None), - ) - with ep(): - with ExpectedException(AssertionError): - requests.get("http://example.com/endpoint", data=b"") diff --git a/setup.py b/setup.py index f58b181..22d3330 100755 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ from setuptools import find_packages, setup -VERSION = "0.39" +VERSION = "0.40" setup( name="acceptable", @@ -22,10 +22,5 @@ extras_require=dict(flask=["Flask"], django=["django>=2.1,<3"]), test_suite="acceptable.tests", include_package_data=True, - entry_points={ - "console_scripts": [ - "build_service_doubles = acceptable._build_doubles:main", - "acceptable = acceptable.__main__:main", - ] - }, + entry_points={"console_scripts": ["acceptable = acceptable.__main__:main"]}, )