Skip to content

Commit

Permalink
Get tests running
Browse files Browse the repository at this point in the history
  • Loading branch information
Moosems committed Sep 6, 2024
1 parent 980a1a4 commit ca3dadf
Show file tree
Hide file tree
Showing 9 changed files with 24 additions and 40 deletions.
4 changes: 2 additions & 2 deletions collegamento/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

beartype_this_package()

from .files_variant import FileClient, FileServer # noqa: F401, E402
from .simple_client_server import ( # noqa: F401, E402
from .client_server import ( # noqa: F401, E402
COMMANDS_MAPPING,
USER_FUNCTION,
Client,
Expand All @@ -14,3 +13,4 @@
ResponseQueueType,
Server,
)
from .files_variant import FileClient, FileServer # noqa: F401, E402
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@
from time import sleep

from .server import Server
from .utils import USER_FUNCTION, CollegamentoError, Request, Response
from .utils import (
COMMANDS_MAPPING,
USER_FUNCTION,
CollegamentoError,
Request,
Response,
)


class Client:
Expand All @@ -32,7 +38,7 @@ class Client:

def __init__(
self,
commands: dict[str, USER_FUNCTION | tuple[USER_FUNCTION, bool]] = {},
commands: COMMANDS_MAPPING = {},
id_max: int = 15_000,
server_type: type = Server,
) -> None:
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
33 changes: 8 additions & 25 deletions collegamento/files_variant.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# TODO: Actually fix this section of the package
from logging import Logger
from typing import NotRequired

from .simple_client_server import (
from .client_server import (
COMMANDS_MAPPING,
USER_FUNCTION,
Client,
CollegamentoError,
Expand All @@ -22,13 +22,10 @@ def update_files(server: "FileServer", request: Request) -> None:
file: str = request["file"] # type: ignore

if request["remove"]: # type: ignore
server.logger.info(f"File {file} was requested for removal")
server.files.pop(file)
server.logger.info(f"File {file} has been removed")
else:
contents: str = request["contents"] # type: ignore
server.files[file] = contents
server.logger.info(f"File {file} has been updated with new contents")


class FileClient(Client):
Expand All @@ -38,7 +35,7 @@ class FileClient(Client):
"""

def __init__(
self, commands: dict[str, USER_FUNCTION], id_max: int = 15_000
self, commands: COMMANDS_MAPPING, id_max: int = 15_000
) -> None:
self.files: dict[str, str] = {}

Expand All @@ -51,12 +48,10 @@ def create_server(self) -> None:

super().create_server()

self.logger.info("Copying files to server")
files_copy = self.files.copy()
self.files = {}
for file, data in files_copy.items():
self.update_file(file, data)
self.logger.debug("Finished copying files to server")

def request(
self,
Expand All @@ -65,10 +60,7 @@ def request(
if "file" in request_details:
file = request_details["file"]
if file not in self.files:
self.logger.exception(
f"File {file} not in files! Files are {self.files.keys()}"
)
raise Exception(
raise CollegamentoError(
f"File {file} not in files! Files are {self.files.keys()}"
)

Expand All @@ -77,37 +69,30 @@ def request(
def update_file(self, file: str, current_state: str) -> None:
"""Updates files in the system - external API"""

self.logger.info(f"Updating file: {file}")
self.files[file] = current_state

self.logger.debug("Creating notification dict")
file_notification: dict = {
"command": "FileNotification",
"file": file,
"remove": False,
"contents": self.files[file],
}

self.logger.debug("Notifying server of file update")
super().request(file_notification)

def remove_file(self, file: str) -> None:
"""Removes a file from the main_server - external API"""
if file not in list(self.files.keys()):
self.logger.exception(
f"Cannot remove file {file} as file is not in file database!"
)
raise CollegamentoError(
f"Cannot remove file {file} as file is not in file database!"
)

self.logger.info("Notifying server of file deletion")
file_notification: dict = {
"command": "FileNotification",
"file": file,
"remove": True,
}
self.logger.debug("Notifying server of file removal")

super().request(file_notification)


Expand All @@ -116,18 +101,16 @@ class FileServer(Server):

def __init__(
self,
commands: dict[str, USER_FUNCTION],
response_queue: ResponseQueueType,
commands: dict[str, tuple[USER_FUNCTION, bool]],
requests_queue: RequestQueueType,
logger: Logger,
response_queue: ResponseQueueType,
) -> None:
self.files: dict[str, str] = {}

super().__init__(
commands,
response_queue,
requests_queue,
logger,
response_queue,
["FileNotification"],
)

Expand Down
11 changes: 5 additions & 6 deletions tests/test_file_variant.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,25 @@ def split_str(server: FileServer, arg: Request) -> list[str]:


def test_file_variants():
commands: dict[str, USER_FUNCTION] = {"test": func}
context = FileClient(commands)
context = FileClient({"test": func})

context.update_file("test", "test contents")
context.request({"command": "test"})

sleep(1)

output: Response | None = context.get_response("test")
output: list[Response] = context.get_response("test")
assert output is not None # noqa: E711
assert output["result"] is True # noqa: E712 # type: ignore
assert output[0]["result"] is True # noqa: E712 # type: ignore

context.add_command("test1", split_str)
context.request({"command": "test1", "file": "test"})

sleep(1)

output: Response | None = context.get_response("test1")
output = context.get_response("test1")
assert output is not None # noqa: E711
assert output["result"] == ["test", "contents"] # noqa: E712 # type: ignore
assert output[0]["result"] == ["test", "contents"] # noqa: E712 # type: ignore

assert context.all_ids == []

Expand Down
6 changes: 1 addition & 5 deletions tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def foo(server, request):
print("Foo called", request["id"])


def main():
def test_normal_client():
Client({"foo": foo})
x = Client({"foo": (foo, True), "foo2": foo})

Expand Down Expand Up @@ -47,7 +47,3 @@ def main():
Client().create_server()

sleep(1)


if __name__ == "__main__":
main()

0 comments on commit ca3dadf

Please sign in to comment.