Skip to content

Commit

Permalink
Add initial testing foundation
Browse files Browse the repository at this point in the history
Signed-off-by: Tobias Wolf <[email protected]>
  • Loading branch information
NotTheEvilOne committed Mar 26, 2024
1 parent d92f3b6 commit b441597
Show file tree
Hide file tree
Showing 10 changed files with 276 additions and 0 deletions.
1 change: 1 addition & 0 deletions .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ jobs:
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
- run: python -m pip install .[tests]
- uses: pre-commit/[email protected]
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ repos:
args: [--strict, --ignore-missing-imports, --check-untyped-defs]
additional_dependencies:
- types-PyYAML
- types-paramiko==3.4.0.*
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,6 @@ classifiers = [

[project.urls]
Homepage = "https://scs.community"

[tool.pytest.ini_options]
pythonpath = [ "src" ]
4 changes: 4 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@ platforms = any

[options]
install_requires=file:requirements.txt

[options.extras_require]
tests =
pytest==8.0.2
41 changes: 41 additions & 0 deletions tests/mock_ceph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# -*- coding: utf-8 -*-

import json
from collections.abc import Callable
from rookify.modules.module import ModuleException
from threading import RLock
from typing import Any, Dict, List, Optional, Tuple


class MockCeph(object):
def __init__(self, config: Dict[str, Any]):
self._callback_handler: Optional[
Callable[[str, bytes], Tuple[int, bytes, str]]
] = None
self._thread_lock = RLock()

def handle_with_callback(
self, _callable: Callable[[str, bytes], Tuple[int, bytes, str]]
) -> None:
with self._thread_lock:
if self._callback_handler is not None:
raise RuntimeError("Callback handler already registered")

self._callback_handler = _callable

def mon_command(
self, command: str, inbuf: bytes, **kwargs: Any
) -> Dict[str, Any] | List[Any]:
if not callable(self._callback_handler):
raise RuntimeError("Handler function given is invalid")

ret, outbuf, outstr = self._callback_handler(command, inbuf, **kwargs)
if ret != 0:
raise ModuleException("Ceph did return an error: {0!r}".format(outbuf))

data = json.loads(outbuf)
assert isinstance(data, dict) or isinstance(data, list)
return data

def stop_handler(self) -> None:
self._callback_handler = None
147 changes: 147 additions & 0 deletions tests/mock_ssh_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# -*- coding: utf-8 -*-


from collections.abc import Callable
from socket import AF_INET, IPPROTO_TCP, SO_REUSEADDR, SOCK_STREAM, SOL_SOCKET, socket
from threading import Event, RLock
from typing import Any, Optional

from paramiko import ( # type: ignore[attr-defined]
AUTH_FAILED,
AUTH_SUCCESSFUL,
OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED,
OPEN_SUCCEEDED,
AutoAddPolicy,
Channel,
PKey,
RSAKey,
ServerInterface,
SSHClient,
Transport,
)


class MockSSHServer(ServerInterface):
"""An ssh server accepting the pre-generated key."""

ssh_username = "pytest"
ssh_key = RSAKey.generate(4096)

def __init__(self) -> None:
ServerInterface.__init__(self)

self._callback_handler: Optional[Callable[[bytes, Channel], None]] = None
self._channel: Any = None
self._client: Optional[SSHClient] = None
self._command: Optional[bytes] = None
self.event = Event()
self._server_transport: Optional[Transport] = None
self._thread_lock = RLock()

def __del__(self) -> None:
self.close()

@property
def client(self) -> SSHClient:
with self._thread_lock:
if self._client is None:
connection_event = Event()

server_socket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)
server_socket.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
server_socket.bind(("127.0.0.1", 0))
server_socket.listen()

server_address = server_socket.getsockname()

client_socket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)
client_socket.connect(server_address)

(transport_socket, _) = server_socket.accept()

self._server_transport = Transport(transport_socket)
self._server_transport.add_server_key(self.__class__.ssh_key)
self._server_transport.start_server(connection_event, self)

self._client = SSHClient()
self._client.set_missing_host_key_policy(AutoAddPolicy())

self._client.connect(
server_address[0],
server_address[1],
username=self.__class__.ssh_username,
pkey=self.__class__.ssh_key,
sock=client_socket,
)

connection_event.wait()

return self._client

