Skip to content

Commit

Permalink
chathistory: Validate BATCH commands more strictly (#208)
Browse files Browse the repository at this point in the history
  • Loading branch information
progval authored Jun 3, 2023
1 parent 5a5dbdb commit e5f22e8
Showing 1 changed file with 46 additions and 45 deletions.
91 changes: 46 additions & 45 deletions irctest/server_tests/chathistory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from irctest import cases, runner
from irctest.irc_utils.junkdrawer import random_name
from irctest.patma import ANYSTR
from irctest.patma import ANYSTR, StrRe

CHATHISTORY_CAP = "draft/chathistory"
EVENT_PLAYBACK_CAP = "draft/event-playback"
Expand All @@ -21,28 +21,6 @@
MYSQL_PASSWORD = ""


def validate_chathistory_batch(msgs):
batch_tag = None
closed_batch_tag = None
result = []
for msg in msgs:
if msg.command == "BATCH":
batch_param = msg.params[0]
if batch_tag is None and batch_param[0] == "+":
batch_tag = batch_param[1:]
elif batch_param[0] == "-":
closed_batch_tag = batch_param[1:]
elif (
msg.command == "PRIVMSG"
and batch_tag is not None
and msg.tags.get("batch") == batch_tag
):
if not msg.prefix.startswith("HistServ!"): # FIXME: ergo-specific
result.append(msg.to_history_message())
assert batch_tag == closed_batch_tag
return result


def skip_ngircd(f):
@functools.wraps(f)
def newf(self, *args, **kwargs):
Expand All @@ -56,6 +34,26 @@ def newf(self, *args, **kwargs):
@cases.mark_specifications("IRCv3")
@cases.mark_services
class ChathistoryTestCase(cases.BaseServerTestCase):
def validate_chathistory_batch(self, msgs, target):
(start, *inner_msgs, end) = msgs

self.assertMessageMatch(
start, command="BATCH", params=[StrRe(r"\+.*"), "chathistory", target]
)
batch_tag = start.params[0][1:]
self.assertMessageMatch(end, command="BATCH", params=["-" + batch_tag])

result = []
for msg in inner_msgs:
if (
msg.command == "PRIVMSG"
and batch_tag is not None
and msg.tags.get("batch") == batch_tag
):
if not msg.prefix.startswith("HistServ!"): # FIXME: ergo-specific
result.append(msg.to_history_message())
return result

@staticmethod
def config() -> cases.TestCaseControllerConfig:
return cases.TestCaseControllerConfig(chathistory=True)
Expand Down Expand Up @@ -308,6 +306,9 @@ def testChathistoryDMs(self, subcommand):
)
time.sleep(0.002)

self.getMessages(1)
self.getMessages(2)

self.validate_echo_messages(NUM_MESSAGES, echo_messages)
self.validate_chathistory(subcommand, echo_messages, 1, c2)
self.validate_chathistory(subcommand, echo_messages, 2, c1)
Expand Down Expand Up @@ -401,31 +402,31 @@ def validate_chathistory(self, subcommand, echo_messages, user, chname):
def _validate_chathistory_LATEST(self, echo_messages, user, chname):
INCLUSIVE_LIMIT = len(echo_messages) * 2
self.sendLine(user, "CHATHISTORY LATEST %s * %d" % (chname, INCLUSIVE_LIMIT))
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages, result)

self.sendLine(user, "CHATHISTORY LATEST %s * %d" % (chname, 5))
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[-5:], result)

self.sendLine(user, "CHATHISTORY LATEST %s * %d" % (chname, 1))
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[-1:], result)

self.sendLine(
user,
"CHATHISTORY LATEST %s msgid=%s %d"
% (chname, echo_messages[4].msgid, INCLUSIVE_LIMIT),
)
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[5:], result)

self.sendLine(
user,
"CHATHISTORY LATEST %s timestamp=%s %d"
% (chname, echo_messages[4].time, INCLUSIVE_LIMIT),
)
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[5:], result)

def _validate_chathistory_BEFORE(self, echo_messages, user, chname):
Expand All @@ -435,23 +436,23 @@ def _validate_chathistory_BEFORE(self, echo_messages, user, chname):
"CHATHISTORY BEFORE %s msgid=%s %d"
% (chname, echo_messages[6].msgid, INCLUSIVE_LIMIT),
)
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[:6], result)

