diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index 3a1f150082..3fb77fd9dd 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -20,7 +20,17 @@ # import time import urllib.parse -from typing import Any, Collection, Dict, List, Optional, Tuple, Union +from typing import ( + Any, + BinaryIO, + Callable, + Collection, + Dict, + List, + Optional, + Tuple, + Union, +) from unittest.mock import Mock from urllib.parse import urlencode @@ -34,8 +44,9 @@ from synapse.api.constants import ApprovalNoticeMedium, LoginType from synapse.api.errors import Codes from synapse.appservice import ApplicationService +from synapse.http.client import RawHeaders from synapse.module_api import ModuleApi -from synapse.rest.client import devices, login, logout, register +from synapse.rest.client import account, devices, login, logout, profile, register from synapse.rest.client.account import WhoamiRestServlet from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.server import HomeServer @@ -48,6 +59,7 @@ from tests.rest.client.utils import TEST_OIDC_CONFIG from tests.server import FakeChannel from tests.test_utils.html_parsers import TestHtmlParser +from tests.test_utils.oidc import FakeOidcServer from tests.unittest import HomeserverTestCase, override_config, skip_unless try: @@ -1421,7 +1433,19 @@ def test_login_appservice_no_token(self) -> None: class UsernamePickerTestCase(HomeserverTestCase): """Tests for the username picker flow of SSO login""" - servlets = [login.register_servlets] + servlets = [ + login.register_servlets, + profile.register_servlets, + account.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + self.http_client = Mock(spec=["get_file"]) + self.http_client.get_file.side_effect = mock_get_file + hs = self.setup_test_homeserver( + proxied_blocklisted_http_client=self.http_client + ) + return hs def default_config(self) -> Dict[str, Any]: config = super().default_config() @@ -1430,7 +1454,11 @@ def default_config(self) -> Dict[str, Any]: config["oidc_config"] = {} config["oidc_config"].update(TEST_OIDC_CONFIG) config["oidc_config"]["user_mapping_provider"] = { - "config": {"display_name_template": "{{ user.displayname }}"} + "config": { + "display_name_template": "{{ user.displayname }}", + "email_template": "{{ user.email }}", + "picture_template": "{{ user.picture }}", + } } # whitelist this client URI so we redirect straight to it rather than @@ -1443,15 +1471,22 @@ def create_resource_dict(self) -> Dict[str, Resource]: d.update(build_synapse_client_resource_tree(self.hs)) return d - def test_username_picker(self) -> None: - """Test the happy path of a username picker flow.""" - - fake_oidc_server = self.helper.fake_oidc_server() - + def proceed_to_username_picker_page( + self, + fake_oidc_server: FakeOidcServer, + displayname: str, + email: str, + picture: str, + ) -> Tuple[str, str]: # do the start of the login flow channel, _ = self.helper.auth_via_oidc( fake_oidc_server, - {"sub": "tester", "displayname": "Jonny"}, + { + "sub": "tester", + "displayname": displayname, + "picture": picture, + "email": email, + }, TEST_CLIENT_REDIRECT_URL, ) @@ -1478,16 +1513,132 @@ def test_username_picker(self) -> None: ) session = username_mapping_sessions[session_id] self.assertEqual(session.remote_user_id, "tester") - self.assertEqual(session.display_name, "Jonny") + self.assertEqual(session.display_name, displayname) + self.assertEqual(session.emails, [email]) + self.assertEqual(session.avatar_url, picture) self.assertEqual(session.client_redirect_url, TEST_CLIENT_REDIRECT_URL) # the expiry time should be about 15 minutes away expected_expiry = self.clock.time_msec() + (15 * 60 * 1000) self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000) + return picker_url, session_id + + def test_username_picker_use_displayname_avatar_and_email(self) -> None: + """Test the happy path of a username picker flow with using displayname, avatar and email.""" + + fake_oidc_server = self.helper.fake_oidc_server() + + mxid = "@bobby:test" + displayname = "Jonny" + email = "bobby@test.com" + picture = "mxc://test/avatar_url" + + picker_url, session_id = self.proceed_to_username_picker_page( + fake_oidc_server, displayname, email, picture + ) + + # Now, submit a username to the username picker, which should serve a redirect + # to the completion page. + # Also specify that we should use the provided displayname, avatar and email. + content = urlencode( + { + b"username": b"bobby", + b"use_display_name": b"true", + b"use_avatar": b"true", + b"use_email": email, + } + ).encode("utf8") + chan = self.make_request( + "POST", + path=picker_url, + content=content, + content_is_form=True, + custom_headers=[ + ("Cookie", "username_mapping_session=" + session_id), + # old versions of twisted don't do form-parsing without a valid + # content-length header. + ("Content-Length", str(len(content))), + ], + ) + self.assertEqual(chan.code, 302, chan.result) + location_headers = chan.headers.getRawHeaders("Location") + assert location_headers + + # send a request to the completion page, which should 302 to the client redirectUrl + chan = self.make_request( + "GET", + path=location_headers[0], + custom_headers=[("Cookie", "username_mapping_session=" + session_id)], + ) + self.assertEqual(chan.code, 302, chan.result) + location_headers = chan.headers.getRawHeaders("Location") + assert location_headers + + # ensure that the returned location matches the requested redirect URL + path, query = location_headers[0].split("?", 1) + self.assertEqual(path, "https://x") + + # it will have url-encoded the params properly, so we'll have to parse them + params = urllib.parse.parse_qsl( + query, keep_blank_values=True, strict_parsing=True, errors="strict" + ) + self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS) + self.assertEqual(params[2][0], "loginToken") + + # fish the login token out of the returned redirect uri + login_token = params[2][1] + + # finally, submit the matrix login token to the login API, which gives us our + # matrix access token, mxid, and device id. + chan = self.make_request( + "POST", + "/login", + content={"type": "m.login.token", "token": login_token}, + ) + self.assertEqual(chan.code, 200, chan.result) + self.assertEqual(chan.json_body["user_id"], mxid) + + # ensure the displayname and avatar from the OIDC response have been configured for the user. + channel = self.make_request( + "GET", "/profile/" + mxid, access_token=chan.json_body["access_token"] + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertIn("mxc://test", channel.json_body["avatar_url"]) + self.assertEqual(displayname, channel.json_body["displayname"]) + + # ensure the email from the OIDC response has been configured for the user. + channel = self.make_request( + "GET", "/account/3pid", access_token=chan.json_body["access_token"] + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(email, channel.json_body["threepids"][0]["address"]) + + def test_username_picker_dont_use_displayname_avatar_or_email(self) -> None: + """Test the happy path of a username picker flow without using displayname, avatar or email.""" + + fake_oidc_server = self.helper.fake_oidc_server() + + mxid = "@bobby:test" + displayname = "Jonny" + email = "bobby@test.com" + picture = "mxc://test/avatar_url" + username = "bobby" + + picker_url, session_id = self.proceed_to_username_picker_page( + fake_oidc_server, displayname, email, picture + ) + # Now, submit a username to the username picker, which should serve a redirect - # to the completion page - content = urlencode({b"username": b"bobby"}).encode("utf8") + # to the completion page. + # Also specify that we should not use the provided displayname, avatar or email. + content = urlencode( + { + b"username": username, + b"use_display_name": b"false", + b"use_avatar": b"false", + } + ).encode("utf8") chan = self.make_request( "POST", path=picker_url, @@ -1536,4 +1687,29 @@ def test_username_picker(self) -> None: content={"type": "m.login.token", "token": login_token}, ) self.assertEqual(chan.code, 200, chan.result) - self.assertEqual(chan.json_body["user_id"], "@bobby:test") + self.assertEqual(chan.json_body["user_id"], mxid) + + # ensure the displayname and avatar from the OIDC response have not been configured for the user. + channel = self.make_request( + "GET", "/profile/" + mxid, access_token=chan.json_body["access_token"] + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertNotIn("avatar_url", channel.json_body) + self.assertEqual(username, channel.json_body["displayname"]) + + # ensure the email from the OIDC response has not been configured for the user. + channel = self.make_request( + "GET", "/account/3pid", access_token=chan.json_body["access_token"] + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertListEqual([], channel.json_body["threepids"]) + + +async def mock_get_file( + url: str, + output_stream: BinaryIO, + max_size: Optional[int] = None, + headers: Optional[RawHeaders] = None, + is_allowed_content_type: Optional[Callable[[str], bool]] = None, +) -> Tuple[int, Dict[bytes, List[bytes]], str, int]: + return 0, {b"Content-Type": [b"image/png"]}, "", 200