Skip to content

Commit

Permalink
Merge pull request #125 from MeRuslan/master
Browse files Browse the repository at this point in the history
Async query
  • Loading branch information
kevinkjt2000 authored May 1, 2021
2 parents f7de90e + 8c646e7 commit 0604bbc
Show file tree
Hide file tree
Showing 5 changed files with 205 additions and 55 deletions.
111 changes: 77 additions & 34 deletions mcstatus/protocol/connection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import socket
import struct
import asyncio
import asyncio_dgram

from ..scripts.address_tools import ip_type

Expand Down Expand Up @@ -126,6 +127,51 @@ def write_buffer(self, buffer):
self.write(data)


class AsyncReadConnection(Connection):
async def read_varint(self):
result = 0
for i in range(5):
part = ord(await self.read(1))
result |= (part & 0x7F) << 7 * i
if not part & 0x80:
return result
raise IOError("Server sent a varint that was too big!")

async def read_utf(self):
length = await self.read_varint()
return self.read(length).decode("utf8")

async def read_ascii(self):
result = bytearray()
while len(result) == 0 or result[-1] != 0:
result.extend(await self.read(1))
return result[:-1].decode("ISO-8859-1")

async def read_short(self):
return self._unpack("h", await self.read(2))

async def read_ushort(self):
return self._unpack("H", await self.read(2))

async def read_int(self):
return self._unpack("i", await self.read(4))

async def read_uint(self):
return self._unpack("I", await self.read(4))

async def read_long(self):
return self._unpack("q", await self.read(8))

async def read_ulong(self):
return self._unpack("Q", await self.read(8))

async def read_buffer(self):
length = await self.read_varint()
result = Connection()
result.receive(await self.read(length))
return result


class TCPSocketConnection(Connection):
def __init__(self, addr, timeout=3):
Connection.__init__(self)
Expand Down Expand Up @@ -194,7 +240,10 @@ def __del__(self):
pass


class TCPAsyncSocketConnection(Connection):
class TCPAsyncSocketConnection(AsyncReadConnection):
reader = None
writer = None

def __init__(self):
super().__init__()

Expand All @@ -214,45 +263,39 @@ async def read(self, length):
def write(self, data):
self.writer.write(data)

async def read_varint(self):
result = 0
for i in range(5):
part = ord(await self.read(1))
result |= (part & 0x7F) << 7 * i
if not part & 0x80:
return result
raise IOError("Server sent a varint that was too big!")

async def read_utf(self):
length = await self.read_varint()
return self.read(length).decode("utf8")
class UDPAsyncSocketConnection(AsyncReadConnection):
stream = None
timeout = None

async def read_ascii(self):
result = bytearray()
while len(result) == 0 or result[-1] != 0:
result.extend(await self.read(1))
return result[:-1].decode("ISO-8859-1")
def __init__(self):
super().__init__()

async def read_short(self):
return self._unpack("h", await self.read(2))
async def connect(self, addr, timeout=3):
self.timeout = timeout
conn = asyncio_dgram.connect((addr[0], addr[1]))
self.stream = await asyncio.wait_for(conn, timeout=self.timeout)

async def read_ushort(self):
return self._unpack("H", await self.read(2))
def flush(self):
raise TypeError("UDPSocketConnection does not support flush()")

async def read_int(self):
return self._unpack("i", await self.read(4))
def receive(self, data):
raise TypeError("UDPSocketConnection does not support receive()")

async def read_uint(self):
return self._unpack("I", await self.read(4))
def remaining(self):
return 65535

async def read_long(self):
return self._unpack("q", await self.read(8))
async def read(self, length):
data, remote_addr = await asyncio.wait_for(self.stream.recv(), timeout=self.timeout)
return data

async def read_ulong(self):
return self._unpack("Q", await self.read(8))
async def write(self, data):
if isinstance(data, Connection):
data = bytearray(data.flush())
await self.stream.send(data)

async def read_buffer(self):
length = await self.read_varint()
result = Connection()
result.receive(await self.read(length))
return result
def __del__(self):
try:
self.stream.close()
except:
pass
60 changes: 43 additions & 17 deletions mcstatus/querier.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,27 +38,29 @@ def read_query(self):
self.connection.write(request)

response = self._read_packet()
response.read(len("splitnum") + 1 + 1 + 1)
data = {}
players = []
return parse_response(response)

while True:
key = response.read_ascii()
if len(key) == 0:
response.read(1)
break
value = response.read_ascii()
data[key] = value

response.read(len("player_") + 1 + 1)
class AsyncServerQuerier(ServerQuerier):
async def _read_packet(self):
packet = Connection()
packet.receive(await self.connection.read(self.connection.remaining()))
packet.read(1 + 4)
return packet

async def handshake(self):
await self.connection.write(self._create_packet(self.PACKET_TYPE_CHALLENGE))

while True:
name = response.read_ascii()
if len(name) == 0:
break
players.append(name)
packet = await self._read_packet()
self.challenge = int(packet.read_ascii())

async def read_query(self):
request = self._create_packet(self.PACKET_TYPE_QUERY)
request.write_uint(0)
await self.connection.write(request)

return QueryResponse(data, players)
response = await self._read_packet()
return parse_response(response)


