Add locking to more safely delete state groups: Part 1 (#18107)

Currently we don't really have anything that stops us from deleting
state groups when an in-flight event references it. This is a fairly
rare race currently, but we want to be able to more aggressively delete
state groups so it is important to address this to ensure that the
database remains valid.

This implements the locking, but doesn't actually use it.

See the class docstring of the new data store for an explanation for how
this works.

---------

Co-authored-by: Devon Hudson <devon.dmytro@gmail.com>
This commit is contained in:
Erik Johnston 2025-02-03 18:29:15 +01:00 committed by GitHub
parent ac1bf682ff
commit aa6e5c2ecb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 1047 additions and 46 deletions

1
changelog.d/18107.bugfix Normal file
View file

@ -0,0 +1 @@
Fix rare edge case where state groups could be deleted while we are persisting new events that reference them.

View file

@ -151,6 +151,8 @@ class FederationEventHandler:
def __init__(self, hs: "HomeServer"):
self._clock = hs.get_clock()
self._store = hs.get_datastores().main
self._state_store = hs.get_datastores().state
self._state_deletion_store = hs.get_datastores().state_deletion
self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state
@ -580,7 +582,9 @@ class FederationEventHandler:
room_version.identifier,
state_maps_to_resolve,
event_map=None,
state_res_store=StateResolutionStore(self._store),
state_res_store=StateResolutionStore(
self._store, self._state_deletion_store
),
)
)
else:
@ -1179,7 +1183,9 @@ class FederationEventHandler:
room_version,
state_maps,
event_map={event_id: event},
state_res_store=StateResolutionStore(self._store),
state_res_store=StateResolutionStore(
self._store, self._state_deletion_store
),
)
except Exception as e:
@ -1874,7 +1880,9 @@ class FederationEventHandler:
room_version,
[local_state_id_map, claimed_auth_events_id_map],
event_map=None,
state_res_store=StateResolutionStore(self._store),
state_res_store=StateResolutionStore(
self._store, self._state_deletion_store
),
)
)
else:
@ -2014,7 +2022,9 @@ class FederationEventHandler:
room_version,
state_sets,
event_map=None,
state_res_store=StateResolutionStore(self._store),
state_res_store=StateResolutionStore(
self._store, self._state_deletion_store
),
)
)
else:

View file

