diff --git a/changelog.d/17718.misc b/changelog.d/17718.misc new file mode 100644 index 0000000000..ea73a03f53 --- /dev/null +++ b/changelog.d/17718.misc @@ -0,0 +1 @@ +Slight optimization when fetching state/events for Sliding Sync. diff --git a/synapse/handlers/sliding_sync/__init__.py b/synapse/handlers/sliding_sync/__init__.py index 39dba4ff98..a1a6728fb9 100644 --- a/synapse/handlers/sliding_sync/__init__.py +++ b/synapse/handlers/sliding_sync/__init__.py @@ -452,13 +452,11 @@ async def get_current_state_at( to_token=to_token, ) - event_map = await self.store.get_events(list(state_ids.values())) + events = await self.store.get_events_as_list(list(state_ids.values())) state_map = {} - for key, event_id in state_ids.items(): - event = event_map.get(event_id) - if event: - state_map[key] = event + for event in events: + state_map[(event.type, event.state_key)] = event return state_map diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index c029228422..403407068c 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -61,7 +61,13 @@ current_context, make_deferred_yieldable, ) -from synapse.logging.opentracing import start_active_span, tag_args, trace +from synapse.logging.opentracing import ( + SynapseTags, + set_tag, + start_active_span, + tag_args, + trace, +) from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, @@ -525,6 +531,7 @@ async def get_event( return event + @trace async def get_events( self, event_ids: Collection[str], @@ -556,6 +563,11 @@ async def get_events( Returns: A mapping from event_id to event. """ + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "event_ids.length", + str(len(event_ids)), + ) + events = await self.get_events_as_list( event_ids, redact_behaviour=redact_behaviour, @@ -603,6 +615,10 @@ async def get_events_as_list( Note that the returned list may be smaller than the list of event IDs if not all events could be fetched. """ + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "event_ids.length", + str(len(event_ids)), + ) if not event_ids: return [] @@ -723,10 +739,11 @@ async def get_events_as_list( return events + @trace @cancellable async def get_unredacted_events_from_cache_or_db( self, - event_ids: Iterable[str], + event_ids: Collection[str], allow_rejected: bool = False, ) -> Dict[str, EventCacheEntry]: """Fetch a bunch of events from the cache or the database. @@ -748,6 +765,11 @@ async def get_unredacted_events_from_cache_or_db( Returns: map from event id to result """ + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "event_ids.length", + str(len(event_ids)), + ) + # Shortcut: check if we have any events in the *in memory* cache - this function # may be called repeatedly for the same event so at this point we cannot reach # out to any external cache for performance reasons. The external cache is @@ -936,7 +958,7 @@ async def _get_events_from_cache( events, update_metrics=update_metrics ) - missing_event_ids = (e for e in events if e not in event_map) + missing_event_ids = [e for e in events if e not in event_map] event_map.update( await self._get_events_from_external_cache( events=missing_event_ids, @@ -946,8 +968,9 @@ async def _get_events_from_cache( return event_map + @trace async def _get_events_from_external_cache( - self, events: Iterable[str], update_metrics: bool = True + self, events: Collection[str], update_metrics: bool = True ) -> Dict[str, EventCacheEntry]: """Fetch events from any configured external cache. @@ -957,6 +980,10 @@ async def _get_events_from_external_cache( events: list of event_ids to fetch update_metrics: Whether to update the cache hit ratio metrics """ + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "events.length", + str(len(events)), + ) event_map = {} for event_id in events: @@ -1222,6 +1249,7 @@ def fire_errback(exc: Exception) -> None: with PreserveLoggingContext(): self.hs.get_reactor().callFromThread(fire_errback, e) + @trace async def _get_events_from_db( self, event_ids: Collection[str] ) -> Dict[str, EventCacheEntry]: @@ -1240,6 +1268,11 @@ async def _get_events_from_db( map from event id to result. May return extra events which weren't asked for. """ + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "event_ids.length", + str(len(event_ids)), + ) + fetched_event_ids: Set[str] = set() fetched_events: Dict[str, _EventRow] = {} diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index fd1f5e7fd5..104d141a72 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -20,7 +20,7 @@ # import json from contextlib import contextmanager -from typing import Generator, List, Tuple +from typing import Generator, List, Set, Tuple from unittest import mock from twisted.enterprise.adbapi import ConnectionPool @@ -295,6 +295,53 @@ def test_dedupe(self) -> None: self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1) +class GetEventsTestCase(unittest.HomeserverTestCase): + """Test `get_events(...)`/`get_events_as_list(...)`""" + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store: EventsWorkerStore = hs.get_datastores().main + + def test_get_lots_of_messages(self) -> None: + """Sanity check that `get_events(...)`/`get_events_as_list(...)` works""" + num_events = 100 + + user_id = self.register_user("user", "pass") + user_tok = self.login(user_id, "pass") + + room_id = self.helper.create_room_as(user_id, tok=user_tok) + + event_ids: Set[str] = set() + for i in range(num_events): + event = self.get_success( + inject_event( + self.hs, + room_id=room_id, + type="m.room.message", + sender=user_id, + content={ + "body": f"foo{i}", + "msgtype": "m.text", + }, + ) + ) + event_ids.add(event.event_id) + + # Sanity check that we actually created the events + self.assertEqual(len(event_ids), num_events) + + # This is the function under test + fetched_event_map = self.get_success(self.store.get_events(event_ids)) + + # Sanity check that we got the events back + self.assertIncludes(fetched_event_map.keys(), event_ids, exact=True) + + class DatabaseOutageTestCase(unittest.HomeserverTestCase): """Test event fetching during a database outage."""