Skip to content

Commit

Permalink
refactor: add tests to get
Browse files Browse the repository at this point in the history
  • Loading branch information
Amazia Gur authored and Amazia Gur committed Oct 29, 2024
1 parent 29686e9 commit dbf372f
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 26 deletions.
21 changes: 17 additions & 4 deletions mockingbird/handler.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import http.server
import json
import logging
from typing import Dict, Tuple, Any
from typing import Dict, Tuple, Any, Optional, Callable


class MockingbirdHandler(http.server.SimpleHTTPRequestHandler):
stubs: Dict[str, Dict[str, Tuple[int, Any]]] = {}
stubs: Dict[str, Dict[str, Tuple[int, Any, Optional[Callable[[], Tuple[int, Any]]]]]] = {}

def do_GET(self):
self._handle_request("GET")
Expand All @@ -21,11 +21,24 @@ def do_DELETE(self):

def _handle_request(self, method: str):
if self.path in self.stubs[method]:
status_code, response = self.stubs[method][self.path]
status_code, response, response_func = self.stubs[method][self.path]
if response_func:
# If a response function is defined, call it to get status and response
status_code, response = response_func()

self.send_response(status_code)
self.send_header('Content-Type', 'application/json')
self.end_headers()
self.wfile.write(json.dumps(response).encode('utf-8'))

if isinstance(response, dict):
self.wfile.write(json.dumps(response).encode('utf-8'))
elif isinstance(response, str):
self.wfile.write(response.encode('utf-8'))
elif isinstance(response, bytes):
self.wfile.write(response)
else:
self.wfile.write(str(response).encode('utf-8'))

self.log_message("Responded with %d for %s request to %s", status_code, method, self.path)
else:
self.send_response(404)
Expand Down
29 changes: 19 additions & 10 deletions mockingbird/mockingbird.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import atexit
import http.server
import json
import logging
import socketserver
import threading
from typing import Dict, Tuple, Any
from typing import Dict, Tuple, Any, Callable, Optional

from mockingbird.handler import MockingbirdHandler
from mockingbird.route import Route
Expand All @@ -26,15 +29,16 @@ def delete(path: str) -> Route:


class MockingbirdServer:
def __init__(self, port: int):
self.stubs: Dict[str, Dict[str, Tuple[int, Any]]] = {
def __init__(self, port: int = 8080):
self.stubs: Dict[str, Dict[str, Tuple[int, Any, Optional[Callable[[], Tuple[int, Any]]]]]] = {
"GET": {},
"POST": {},
"PUT": {},
"DELETE": {}
}
self.server = socketserver.TCPServer(("", port), self._handler_factory)
self._thread = None
self._thread = threading.Thread(target=self.server.serve_forever, daemon=True)
atexit.register(self.shutdown)

def _handler_factory(self, *args):
handler = MockingbirdHandler
Expand All @@ -46,20 +50,25 @@ def routes(self, *routes: Route):
route_config = route.build()
method = route_config["method"]
path = route_config["path"]
self.stubs[method][path] = (route_config["status"], route_config["body"])
self.stubs[method][path] = (
route_config["status"],
route_config["body"],
route_config["response_func"]
)
return self

def start(self):
logging.info("MockingbirdServer starting on port %s", self.server.server_address[1])
self._thread = threading.Thread(target=self.server.serve_forever, daemon=True)
self._thread.start()
return self

def shutdown(self):
logging.info("MockingbirdServer shutting down.")
self.server.shutdown()
if self._thread:
if self._thread.is_alive():
self._thread.join()


def mockingbird(port: int) -> MockingbirdServer:
return MockingbirdServer(port)
def mockingbird(port: int = 8080) -> MockingbirdServer:
server = MockingbirdServer(port).start()
return server

10 changes: 8 additions & 2 deletions mockingbird/route.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Any
from typing import Optional, Callable, Tuple, Any, Dict


class Route:
Expand All @@ -7,6 +7,7 @@ def __init__(self, method: str, path: str):
self.path = path
self._body = {}
self._status = 200
self._response_func: Optional[Callable[[], Tuple[int, Any]]] = None

def body(self, response: Dict[str, Any]):
self._body = response
Expand All @@ -16,10 +17,15 @@ def status(self, status_code: int):
self._status = status_code
return self

def response_func(self, func: Callable[[], Tuple[int, Any]]):
self._response_func = func
return self

def build(self):
return {
"method": self.method,
"path": self.path,
"body": self._body,
"status": self._status
"status": self._status,
"response_func": self._response_func
}
13 changes: 13 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytest
import threading

from mockingbird.mockingbird import mockingbird


@pytest.fixture(scope="session")
def mockingbird_server():
server = mockingbird(port=8080)

yield server

server.shutdown()
33 changes: 23 additions & 10 deletions tests/test_mockingbird_get.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,30 @@
import requests
from hamcrest import assert_that, is_
from hamcrest import assert_that, is_, has_entry, equal_to

from mockingbird.mockingbird import mockingbird, get
from mockingbird.mockingbird import get


def test_should_get_status_code():
server = mockingbird(8080).routes(
get("/hello")
.body({"hello": "mockingbird"})
.status(201)
def test_get_default_200_status_code(mockingbird_server):
mockingbird_server.routes(
get("/hello").body({"message": "Hello, World!"})
)
server.start()
response = requests.get('http://localhost:8080/hello')
assert_that(response.status_code, is_(200))

assert_that(response.status_code, is_(201))
server.shutdown()

def test_get_default_application_json_header(mockingbird_server):
mockingbird_server.routes(
get("/hello").body({"message": "Hello, World!"}).status(200)
)
response = requests.get('http://localhost:8080/hello')
assert_that(response.headers, has_entry("Content-Type", "application/json"))


def test_get_json_response(mockingbird_server):
mockingbird_server.routes(
get("/hello").body({"message": "Hello, World!"}).status(200)
)
response = requests.get('http://localhost:8080/hello')
assert_that(response.headers, has_entry("Content-Type", "application/json"))

assert_that(response.json(), equal_to({"message": "Hello, World!"}))

0 comments on commit dbf372f

Please sign in to comment.