@ -59,11 +59,13 @@ from synapse.types.state import StateFilter
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import Measure, measure_func
from synapse.util.stringutils import shortstr
if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.storage.controllers import StateStorageController
from synapse.storage.databases.main import DataStore
from synapse.storage.databases.state.deletion import StateDeletionDataStore
logger = logging.getLogger(__name__)
metrics_logger = logging.getLogger("synapse.state.metrics")
@ -194,6 +196,8 @@ class StateHandler:
self._storage_controllers = hs.get_storage_controllers()
self._events_shard_config = hs.config.worker.events_shard_config
self._instance_name = hs.get_instance_name()
self._state_store = hs.get_datastores().state
self._state_deletion_store = hs.get_datastores().state_deletion
self._update_current_state_client = (
ReplicationUpdateCurrentStateRestServlet.make_client(hs)
@ -475,7 +479,10 @@ class StateHandler:
@trace
@measure_func()
async def resolve_state_groups_for_events(
self, room_id: str, event_ids: StrCollection, await_full_state: bool = True
self,
room_id: str,
event_ids: StrCollection,
await_full_state: bool = True,
) -> _StateCacheEntry:
"""Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them.
@ -511,6 +518,17 @@ class StateHandler:
) = await self._state_storage_controller.get_state_group_delta(
state_group_id
)
if prev_group:
# Ensure that we still have the prev group, and ensure we don't
# delete it while we're persisting the event.
missing_state_group = await self._state_deletion_store.check_state_groups_and_bump_deletion(
{prev_group}
)
if missing_state_group:
prev_group = None
delta_ids = None
return _StateCacheEntry(
state=None,
state_group=state_group_id,
@ -531,7 +549,9 @@ class StateHandler:
room_version,
state_to_resolve,
None,
state_res_store=StateResolutionStore(self.store),
state_res_store=StateResolutionStore(
self.store, self._state_deletion_store
),
)
return result
@ -663,7 +683,25 @@ class StateResolutionHandler:
async with self.resolve_linearizer.queue(group_names):
cache = self._state_cache.get(group_names, None)
if cache:
return cache
# Check that the returned cache entry doesn't point to deleted
# state groups.
state_groups_to_check = set()
if cache.state_group is not None:
state_groups_to_check.add(cache.state_group)
if cache.prev_group is not None:
state_groups_to_check.add(cache.prev_group)
missing_state_groups = await state_res_store.state_deletion_store.check_state_groups_and_bump_deletion(
state_groups_to_check
)
if not missing_state_groups:
return cache
else:
# There are missing state groups, so let's remove the stale
# entry and continue as if it was a cache miss.
self._state_cache.pop(group_names, None)
logger.info(
"Resolving state for %s with groups %s",
@ -671,6 +709,16 @@ class StateResolutionHandler:
list(group_names),
)
# We double check that none of the state groups have been deleted.
# They shouldn't be as all these state groups should be referenced.
missing_state_groups = await state_res_store.state_deletion_store.check_state_groups_and_bump_deletion(
group_names
)
if missing_state_groups:
raise Exception(
f"State groups have been deleted: {shortstr(missing_state_groups)}"
)
state_groups_histogram.observe(len(state_groups_ids))
new_state = await self.resolve_events_with_store(
@ -884,7 +932,8 @@ class StateResolutionStore:
in well defined way.
"""
store: "DataStore"
main_store: "DataStore"
state_deletion_store: "StateDeletionDataStore"
def get_events(
self, event_ids: StrCollection, allow_rejected: bool = False
@ -899,7 +948,7 @@ class StateResolutionStore:
An awaitable which resolves to a dict from event_id to event.
"""
return self.store.get_events(
return self.main_store.get_events(
event_ids,
redact_behaviour=EventRedactBehaviour.as_is,
get_prev_content=False,
@ -920,4 +969,4 @@ class StateResolutionStore:
An awaitable that resolves to a set of event IDs.
"""
return self.store.get_auth_chain_difference(room_id, state_sets)
return self.main_store.get_auth_chain_difference(room_id, state_sets)

View file

@ -332,6 +332,7 @@ class EventsPersistenceStorageController:
# store for now.
self.main_store = stores.main
self.state_store = stores.state
self._state_deletion_store = stores.state_deletion
assert stores.persist_events
self.persist_events_store = stores.persist_events
@ -549,7 +550,9 @@ class EventsPersistenceStorageController:
room_version,
state_maps_by_state_group,
event_map=None,
state_res_store=StateResolutionStore(self.main_store),
state_res_store=StateResolutionStore(
self.main_store, self._state_deletion_store
),
)
return await res.get_state(self._state_controller, StateFilter.all())
@ -635,15 +638,20 @@ class EventsPersistenceStorageController:
room_id, [e for e, _ in chunk]
)
await self.persist_events_store._persist_events_and_state_updates(
room_id,
chunk,
state_delta_for_room=state_delta_for_room,
new_forward_extremities=new_forward_extremities,
use_negative_stream_ordering=backfilled,
inhibit_local_membership_updates=backfilled,
new_event_links=new_event_links,
)
# Stop the state groups from being deleted while we're persisting
# them.
async with self._state_deletion_store.persisting_state_group_references(
events_and_contexts
):
await self.persist_events_store._persist_events_and_state_updates(
room_id,
chunk,
state_delta_for_room=state_delta_for_room,
new_forward_extremities=new_forward_extremities,
use_negative_stream_ordering=backfilled,
inhibit_local_membership_updates=backfilled,
new_event_links=new_event_links,
)
return replaced_events
@ -965,7 +973,9 @@ class EventsPersistenceStorageController:
room_version,
state_groups,
events_map,
state_res_store=StateResolutionStore(self.main_store),
state_res_store=StateResolutionStore(
self.main_store, self._state_deletion_store
),
)
state_resolutions_during_persistence.inc()

View file

