Make all `process_replication_rows` methods async (#13304)

More prep work for asyncronous caching, also makes all process_replication_rows methods consistent (presence handler already is so).

Signed off by Nick @ Beeper (@Fizzadar)
1.103.0-whithout-watcha
Nick Mills-Barrett 2 years ago committed by GitHub
parent 96cf81e312
commit 5d4028f217
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      changelog.d/13304.misc
  2. 4
      synapse/handlers/typing.py
  3. 6
      synapse/replication/slave/storage/devices.py
  4. 6
      synapse/replication/slave/storage/push_rule.py
  5. 6
      synapse/replication/slave/storage/pushers.py
  6. 6
      synapse/replication/tcp/client.py
  7. 2
      synapse/storage/_base.py
  8. 4
      synapse/storage/databases/main/account_data.py
  9. 4
      synapse/storage/databases/main/cache.py
  10. 6
      synapse/storage/databases/main/deviceinbox.py
  11. 4
      synapse/storage/databases/main/events_worker.py
  12. 6
      synapse/storage/databases/main/presence.py
  13. 6
      synapse/storage/databases/main/receipts.py
  14. 4
      synapse/storage/databases/main/tags.py

@ -0,0 +1 @@
Make all replication row processing methods asynchronous. Contributed by Nick @ Beeper (@fizzadar).

@ -158,7 +158,7 @@ class FollowerTypingHandler:
except Exception: except Exception:
logger.exception("Error pushing typing notif to remotes") logger.exception("Error pushing typing notif to remotes")
def process_replication_rows( async def process_replication_rows(
self, token: int, rows: List[TypingStream.TypingStreamRow] self, token: int, rows: List[TypingStream.TypingStreamRow]
) -> None: ) -> None:
"""Should be called whenever we receive updates for typing stream.""" """Should be called whenever we receive updates for typing stream."""
@ -444,7 +444,7 @@ class TypingWriterHandler(FollowerTypingHandler):
return rows, current_id, limited return rows, current_id, limited
def process_replication_rows( async def process_replication_rows(
self, token: int, rows: List[TypingStream.TypingStreamRow] self, token: int, rows: List[TypingStream.TypingStreamRow]
) -> None: ) -> None:
# The writing process should never get updates from replication. # The writing process should never get updates from replication.

@ -49,7 +49,7 @@ class SlavedDeviceStore(DeviceWorkerStore, BaseSlavedStore):
def get_device_stream_token(self) -> int: def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token() return self._device_list_id_gen.get_current_token()
def process_replication_rows( async def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None: ) -> None:
if stream_name == DeviceListsStream.NAME: if stream_name == DeviceListsStream.NAME:
@ -59,7 +59,9 @@ class SlavedDeviceStore(DeviceWorkerStore, BaseSlavedStore):
self._device_list_id_gen.advance(instance_name, token) self._device_list_id_gen.advance(instance_name, token)
for row in rows: for row in rows:
self._user_signature_stream_cache.entity_has_changed(row.user_id, token) self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows) return await super().process_replication_rows(
stream_name, instance_name, token, rows
)
def _invalidate_caches_for_devices( def _invalidate_caches_for_devices(
self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow] self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow]

@ -24,7 +24,7 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
def get_max_push_rules_stream_id(self) -> int: def get_max_push_rules_stream_id(self) -> int:
return self._push_rules_stream_id_gen.get_current_token() return self._push_rules_stream_id_gen.get_current_token()
def process_replication_rows( async def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None: ) -> None:
if stream_name == PushRulesStream.NAME: if stream_name == PushRulesStream.NAME:
@ -33,4 +33,6 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
self.get_push_rules_for_user.invalidate((row.user_id,)) self.get_push_rules_for_user.invalidate((row.user_id,))
self.get_push_rules_enabled_for_user.invalidate((row.user_id,)) self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
self.push_rules_stream_cache.entity_has_changed(row.user_id, token) self.push_rules_stream_cache.entity_has_changed(row.user_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows) return await super().process_replication_rows(
stream_name, instance_name, token, rows
)

@ -40,9 +40,11 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
def get_pushers_stream_token(self) -> int: def get_pushers_stream_token(self) -> int:
return self._pushers_id_gen.get_current_token() return self._pushers_id_gen.get_current_token()
def process_replication_rows( async def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None: ) -> None:
if stream_name == PushersStream.NAME: if stream_name == PushersStream.NAME:
self._pushers_id_gen.advance(instance_name, token) self._pushers_id_gen.advance(instance_name, token)
return super().process_replication_rows(stream_name, instance_name, token, rows) return await super().process_replication_rows(
stream_name, instance_name, token, rows
)

@ -144,13 +144,15 @@ class ReplicationDataHandler:
token: stream token for this batch of rows token: stream token for this batch of rows
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
""" """
self.store.process_replication_rows(stream_name, instance_name, token, rows) await self.store.process_replication_rows(
stream_name, instance_name, token, rows
)
if self.send_handler: if self.send_handler:
await self.send_handler.process_replication_rows(stream_name, token, rows) await self.send_handler.process_replication_rows(stream_name, token, rows)
if stream_name == TypingStream.NAME: if stream_name == TypingStream.NAME:
self._typing_handler.process_replication_rows(token, rows) await self._typing_handler.process_replication_rows(token, rows)
self.notifier.on_new_event( self.notifier.on_new_event(
StreamKeyType.TYPING, token, rooms=[row.room_id for row in rows] StreamKeyType.TYPING, token, rooms=[row.room_id for row in rows]
) )

