diff --git a/changelog.d/9470.bugfix b/changelog.d/9470.bugfix new file mode 100644 index 000000000..c1b7dbb17 --- /dev/null +++ b/changelog.d/9470.bugfix @@ -0,0 +1 @@ +Fix missing startup checks for the consistency of certain PostgreSQL sequences. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 464692644..f1ba529a2 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -49,7 +49,6 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.types import Connection, Cursor -from synapse.storage.util.sequence import build_sequence_generator from synapse.types import Collection # python 3 does not have a maximum int value @@ -381,7 +380,10 @@ class DatabasePool: _TXN_ID = 0 def __init__( - self, hs, database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine + self, + hs, + database_config: DatabaseConnectionConfig, + engine: BaseDatabaseEngine, ): self.hs = hs self._clock = hs.get_clock() @@ -420,16 +422,6 @@ class DatabasePool: self._check_safe_to_upsert, ) - # We define this sequence here so that it can be referenced from both - # the DataStore and PersistEventStore. - def get_chain_id_txn(txn): - txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains") - return txn.fetchone()[0] - - self.event_chain_id_gen = build_sequence_generator( - engine, get_chain_id_txn, "event_auth_chain_id" - ) - def is_running(self) -> bool: """Is the database pool currently running""" return self._db_pool.running diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index a7a11a5bc..cd1ceac50 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -44,6 +44,7 @@ from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.search import SearchEntry from synapse.storage.types import Connection from synapse.storage.util.id_generators import MultiWriterIdGenerator +from synapse.storage.util.sequence import SequenceGenerator from synapse.types import StateMap, get_domain_from_id from synapse.util import json_encoder from synapse.util.iterutils import batch_iter, sorted_topologically @@ -114,12 +115,6 @@ class PersistEventsStore: ) # type: MultiWriterIdGenerator self._stream_id_gen = self.store._stream_id_gen # type: MultiWriterIdGenerator - # The consistency of this cannot be checked when the ID generator is - # created since the database might not yet be up-to-date. - self.db_pool.event_chain_id_gen.check_consistency( - db_conn, "event_auth_chains", "chain_id" # type: ignore - ) - # This should only exist on instances that are configured to write assert ( hs.get_instance_name() in hs.config.worker.writers.events @@ -485,6 +480,7 @@ class PersistEventsStore: self._add_chain_cover_index( txn, self.db_pool, + self.store.event_chain_id_gen, event_to_room_id, event_to_types, event_to_auth_chain, @@ -495,6 +491,7 @@ class PersistEventsStore: cls, txn, db_pool: DatabasePool, + event_chain_id_gen: SequenceGenerator, event_to_room_id: Dict[str, str], event_to_types: Dict[str, Tuple[str, str]], event_to_auth_chain: Dict[str, List[str]], @@ -641,6 +638,7 @@ class PersistEventsStore: new_chain_tuples = cls._allocate_chain_ids( txn, db_pool, + event_chain_id_gen, event_to_room_id, event_to_types, event_to_auth_chain, @@ -779,6 +777,7 @@ class PersistEventsStore: def _allocate_chain_ids( txn, db_pool: DatabasePool, + event_chain_id_gen: SequenceGenerator, event_to_room_id: Dict[str, str], event_to_types: Dict[str, Tuple[str, str]], event_to_auth_chain: Dict[str, List[str]], @@ -891,7 +890,7 @@ class PersistEventsStore: chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1] # Generate new chain IDs for all unallocated chain IDs. - newly_allocated_chain_ids = db_pool.event_chain_id_gen.get_next_mult_txn( + newly_allocated_chain_ids = event_chain_id_gen.get_next_mult_txn( txn, len(unallocated_chain_ids) ) diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 89274e75f..c1626ccf2 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -917,6 +917,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): PersistEventsStore._add_chain_cover_index( txn, self.db_pool, + self.event_chain_id_gen, event_to_room_id, event_to_types, event_to_auth_chain, diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index c8850a470..edbe42f2b 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -45,6 +45,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla from synapse.storage.database import DatabasePool from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator +from synapse.storage.util.sequence import build_sequence_generator from synapse.types import Collection, JsonDict, get_domain_from_id from synapse.util.caches.descriptors import cached from synapse.util.caches.lrucache import LruCache @@ -156,6 +157,21 @@ class EventsWorkerStore(SQLBaseStore): self._event_fetch_list = [] self._event_fetch_ongoing = 0 + # We define this sequence here so that it can be referenced from both + # the DataStore and PersistEventStore. + def get_chain_id_txn(txn): + txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains") + return txn.fetchone()[0] + + self.event_chain_id_gen = build_sequence_generator( + db_conn, + database.engine, + get_chain_id_txn, + "event_auth_chain_id", + table="event_auth_chains", + id_column="chain_id", + ) + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == EventsStream.NAME: self._stream_id_gen.advance(instance_name, token) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index d5b550781..61a7556e5 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -23,7 +23,7 @@ import attr from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.stats import StatsStore from synapse.storage.types import Connection, Cursor @@ -70,7 +70,12 @@ class TokenLookupResult: class RegistrationWorkerStore(CacheInvalidationWorkerStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.config = hs.config @@ -79,9 +84,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): # call `find_max_generated_user_id_localpart` each time, which is # expensive if there are many entries. self._user_id_seq = build_sequence_generator( + db_conn, database.engine, find_max_generated_user_id_localpart, "user_id_seq", + table=None, + id_column=None, ) self._account_validity = hs.config.account_validity @@ -1036,7 +1044,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self._clock = hs.get_clock() diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index b16b9905d..e2240703a 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -97,10 +97,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): return txn.fetchone()[0] self._state_group_seq_gen = build_sequence_generator( - self.database_engine, get_max_state_group_txn, "state_group_id_seq" - ) - self._state_group_seq_gen.check_consistency( - db_conn, table="state_groups", id_column="id" + db_conn, + self.database_engine, + get_max_state_group_txn, + "state_group_id_seq", + table="state_groups", + id_column="id", ) @cached(max_entries=10000, iterable=True) diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py index 3ea637b28..36a67e701 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py @@ -251,9 +251,14 @@ class LocalSequenceGenerator(SequenceGenerator): def build_sequence_generator( + db_conn: "LoggingDatabaseConnection", database_engine: BaseDatabaseEngine, get_first_callback: GetFirstCallbackType, sequence_name: str, + table: Optional[str], + id_column: Optional[str], + stream_name: Optional[str] = None, + positive: bool = True, ) -> SequenceGenerator: """Get the best impl of SequenceGenerator available @@ -265,8 +270,23 @@ def build_sequence_generator( get_first_callback: a callback which gets the next sequence ID. Used if we're on sqlite. sequence_name: the name of a postgres sequence to use. + table, id_column, stream_name, positive: If set then `check_consistency` + is called on the created sequence. See docstring for + `check_consistency` details. """ if isinstance(database_engine, PostgresEngine): - return PostgresSequenceGenerator(sequence_name) + seq = PostgresSequenceGenerator(sequence_name) # type: SequenceGenerator else: - return LocalSequenceGenerator(get_first_callback) + seq = LocalSequenceGenerator(get_first_callback) + + if table: + assert id_column + seq.check_consistency( + db_conn=db_conn, + table=table, + id_column=id_column, + stream_name=stream_name, + positive=positive, + ) + + return seq