Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expand Client with a method for sending arbitrary commands. #395

Merged
merged 1 commit into from
Jun 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 72 additions & 2 deletions pymemcache/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import errno
from functools import partial
import platform
import socket
from typing import Tuple, Union
Expand Down Expand Up @@ -867,6 +868,27 @@ def version(self):
raise MemcacheUnknownError("Received unexpected response: %s" % results[0])
return after

def raw_command(self, command, end_tokens="\r\n"):
"""
Sends an arbitrary command to the server and parses the response until a
specified token is encountered.

Args:
command: str|bytes: The command to send.
end_tokens: str|bytes: The token expected at the end of the
response. If the `end_token` is not found, the client will wait
until the timeout specified in the constructor.

Returns:
The response from the server, with the `end_token` removed.
"""
encoding = "utf8" if self.allow_unicode_keys else "ascii"
command = command.encode(encoding) if isinstance(command, str) else command
end_tokens = (
end_tokens.encode(encoding) if isinstance(end_tokens, str) else end_tokens
)
return self._misc_cmd([b"" + command + b"\r\n"], command, False, end_tokens)[0]

def flush_all(self, delay=0, noreply=None):
"""
The memcached "flush_all" command.
Expand Down Expand Up @@ -1126,7 +1148,15 @@ def _store_cmd(self, name, values, expire, noreply, flags=None, cas=None):
self.close()
raise

def _misc_cmd(self, cmds, cmd_name, noreply):
def _misc_cmd(self, cmds, cmd_name, noreply, end_tokens=None):

# If no end_tokens have been given, just assume standard memcached
# operations, which end in "\r\n", use regular code for that.
if end_tokens:
_reader = partial(_readsegment, end_tokens=end_tokens)
else:
_reader = _readline

if self.sock is None:
self._connect()

Expand All @@ -1141,7 +1171,7 @@ def _misc_cmd(self, cmds, cmd_name, noreply):
line = None
for cmd in cmds:
try:
buf, line = _readline(self.sock, buf)
buf, line = _reader(self.sock, buf)
except MemcacheUnexpectedCloseError:
self.close()
raise
Expand Down Expand Up @@ -1396,6 +1426,10 @@ def shutdown(self, graceful=False):
with self.client_pool.get_and_release(destroy_on_fail=True) as client:
client.shutdown(graceful)

def raw_command(self, command, end_tokens=b"\r\n"):
with self.client_pool.get_and_release(destroy_on_fail=True) as client:
return client.raw_command(command, end_tokens)

def __setitem__(self, key, value):
self.set(key, value, noreply=True)

Expand Down Expand Up @@ -1505,6 +1539,42 @@ def _readvalue(sock, buf, size):
return buf[rlen:], b"".join(chunks)


def _readsegment(sock, buf, end_tokens):
"""Read a segment from the socket.

Read a segment from the socket, up to the first end_token sub-string/bytes,
and return that segment.

Args:
sock: Socket object, should be connected.
buf: bytes, zero or more bytes, returned from an earlier
call to _readline, _readsegment or _readvalue (pass an empty
byte-string on the first call).
end_tokens: bytes, indicates the end of the segment, generally this is
b"\\r\\n" for memcached.

Returns:
A tuple of (buf, line) where line is the full line read from the
socket (minus the end_tokens bytes) and buf is any trailing
characters read after the end_tokens was found (which may be an empty
bytes object).

"""
result = bytes()

while True:

tokens_pos = buf.find(end_tokens)
if tokens_pos != -1:
before, after = buf[:tokens_pos], buf[tokens_pos + len(end_tokens) :]
result += before
return after, result

buf = _recv(sock, RECV_SIZE)
if not buf:
raise MemcacheUnexpectedCloseError()


def _recv(sock, size):
"""sock.recv() with retry on EINTR"""
while True:
Expand Down
57 changes: 57 additions & 0 deletions pymemcache/test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,6 +1149,63 @@ def test_version_exception(self):
with pytest.raises(MemcacheUnknownError):
client.version()

def test_raw_command_default_end_tokens(self):
client = self.make_client([b"REPLY\r\n", b"REPLY\r\nLEFTOVER"])
result = client.raw_command(b"misc")
assert result == b"REPLY"
result = client.raw_command(b"misc")
assert result == b"REPLY"

def test_raw_command_custom_end_tokens(self):
client = self.make_client(
[
b"REPLY\r\nEND\r\n",
b"REPLY\r\nEND\r\nLEFTOVER",
b"REPLYEND\r\nLEFTOVER",
b"REPLY\nLEFTOVER",
]
)
end_tokens = b"END\r\n"
result = client.raw_command(b"misc", end_tokens)
assert result == b"REPLY\r\n"
result = client.raw_command(b"misc", end_tokens)
assert result == b"REPLY\r\n"
result = client.raw_command(b"misc", end_tokens)
assert result == b"REPLY"
result = client.raw_command(b"misc", b"\n")
assert result == b"REPLY"

def test_raw_command_missing_end_tokens(self):
client = self.make_client([b"REPLY", b"REPLY"])
with pytest.raises(IndexError):
client.raw_command(b"misc")
with pytest.raises(IndexError):
client.raw_command(b"misc", b"END\r\n")

def test_raw_command_empty_end_tokens(self):
client = self.make_client([b"REPLY"])

with pytest.raises(IndexError):
client.raw_command(b"misc", b"")

def test_raw_command_types(self):
client = self.make_client(
[b"REPLY\r\n", b"REPLY\r\n", b"REPLY\r\nLEFTOVER", b"REPLY\r\nLEFTOVER"]
)
assert client.raw_command("key") == b"REPLY"
assert client.raw_command(b"key") == b"REPLY"
assert client.raw_command("key") == b"REPLY"
assert client.raw_command(b"key") == b"REPLY"

def test_send_end_token_types(self):
client = self.make_client(
[b"REPLY\r\n", b"REPLY\r\n", b"REPLY\r\nLEFTOVER", b"REPLY\r\nLEFTOVER"]
)
assert client.raw_command("key", "\r\n") == b"REPLY"
assert client.raw_command(b"key", b"\r\n") == b"REPLY"
assert client.raw_command("key", "\r\n") == b"REPLY"
assert client.raw_command(b"key", b"\r\n") == b"REPLY"


@pytest.mark.unit()
class TestClientSocketConnect(unittest.TestCase):
Expand Down