@ -47,7 +47,7 @@ class SQLBaseStore(metaclass=ABCMeta):
self.database_engine = database.engine self.database_engine = database.engine
self.db_pool = database self.db_pool = database
def process_replication_rows( async def process_replication_rows(
self, self,
stream_name: str, stream_name: str,
instance_name: str, instance_name: str,

@ -414,7 +414,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
) )
) )
def process_replication_rows( async def process_replication_rows(
self, self,
stream_name: str, stream_name: str,
instance_name: str, instance_name: str,
@ -437,7 +437,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
) )
self._account_data_stream_cache.entity_has_changed(row.user_id, token) self._account_data_stream_cache.entity_has_changed(row.user_id, token)
super().process_replication_rows(stream_name, instance_name, token, rows) await super().process_replication_rows(stream_name, instance_name, token, rows)
async def add_account_data_to_room( async def add_account_data_to_room(
self, user_id: str, room_id: str, account_data_type: str, content: JsonDict self, user_id: str, room_id: str, account_data_type: str, content: JsonDict

@ -119,7 +119,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
"get_all_updated_caches", get_all_updated_caches_txn "get_all_updated_caches", get_all_updated_caches_txn
) )
def process_replication_rows( async def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None: ) -> None:
if stream_name == EventsStream.NAME: if stream_name == EventsStream.NAME:
@ -154,7 +154,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
else: else:
self._attempt_to_invalidate_cache(row.cache_func, row.keys) self._attempt_to_invalidate_cache(row.cache_func, row.keys)
super().process_replication_rows(stream_name, instance_name, token, rows) await super().process_replication_rows(stream_name, instance_name, token, rows)
def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None: def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None:
data = row.data data = row.data

@ -128,7 +128,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
prefilled_cache=device_outbox_prefill, prefilled_cache=device_outbox_prefill,
) )
def process_replication_rows( async def process_replication_rows(
self, self,
stream_name: str, stream_name: str,
instance_name: str, instance_name: str,
@ -148,7 +148,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
self._device_federation_outbox_stream_cache.entity_has_changed( self._device_federation_outbox_stream_cache.entity_has_changed(
row.entity, token row.entity, token
) )
return super().process_replication_rows(stream_name, instance_name, token, rows) return await super().process_replication_rows(
stream_name, instance_name, token, rows
)
def get_to_device_stream_token(self) -> int: def get_to_device_stream_token(self) -> int:
return self._device_inbox_id_gen.get_current_token() return self._device_inbox_id_gen.get_current_token()

@ -280,7 +280,7 @@ class EventsWorkerStore(SQLBaseStore):
id_column="chain_id", id_column="chain_id",
) )
def process_replication_rows( async def process_replication_rows(
self, self,
stream_name: str, stream_name: str,
instance_name: str, instance_name: str,
@ -292,7 +292,7 @@ class EventsWorkerStore(SQLBaseStore):
elif stream_name == BackfillStream.NAME: elif stream_name == BackfillStream.NAME:
self._backfill_id_gen.advance(instance_name, -token) self._backfill_id_gen.advance(instance_name, -token)
super().process_replication_rows(stream_name, instance_name, token, rows) await super().process_replication_rows(stream_name, instance_name, token, rows)
async def have_censored_event(self, event_id: str) -> bool: async def have_censored_event(self, event_id: str) -> bool:
"""Check if an event has been censored, i.e. if the content of the event has been erased """Check if an event has been censored, i.e. if the content of the event has been erased

@ -431,7 +431,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
self._presence_on_startup = [] self._presence_on_startup = []
return active_on_startup return active_on_startup
def process_replication_rows( async def process_replication_rows(
self, self,
stream_name: str, stream_name: str,
instance_name: str, instance_name: str,
@ -443,4 +443,6 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
for row in rows: for row in rows:
self.presence_stream_cache.entity_has_changed(row.user_id, token) self.presence_stream_cache.entity_has_changed(row.user_id, token)
self._get_presence_for_user.invalidate((row.user_id,)) self._get_presence_for_user.invalidate((row.user_id,))
return super().process_replication_rows(stream_name, instance_name, token, rows) return await super().process_replication_rows(
stream_name, instance_name, token, rows
)

@ -589,7 +589,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"get_unread_event_push_actions_by_room_for_user", (room_id,) "get_unread_event_push_actions_by_room_for_user", (room_id,)
) )
def process_replication_rows( async def process_replication_rows(
self, self,
stream_name: str, stream_name: str,
instance_name: str, instance_name: str,
@ -604,7 +604,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
) )
self._receipts_stream_cache.entity_has_changed(row.room_id, token) self._receipts_stream_cache.entity_has_changed(row.room_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows) return await super().process_replication_rows(
stream_name, instance_name, token, rows
)
def _insert_linearized_receipt_txn( def _insert_linearized_receipt_txn(
self, self,

@ -292,7 +292,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
# than the id that the client has. # than the id that the client has.
pass pass
def process_replication_rows( async def process_replication_rows(
self, self,
stream_name: str, stream_name: str,
instance_name: str, instance_name: str,
@ -305,7 +305,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
self.get_tags_for_user.invalidate((row.user_id,)) self.get_tags_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed(row.user_id, token) self._account_data_stream_cache.entity_has_changed(row.user_id, token)
super().process_replication_rows(stream_name, instance_name, token, rows) await super().process_replication_rows(stream_name, instance_name, token, rows)
class TagsStore(TagsWorkerStore): class TagsStore(TagsWorkerStore):

Loading…
Cancel
Save