Skip to content

Commit

Permalink
Satisfy "mypy"
Browse files Browse the repository at this point in the history
This doesn't really improve code quality or readability but satisfies "mypy" enforced changes.

Signed-off-by: Tobias Wolf <[email protected]>
  • Loading branch information
NotTheEvilOne committed Mar 19, 2024
1 parent 18cce26 commit 6164de7
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 17 deletions.
33 changes: 18 additions & 15 deletions tests/mock_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,35 @@
from collections.abc import Callable
from socket import AF_INET, IPPROTO_TCP, SO_REUSEADDR, SOCK_STREAM, SOL_SOCKET, socket
from threading import Event, RLock
from typing import Optional

from paramiko import (
from paramiko import ( # type: ignore[attr-defined]
AUTH_FAILED,
AUTH_SUCCESSFUL,
OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED,
OPEN_SUCCEEDED,
AutoAddPolicy,
Channel,
PKey,
RSAKey,
ServerInterface,
SSHClient,
Transport,
)


class MockSSHServer(ServerInterface):
"""An ssh server accepting the pre-generated key."""

ssh_username = "pytest"
ssh_key = RSAKey.generate(4096)

def __init__(self) -> None:
self._channel = None
self._client = None
self._command = None
self._channel: Optional[Channel] = None
self._client: Optional[SSHClient] = None
self._command: Optional[bytes] = None
self.event = Event()
self._server_transport = None
self._server_transport: Optional[Transport] = None
self._thread_lock = RLock()

def __del__(self) -> None:
Expand Down Expand Up @@ -71,20 +74,20 @@ def client(self) -> SSHClient:

return self._client

def check_channel_request(self, kind: str, chanid: str) -> int:
def check_channel_request(self, kind: str, chanid: int) -> int:
if kind == "session":
return OPEN_SUCCEEDED
return OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
return OPEN_SUCCEEDED # type: ignore[no-any-return]
return OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED # type: ignore[no-any-return]

def check_auth_password(self, username: str, password: str) -> int:
return AUTH_FAILED
return AUTH_FAILED # type: ignore[no-any-return]

def check_auth_publickey(self, username: str, key: str) -> int:
def check_auth_publickey(self, username: str, key: PKey) -> int:
if username == self.__class__.ssh_username and key == self.__class__.ssh_key:
return AUTH_SUCCESSFUL
return AUTH_FAILED
return AUTH_SUCCESSFUL # type: ignore[no-any-return]
return AUTH_FAILED # type: ignore[no-any-return]

def check_channel_exec_request(self, channel: str, command: str) -> bool:
def check_channel_exec_request(self, channel: Channel, command: bytes) -> bool:
if self.event.is_set():
return False

Expand All @@ -105,11 +108,11 @@ def get_allowed_auths(self, username: str) -> str:
return "publickey"
return ""

def handle_exec_request(self, _callable: Callable[[str, Channel], None]) -> None:
def handle_exec_request(self, _callable: Callable[[bytes, Channel], None]) -> None:
if not callable(_callable):
raise RuntimeError("Handler function given is invalid")

_callable(self._command, self._channel)
_callable(self._command, self._channel) # type: ignore[arg-type]

self._channel = None
self._client = None
Expand Down
5 changes: 3 additions & 2 deletions tests/modules/test_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import pytest

from rookify.modules.example import ExampleHandler
from rookify.modules.example.main import ExampleHandler
from rookify.modules.module import ModuleException

def test_preflight_check():

def test_preflight() -> None:
with pytest.raises(ModuleException):
ExampleHandler({}, {}).preflight_check()

0 comments on commit 6164de7

Please sign in to comment.