def check_channel_request(self, kind: str, chanid: int) -> int:
if kind == "session":
return OPEN_SUCCEEDED # type: ignore[no-any-return]
return OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED # type: ignore[no-any-return]

def check_auth_password(self, username: str, password: str) -> int:
return AUTH_FAILED # type: ignore[no-any-return]

def check_auth_publickey(self, username: str, key: PKey) -> int:
if username == self.__class__.ssh_username and key == self.__class__.ssh_key:
return AUTH_SUCCESSFUL # type: ignore[no-any-return]
return AUTH_FAILED # type: ignore[no-any-return]

def check_channel_exec_request(self, channel: Channel, command: bytes) -> bool:
if self.event.is_set():
return False

self.event.set()

with self._thread_lock:
self._channel = channel
self._command = command

if self._callback_handler is not None:
self.handle_exec_request(self._callback_handler)

return True

def close(self) -> None:
self.stop_exec_requests_handler()

if self._server_transport is not None:
self._server_transport.close()
self._server_transport = None

def get_allowed_auths(self, username: str) -> str:
if username == self.__class__.ssh_username:
return "publickey"
return ""

def handle_exec_request(self, _callable: Callable[[bytes, Channel], None]) -> None:
if not callable(_callable):
raise RuntimeError("Handler function given is invalid")

_callable(self._command, self._channel) # type: ignore[arg-type]

if self._channel.recv_ready() is not True:
self._channel.send(
bytes("Command {0!r} invalid\n".format(self._command), "utf-8")
)

self._channel = None
self._client = None

self.event.clear()

def handle_exec_requests_with_callback(
self, _callable: Callable[[bytes, Channel], None]
) -> None:
with self._thread_lock:
if self._callback_handler is not None:
raise RuntimeError("Callback handler already registered")

self._callback_handler = _callable

def stop_exec_requests_handler(self) -> None:
self._callback_handler = None
Empty file added tests/modules/__init__.py
Empty file.
11 changes: 11 additions & 0 deletions tests/modules/test_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# -*- coding: utf-8 -*-

import pytest

from rookify.modules.example.main import ExampleHandler
from rookify.modules.module import ModuleException


def test_preflight() -> None:
with pytest.raises(ModuleException):
ExampleHandler({}, {}, "").preflight()
31 changes: 31 additions & 0 deletions tests/test_mock_ceph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# -*- coding: utf-8 -*-

from typing import Any, Dict, Tuple
from unittest import TestCase

from .mock_ceph import MockCeph


class TestMockCeph(TestCase):
ceph: Any = None

@classmethod
def setUpClass(cls) -> None:
cls.ceph = MockCeph({})

def setUp(self) -> None:
self.__class__.ceph.handle_with_callback(self._command_callback)

def tearDown(self) -> None:
self.__class__.ceph.stop_handler()

def _command_callback(
self, command: str, inbuf: bytes, **kwargs: Dict[Any, Any]
) -> Tuple[int, bytes, str]:
if command == "test":
return 0, b'["ok"]', ""
return -1, b'["Command not found"]', ""

def test_self(self) -> None:
res = self.__class__.ceph.mon_command("test", b"")
self.assertEqual(res, ["ok"])
37 changes: 37 additions & 0 deletions tests/test_mock_ssh_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# -*- coding: utf-8 -*-

from paramiko import Channel
from typing import Any
from unittest import TestCase

from .mock_ssh_server import MockSSHServer


class TestMockSSHServer(TestCase):
ssh_client: Any = None
ssh_server: Any = None

@classmethod
def setUpClass(cls) -> None:
cls.ssh_server = MockSSHServer()
cls.ssh_client = cls.ssh_server.client

@classmethod
def tearDownClass(cls) -> None:
cls.ssh_server.close()

def setUp(self) -> None:
self.__class__.ssh_server.handle_exec_requests_with_callback(
self._command_callback
)

def tearDown(self) -> None:
self.__class__.ssh_server.stop_exec_requests_handler()

def _command_callback(self, command: bytes, channel: Channel) -> None:
if command == b"test":
channel.send(b"ok\n")

def test_self(self) -> None:
_, stdout, _ = self.__class__.ssh_client.exec_command("test")
self.assertEqual(stdout.readline(), "ok\n")

0 comments on commit b441597

Please sign in to comment.