From 019c7a7d0642024b2dc875cc3ed546f3ae6c74e5 Mon Sep 17 00:00:00 2001 From: Rus Date: Fri, 30 Apr 2021 16:49:02 +0300 Subject: [PATCH 1/4] async udp query support --- mcstatus/protocol/connection.py | 109 ++++++++++++++++++--------- mcstatus/querier.py | 60 ++++++++++----- mcstatus/server.py | 33 +++++++- mcstatus/tests/test_async_querier.py | 85 +++++++++++++++++++++ 4 files changed, 232 insertions(+), 55 deletions(-) create mode 100644 mcstatus/tests/test_async_querier.py diff --git a/mcstatus/protocol/connection.py b/mcstatus/protocol/connection.py index 15f06ab..2abba8e 100644 --- a/mcstatus/protocol/connection.py +++ b/mcstatus/protocol/connection.py @@ -1,6 +1,7 @@ import socket import struct import asyncio +import asyncio_dgram from ..scripts.address_tools import ip_type @@ -126,6 +127,51 @@ def write_buffer(self, buffer): self.write(data) +class AsyncReadConnection(Connection): + async def read_varint(self): + result = 0 + for i in range(5): + part = ord(await self.read(1)) + result |= (part & 0x7F) << 7 * i + if not part & 0x80: + return result + raise IOError("Server sent a varint that was too big!") + + async def read_utf(self): + length = await self.read_varint() + return self.read(length).decode("utf8") + + async def read_ascii(self): + result = bytearray() + while len(result) == 0 or result[-1] != 0: + result.extend(await self.read(1)) + return result[:-1].decode("ISO-8859-1") + + async def read_short(self): + return self._unpack("h", await self.read(2)) + + async def read_ushort(self): + return self._unpack("H", await self.read(2)) + + async def read_int(self): + return self._unpack("i", await self.read(4)) + + async def read_uint(self): + return self._unpack("I", await self.read(4)) + + async def read_long(self): + return self._unpack("q", await self.read(8)) + + async def read_ulong(self): + return self._unpack("Q", await self.read(8)) + + async def read_buffer(self): + length = await self.read_varint() + result = Connection() + result.receive(await self.read(length)) + return result + + class TCPSocketConnection(Connection): def __init__(self, addr, timeout=3): Connection.__init__(self) @@ -194,7 +240,10 @@ def __del__(self): pass -class TCPAsyncSocketConnection(Connection): +class TCPAsyncSocketConnection(AsyncReadConnection): + reader = None + writer = None + def __init__(self): super().__init__() @@ -214,45 +263,35 @@ async def read(self, length): def write(self, data): self.writer.write(data) - async def read_varint(self): - result = 0 - for i in range(5): - part = ord(await self.read(1)) - result |= (part & 0x7F) << 7 * i - if not part & 0x80: - return result - raise IOError("Server sent a varint that was too big!") - - async def read_utf(self): - length = await self.read_varint() - return self.read(length).decode("utf8") - async def read_ascii(self): - result = bytearray() - while len(result) == 0 or result[-1] != 0: - result.extend(await self.read(1)) - return result[:-1].decode("ISO-8859-1") +class UDPAsyncSocketConnection(AsyncReadConnection): + def __init__(self): + super().__init__() + self.stream = None - async def read_short(self): - return self._unpack("h", await self.read(2)) + async def connect(self, addr, timeout=3): + self.stream = await asyncio_dgram.connect((addr[0], addr[1])) - async def read_ushort(self): - return self._unpack("H", await self.read(2)) + def flush(self): + raise TypeError("UDPSocketConnection does not support flush()") - async def read_int(self): - return self._unpack("i", await self.read(4)) + def receive(self, data): + raise TypeError("UDPSocketConnection does not support receive()") - async def read_uint(self): - return self._unpack("I", await self.read(4)) + def remaining(self): + return 65535 - async def read_long(self): - return self._unpack("q", await self.read(8)) + async def read(self, length): + data, remote_addr = await self.stream.recv() + return data - async def read_ulong(self): - return self._unpack("Q", await self.read(8)) + async def write(self, data): + if isinstance(data, Connection): + data = bytearray(data.flush()) + await self.stream.send(data) - async def read_buffer(self): - length = await self.read_varint() - result = Connection() - result.receive(await self.read(length)) - return result + def __del__(self): + try: + self.stream.close() + except: + pass diff --git a/mcstatus/querier.py b/mcstatus/querier.py index 4a55829..0807c10 100644 --- a/mcstatus/querier.py +++ b/mcstatus/querier.py @@ -38,27 +38,29 @@ def read_query(self): self.connection.write(request) response = self._read_packet() - response.read(len("splitnum") + 1 + 1 + 1) - data = {} - players = [] + return parse_response(response) - while True: - key = response.read_ascii() - if len(key) == 0: - response.read(1) - break - value = response.read_ascii() - data[key] = value - response.read(len("player_") + 1 + 1) +class AsyncServerQuerier(ServerQuerier): + async def _read_packet(self): + packet = Connection() + packet.receive(await self.connection.read(self.connection.remaining())) + packet.read(1 + 4) + return packet + + async def handshake(self): + await self.connection.write(self._create_packet(self.PACKET_TYPE_CHALLENGE)) - while True: - name = response.read_ascii() - if len(name) == 0: - break - players.append(name) + packet = await self._read_packet() + self.challenge = int(packet.read_ascii()) + + async def read_query(self): + request = self._create_packet(self.PACKET_TYPE_QUERY) + request.write_uint(0) + await self.connection.write(request) - return QueryResponse(data, players) + response = await self._read_packet() + return parse_response(response) class QueryResponse: @@ -87,3 +89,27 @@ def __init__(self, raw, players): self.map = raw["map"] self.players = QueryResponse.Players(raw["numplayers"], raw["maxplayers"], players) self.software = QueryResponse.Software(raw["version"], raw["plugins"]) + + +def parse_response(response: Connection) -> QueryResponse: + response.read(len("splitnum") + 1 + 1 + 1) + data = {} + players = [] + + while True: + key = response.read_ascii() + if len(key) == 0: + response.read(1) + break + value = response.read_ascii() + data[key] = value + + response.read(len("player_") + 1 + 1) + + while True: + name = response.read_ascii() + if len(name) == 0: + break + players.append(name) + + return QueryResponse(data, players) diff --git a/mcstatus/server.py b/mcstatus/server.py index c2292ba..6c8c934 100644 --- a/mcstatus/server.py +++ b/mcstatus/server.py @@ -1,6 +1,7 @@ from mcstatus.pinger import ServerPinger, AsyncServerPinger -from mcstatus.protocol.connection import TCPSocketConnection, UDPSocketConnection, TCPAsyncSocketConnection -from mcstatus.querier import ServerQuerier +from mcstatus.protocol.connection import TCPSocketConnection, UDPSocketConnection, TCPAsyncSocketConnection, \ + UDPAsyncSocketConnection +from mcstatus.querier import ServerQuerier, AsyncServerQuerier from mcstatus.bedrock_status import BedrockServerStatus from mcstatus.scripts.address_tools import parse_address import dns.resolver @@ -161,7 +162,33 @@ def query(self, tries: int = 3): raise exception async def async_query(self, tries: int = 3): - raise NotImplementedError # TODO: '-' + """Asynchronously checks the status of a Minecraft Java Edition server via the query protocol. + + :param int tries: How many times to retry if it fails. + :return: Query status information in a `QueryResponse` instance. + :rtype: QueryResponse + """ + + exception = None + host = self.host + try: + answers = dns.resolver.query(host, "A") + if len(answers): + answer = answers[0] + host = str(answer).rstrip(".") + except Exception as e: + pass + for attempt in range(tries): + try: + connection = UDPAsyncSocketConnection() + await connection.connect((host, self.port)) + querier = AsyncServerQuerier(connection) + await querier.handshake() + return await querier.read_query() + except Exception as e: + exception = e + else: + raise exception class MinecraftBedrockServer: diff --git a/mcstatus/tests/test_async_querier.py b/mcstatus/tests/test_async_querier.py new file mode 100644 index 0000000..db1540c --- /dev/null +++ b/mcstatus/tests/test_async_querier.py @@ -0,0 +1,85 @@ +from mcstatus.protocol.connection import Connection +from mcstatus.querier import ServerQuerier, QueryResponse + + +class TestMinecraftQuerier: + def setup_method(self): + self.querier = ServerQuerier(Connection()) + + def test_handshake(self): + self.querier.connection.receive(bytearray.fromhex("090000000035373033353037373800")) + self.querier.handshake() + assert self.querier.connection.flush() == bytearray.fromhex("FEFD090000000000000000") + assert self.querier.challenge == 570350778 + + def test_query(self): + self.querier.connection.receive( + bytearray.fromhex( + "00000000000000000000000000000000686f73746e616d650041204d696e656372616674205365727665720067616d657479706500534d500067616d655f6964004d494e4543524146540076657273696f6e00312e3800706c7567696e7300006d617000776f726c64006e756d706c61796572730033006d6178706c617965727300323000686f7374706f727400323535363500686f73746970003139322e3136382e35362e31000001706c617965725f000044696e6e6572626f6e6500446a696e6e69626f6e650053746576650000" + ) + ) + response = self.querier.read_query() + assert self.querier.connection.flush() == bytearray.fromhex("FEFD00000000000000000000000000") + assert response.raw == { + "hostname": "A Minecraft Server", + "gametype": "SMP", + "game_id": "MINECRAFT", + "version": "1.8", + "plugins": "", + "map": "world", + "numplayers": "3", + "maxplayers": "20", + "hostport": "25565", + "hostip": "192.168.56.1", + } + assert response.players.names == ["Dinnerbone", "Djinnibone", "Steve"] + + +class TestQueryResponse: + def setup_method(self): + self.raw = { + "hostname": "A Minecraft Server", + "gametype": "SMP", + "game_id": "MINECRAFT", + "version": "1.8", + "plugins": "", + "map": "world", + "numplayers": "3", + "maxplayers": "20", + "hostport": "25565", + "hostip": "192.168.56.1", + } + self.players = ["Dinnerbone", "Djinnibone", "Steve"] + + def test_valid(self): + response = QueryResponse(self.raw, self.players) + assert response.motd == "A Minecraft Server" + assert response.map == "world" + assert response.players.online == 3 + assert response.players.max == 20 + assert response.players.names == ["Dinnerbone", "Djinnibone", "Steve"] + assert response.software.brand == "vanilla" + assert response.software.version == "1.8" + assert response.software.plugins == [] + + def test_valid(self): + players = QueryResponse.Players(5, 20, ["Dinnerbone", "Djinnibone", "Steve"]) + assert players.online == 5 + assert players.max == 20 + assert players.names == ["Dinnerbone", "Djinnibone", "Steve"] + + def test_vanilla(self): + software = QueryResponse.Software("1.8", "") + assert software.brand == "vanilla" + assert software.version == "1.8" + assert software.plugins == [] + + def test_modded(self): + software = QueryResponse.Software("1.8", "A modded server: Foo 1.0; Bar 2.0; Baz 3.0") + assert software.brand == "A modded server" + assert software.plugins == ["Foo 1.0", "Bar 2.0", "Baz 3.0"] + + def test_modded_no_plugins(self): + software = QueryResponse.Software("1.8", "A modded server") + assert software.brand == "A modded server" + assert software.plugins == [] From 21a750964ddd5cc6426fe342668e45e891b6272d Mon Sep 17 00:00:00 2001 From: Rus Date: Fri, 30 Apr 2021 16:50:48 +0300 Subject: [PATCH 2/4] async tests --- mcstatus/tests/test_async_querier.py | 17 +++++++++++++---- mcstatus/tests/test_async_support.py | 12 +++++++++++- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/mcstatus/tests/test_async_querier.py b/mcstatus/tests/test_async_querier.py index db1540c..bbe1f7a 100644 --- a/mcstatus/tests/test_async_querier.py +++ b/mcstatus/tests/test_async_querier.py @@ -1,14 +1,23 @@ from mcstatus.protocol.connection import Connection -from mcstatus.querier import ServerQuerier, QueryResponse +from mcstatus.querier import QueryResponse, AsyncServerQuerier +from mcstatus.tests.test_async_pinger import async_decorator + + +class FakeUDPAsyncConnection(Connection): + async def read(self, length): + return super().read(length) + + async def write(self, data): + return super().write(data) class TestMinecraftQuerier: def setup_method(self): - self.querier = ServerQuerier(Connection()) + self.querier = AsyncServerQuerier(FakeUDPAsyncConnection()) def test_handshake(self): self.querier.connection.receive(bytearray.fromhex("090000000035373033353037373800")) - self.querier.handshake() + async_decorator(self.querier.handshake)() assert self.querier.connection.flush() == bytearray.fromhex("FEFD090000000000000000") assert self.querier.challenge == 570350778 @@ -18,7 +27,7 @@ def test_query(self): "00000000000000000000000000000000686f73746e616d650041204d696e656372616674205365727665720067616d657479706500534d500067616d655f6964004d494e4543524146540076657273696f6e00312e3800706c7567696e7300006d617000776f726c64006e756d706c61796572730033006d6178706c617965727300323000686f7374706f727400323535363500686f73746970003139322e3136382e35362e31000001706c617965725f000044696e6e6572626f6e6500446a696e6e69626f6e650053746576650000" ) ) - response = self.querier.read_query() + response = async_decorator(self.querier.read_query)() assert self.querier.connection.flush() == bytearray.fromhex("FEFD00000000000000000000000000") assert response.raw == { "hostname": "A Minecraft Server", diff --git a/mcstatus/tests/test_async_support.py b/mcstatus/tests/test_async_support.py index 3bfc79a..fda92a3 100644 --- a/mcstatus/tests/test_async_support.py +++ b/mcstatus/tests/test_async_support.py @@ -1,6 +1,6 @@ from inspect import iscoroutinefunction -from mcstatus.protocol.connection import TCPAsyncSocketConnection +from mcstatus.protocol.connection import TCPAsyncSocketConnection, UDPAsyncSocketConnection def test_is_completely_asynchronous(): @@ -11,3 +11,13 @@ def test_is_completely_asynchronous(): assert iscoroutinefunction(conn.__getattribute__(attribute)) assertions += 1 assert assertions > 0, "None of the read_* attributes were async" + + +def test_query_is_completely_asynchronous(): + conn = UDPAsyncSocketConnection() + assertions = 0 + for attribute in dir(conn): + if attribute.startswith("read_"): + assert iscoroutinefunction(conn.__getattribute__(attribute)) + assertions += 1 + assert assertions > 0, "None of the read_* attributes were async" From 51561e85869ac0298e900f5a198a925707aabc9b Mon Sep 17 00:00:00 2001 From: Rus Date: Fri, 30 Apr 2021 17:22:20 +0300 Subject: [PATCH 3/4] async query timeout --- mcstatus/protocol/connection.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/mcstatus/protocol/connection.py b/mcstatus/protocol/connection.py index 2abba8e..9563427 100644 --- a/mcstatus/protocol/connection.py +++ b/mcstatus/protocol/connection.py @@ -265,12 +265,16 @@ def write(self, data): class UDPAsyncSocketConnection(AsyncReadConnection): + stream = None + timeout = None + def __init__(self): super().__init__() - self.stream = None async def connect(self, addr, timeout=3): - self.stream = await asyncio_dgram.connect((addr[0], addr[1])) + self.timeout = timeout + conn = asyncio_dgram.connect((addr[0], addr[1])) + self.stream = await asyncio.wait_for(conn, timeout=self.timeout) def flush(self): raise TypeError("UDPSocketConnection does not support flush()") @@ -282,7 +286,7 @@ def remaining(self): return 65535 async def read(self, length): - data, remote_addr = await self.stream.recv() + data, remote_addr = await asyncio.wait_for(self.stream.recv(), timeout=self.timeout) return data async def write(self, data): From 8c646e70e2b6aff89d5e38a6e0b1976f4899be92 Mon Sep 17 00:00:00 2001 From: Rus Date: Fri, 30 Apr 2021 21:36:27 +0300 Subject: [PATCH 4/4] accidentally duplicated some extra code --- mcstatus/tests/test_async_querier.py | 52 +--------------------------- 1 file changed, 1 insertion(+), 51 deletions(-) diff --git a/mcstatus/tests/test_async_querier.py b/mcstatus/tests/test_async_querier.py index bbe1f7a..cf54c94 100644 --- a/mcstatus/tests/test_async_querier.py +++ b/mcstatus/tests/test_async_querier.py @@ -11,7 +11,7 @@ async def write(self, data): return super().write(data) -class TestMinecraftQuerier: +class TestMinecraftAsyncQuerier: def setup_method(self): self.querier = AsyncServerQuerier(FakeUDPAsyncConnection()) @@ -42,53 +42,3 @@ def test_query(self): "hostip": "192.168.56.1", } assert response.players.names == ["Dinnerbone", "Djinnibone", "Steve"] - - -class TestQueryResponse: - def setup_method(self): - self.raw = { - "hostname": "A Minecraft Server", - "gametype": "SMP", - "game_id": "MINECRAFT", - "version": "1.8", - "plugins": "", - "map": "world", - "numplayers": "3", - "maxplayers": "20", - "hostport": "25565", - "hostip": "192.168.56.1", - } - self.players = ["Dinnerbone", "Djinnibone", "Steve"] - - def test_valid(self): - response = QueryResponse(self.raw, self.players) - assert response.motd == "A Minecraft Server" - assert response.map == "world" - assert response.players.online == 3 - assert response.players.max == 20 - assert response.players.names == ["Dinnerbone", "Djinnibone", "Steve"] - assert response.software.brand == "vanilla" - assert response.software.version == "1.8" - assert response.software.plugins == [] - - def test_valid(self): - players = QueryResponse.Players(5, 20, ["Dinnerbone", "Djinnibone", "Steve"]) - assert players.online == 5 - assert players.max == 20 - assert players.names == ["Dinnerbone", "Djinnibone", "Steve"] - - def test_vanilla(self): - software = QueryResponse.Software("1.8", "") - assert software.brand == "vanilla" - assert software.version == "1.8" - assert software.plugins == [] - - def test_modded(self): - software = QueryResponse.Software("1.8", "A modded server: Foo 1.0; Bar 2.0; Baz 3.0") - assert software.brand == "A modded server" - assert software.plugins == ["Foo 1.0", "Bar 2.0", "Baz 3.0"] - - def test_modded_no_plugins(self): - software = QueryResponse.Software("1.8", "A modded server") - assert software.brand == "A modded server" - assert software.plugins == []