self.sendLine(
user,
"CHATHISTORY BEFORE %s timestamp=%s %d"
% (chname, echo_messages[6].time, INCLUSIVE_LIMIT),
)
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[:6], result)

self.sendLine(
user,
"CHATHISTORY BEFORE %s timestamp=%s %d"
% (chname, echo_messages[6].time, 2),
)
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[4:6], result)

def _validate_chathistory_AFTER(self, echo_messages, user, chname):
Expand All @@ -461,22 +462,22 @@ def _validate_chathistory_AFTER(self, echo_messages, user, chname):
"CHATHISTORY AFTER %s msgid=%s %d"
% (chname, echo_messages[3].msgid, INCLUSIVE_LIMIT),
)
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[4:], result)

self.sendLine(
user,
"CHATHISTORY AFTER %s timestamp=%s %d"
% (chname, echo_messages[3].time, INCLUSIVE_LIMIT),
)
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[4:], result)

self.sendLine(
user,
"CHATHISTORY AFTER %s timestamp=%s %d" % (chname, echo_messages[3].time, 3),
)
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[4:7], result)

def _validate_chathistory_BETWEEN(self, echo_messages, user, chname):
Expand All @@ -492,7 +493,7 @@ def _validate_chathistory_BETWEEN(self, echo_messages, user, chname):
INCLUSIVE_LIMIT,
),
)
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[1:-1], result)

self.sendLine(
Expand All @@ -505,7 +506,7 @@ def _validate_chathistory_BETWEEN(self, echo_messages, user, chname):
INCLUSIVE_LIMIT,
),
)
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[1:-1], result)

# BETWEEN forwards and backwards with a limit, should get
Expand All @@ -515,15 +516,15 @@ def _validate_chathistory_BETWEEN(self, echo_messages, user, chname):
"CHATHISTORY BETWEEN %s msgid=%s msgid=%s %d"
% (chname, echo_messages[0].msgid, echo_messages[-1].msgid, 3),
)
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[1:4], result)

self.sendLine(
user,
"CHATHISTORY BETWEEN %s msgid=%s msgid=%s %d"
% (chname, echo_messages[-1].msgid, echo_messages[0].msgid, 3),
)
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[-4:-1], result)

# same stuff again but with timestamps
Expand All @@ -532,51 +533,51 @@ def _validate_chathistory_BETWEEN(self, echo_messages, user, chname):
"CHATHISTORY BETWEEN %s timestamp=%s timestamp=%s %d"
% (chname, echo_messages[0].time, echo_messages[-1].time, INCLUSIVE_LIMIT),
)
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[1:-1], result)
self.sendLine(
user,
"CHATHISTORY BETWEEN %s timestamp=%s timestamp=%s %d"
% (chname, echo_messages[-1].time, echo_messages[0].time, INCLUSIVE_LIMIT),
)
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[1:-1], result)
self.sendLine(
user,
"CHATHISTORY BETWEEN %s timestamp=%s timestamp=%s %d"
% (chname, echo_messages[0].time, echo_messages[-1].time, 3),
)
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[1:4], result)
self.sendLine(
user,
"CHATHISTORY BETWEEN %s timestamp=%s timestamp=%s %d"
% (chname, echo_messages[-1].time, echo_messages[0].time, 3),
)
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[-4:-1], result)

def _validate_chathistory_AROUND(self, echo_messages, user, chname):
self.sendLine(
user,
"CHATHISTORY AROUND %s msgid=%s %d" % (chname, echo_messages[7].msgid, 1),
)
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual([echo_messages[7]], result)

self.sendLine(
user,
"CHATHISTORY AROUND %s msgid=%s %d" % (chname, echo_messages[7].msgid, 3),
)
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertEqual(echo_messages[6:9], result)

self.sendLine(
user,
"CHATHISTORY AROUND %s timestamp=%s %d"
% (chname, echo_messages[7].time, 3),
)
result = validate_chathistory_batch(self.getMessages(user))
result = self.validate_chathistory_batch(self.getMessages(user), chname)
self.assertIn(echo_messages[7], result)

@pytest.mark.arbitrary_client_tags
Expand Down

0 comments on commit e5f22e8

Please sign in to comment.