Skip to content

Commit

Permalink
Correctly changes to required state config in sliding sync (#17785)
Browse files Browse the repository at this point in the history
Fixes #17698

This handles `required_state` changes by checking if new state has been
added to the config, and if so fetching and returning that from the
current state.

This also takes care to ensure that given a state entry S that is added,
removed and then re-added that we do *not* send S down a second time if
there have been no changes to S in the current state. This is fine for
Rust SDK (as it just remembers all state), but we might decide not to do
this behaviour in the MSC. If we decide to always send down S then its
easy enough to rip out all the code.

---------

Co-authored-by: Eric Eastwood <[email protected]>
  • Loading branch information
erikjohnston and MadLittleMods authored Oct 14, 2024
1 parent ae6179b commit d025b5a
Show file tree
Hide file tree
Showing 7 changed files with 1,188 additions and 14 deletions.
1 change: 1 addition & 0 deletions changelog.d/17785.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix bug with sliding sync where the server would not return state that was added to the `required_state` config.
1 change: 1 addition & 0 deletions changelog.d/17805.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix bug with sliding sync where the server would not return state that was added to the `required_state` config.
234 changes: 224 additions & 10 deletions synapse/handlers/sliding_sync/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import logging
from itertools import chain
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple
from typing import TYPE_CHECKING, AbstractSet, Dict, List, Mapping, Optional, Set, Tuple

from prometheus_client import Histogram
from typing_extensions import assert_never
Expand Down Expand Up @@ -522,6 +522,8 @@ async def get_room_sync_data(

state_reset_out_of_room = True

prev_room_sync_config = previous_connection_state.room_configs.get(room_id)

# Determine whether we should limit the timeline to the token range.
#
# We should return historical messages (before token range) in the
Expand Down Expand Up @@ -550,7 +552,6 @@ async def get_room_sync_data(
# or `limited` mean for clients that interpret them correctly. In future this
# behavior is almost certainly going to change.
#
# TODO: Also handle changes to `required_state`
from_bound = None
initial = True
ignore_timeline_bound = False
Expand All @@ -571,7 +572,6 @@ async def get_room_sync_data(

log_kv({"sliding_sync.room_status": room_status})

prev_room_sync_config = previous_connection_state.room_configs.get(room_id)
if prev_room_sync_config is not None:
# Check if the timeline limit has increased, if so ignore the
# timeline bound and record the change (see "XXX: Odd behavior"
Expand All @@ -582,8 +582,6 @@ async def get_room_sync_data(
):
ignore_timeline_bound = True

# TODO: Check for changes in `required_state``

log_kv(
{
"sliding_sync.from_bound": from_bound,
Expand Down Expand Up @@ -997,6 +995,10 @@ async def get_room_sync_data(
include_others=required_state_filter.include_others,
)

# The required state map to store in the room sync config, if it has
# changed.
changed_required_state_map: Optional[Mapping[str, AbstractSet[str]]] = None

# We can return all of the state that was requested if this was the first
# time we've sent the room down this connection.
room_state: StateMap[EventBase] = {}
Expand All @@ -1010,6 +1012,29 @@ async def get_room_sync_data(
else:
assert from_bound is not None

if prev_room_sync_config is not None:
# Check if there are any changes to the required state config
# that we need to handle.
changed_required_state_map, added_state_filter = (
_required_state_changes(
user.to_string(),
previous_room_config=prev_room_sync_config,
room_sync_config=room_sync_config,
state_deltas=room_state_delta_id_map,
)
)

if added_state_filter:
# Some state entries got added, so we pull out the current
# state for them. If we don't do this we'd only send down new deltas.
state_ids = await self.get_current_state_ids_at(
room_id=room_id,
room_membership_for_user_at_to_token=room_membership_for_user_at_to_token,
state_filter=added_state_filter,
to_token=to_token,
)
room_state_delta_id_map.update(state_ids)

events = await self.store.get_events(
state_filter.filter_state(room_state_delta_id_map).values()
)
Expand Down Expand Up @@ -1108,10 +1133,13 @@ async def get_room_sync_data(
# sensible order again.
bump_stamp = 0

unstable_expanded_timeline = False
prev_room_sync_config = previous_connection_state.room_configs.get(room_id)
room_sync_required_state_map_to_persist = room_sync_config.required_state_map
if changed_required_state_map:
room_sync_required_state_map_to_persist = changed_required_state_map

# Record the `room_sync_config` if we're `ignore_timeline_bound` (which means
# that the `timeline_limit` has increased)
unstable_expanded_timeline = False
if ignore_timeline_bound:
# FIXME: We signal the fact that we're sending down more events to
# the client by setting `unstable_expanded_timeline` to true (see
Expand All @@ -1120,7 +1148,7 @@ async def get_room_sync_data(

new_connection_state.room_configs[room_id] = RoomSyncConfig(
timeline_limit=room_sync_config.timeline_limit,
required_state_map=room_sync_config.required_state_map,
required_state_map=room_sync_required_state_map_to_persist,
)
elif prev_room_sync_config is not None:
# If the result is `limited` then we need to record that the
Expand Down Expand Up @@ -1149,10 +1177,14 @@ async def get_room_sync_data(
):
new_connection_state.room_configs[room_id] = RoomSyncConfig(
timeline_limit=room_sync_config.timeline_limit,
required_state_map=room_sync_config.required_state_map,
required_state_map=room_sync_required_state_map_to_persist,
)

# TODO: Record changes in required_state.
elif changed_required_state_map is not None:
new_connection_state.room_configs[room_id] = RoomSyncConfig(
timeline_limit=room_sync_config.timeline_limit,
required_state_map=room_sync_required_state_map_to_persist,
)

else:
new_connection_state.room_configs[room_id] = room_sync_config
Expand Down Expand Up @@ -1285,3 +1317,185 @@ async def _get_bump_stamp(
return new_bump_event_pos.stream

return None


def _required_state_changes(
user_id: str,
*,
previous_room_config: "RoomSyncConfig",
room_sync_config: RoomSyncConfig,
state_deltas: StateMap[str],
) -> Tuple[Optional[Mapping[str, AbstractSet[str]]], StateFilter]:
"""Calculates the changes between the required state room config from the
previous requests compared with the current request.
This does two things. First, it calculates if we need to update the room
config due to changes to required state. Secondly, it works out which state
entries we need to pull from current state and return due to the state entry
now appearing in the required state when it previously wasn't (on top of the
state deltas).
This function tries to ensure to handle the case where a state entry is
added, removed and then added again to the required state. In that case we
only want to re-send that entry down sync if it has changed.
Returns:
A 2-tuple of updated required state config (or None if there is no update)
and the state filter to use to fetch extra current state that we need to
return.
"""

prev_required_state_map = previous_room_config.required_state_map
request_required_state_map = room_sync_config.required_state_map

if prev_required_state_map == request_required_state_map:
# There has been no change. Return immediately.
return None, StateFilter.none()

prev_wildcard = prev_required_state_map.get(StateValues.WILDCARD, set())
request_wildcard = request_required_state_map.get(StateValues.WILDCARD, set())

# If we were previously fetching everything ("*", "*"), always update the effective
# room required state config to match the request. And since we we're previously
# already fetching everything, we don't have to fetch anything now that they've
# narrowed.
if StateValues.WILDCARD in prev_wildcard:
return request_required_state_map, StateFilter.none()

# If a event type wildcard has been added or removed we don't try and do
# anything fancy, and instead always update the effective room required
# state config to match the request.
if request_wildcard - prev_wildcard:
# Some keys were added, so we need to fetch everything
return request_required_state_map, StateFilter.all()
if prev_wildcard - request_wildcard:
# Keys were only removed, so we don't have to fetch everything.
return request_required_state_map, StateFilter.none()

# Contains updates to the required state map compared with the previous room
# config. This has the same format as `RoomSyncConfig.required_state`
changes: Dict[str, AbstractSet[str]] = {}

# The set of types/state keys that we need to fetch and return to the
# client. Passed to `StateFilter.from_types(...)`
added: List[Tuple[str, Optional[str]]] = []

# First we calculate what, if anything, has been *added*.
for event_type in (
prev_required_state_map.keys() | request_required_state_map.keys()
):
old_state_keys = prev_required_state_map.get(event_type, set())
request_state_keys = request_required_state_map.get(event_type, set())

if old_state_keys == request_state_keys:
# No change to this type
continue

if not request_state_keys - old_state_keys:
# Nothing *added*, so we skip. Removals happen below.
continue

# Always update changes to include the newly added keys
changes[event_type] = request_state_keys

if StateValues.WILDCARD in old_state_keys:
# We were previously fetching everything for this type, so we don't need to
# fetch anything new.
continue

# Record the new state keys to fetch for this type.
if StateValues.WILDCARD in request_state_keys:
# If we have added a wildcard then we always just fetch everything.
added.append((event_type, None))
else:
for state_key in request_state_keys - old_state_keys:
if state_key == StateValues.ME:
added.append((event_type, user_id))
elif state_key == StateValues.LAZY:
# We handle lazy loading separately (outside this function),
# so don't need to explicitly add anything here.
#
# LAZY values should also be ignore for event types that are
# not membership.
pass
else:
added.append((event_type, state_key))

added_state_filter = StateFilter.from_types(added)

# Convert the list of state deltas to map from type to state_keys that have
# changed.
changed_types_to_state_keys: Dict[str, Set[str]] = {}
for event_type, state_key in state_deltas:
changed_types_to_state_keys.setdefault(event_type, set()).add(state_key)

# Figure out what changes we need to apply to the effective required state
# config.
for event_type, changed_state_keys in changed_types_to_state_keys.items():
old_state_keys = prev_required_state_map.get(event_type, set())
request_state_keys = request_required_state_map.get(event_type, set())

if old_state_keys == request_state_keys:
# No change.
continue

if request_state_keys - old_state_keys:
# We've expanded the set of state keys, so we just clobber the
# current set with the new set.
#
# We could also ensure that we keep entries where the state hasn't
# changed, but are no longer in the requested required state, but
# that's a sufficient edge case that we can ignore (as its only a
# performance optimization).
changes[event_type] = request_state_keys
continue

old_state_key_wildcard = StateValues.WILDCARD in old_state_keys
request_state_key_wildcard = StateValues.WILDCARD in request_state_keys

if old_state_key_wildcard != request_state_key_wildcard:
# If a state_key wildcard has been added or removed, we always update the
# effective room required state config to match the request.
changes[event_type] = request_state_keys
continue

if event_type == EventTypes.Member:
old_state_key_lazy = StateValues.LAZY in old_state_keys
request_state_key_lazy = StateValues.LAZY in request_state_keys

if old_state_key_lazy != request_state_key_lazy:
# If a "$LAZY" has been added or removed we always update the effective room
# required state config to match the request.
changes[event_type] = request_state_keys
continue

# Handle "$ME" values by adding "$ME" if the state key matches the user
# ID.
if user_id in changed_state_keys:
changed_state_keys.add(StateValues.ME)

# At this point there are no wildcards and no additions to the set of
# state keys requested, only deletions.
#
# We only remove state keys from the effective state if they've been
# removed from the request *and* the state has changed. This ensures
# that if a client removes and then re-adds a state key, we only send
# down the associated current state event if its changed (rather than
# sending down the same event twice).
invalidated = (old_state_keys - request_state_keys) & changed_state_keys
if invalidated:
changes[event_type] = old_state_keys - invalidated

if changes:
# Update the required state config based on the changes.
new_required_state_map = dict(prev_required_state_map)
for event_type, state_keys in changes.items():
if state_keys:
new_required_state_map[event_type] = state_keys
else:
# Remove entries with empty state keys.
new_required_state_map.pop(event_type, None)

return new_required_state_map, added_state_filter
else:
return None, added_state_filter
4 changes: 2 additions & 2 deletions synapse/storage/databases/main/sliding_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,8 +386,8 @@ def _get_and_clear_connection_positions_txn(
required_state_map: Dict[int, Dict[str, Set[str]]] = {}
for row in rows:
state = required_state_map[row[0]] = {}
for event_type, state_keys in db_to_json(row[1]):
state[event_type] = set(state_keys)
for event_type, state_key in db_to_json(row[1]):
state.setdefault(event_type, set()).add(state_key)

# Get all the room configs, looking up the required state from the map
# above.
Expand Down
7 changes: 7 additions & 0 deletions synapse/types/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,13 @@ def __contains__(self, key: Any) -> bool:

return False

def __bool__(self) -> bool:
"""Returns true if this state filter will match any state, or false if
this is the empty filter"""
if self.include_others:
return True
return bool(self.types)


_ALL_STATE_FILTER = StateFilter(types=immutabledict(), include_others=True)
_ALL_NON_MEMBER_STATE_FILTER = StateFilter(
Expand Down
Loading

0 comments on commit d025b5a

Please sign in to comment.