diff --git a/irctest/basecontrollers.py b/irctest/basecontrollers.py index b1464ac4..cbc56aae 100644 --- a/irctest/basecontrollers.py +++ b/irctest/basecontrollers.py @@ -210,9 +210,10 @@ def registerUser( case: irctest.cases.BaseServerTestCase, # type: ignore username: str, password: Optional[str] = None, + **kwargs: Any, ) -> None: if self.services_controller is not None: - self.services_controller.registerUser(case, username, password) + self.services_controller.registerUser(case, username, password, **kwargs) else: raise NotImplementedByController("account registration") @@ -293,7 +294,7 @@ def wait_for_services(self) -> None: timeout = time.time() + 5 while True: c.sendLine(f"PRIVMSG {self.server_controller.nickserv} :HELP") - msgs = self.getNickServResponse(c) + msgs = self.getServiceResponse(c) for msg in msgs: if msg.command == "401": # NickServ not available yet @@ -319,7 +320,7 @@ def wait_for_services(self) -> None: c.disconnect() self.services_up = True - def getNickServResponse(self, client: Any) -> List[Message]: + def getServiceResponse(self, client: Any) -> List[Message]: """Wrapper aroung getMessages() that waits longer, because NickServ is queried asynchronously.""" msgs: List[Message] = [] @@ -333,11 +334,14 @@ def registerUser( case: irctest.cases.BaseServerTestCase, # type: ignore username: str, password: Optional[str] = None, + **kwargs: Any, ) -> None: if not case.run_services: raise ValueError( "Attempted to register a nick, but `run_services` it not True." ) + if kwargs: + raise NotImplementedByController(", ".join(kwargs)) assert password client = case.addClient(show_io=True) case.sendLine(client, "NICK " + username) @@ -350,7 +354,7 @@ def registerUser( f"PRIVMSG {self.server_controller.nickserv} " f":REGISTER {password} foo@example.org", ) - msgs = self.getNickServResponse(case.clients[client]) + msgs = self.getServiceResponse(case.clients[client]) if self.server_controller.software_name == "inspircd": assert "900" in {msg.command for msg in msgs}, msgs assert "NOTICE" in {msg.command for msg in msgs}, msgs diff --git a/irctest/controllers/anope_services.py b/irctest/controllers/anope_services.py index f620dd9c..410da5a6 100644 --- a/irctest/controllers/anope_services.py +++ b/irctest/controllers/anope_services.py @@ -1,9 +1,12 @@ import os import shutil import subprocess -from typing import Type +from typing import Any, Optional, Type +from irctest import cases, runner from irctest.basecontrollers import BaseServicesController, DirectoryBasedController +from irctest.client_mock import ClientMock +from irctest.irc_utils.sasl import sasl_plain_blob TEMPLATE_CONFIG = """ serverinfo {{ @@ -30,12 +33,17 @@ userlen = 10 hostlen = 64 chanlen = 32 + vhost_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789.-" }} mail {{ usemail = no }} +/************************ + * NickServ: + */ + service {{ nick = "NickServ" user = "services" @@ -57,6 +65,29 @@ }} command {{ service = "NickServ"; name = "REGISTER"; command = "nickserv/register"; }} + +/************************ + * HostServ: + */ + +service {{ + nick = "HostServ" + user = "services" + host = "services.host" + gecos = "vHost Service" +}} +module {{ + name = "hostserv" + client = "HostServ" +}} +module {{ name = "hs_set" }} +command {{ service = "HostServ"; name = "SET"; command = "hostserv/set"; }} + + +/************************ + * Misc: + */ + options {{ casemap = "ascii" readtimeout = 5s @@ -66,7 +97,6 @@ module {{ name = "m_sasl" }} module {{ name = "enc_sha256" }} module {{ name = "ns_cert" }} - """ @@ -121,6 +151,39 @@ def run(self, protocol: str, server_hostname: str, server_port: int) -> None: # stderr=subprocess.DEVNULL, ) + def registerUser( + self, + case: cases.BaseServerTestCase, # type: ignore + username: str, + password: Optional[str] = None, + vhost: Optional[str] = None, + **kwargs: Any, + ) -> None: + super().registerUser(case, username, password) + + if vhost: + if not password: + raise runner.NotImplementedByController( + "vHost for users with no password" + ) + c = ClientMock(name="setVhost", show_io=True) + c.connect(self.server_controller.hostname, self.server_controller.port) + c.sendLine("CAP REQ :sasl") + c.sendLine("NICK " + username) + c.sendLine("USER r e g :user") + while c.getMessage(synchronize=False).command != "CAP": + pass + c.sendLine("AUTHENTICATE PLAIN") + while c.getMessage(synchronize=False).command != "AUTHENTICATE": + pass + c.sendLine(sasl_plain_blob(username, password)) + c.sendLine("CAP END") + while c.getMessage(synchronize=False).command != "001": + pass + c.getMessages() + c.sendLine(f"PRIVMSG HostServ :SET {username} {vhost}") + self.getServiceResponse(c) + def get_irctest_controller_class() -> Type[AnopeController]: return AnopeController diff --git a/irctest/controllers/inspircd.py b/irctest/controllers/inspircd.py index 1cd40ba2..e1a869ea 100644 --- a/irctest/controllers/inspircd.py +++ b/irctest/controllers/inspircd.py @@ -61,10 +61,12 @@ # Protocol: + + diff --git a/irctest/server_tests/chghost.py b/irctest/server_tests/chghost.py new file mode 100644 index 00000000..0b52e740 --- /dev/null +++ b/irctest/server_tests/chghost.py @@ -0,0 +1,46 @@ +""" + +""" + +from irctest import cases +from irctest.irc_utils.sasl import sasl_plain_blob +from irctest.patma import ANYSTR, StrRe + + +@cases.mark_services +class ChghostServicesTestCase(cases.BaseServerTestCase, cases.OptionalityHelper): + def testChghostFromServices(self): + self.connectClient("observer", capabilities=["chghost"], skip_if_cap_nak=True) + self.connectClient("oldclient") + + self.controller.registerUser( + self, "vhostuser", "sesame", vhost="vhost.example.com" + ) + self.connectClient("vhost-user", capabilities=["sasl"], skip_if_cap_nak=True) + + for i in (1, 2, 3): + self.sendLine(i, "JOIN #chan") + self.getMessages(i) + + for i in (1, 2, 3): + self.getMessages(i) + + self.sendLine(3, "AUTHENTICATE PLAIN") + self.assertMessageMatch( + self.getRegistrationMessage(3), + command="AUTHENTICATE", + params=["+"], + ) + self.sendLine(3, sasl_plain_blob("vhostuser", "sesame")) + self.assertMessageMatch( + self.getRegistrationMessage(3), + command="900", + ) + + self.assertMessageMatch( + self.getMessage(1), + prefix=StrRe("vhost-user!.*@(?!vhost-user.example)"), + command="CHGHOST", + params=[ANYSTR, "vhost.example.com"], + ) + self.assertEqual(self.getMessages(2), []) # cycle?