@ -26,6 +26,7 @@ from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_conn
from synapse.storage.databases.main.events import PersistEventsStore
from synapse.storage.databases.state import StateGroupDataStore
from synapse.storage.databases.state.deletion import StateDeletionDataStore
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
@ -49,12 +50,14 @@ class Databases(Generic[DataStoreT]):
main
state
persist_events
state_deletion
"""
databases: List[DatabasePool]
main: "DataStore" # FIXME: https://github.com/matrix-org/synapse/issues/11165: actually an instance of `main_store_class`
state: StateGroupDataStore
persist_events: Optional[PersistEventsStore]
state_deletion: StateDeletionDataStore
def __init__(self, main_store_class: Type[DataStoreT], hs: "HomeServer"):
# Note we pass in the main store class here as workers use a different main
@ -63,6 +66,7 @@ class Databases(Generic[DataStoreT]):
self.databases = []
main: Optional[DataStoreT] = None
state: Optional[StateGroupDataStore] = None
state_deletion: Optional[StateDeletionDataStore] = None
persist_events: Optional[PersistEventsStore] = None
for database_config in hs.config.database.databases:
@ -114,7 +118,8 @@ class Databases(Generic[DataStoreT]):
if state:
raise Exception("'state' data store already configured")
state = StateGroupDataStore(database, db_conn, hs)
state_deletion = StateDeletionDataStore(database, db_conn, hs)
state = StateGroupDataStore(database, db_conn, hs, state_deletion)
db_conn.commit()
@ -135,7 +140,7 @@ class Databases(Generic[DataStoreT]):
if not main:
raise Exception("No 'main' database configured")
if not state:
if not state or not state_deletion:
raise Exception("No 'state' database configured")
# We use local variables here to ensure that the databases do not have
@ -143,3 +148,4 @@ class Databases(Generic[DataStoreT]):
self.main = main # type: ignore[assignment]
self.state = state
self.persist_events = persist_events
self.state_deletion = state_deletion

View file

@ -0,0 +1,446 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2025 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
import contextlib
from typing import (
TYPE_CHECKING,
AbstractSet,
AsyncIterator,
Collection,
Mapping,
Set,
Tuple,
)
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
make_in_list_sql_clause,
)
from synapse.storage.engines import PostgresEngine
from synapse.util.stringutils import shortstr
if TYPE_CHECKING:
from synapse.server import HomeServer
class StateDeletionDataStore:
"""Manages deletion of state groups in a safe manner.
Deleting state groups is challenging as before we actually delete them we
need to ensure that there are no in-flight events that refer to the state
groups that we want to delete.
To handle this, we take two approaches. First, before we persist any event
we ensure that the state group still exists and mark in the
`state_groups_persisting` table that the state group is about to be used.
(Note that we have to have the extra table here as state groups and events
can be in different databases, and thus we can't check for the existence of
state groups in the persist event transaction). Once the event has been
persisted, we can remove the row from `state_groups_persisting`. So long as
we check that table before deleting state groups, we can ensure that we
never persist events that reference deleted state groups, maintaining
database integrity.
However, we want to avoid throwing exceptions so deep in the process of
persisting events. So instead of deleting state groups immediately, we mark
them as pending/proposed for deletion and wait for a certain amount of time
before performing the deletion. When we come to handle new events that
reference state groups, we check if they are pending deletion and bump the
time for when they'll be deleted (to give a chance for the event to be
persisted, or not).
When deleting, we need to check that state groups remain unreferenced. There
is a race here where we a) fetch state groups that are ready for deletion,
b) check they're unreferenced, c) the state group becomes referenced but
then gets marked as pending deletion again, d) during the deletion
transaction we recheck `state_groups_pending_deletion` table again and see
that it exists and so continue with the deletion. To prevent this from
happening we add a `sequence_number` column to
`state_groups_pending_deletion`, and during deletion we ensure that for a
state group we're about to delete that the sequence number doesn't change
between steps (a) and (d). So long as we always bump the sequence number
whenever an event may become used the race can never happen.
"""
# How long to wait before we delete state groups. This should be long enough
# for any in-flight events to be persisted. If events take longer to persist
# and any of the state groups they reference have been deleted, then the
# event will fail to persist (as well as any event in the same batch).
DELAY_BEFORE_DELETION_MS = 10 * 60 * 1000
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
self._clock = hs.get_clock()
self.db_pool = database
self._instance_name = hs.get_instance_name()
# TODO: Clear from `state_groups_persisting` any holdovers from previous
# running instance.
async def check_state_groups_and_bump_deletion(
self, state_groups: AbstractSet[int]
) -> Collection[int]:
"""Checks to make sure that the state groups haven't been deleted, and
if they're pending deletion we delay it (allowing time for any event
that will use them to finish persisting).
Returns:
The state groups that are missing, if any.
"""
return await self.db_pool.runInteraction(
"check_state_groups_and_bump_deletion",
self._check_state_groups_and_bump_deletion_txn,
state_groups,
)
def _check_state_groups_and_bump_deletion_txn(
self, txn: LoggingTransaction, state_groups: AbstractSet[int]
) -> Collection[int]:
existing_state_groups = self._get_existing_groups_with_lock(txn, state_groups)
self._bump_deletion_txn(txn, existing_state_groups)
missing_state_groups = state_groups - existing_state_groups
if missing_state_groups:
return missing_state_groups
return ()
def _bump_deletion_txn(
self, txn: LoggingTransaction, state_groups: Collection[int]
) -> None:
"""Update any pending deletions of the state group that they may now be
referenced."""
if not state_groups:
return
now = self._clock.time_msec()
if isinstance(self.db_pool.engine, PostgresEngine):
clause, args = make_in_list_sql_clause(
self.db_pool.engine, "state_group", state_groups
)
sql = f"""
UPDATE state_groups_pending_deletion
SET sequence_number = DEFAULT, insertion_ts = ?
WHERE {clause}
"""
args.insert(0, now)
txn.execute(sql, args)
else:
rows = self.db_pool.simple_select_many_txn(
txn,
table="state_groups_pending_deletion",
column="state_group",
iterable=state_groups,
keyvalues={},
retcols=("state_group",),
)
if not rows:
return
state_groups_to_update = [state_group for (state_group,) in rows]
self.db_pool.simple_delete_many_txn(
txn,
table="state_groups_pending_deletion",
column="state_group",
values=state_groups_to_update,
keyvalues={},
)
self.db_pool.simple_insert_many_txn(
txn,
table="state_groups_pending_deletion",
keys=("state_group", "insertion_ts"),
values=[(state_group, now) for state_group in state_groups_to_update],
)
def _get_existing_groups_with_lock(
self, txn: LoggingTransaction, state_groups: Collection[int]
) -> AbstractSet[int]:
"""Return which of the given state groups are in the database, and locks
those rows with `KEY SHARE` to ensure they don't get concurrently
deleted."""
clause, args = make_in_list_sql_clause(self.db_pool.engine, "id", state_groups)
sql = f"""
SELECT id FROM state_groups
WHERE {clause}
"""
if isinstance(self.db_pool.engine, PostgresEngine):
# On postgres we add a row level lock to the rows to ensure that we
# conflict with any concurrent DELETEs. `FOR KEY SHARE` lock will
# not conflict with other read
sql += """
FOR KEY SHARE
"""
txn.execute(sql, args)
return {state_group for (state_group,) in txn}
@contextlib.asynccontextmanager
async def persisting_state_group_references(
self, event_and_contexts: Collection[Tuple[EventBase, EventContext]]
) -> AsyncIterator[None]:
"""Wraps the persistence of the given events and contexts, ensuring that
any state groups referenced still exist and that they don't get deleted
during this."""
referenced_state_groups: Set[int] = set()
for event, ctx in event_and_contexts:
if ctx.rejected or event.internal_metadata.is_outlier():
continue
assert ctx.state_group is not None
referenced_state_groups.add(ctx.state_group)
if ctx.state_group_before_event:
referenced_state_groups.add(ctx.state_group_before_event)
if not referenced_state_groups:
# We don't reference any state groups, so nothing to do
yield
return
await self.db_pool.runInteraction(
"mark_state_groups_as_persisting",
self._mark_state_groups_as_persisting_txn,
referenced_state_groups,
)
error = True
try:
yield None
error = False
finally:
await self.db_pool.runInteraction(
"finish_persisting",
self._finish_persisting_txn,
referenced_state_groups,
error=error,
)
def _mark_state_groups_as_persisting_txn(
self, txn: LoggingTransaction, state_groups: Set[int]
) -> None:
"""Marks the given state groups as being persisted."""
existing_state_groups = self._get_existing_groups_with_lock(txn, state_groups)
missing_state_groups = state_groups - existing_state_groups
if missing_state_groups:
raise Exception(
f"state groups have been deleted: {shortstr(missing_state_groups)}"
)
self.db_pool.simple_insert_many_txn(
txn,
table="state_groups_persisting",
keys=("state_group", "instance_name"),
values=[(state_group, self._instance_name) for state_group in state_groups],
)
def _finish_persisting_txn(
self, txn: LoggingTransaction, state_groups: Collection[int], error: bool
) -> None:
"""Mark the state groups as having finished persistence.
If `error` is true then we assume the state groups were not persisted,
and so we do not clear them from the pending deletion table.
"""
self.db_pool.simple_delete_many_txn(
txn,
table="state_groups_persisting",
column="state_group",
values=state_groups,
keyvalues={"instance_name": self._instance_name},
)
if error:
# The state groups may or may not have been persisted, so we need to
# bump the deletion to ensure we recheck if they have become
# referenced.
self._bump_deletion_txn(txn, state_groups)
return
self.db_pool.simple_delete_many_batch_txn(
txn,
table="state_groups_pending_deletion",
keys=("state_group",),
values=[(state_group,) for state_group in state_groups],
)
async def mark_state_groups_as_pending_deletion(
self, state_groups: Collection[int]
) -> None:
"""Mark the given state groups as pending deletion"""
now = self._clock.time_msec()
await self.db_pool.simple_upsert_many(
table="state_groups_pending_deletion",
key_names=("state_group",),
key_values=[(state_group,) for state_group in state_groups],
value_names=("insertion_ts",),
value_values=[(now,) for _ in state_groups],
desc="mark_state_groups_as_pending_deletion",
)
async def get_pending_deletions(
self, state_groups: Collection[int]
) -> Mapping[int, int]:
"""Get which state groups are pending deletion.
Returns:
a mapping from state groups that are pending deletion to their
sequence number
"""
rows = await self.db_pool.simple_select_many_batch(
table="state_groups_pending_deletion",
column="state_group",
iterable=state_groups,
retcols=("state_group", "sequence_number"),
keyvalues={},
desc="get_pending_deletions",
)
return dict(rows)
def get_state_groups_ready_for_potential_deletion_txn(
self,
txn: LoggingTransaction,
state_groups_to_sequence_numbers: Mapping[int, int],
) -> Collection[int]:
"""Given a set of state groups, return which state groups can
potentially be deleted.
The state groups must have been checked to see if they remain
unreferenced before calling this function.
Note: This must be called within the same transaction that the state
groups are deleted.
Args:
state_groups_to_sequence_numbers: The state groups, and the sequence
numbers from before the state groups were checked to see if they
were unreferenced.
Returns:
The subset of state groups that can safely be deleted
"""
if not state_groups_to_sequence_numbers:
return state_groups_to_sequence_numbers
if isinstance(self.db_pool.engine, PostgresEngine):
# On postgres we want to lock the rows FOR UPDATE as early as
# possible to help conflicts.
clause, args = make_in_list_sql_clause(
self.db_pool.engine, "id", state_groups_to_sequence_numbers
)
sql = f"""
SELECT id FROM state_groups
WHERE {clause}
FOR UPDATE
"""
txn.execute(sql, args)
# Check the deletion status in the DB of the given state groups
clause, args = make_in_list_sql_clause(
self.db_pool.engine,
column="state_group",
iterable=state_groups_to_sequence_numbers,
)
sql = f"""
SELECT state_group, insertion_ts, sequence_number FROM (
SELECT state_group, insertion_ts, sequence_number FROM state_groups_pending_deletion
UNION
SELECT state_group, null, null FROM state_groups_persisting
) AS s
WHERE {clause}
"""
txn.execute(sql, args)
# The above query will return potentially two rows per state group (one
# for each table), so we track which state groups have enough time
# elapsed and which are not ready to be persisted.
ready_to_be_deleted = set()
not_ready_to_be_deleted = set()
now = self._clock.time_msec()
for state_group, insertion_ts, sequence_number in txn:
if insertion_ts is None:
# A null insertion_ts means that we are currently persisting
# events that reference the state group, so we don't delete
# them.
not_ready_to_be_deleted.add(state_group)
continue
# We know this can't be None if insertion_ts is not None
assert sequence_number is not None
# Check if the sequence number has changed, if it has then it
# indicates that the state group may have become referenced since we
# checked.
if state_groups_to_sequence_numbers[state_group] != sequence_number:
not_ready_to_be_deleted.add(state_group)
continue
if now - insertion_ts < self.DELAY_BEFORE_DELETION_MS:
# Not enough time has elapsed to allow us to delete.
not_ready_to_be_deleted.add(state_group)
continue
ready_to_be_deleted.add(state_group)
can_be_deleted = ready_to_be_deleted - not_ready_to_be_deleted
if not_ready_to_be_deleted:
# If there are any state groups that aren't ready to be deleted,
# then we also need to remove any state groups that are referenced
# by them.
clause, args = make_in_list_sql_clause(
self.db_pool.engine,
column="state_group",
iterable=state_groups_to_sequence_numbers,
)
sql = f"""
WITH RECURSIVE ancestors(state_group) AS (
SELECT DISTINCT prev_state_group
FROM state_group_edges WHERE {clause}
UNION
SELECT prev_state_group
FROM state_group_edges
INNER JOIN ancestors USING (state_group)
)
SELECT state_group FROM ancestors
"""
txn.execute(sql, args)
can_be_deleted.difference_update(state_group for (state_group,) in txn)
return can_be_deleted

View file

@ -36,7 +36,10 @@ import attr
from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.events.snapshot import UnpersistedEventContext, UnpersistedEventContextBase
from synapse.events.snapshot import (
UnpersistedEventContext,
UnpersistedEventContextBase,
)
from synapse.logging.opentracing import tag_args, trace
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
@ -55,6 +58,7 @@ from synapse.util.cancellation import cancellable
if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.storage.databases.state.deletion import StateDeletionDataStore
logger = logging.getLogger(__name__)
@ -83,8 +87,10 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
state_deletion_store: "StateDeletionDataStore",
):
super().__init__(database, db_conn, hs)
self._state_deletion_store = state_deletion_store
# Originally the state store used a single DictionaryCache to cache the
# event IDs for the state types in a given state group to avoid hammering
@ -467,14 +473,15 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
Returns:
A list of state groups
"""
is_in_db = self.db_pool.simple_select_one_onecol_txn(
txn,
table="state_groups",
keyvalues={"id": prev_group},
retcol="id",
allow_none=True,
# We need to check that the prev group isn't about to be deleted
is_missing = (
self._state_deletion_store._check_state_groups_and_bump_deletion_txn(
txn,
{prev_group},
)
)
if not is_in_db:
if is_missing:
raise Exception(
"Trying to persist state with unpersisted prev_group: %r"
% (prev_group,)
@ -546,6 +553,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
for key, state_id in context.state_delta_due_to_event.items()
],
)
return events_and_context
return await self.db_pool.runInteraction(
@ -601,14 +609,15 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
The state group if successfully created, or None if the state
needs to be persisted as a full state.
"""
is_in_db = self.db_pool.simple_select_one_onecol_txn(
txn,
table="state_groups",
keyvalues={"id": prev_group},
retcol="id",
allow_none=True,
# We need to check that the prev group isn't about to be deleted
is_missing = (
self._state_deletion_store._check_state_groups_and_bump_deletion_txn(
txn,
{prev_group},
)
)
if not is_in_db:
if is_missing:
raise Exception(
"Trying to persist state with unpersisted prev_group: %r"
% (prev_group,)

View file

@ -19,7 +19,7 @@
#
#
SCHEMA_VERSION = 88 # remember to update the list below when updating
SCHEMA_VERSION = 89 # remember to update the list below when updating
"""Represents the expectations made by the codebase about the database schema
This should be incremented whenever the codebase changes its requirements on the
@ -155,6 +155,9 @@ Changes in SCHEMA_VERSION = 88
be posted in response to a resettable timeout or an on-demand action.
- Add background update to fix data integrity issue in the
`sliding_sync_membership_snapshots` -> `forgotten` column
Changes in SCHEMA_VERSION = 89
- Add `state_groups_pending_deletion` and `state_groups_persisting` tables.
"""

View file

@ -0,0 +1,39 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2025 New Vector, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
-- See the `StateDeletionDataStore` for details of these tables.
-- We add state groups to this table when we want to later delete them. The
-- `insertion_ts` column indicates when the state group was proposed for
-- deletion (rather than when it should be deleted).
CREATE TABLE IF NOT EXISTS state_groups_pending_deletion (
sequence_number $%AUTO_INCREMENT_PRIMARY_KEY%$,
state_group BIGINT NOT NULL,
insertion_ts BIGINT NOT NULL
);
CREATE UNIQUE INDEX state_groups_pending_deletion_state_group ON state_groups_pending_deletion(state_group);
CREATE INDEX state_groups_pending_deletion_insertion_ts ON state_groups_pending_deletion(insertion_ts);
-- Holds the state groups the worker is currently persisting.
--
-- The `sequence_number` column of the `state_groups_pending_deletion` table
-- *must* be updated whenever a state group may have become referenced.
CREATE TABLE IF NOT EXISTS state_groups_persisting (
state_group BIGINT NOT NULL,
instance_name TEXT NOT NULL,
PRIMARY KEY (state_group, instance_name)
);
CREATE INDEX state_groups_persisting_instance_name ON state_groups_persisting(instance_name);

View file

@ -807,6 +807,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}"
main_store = self.hs.get_datastores().main
state_deletion_store = self.hs.get_datastores().state_deletion
# Create the room.
kermit_user_id = self.register_user("kermit", "test")
@ -958,7 +959,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
bert_member_event.event_id: bert_member_event,
rejected_kick_event.event_id: rejected_kick_event,
},
state_res_store=StateResolutionStore(main_store),
state_res_store=StateResolutionStore(
main_store, state_deletion_store
),
)
),
[bert_member_event.event_id, rejected_kick_event.event_id],
@ -1003,7 +1006,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
rejected_power_levels_event.event_id,
],
event_map={},
state_res_store=StateResolutionStore(main_store),
state_res_store=StateResolutionStore(
main_store, state_deletion_store
),
full_conflicted_set=set(),
)
),

View file

@ -742,7 +742,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None
self.assertEqual(34, channel.resource_usage.db_txn_count)
self.assertEqual(36, channel.resource_usage.db_txn_count)
def test_post_room_initial_state(self) -> None:
# POST with initial_state config key, expect new room id
@ -755,7 +755,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None
self.assertEqual(36, channel.resource_usage.db_txn_count)
self.assertEqual(38, channel.resource_usage.db_txn_count)
def test_post_room_visibility_key(self) -> None:
# POST with visibility config key, expect new room id

View file

@ -0,0 +1,411 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2025 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
import logging
from twisted.test.proto_helpers import MemoryReactor
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.util import Clock
from tests.test_utils.event_injection import create_event
from tests.unittest import HomeserverTestCase
logger = logging.getLogger(__name__)
class StateDeletionStoreTestCase(HomeserverTestCase):
"""Tests for the StateDeletionStore."""
servlets = [
admin.register_servlets,
room.register_servlets,
login.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.state_store = hs.get_datastores().state
self.state_deletion_store = hs.get_datastores().state_deletion
self.user_id = self.register_user("test", "password")
tok = self.login("test", "password")
self.room_id = self.helper.create_room_as(self.user_id, tok=tok)
def check_if_can_be_deleted(self, state_group: int) -> bool:
"""Check if the state group is pending deletion."""
state_group_to_sequence_number = self.get_success(
self.state_deletion_store.get_pending_deletions([state_group])
)
can_be_deleted = self.get_success(
self.state_deletion_store.db_pool.runInteraction(
"test_existing_pending_deletion_is_cleared",
self.state_deletion_store.get_state_groups_ready_for_potential_deletion_txn,
state_group_to_sequence_number,
)
)
return state_group in can_be_deleted
def test_no_deletion(self) -> None:
"""Test that calling persisting_state_group_references is fine if
nothing is pending deletion"""
event, context = self.get_success(
create_event(
self.hs,
room_id=self.room_id,
type="m.test",
sender=self.user_id,
)
)
ctx_mgr = self.state_deletion_store.persisting_state_group_references(
[(event, context)]
)
self.get_success(ctx_mgr.__aenter__())
self.get_success(ctx_mgr.__aexit__(None, None, None))
def test_no_deletion_error(self) -> None:
"""Test that calling persisting_state_group_references is fine if
nothing is pending deletion, but an error occurs."""
event, context = self.get_success(
create_event(
self.hs,
room_id=self.room_id,
type="m.test",
sender=self.user_id,
)
)
ctx_mgr = self.state_deletion_store.persisting_state_group_references(
[(event, context)]
)
self.get_success(ctx_mgr.__aenter__())
self.get_success(ctx_mgr.__aexit__(Exception, Exception("test"), None))
def test_existing_pending_deletion_is_cleared(self) -> None:
"""Test that the pending deletion flag gets cleared when the state group
gets persisted."""
event, context = self.get_success(
create_event(
self.hs,
room_id=self.room_id,
type="m.test",
state_key="",
sender=self.user_id,
)
)
assert context.state_group is not None
# Mark a state group that we're referencing as pending deletion.
self.get_success(
self.state_deletion_store.mark_state_groups_as_pending_deletion(
[context.state_group]
)
)
ctx_mgr = self.state_deletion_store.persisting_state_group_references(
[(event, context)]
)
self.get_success(ctx_mgr.__aenter__())
self.get_success(ctx_mgr.__aexit__(None, None, None))
# The pending deletion flag should be cleared
pending_deletion = self.get_success(
self.state_deletion_store.db_pool.simple_select_one_onecol(
table="state_groups_pending_deletion",
keyvalues={"state_group": context.state_group},
retcol="1",
allow_none=True,
desc="test_existing_pending_deletion_is_cleared",
)
)
self.assertIsNone(pending_deletion)
def test_pending_deletion_is_cleared_during_persist(self) -> None:
"""Test that the pending deletion flag is cleared when a state group
gets marked for deletion during persistence"""
event, context = self.get_success(
create_event(
self.hs,
room_id=self.room_id,
type="m.test",
state_key="",
sender=self.user_id,
)
)
assert context.state_group is not None
ctx_mgr = self.state_deletion_store.persisting_state_group_references(
[(event, context)]
)
self.get_success(ctx_mgr.__aenter__())
# Mark the state group that we're referencing as pending deletion,
# *after* we have started persisting.
self.get_success(
self.state_deletion_store.mark_state_groups_as_pending_deletion(
[context.state_group]
)
)
self.get_success(ctx_mgr.__aexit__(None, None, None))
# The pending deletion flag should be cleared
pending_deletion = self.get_success(
self.state_deletion_store.db_pool.simple_select_one_onecol(
table="state_groups_pending_deletion",
keyvalues={"state_group": context.state_group},
retcol="1",
allow_none=True,
desc="test_existing_pending_deletion_is_cleared",
)
)
self.assertIsNone(pending_deletion)
def test_deletion_check(self) -> None:
"""Test that the `get_state_groups_that_can_be_purged_txn` check is
correct during different points of the lifecycle of persisting an
event."""
event, context = self.get_success(
create_event(
self.hs,
room_id=self.room_id,
type="m.test",
state_key="",
sender=self.user_id,
)
)
assert context.state_group is not None
self.get_success(
self.state_deletion_store.mark_state_groups_as_pending_deletion(
[context.state_group]
)
)
# We shouldn't be able to delete the state group as not enough time as passed
can_be_deleted = self.check_if_can_be_deleted(context.state_group)
self.assertFalse(can_be_deleted)
# After enough time we can delete the state group
self.reactor.advance(
1 + self.state_deletion_store.DELAY_BEFORE_DELETION_MS / 1000
)
can_be_deleted = self.check_if_can_be_deleted(context.state_group)
self.assertTrue(can_be_deleted)
ctx_mgr = self.state_deletion_store.persisting_state_group_references(
[(event, context)]
)
self.get_success(ctx_mgr.__aenter__())
# But once we start persisting we can't delete the state group
can_be_deleted = self.check_if_can_be_deleted(context.state_group)
self.assertFalse(can_be_deleted)
self.get_success(ctx_mgr.__aexit__(None, None, None))
# The pending deletion flag should remain cleared after persistence has
# finished.
can_be_deleted = self.check_if_can_be_deleted(context.state_group)
self.assertFalse(can_be_deleted)
def test_deletion_error_during_persistence(self) -> None:
"""Test that state groups remain marked as pending deletion if persisting
the event fails."""
event, context = self.get_success(
create_event(
self.hs,
room_id=self.room_id,
type="m.test",
state_key="",
sender=self.user_id,
)
)
assert context.state_group is not None
# Mark a state group that we're referencing as pending deletion.
self.get_success(
self.state_deletion_store.mark_state_groups_as_pending_deletion(
[context.state_group]
)
)
ctx_mgr = self.state_deletion_store.persisting_state_group_references(
[(event, context)]
)
self.get_success(ctx_mgr.__aenter__())
self.get_success(ctx_mgr.__aexit__(Exception, Exception("test"), None))
# We should be able to delete the state group after a certain amount of
# time
self.reactor.advance(
1 + self.state_deletion_store.DELAY_BEFORE_DELETION_MS / 1000
)
can_be_deleted = self.check_if_can_be_deleted(context.state_group)
self.assertTrue(can_be_deleted)
def test_race_between_check_and_insert(self) -> None:
"""Check that we correctly handle the race where we go to delete a
state group, check that it is unreferenced, and then it becomes
referenced just before we delete it."""
event, context = self.get_success(
create_event(
self.hs,
room_id=self.room_id,
type="m.test",
state_key="",
sender=self.user_id,
)
)
assert context.state_group is not None
# Mark a state group that we're referencing as pending deletion.
self.get_success(
self.state_deletion_store.mark_state_groups_as_pending_deletion(
[context.state_group]
)
)
# Advance time enough so we can delete the state group
self.reactor.advance(
1 + self.state_deletion_store.DELAY_BEFORE_DELETION_MS / 1000
)
# Check that we'd be able to delete this state group.
state_group_to_sequence_number = self.get_success(
self.state_deletion_store.get_pending_deletions([context.state_group])
)
can_be_deleted = self.get_success(
self.state_deletion_store.db_pool.runInteraction(
"test_existing_pending_deletion_is_cleared",
self.state_deletion_store.get_state_groups_ready_for_potential_deletion_txn,
state_group_to_sequence_number,
)
)
self.assertCountEqual(can_be_deleted, [context.state_group])
# ... in the real world we'd check that the state group isn't referenced here ...
# Now we persist the event to reference the state group, *after* we
# check that the state group wasn't referenced
ctx_mgr = self.state_deletion_store.persisting_state_group_references(
[(event, context)]
)
self.get_success(ctx_mgr.__aenter__())
self.get_success(ctx_mgr.__aexit__(Exception, Exception("test"), None))
# We simulate a pause (required to hit the race)
self.reactor.advance(
1 + self.state_deletion_store.DELAY_BEFORE_DELETION_MS / 1000
)
# We should no longer be able to delete the state group, without having
# to recheck if its referenced.
can_be_deleted = self.get_success(
self.state_deletion_store.db_pool.runInteraction(
"test_existing_pending_deletion_is_cleared",
self.state_deletion_store.get_state_groups_ready_for_potential_deletion_txn,
state_group_to_sequence_number,
)
)
self.assertCountEqual(can_be_deleted, [])
def test_remove_ancestors_from_can_delete(self) -> None:
"""Test that if a state group is not ready to be deleted, we also don't
delete anything that is refernced by it"""
event, context = self.get_success(
create_event(
self.hs,
room_id=self.room_id,
type="m.test",
state_key="",
sender=self.user_id,
)
)
assert context.state_group is not None
# Create a new state group that refernces the one from the event
new_state_group = self.get_success(
self.state_store.store_state_group(
event.event_id,
event.room_id,
prev_group=context.state_group,
delta_ids={},
current_state_ids=None,
)
)
# Mark them both as pending deletion
self.get_success(
self.state_deletion_store.mark_state_groups_as_pending_deletion(
[context.state_group, new_state_group]
)
)
# Advance time enough so we can delete the state group so they're both
# ready for deletion.
self.reactor.advance(
1 + self.state_deletion_store.DELAY_BEFORE_DELETION_MS / 1000
)
# We can now delete both state groups
self.assertTrue(self.check_if_can_be_deleted(context.state_group))
self.assertTrue(self.check_if_can_be_deleted(new_state_group))
# Use the new_state_group to bump its deletion time
self.get_success(
self.state_store.store_state_group(
event.event_id,
event.room_id,
prev_group=new_state_group,
delta_ids={},
current_state_ids=None,
)
)
# We should now not be able to delete either of the state groups.
state_group_to_sequence_number = self.get_success(
self.state_deletion_store.get_pending_deletions(
[context.state_group, new_state_group]
)
)
# We shouldn't be able to delete the state group as not enough time has passed
can_be_deleted = self.get_success(
self.state_deletion_store.db_pool.runInteraction(
"test_existing_pending_deletion_is_cleared",
self.state_deletion_store.get_state_groups_ready_for_potential_deletion_txn,
state_group_to_sequence_number,
)
)
self.assertCountEqual(can_be_deleted, [])

View file

@ -31,7 +31,7 @@ from typing import (
Tuple,
cast,
)
from unittest.mock import Mock
from unittest.mock import AsyncMock, Mock
from twisted.internet import defer
@ -221,7 +221,16 @@ class Graph:
class StateTestCase(unittest.TestCase):
def setUp(self) -> None:
self.dummy_store = _DummyStore()
storage_controllers = Mock(main=self.dummy_store, state=self.dummy_store)
# Add a dummy epoch store that always retruns that we have all the
# necessary state groups.
dummy_deletion_store = AsyncMock()
dummy_deletion_store.check_state_groups_and_bump_deletion.return_value = []
storage_controllers = Mock(
main=self.dummy_store,
state=self.dummy_store,
)
hs = Mock(
spec_set=[
"config",
@ -241,7 +250,10 @@ class StateTestCase(unittest.TestCase):
)
clock = cast(Clock, MockClock())
hs.config = default_config("tesths", True)
hs.get_datastores.return_value = Mock(main=self.dummy_store)
hs.get_datastores.return_value = Mock(
main=self.dummy_store,
state_deletion=dummy_deletion_store,
)
hs.get_state_handler.return_value = None
hs.get_clock.return_value = clock
hs.get_macaroon_generator.return_value = MacaroonGenerator(