class QueryResponse:
Expand Down Expand Up @@ -87,3 +89,27 @@ def __init__(self, raw, players):
self.map = raw["map"]
self.players = QueryResponse.Players(raw["numplayers"], raw["maxplayers"], players)
self.software = QueryResponse.Software(raw["version"], raw["plugins"])


def parse_response(response: Connection) -> QueryResponse:
response.read(len("splitnum") + 1 + 1 + 1)
data = {}
players = []

while True:
key = response.read_ascii()
if len(key) == 0:
response.read(1)
break
value = response.read_ascii()
data[key] = value

response.read(len("player_") + 1 + 1)

while True:
name = response.read_ascii()
if len(name) == 0:
break
players.append(name)

return QueryResponse(data, players)
33 changes: 30 additions & 3 deletions mcstatus/server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from mcstatus.pinger import ServerPinger, AsyncServerPinger
from mcstatus.protocol.connection import TCPSocketConnection, UDPSocketConnection, TCPAsyncSocketConnection
from mcstatus.querier import ServerQuerier
from mcstatus.protocol.connection import TCPSocketConnection, UDPSocketConnection, TCPAsyncSocketConnection, \
UDPAsyncSocketConnection
from mcstatus.querier import ServerQuerier, AsyncServerQuerier
from mcstatus.bedrock_status import BedrockServerStatus
from mcstatus.scripts.address_tools import parse_address
import dns.resolver
Expand Down Expand Up @@ -161,7 +162,33 @@ def query(self, tries: int = 3):
raise exception

async def async_query(self, tries: int = 3):
raise NotImplementedError # TODO: '-'
"""Asynchronously checks the status of a Minecraft Java Edition server via the query protocol.
:param int tries: How many times to retry if it fails.
:return: Query status information in a `QueryResponse` instance.
:rtype: QueryResponse
"""

exception = None
host = self.host
try:
answers = dns.resolver.query(host, "A")
if len(answers):
answer = answers[0]
host = str(answer).rstrip(".")
except Exception as e:
pass
for attempt in range(tries):
try:
connection = UDPAsyncSocketConnection()
await connection.connect((host, self.port))
querier = AsyncServerQuerier(connection)
await querier.handshake()
return await querier.read_query()
except Exception as e:
exception = e
else:
raise exception


class MinecraftBedrockServer:
Expand Down
44 changes: 44 additions & 0 deletions mcstatus/tests/test_async_querier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from mcstatus.protocol.connection import Connection
from mcstatus.querier import QueryResponse, AsyncServerQuerier
from mcstatus.tests.test_async_pinger import async_decorator


class FakeUDPAsyncConnection(Connection):
async def read(self, length):
return super().read(length)

async def write(self, data):
return super().write(data)


class TestMinecraftAsyncQuerier:
def setup_method(self):
self.querier = AsyncServerQuerier(FakeUDPAsyncConnection())

def test_handshake(self):
self.querier.connection.receive(bytearray.fromhex("090000000035373033353037373800"))
async_decorator(self.querier.handshake)()
assert self.querier.connection.flush() == bytearray.fromhex("FEFD090000000000000000")
assert self.querier.challenge == 570350778

def test_query(self):
self.querier.connection.receive(
bytearray.fromhex(
"00000000000000000000000000000000686f73746e616d650041204d696e656372616674205365727665720067616d657479706500534d500067616d655f6964004d494e4543524146540076657273696f6e00312e3800706c7567696e7300006d617000776f726c64006e756d706c61796572730033006d6178706c617965727300323000686f7374706f727400323535363500686f73746970003139322e3136382e35362e31000001706c617965725f000044696e6e6572626f6e6500446a696e6e69626f6e650053746576650000"
)
)
response = async_decorator(self.querier.read_query)()
assert self.querier.connection.flush() == bytearray.fromhex("FEFD00000000000000000000000000")
assert response.raw == {
"hostname": "A Minecraft Server",
"gametype": "SMP",
"game_id": "MINECRAFT",
"version": "1.8",
"plugins": "",
"map": "world",
"numplayers": "3",
"maxplayers": "20",
"hostport": "25565",
"hostip": "192.168.56.1",
}
assert response.players.names == ["Dinnerbone", "Djinnibone", "Steve"]
12 changes: 11 additions & 1 deletion mcstatus/tests/test_async_support.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from inspect import iscoroutinefunction

from mcstatus.protocol.connection import TCPAsyncSocketConnection
from mcstatus.protocol.connection import TCPAsyncSocketConnection, UDPAsyncSocketConnection


def test_is_completely_asynchronous():
Expand All @@ -11,3 +11,13 @@ def test_is_completely_asynchronous():
assert iscoroutinefunction(conn.__getattribute__(attribute))
assertions += 1
assert assertions > 0, "None of the read_* attributes were async"


def test_query_is_completely_asynchronous():
conn = UDPAsyncSocketConnection()
assertions = 0
for attribute in dir(conn):
if attribute.startswith("read_"):
assert iscoroutinefunction(conn.__getattribute__(attribute))
assertions += 1
assert assertions > 0, "None of the read_* attributes were async"

0 comments on commit 0604bbc

Please sign in to comment.