Track notification counts per thread (implement MSC3773). (#13776)

When retrieving counts of notifications segment the results based on the
thread ID, but choose whether to return them as individual threads or as
a single summed field by letting the client opt-in via a sync flag.

The summarization code is also updated to be per thread, instead of per
room.
1.103.0-whithout-watcha
Patrick Cloke 2 years ago committed by GitHub
parent 94017e867d
commit b4ec4f5e71
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      changelog.d/13776.feature
  2. 3
      synapse/api/constants.py
  3. 10
      synapse/api/filtering.py
  4. 2
      synapse/config/experimental.py
  5. 40
      synapse/handlers/sync.py
  6. 4
      synapse/push/bulk_push_rule_evaluator.py
  7. 9
      synapse/push/push_tools.py
  8. 4
      synapse/rest/client/sync.py
  9. 3
      synapse/rest/client/versions.py
  10. 2
      synapse/storage/database.py
  11. 180
      synapse/storage/databases/main/event_push_actions.py
  12. 6
      synapse/storage/schema/__init__.py
  13. 29
      synapse/storage/schema/main/delta/73/06thread_notifications_backfill.sql
  14. 19
      synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.postgres
  15. 101
      synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.sqlite
  16. 17
      tests/replication/slave/storage/test_events.py
  17. 169
      tests/storage/test_event_push_actions.py

@ -0,0 +1 @@
Experimental support for thread-specific notifications ([MSC3773](https://github.com/matrix-org/matrix-spec-proposals/pull/3773)).

@ -31,6 +31,9 @@ MAX_ALIAS_LENGTH = 255
# the maximum length for a user id is 255 characters # the maximum length for a user id is 255 characters
MAX_USERID_LENGTH = 255 MAX_USERID_LENGTH = 255
# Constant value used for the pseudo-thread which is the main timeline.
MAIN_TIMELINE: Final = "main"
class Membership: class Membership:

@ -84,6 +84,7 @@ ROOM_EVENT_FILTER_SCHEMA = {
"contains_url": {"type": "boolean"}, "contains_url": {"type": "boolean"},
"lazy_load_members": {"type": "boolean"}, "lazy_load_members": {"type": "boolean"},
"include_redundant_members": {"type": "boolean"}, "include_redundant_members": {"type": "boolean"},
"org.matrix.msc3773.unread_thread_notifications": {"type": "boolean"},
# Include or exclude events with the provided labels. # Include or exclude events with the provided labels.
# cf https://github.com/matrix-org/matrix-doc/pull/2326 # cf https://github.com/matrix-org/matrix-doc/pull/2326
"org.matrix.labels": {"type": "array", "items": {"type": "string"}}, "org.matrix.labels": {"type": "array", "items": {"type": "string"}},
@ -240,6 +241,9 @@ class FilterCollection:
def include_redundant_members(self) -> bool: def include_redundant_members(self) -> bool:
return self._room_state_filter.include_redundant_members return self._room_state_filter.include_redundant_members
def unread_thread_notifications(self) -> bool:
return self._room_timeline_filter.unread_thread_notifications
async def filter_presence( async def filter_presence(
self, events: Iterable[UserPresenceState] self, events: Iterable[UserPresenceState]
) -> List[UserPresenceState]: ) -> List[UserPresenceState]:
@ -304,6 +308,12 @@ class Filter:
self.include_redundant_members = filter_json.get( self.include_redundant_members = filter_json.get(
"include_redundant_members", False "include_redundant_members", False
) )
if hs.config.experimental.msc3773_enabled:
self.unread_thread_notifications: bool = filter_json.get(
"org.matrix.msc3773.unread_thread_notifications", False
)
else:
self.unread_thread_notifications = False
self.types = filter_json.get("types", None) self.types = filter_json.get("types", None)
self.not_types = filter_json.get("not_types", []) self.not_types = filter_json.get("not_types", [])

@ -99,6 +99,8 @@ class ExperimentalConfig(Config):
self.msc3771_enabled: bool = experimental.get("msc3771_enabled", False) self.msc3771_enabled: bool = experimental.get("msc3771_enabled", False)
# MSC3772: A push rule for mutual relations. # MSC3772: A push rule for mutual relations.
self.msc3772_enabled: bool = experimental.get("msc3772_enabled", False) self.msc3772_enabled: bool = experimental.get("msc3772_enabled", False)
# MSC3773: Thread notifications
self.msc3773_enabled: bool = experimental.get("msc3773_enabled", False)
# MSC3715: dir param on /relations. # MSC3715: dir param on /relations.
self.msc3715_enabled: bool = experimental.get("msc3715_enabled", False) self.msc3715_enabled: bool = experimental.get("msc3715_enabled", False)

@ -40,7 +40,7 @@ from synapse.handlers.relations import BundledAggregations
from synapse.logging.context import current_context from synapse.logging.context import current_context
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
from synapse.push.clientformat import format_push_rules_for_user from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.databases.main.event_push_actions import NotifCounts from synapse.storage.databases.main.event_push_actions import RoomNotifCounts
from synapse.storage.roommember import MemberSummary from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import ( from synapse.types import (
@ -128,6 +128,7 @@ class JoinedSyncResult:
ephemeral: List[JsonDict] ephemeral: List[JsonDict]
account_data: List[JsonDict] account_data: List[JsonDict]
unread_notifications: JsonDict unread_notifications: JsonDict
unread_thread_notifications: JsonDict
summary: Optional[JsonDict] summary: Optional[JsonDict]
unread_count: int unread_count: int
@ -278,6 +279,8 @@ class SyncHandler:
self.rooms_to_exclude = hs.config.server.rooms_to_exclude_from_sync self.rooms_to_exclude = hs.config.server.rooms_to_exclude_from_sync
self._msc3773_enabled = hs.config.experimental.msc3773_enabled
async def wait_for_sync_for_user( async def wait_for_sync_for_user(
self, self,
requester: Requester, requester: Requester,
@ -1288,7 +1291,7 @@ class SyncHandler:
async def unread_notifs_for_room_id( async def unread_notifs_for_room_id(
self, room_id: str, sync_config: SyncConfig self, room_id: str, sync_config: SyncConfig
) -> NotifCounts: ) -> RoomNotifCounts:
with Measure(self.clock, "unread_notifs_for_room_id"): with Measure(self.clock, "unread_notifs_for_room_id"):
return await self.store.get_unread_event_push_actions_by_room_for_user( return await self.store.get_unread_event_push_actions_by_room_for_user(
@ -2353,6 +2356,7 @@ class SyncHandler:
ephemeral=ephemeral, ephemeral=ephemeral,
account_data=account_data_events, account_data=account_data_events,
unread_notifications=unread_notifications, unread_notifications=unread_notifications,
unread_thread_notifications={},
summary=summary, summary=summary,
unread_count=0, unread_count=0,
) )
@ -2360,10 +2364,36 @@ class SyncHandler:
if room_sync or always_include: if room_sync or always_include:
notifs = await self.unread_notifs_for_room_id(room_id, sync_config) notifs = await self.unread_notifs_for_room_id(room_id, sync_config)
unread_notifications["notification_count"] = notifs.notify_count # Notifications for the main timeline.
unread_notifications["highlight_count"] = notifs.highlight_count notify_count = notifs.main_timeline.notify_count
highlight_count = notifs.main_timeline.highlight_count
unread_count = notifs.main_timeline.unread_count
room_sync.unread_count = notifs.unread_count # Check the sync configuration.
if (
self._msc3773_enabled
and sync_config.filter_collection.unread_thread_notifications()
):
# And add info for each thread.
room_sync.unread_thread_notifications = {
thread_id: {
"notification_count": thread_notifs.notify_count,
"highlight_count": thread_notifs.highlight_count,
}
for thread_id, thread_notifs in notifs.threads.items()
if thread_id is not None
}
else:
# Combine the unread counts for all threads and main timeline.
for thread_notifs in notifs.threads.values():
notify_count += thread_notifs.notify_count
highlight_count += thread_notifs.highlight_count
unread_count += thread_notifs.unread_count
unread_notifications["notification_count"] = notify_count
unread_notifications["highlight_count"] = highlight_count
room_sync.unread_count = unread_count
sync_result_builder.joined.append(room_sync) sync_result_builder.joined.append(room_sync)

@ -31,7 +31,7 @@ from typing import (
from prometheus_client import Counter from prometheus_client import Counter
from synapse.api.constants import EventTypes, Membership, RelationTypes from synapse.api.constants import MAIN_TIMELINE, EventTypes, Membership, RelationTypes
from synapse.event_auth import auth_types_for_event, get_user_power_level from synapse.event_auth import auth_types_for_event, get_user_power_level
from synapse.events import EventBase, relation_from_event from synapse.events import EventBase, relation_from_event
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
@ -280,7 +280,7 @@ class BulkPushRuleEvaluator:
# If the event does not have a relation, then cannot have any mutual # If the event does not have a relation, then cannot have any mutual
# relations or thread ID. # relations or thread ID.
relations = {} relations = {}
thread_id = "main" thread_id = MAIN_TIMELINE
if relation: if relation:
relations = await self._get_mutual_relations( relations = await self._get_mutual_relations(
relation.parent_id, relation.parent_id,

@ -39,7 +39,12 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
await concurrently_execute(get_room_unread_count, joins, 10) await concurrently_execute(get_room_unread_count, joins, 10)
for notifs in room_notifs: for notifs in room_notifs:
if notifs.notify_count == 0: # Combine the counts from all the threads.
notify_count = notifs.main_timeline.notify_count + sum(
n.notify_count for n in notifs.threads.values()
)
if notify_count == 0:
continue continue
if group_by_room: if group_by_room:
@ -47,7 +52,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
badge += 1 badge += 1
else: else:
# increment the badge count by the number of unread messages in the room # increment the badge count by the number of unread messages in the room
badge += notifs.notify_count badge += notify_count
return badge return badge

@ -509,6 +509,10 @@ class SyncRestServlet(RestServlet):
ephemeral_events = room.ephemeral ephemeral_events = room.ephemeral
result["ephemeral"] = {"events": ephemeral_events} result["ephemeral"] = {"events": ephemeral_events}
result["unread_notifications"] = room.unread_notifications result["unread_notifications"] = room.unread_notifications
if room.unread_thread_notifications:
result[
"org.matrix.msc3773.unread_thread_notifications"
] = room.unread_thread_notifications
result["summary"] = room.summary result["summary"] = room.summary
if self._msc2654_enabled: if self._msc2654_enabled:
result["org.matrix.msc2654.unread_count"] = room.unread_count result["org.matrix.msc2654.unread_count"] = room.unread_count

@ -103,8 +103,9 @@ class VersionsRestServlet(RestServlet):
"org.matrix.msc3030": self.config.experimental.msc3030_enabled, "org.matrix.msc3030": self.config.experimental.msc3030_enabled,
# Adds support for thread relations, per MSC3440. # Adds support for thread relations, per MSC3440.
"org.matrix.msc3440.stable": True, # TODO: remove when "v1.3" is added above "org.matrix.msc3440.stable": True, # TODO: remove when "v1.3" is added above
# Support for thread read receipts. # Support for thread read receipts & notification counts.
"org.matrix.msc3771": self.config.experimental.msc3771_enabled, "org.matrix.msc3771": self.config.experimental.msc3771_enabled,
"org.matrix.msc3773": self.config.experimental.msc3773_enabled,
# Allows moderators to fetch redacted event content as described in MSC2815 # Allows moderators to fetch redacted event content as described in MSC2815
"fi.mau.msc2815": self.config.experimental.msc2815_enabled, "fi.mau.msc2815": self.config.experimental.msc2815_enabled,
# Adds support for login token requests as per MSC3882 # Adds support for login token requests as per MSC3882

@ -94,7 +94,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
"event_search": "event_search_event_id_idx", "event_search": "event_search_event_id_idx",
"local_media_repository_thumbnails": "local_media_repository_thumbnails_method_idx", "local_media_repository_thumbnails": "local_media_repository_thumbnails_method_idx",
"remote_media_cache_thumbnails": "remote_media_repository_thumbnails_method_idx", "remote_media_cache_thumbnails": "remote_media_repository_thumbnails_method_idx",
"event_push_summary": "event_push_summary_unique_index", "event_push_summary": "event_push_summary_unique_index2",
"receipts_linearized": "receipts_linearized_unique_index", "receipts_linearized": "receipts_linearized_unique_index",
"receipts_graph": "receipts_graph_unique_index", "receipts_graph": "receipts_graph_unique_index",
} }

@ -88,7 +88,7 @@ from typing import (
import attr import attr
from synapse.api.constants import ReceiptTypes from synapse.api.constants import MAIN_TIMELINE, ReceiptTypes
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import ( from synapse.storage.database import (
@ -157,7 +157,7 @@ class UserPushAction(EmailPushAction):
@attr.s(slots=True, auto_attribs=True) @attr.s(slots=True, auto_attribs=True)
class NotifCounts: class NotifCounts:
""" """
The per-user, per-room count of notifications. Used by sync and push. The per-user, per-room, per-thread count of notifications. Used by sync and push.
""" """
notify_count: int = 0 notify_count: int = 0
@ -165,6 +165,21 @@ class NotifCounts:
highlight_count: int = 0 highlight_count: int = 0
@attr.s(slots=True, auto_attribs=True)
class RoomNotifCounts:
"""
The per-user, per-room count of notifications. Used by sync and push.
"""
main_timeline: NotifCounts
# Map of thread ID to the notification counts.
threads: Dict[str, NotifCounts]
def __len__(self) -> int:
# To properly account for the amount of space in any caches.
return len(self.threads) + 1
def _serialize_action( def _serialize_action(
actions: Collection[Union[Mapping, str]], is_highlight: bool actions: Collection[Union[Mapping, str]], is_highlight: bool
) -> str: ) -> str:
@ -338,12 +353,12 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
return result return result
@cached(tree=True, max_entries=5000) @cached(tree=True, max_entries=5000, iterable=True)
async def get_unread_event_push_actions_by_room_for_user( async def get_unread_event_push_actions_by_room_for_user(
self, self,
room_id: str, room_id: str,
user_id: str, user_id: str,
) -> NotifCounts: ) -> RoomNotifCounts:
"""Get the notification count, the highlight count and the unread message count """Get the notification count, the highlight count and the unread message count
for a given user in a given room after their latest read receipt. for a given user in a given room after their latest read receipt.
@ -356,8 +371,9 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
user_id: The user to retrieve the counts for. user_id: The user to retrieve the counts for.
Returns Returns
A NotifCounts object containing the notification count, the highlight count A RoomNotifCounts object containing the notification count, the
and the unread message count. highlight count and the unread message count for both the main timeline
and threads.
""" """
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_unread_event_push_actions_by_room", "get_unread_event_push_actions_by_room",
@ -371,7 +387,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
txn: LoggingTransaction, txn: LoggingTransaction,
room_id: str, room_id: str,
user_id: str, user_id: str,
) -> NotifCounts: ) -> RoomNotifCounts:
# Get the stream ordering of the user's latest receipt in the room. # Get the stream ordering of the user's latest receipt in the room.
result = self.get_last_unthreaded_receipt_for_user_txn( result = self.get_last_unthreaded_receipt_for_user_txn(
txn, txn,
@ -406,7 +422,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
room_id: str, room_id: str,
user_id: str, user_id: str,
receipt_stream_ordering: int, receipt_stream_ordering: int,
) -> NotifCounts: ) -> RoomNotifCounts:
"""Get the number of unread messages for a user/room that have happened """Get the number of unread messages for a user/room that have happened
since the given stream ordering. since the given stream ordering.
@ -418,12 +434,19 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
receipt in the room. If there are no receipts, the stream ordering receipt in the room. If there are no receipts, the stream ordering
of the user's join event. of the user's join event.
Returns Returns:
A NotifCounts object containing the notification count, the highlight count A RoomNotifCounts object containing the notification count, the
and the unread message count. highlight count and the unread message count for both the main timeline
and threads.
""" """
counts = NotifCounts() main_counts = NotifCounts()
thread_counts: Dict[str, NotifCounts] = {}
def _get_thread(thread_id: str) -> NotifCounts:
if thread_id == MAIN_TIMELINE:
return main_counts
return thread_counts.setdefault(thread_id, NotifCounts())
# First we pull the counts from the summary table. # First we pull the counts from the summary table.
# #
@ -440,52 +463,61 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# receipt). # receipt).
txn.execute( txn.execute(
""" """
SELECT stream_ordering, notif_count, COALESCE(unread_count, 0) SELECT stream_ordering, notif_count, COALESCE(unread_count, 0), thread_id
FROM event_push_summary FROM event_push_summary
WHERE room_id = ? AND user_id = ? WHERE room_id = ? AND user_id = ?
AND ( AND (
(last_receipt_stream_ordering IS NULL AND stream_ordering > ?) (last_receipt_stream_ordering IS NULL AND stream_ordering > ?)
OR last_receipt_stream_ordering = ? OR last_receipt_stream_ordering = ?
) ) AND (notif_count != 0 OR COALESCE(unread_count, 0) != 0)
""", """,
(room_id, user_id, receipt_stream_ordering, receipt_stream_ordering), (room_id, user_id, receipt_stream_ordering, receipt_stream_ordering),
) )
row = txn.fetchone() max_summary_stream_ordering = 0
for summary_stream_ordering, notif_count, unread_count, thread_id in txn:
counts = _get_thread(thread_id)
counts.notify_count += notif_count
counts.unread_count += unread_count
summary_stream_ordering = 0 # Summaries will only be used if they have not been invalidated by
if row: # a recent receipt; track the latest stream ordering or a valid summary.
summary_stream_ordering = row[0] #
counts.notify_count += row[1] # Note that since there's only one read receipt in the room per user,
counts.unread_count += row[2] # valid summaries are contiguous.
max_summary_stream_ordering = max(
summary_stream_ordering, max_summary_stream_ordering
)
# Next we need to count highlights, which aren't summarised # Next we need to count highlights, which aren't summarised
sql = """ sql = """
SELECT COUNT(*) FROM event_push_actions SELECT COUNT(*), thread_id FROM event_push_actions
WHERE user_id = ? WHERE user_id = ?
AND room_id = ? AND room_id = ?
AND stream_ordering > ? AND stream_ordering > ?
AND highlight = 1 AND highlight = 1
GROUP BY thread_id
""" """
txn.execute(sql, (user_id, room_id, receipt_stream_ordering)) txn.execute(sql, (user_id, room_id, receipt_stream_ordering))
row = txn.fetchone() for highlight_count, thread_id in txn:
if row: _get_thread(thread_id).highlight_count += highlight_count
counts.highlight_count += row[0]
# Finally we need to count push actions that aren't included in the # Finally we need to count push actions that aren't included in the
# summary returned above. This might be due to recent events that haven't # summary returned above. This might be due to recent events that haven't
# been summarised yet or the summary is out of date due to a recent read # been summarised yet or the summary is out of date due to a recent read
# receipt. # receipt.
start_unread_stream_ordering = max( start_unread_stream_ordering = max(
receipt_stream_ordering, summary_stream_ordering receipt_stream_ordering, max_summary_stream_ordering
) )
notify_count, unread_count = self._get_notif_unread_count_for_user_room( unread_counts = self._get_notif_unread_count_for_user_room(
txn, room_id, user_id, start_unread_stream_ordering txn, room_id, user_id, start_unread_stream_ordering
) )
counts.notify_count += notify_count for notif_count, unread_count, thread_id in unread_counts:
counts = _get_thread(thread_id)
counts.notify_count += notif_count
counts.unread_count += unread_count counts.unread_count += unread_count
return counts return RoomNotifCounts(main_counts, thread_counts)
def _get_notif_unread_count_for_user_room( def _get_notif_unread_count_for_user_room(
self, self,
@ -494,7 +526,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
user_id: str, user_id: str,
stream_ordering: int, stream_ordering: int,
max_stream_ordering: Optional[int] = None, max_stream_ordering: Optional[int] = None,
) -> Tuple[int, int]: ) -> List[Tuple[int, int, str]]:
"""Returns the notify and unread counts from `event_push_actions` for """Returns the notify and unread counts from `event_push_actions` for
the given user/room in the given range. the given user/room in the given range.
@ -510,13 +542,14 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
If this is not given, then no maximum is applied. If this is not given, then no maximum is applied.
Return: Return:
A tuple of the notif count and unread count in the given range. A tuple of the notif count and unread count in the given range for
each thread.
""" """
# If there have been no events in the room since the stream ordering, # If there have been no events in the room since the stream ordering,
# there can't be any push actions either. # there can't be any push actions either.
if not self._events_stream_cache.has_entity_changed(room_id, stream_ordering): if not self._events_stream_cache.has_entity_changed(room_id, stream_ordering):
return 0, 0 return []
clause = "" clause = ""
args = [user_id, room_id, stream_ordering] args = [user_id, room_id, stream_ordering]
@ -527,26 +560,23 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# If the max stream ordering is less than the min stream ordering, # If the max stream ordering is less than the min stream ordering,
# then obviously there are zero push actions in that range. # then obviously there are zero push actions in that range.
if max_stream_ordering <= stream_ordering: if max_stream_ordering <= stream_ordering:
return 0, 0 return []
sql = f""" sql = f"""
SELECT SELECT
COUNT(CASE WHEN notif = 1 THEN 1 END), COUNT(CASE WHEN notif = 1 THEN 1 END),
COUNT(CASE WHEN unread = 1 THEN 1 END) COUNT(CASE WHEN unread = 1 THEN 1 END),
thread_id
FROM event_push_actions ea FROM event_push_actions ea
WHERE user_id = ? WHERE user_id = ?
AND room_id = ? AND room_id = ?
AND ea.stream_ordering > ? AND ea.stream_ordering > ?
{clause} {clause}
GROUP BY thread_id
""" """
txn.execute(sql, args) txn.execute(sql, args)
row = txn.fetchone() return cast(List[Tuple[int, int, str]], txn.fetchall())
if row:
return cast(Tuple[int, int], row)
return 0, 0
async def get_push_action_users_in_range( async def get_push_action_users_in_range(
self, min_stream_ordering: int, max_stream_ordering: int self, min_stream_ordering: int, max_stream_ordering: int
@ -1099,26 +1129,34 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# Fetch the notification counts between the stream ordering of the # Fetch the notification counts between the stream ordering of the
# latest receipt and what was previously summarised. # latest receipt and what was previously summarised.
notif_count, unread_count = self._get_notif_unread_count_for_user_room( unread_counts = self._get_notif_unread_count_for_user_room(
txn, room_id, user_id, stream_ordering, old_rotate_stream_ordering txn, room_id, user_id, stream_ordering, old_rotate_stream_ordering
) )
# Replace the previous summary with the new counts. # First mark the summary for all threads in the room as cleared.
# self.db_pool.simple_update_txn(
# TODO(threads): Upsert per-thread instead of setting them all to main.
self.db_pool.simple_upsert_txn(
txn, txn,
table="event_push_summary", table="event_push_summary",
keyvalues={"room_id": room_id, "user_id": user_id}, keyvalues={"user_id": user_id, "room_id": room_id},
values={ updatevalues={
"notif_count": notif_count, "notif_count": 0,
"unread_count": unread_count, "unread_count": 0,
"stream_ordering": old_rotate_stream_ordering, "stream_ordering": old_rotate_stream_ordering,
"last_receipt_stream_ordering": stream_ordering, "last_receipt_stream_ordering": stream_ordering,
"thread_id": "main",
}, },
) )
# Then any updated threads get their notification count and unread
# count updated.
self.db_pool.simple_update_many_txn(
txn,
table="event_push_summary",
key_names=("room_id", "user_id", "thread_id"),
key_values=[(room_id, user_id, row[2]) for row in unread_counts],
value_names=("notif_count", "unread_count"),
value_values=[(row[0], row[1]) for row in unread_counts],
)
# We always update `event_push_summary_last_receipt_stream_id` to # We always update `event_push_summary_last_receipt_stream_id` to
# ensure that we don't rescan the same receipts for remote users. # ensure that we don't rescan the same receipts for remote users.
@ -1204,23 +1242,23 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# Calculate the new counts that should be upserted into event_push_summary # Calculate the new counts that should be upserted into event_push_summary
sql = """ sql = """
SELECT user_id, room_id, SELECT user_id, room_id, thread_id,
coalesce(old.%s, 0) + upd.cnt, coalesce(old.%s, 0) + upd.cnt,
upd.stream_ordering upd.stream_ordering
FROM ( FROM (
SELECT user_id, room_id, count(*) as cnt, SELECT user_id, room_id, thread_id, count(*) as cnt,
max(ea.stream_ordering) as stream_ordering max(ea.stream_ordering) as stream_ordering
FROM event_push_actions AS ea FROM event_push_actions AS ea
LEFT JOIN event_push_summary AS old USING (user_id, room_id) LEFT JOIN event_push_summary AS old USING (user_id, room_id, thread_id)
WHERE ? < ea.stream_ordering AND ea.stream_ordering <= ? WHERE ? < ea.stream_ordering AND ea.stream_ordering <= ?
AND ( AND (
old.last_receipt_stream_ordering IS NULL old.last_receipt_stream_ordering IS NULL
OR old.last_receipt_stream_ordering < ea.stream_ordering OR old.last_receipt_stream_ordering < ea.stream_ordering
) )
AND %s = 1 AND %s = 1
GROUP BY user_id, room_id GROUP BY user_id, room_id, thread_id
) AS upd ) AS upd
LEFT JOIN event_push_summary AS old USING (user_id, room_id) LEFT JOIN event_push_summary AS old USING (user_id, room_id, thread_id)
""" """
# First get the count of unread messages. # First get the count of unread messages.
@ -1234,11 +1272,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# object because we might not have the same amount of rows in each of them. To do # object because we might not have the same amount of rows in each of them. To do
# this, we use a dict indexed on the user ID and room ID to make it easier to # this, we use a dict indexed on the user ID and room ID to make it easier to
# populate. # populate.
summaries: Dict[Tuple[str, str], _EventPushSummary] = {} summaries: Dict[Tuple[str, str, str], _EventPushSummary] = {}
for row in txn: for row in txn:
summaries[(row[0], row[1])] = _EventPushSummary( summaries[(row[0], row[1], row[2])] = _EventPushSummary(
unread_count=row[2], unread_count=row[3],
stream_ordering=row[3], stream_ordering=row[4],
notif_count=0, notif_count=0,
) )
@ -1249,34 +1287,35 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
) )
for row in txn: for row in txn:
if (row[0], row[1]) in summaries: if (row[0], row[1], row[2]) in summaries:
summaries[(row[0], row[1])].notif_count = row[2] summaries[(row[0], row[1], row[2])].notif_count = row[3]
else: else:
# Because the rules on notifying are different than the rules on marking # Because the rules on notifying are different than the rules on marking
# a message unread, we might end up with messages that notify but aren't # a message unread, we might end up with messages that notify but aren't
# marked unread, so we might not have a summary for this (user, room) # marked unread, so we might not have a summary for this (user, room)
# tuple to complete. # tuple to complete.
summaries[(row[0], row[1])] = _EventPushSummary( summaries[(row[0], row[1], row[2])] = _EventPushSummary(
unread_count=0, unread_count=0,
stream_ordering=row[3], stream_ordering=row[4],
notif_count=row[2], notif_count=row[3],
) )
logger.info("Rotating notifications, handling %d rows", len(summaries)) logger.info("Rotating notifications, handling %d rows", len(summaries))
# TODO(threads): Update on a per-thread basis.
self.db_pool.simple_upsert_many_txn( self.db_pool.simple_upsert_many_txn(
txn, txn,
table="event_push_summary", table="event_push_summary",
key_names=("user_id", "room_id"), key_names=("user_id", "room_id", "thread_id"),
key_values=[(user_id, room_id) for user_id, room_id in summaries], key_values=[
value_names=("notif_count", "unread_count", "stream_ordering", "thread_id"), (user_id, room_id, thread_id)
for user_id, room_id, thread_id in summaries
],
value_names=("notif_count", "unread_count", "stream_ordering"),
value_values=[ value_values=[
( (
summary.notif_count, summary.notif_count,
summary.unread_count, summary.unread_count,
summary.stream_ordering, summary.stream_ordering,
"main",
) )
for summary in summaries.values() for summary in summaries.values()
], ],
@ -1288,7 +1327,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
) )
async def _remove_old_push_actions_that_have_rotated(self) -> None: async def _remove_old_push_actions_that_have_rotated(self) -> None:
"""Clear out old push actions that have been summarised.""" """
Clear out old push actions that have been summarised (and are older than
1 day ago).
"""
# We want to clear out anything that is older than a day that *has* already # We want to clear out anything that is older than a day that *has* already
# been rotated. # been rotated.

@ -90,9 +90,9 @@ Changes in SCHEMA_VERSION = 73;
SCHEMA_COMPAT_VERSION = ( SCHEMA_COMPAT_VERSION = (
# The groups tables are no longer accessible, so synapses with SCHEMA_VERSION < 72 # The threads_id column must exist for event_push_actions, event_push_summary,
# could break. # receipts_linearized, and receipts_graph.
72 73
) )
"""Limit on how far the synapse codebase can be rolled back without breaking db compat """Limit on how far the synapse codebase can be rolled back without breaking db compat

@ -0,0 +1,29 @@
/* Copyright 2022 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Forces the background updates from 06thread_notifications.sql to run in the
-- foreground as code will now require those to be "done".
DELETE FROM background_updates WHERE update_name = 'event_push_backfill_thread_id';
-- Overwrite any null thread_id columns.
UPDATE event_push_actions_staging SET thread_id = 'main' WHERE thread_id IS NULL;
UPDATE event_push_actions SET thread_id = 'main' WHERE thread_id IS NULL;
UPDATE event_push_summary SET thread_id = 'main' WHERE thread_id IS NULL;
-- Do not run the event_push_summary_unique_index job if it is pending; the
-- thread_id field will be made required.
DELETE FROM background_updates WHERE update_name = 'event_push_summary_unique_index';
DROP INDEX IF EXISTS event_push_summary_unique_index;

@ -0,0 +1,19 @@
/* Copyright 2022 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- The columns can now be made non-nullable.
ALTER TABLE event_push_actions_staging ALTER COLUMN thread_id SET NOT NULL;
ALTER TABLE event_push_actions ALTER COLUMN thread_id SET NOT NULL;
ALTER TABLE event_push_summary ALTER COLUMN thread_id SET NOT NULL;

@ -0,0 +1,101 @@
/* Copyright 2022 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- SQLite doesn't support modifying columns to an existing table, so it must
-- be recreated.
-- Create the new tables.
CREATE TABLE event_push_actions_staging_new (
event_id TEXT NOT NULL,
user_id TEXT NOT NULL,
actions TEXT NOT NULL,
notif SMALLINT NOT NULL,
highlight SMALLINT NOT NULL,
unread SMALLINT,
thread_id TEXT NOT NULL,
inserted_ts BIGINT
);
CREATE TABLE event_push_actions_new (
room_id TEXT NOT NULL,
event_id TEXT NOT NULL,
user_id TEXT NOT NULL,
profile_tag VARCHAR(32),
actions TEXT NOT NULL,
topological_ordering BIGINT,
stream_ordering BIGINT,
notif SMALLINT,
highlight SMALLINT,
unread SMALLINT,
thread_id TEXT NOT NULL,
CONSTRAINT event_id_user_id_profile_tag_uniqueness UNIQUE (room_id, event_id, user_id, profile_tag)
);
CREATE TABLE event_push_summary_new (
user_id TEXT NOT NULL,
room_id TEXT NOT NULL,
notif_count BIGINT NOT NULL,
stream_ordering BIGINT NOT NULL,
unread_count BIGINT,
last_receipt_stream_ordering BIGINT,
thread_id TEXT NOT NULL
);
-- Swap the indexes.
DROP INDEX IF EXISTS event_push_actions_staging_id;
CREATE INDEX event_push_actions_staging_id ON event_push_actions_staging_new(event_id);
DROP INDEX IF EXISTS event_push_actions_room_id_user_id;
DROP INDEX IF EXISTS event_push_actions_rm_tokens;
DROP INDEX IF EXISTS event_push_actions_stream_ordering;
DROP INDEX IF EXISTS event_push_actions_u_highlight;
DROP INDEX IF EXISTS event_push_actions_highlights_index;
CREATE INDEX event_push_actions_room_id_user_id on event_push_actions_new(room_id, user_id);
CREATE INDEX event_push_actions_rm_tokens on event_push_actions_new( user_id, room_id, topological_ordering, stream_ordering );
CREATE INDEX event_push_actions_stream_ordering on event_push_actions_new( stream_ordering, user_id );
CREATE INDEX event_push_actions_u_highlight ON event_push_actions_new (user_id, stream_ordering);
CREATE INDEX event_push_actions_highlights_index ON event_push_actions_new (user_id, room_id, topological_ordering, stream_ordering);
-- Copy the data.
INSERT INTO event_push_actions_staging_new (event_id, user_id, actions, notif, highlight, unread, thread_id, inserted_ts)
SELECT event_id, user_id, actions, notif, highlight, unread, thread_id, inserted_ts
FROM event_push_actions_staging;
INSERT INTO event_push_actions_new (room_id, event_id, user_id, profile_tag, actions, topological_ordering, stream_ordering, notif, highlight, unread, thread_id)
SELECT room_id, event_id, user_id, profile_tag, actions, topological_ordering, stream_ordering, notif, highlight, unread, thread_id
FROM event_push_actions;
INSERT INTO event_push_summary_new (user_id, room_id, notif_count, stream_ordering, unread_count, last_receipt_stream_ordering, thread_id)
SELECT user_id, room_id, notif_count, stream_ordering, unread_count, last_receipt_stream_ordering, thread_id
FROM event_push_summary;
-- Drop the old tables.
DROP TABLE event_push_actions_staging;
DROP TABLE event_push_actions;
DROP TABLE event_push_summary;
-- Rename the tables.
ALTER TABLE event_push_actions_staging_new RENAME TO event_push_actions_staging;
ALTER TABLE event_push_actions_new RENAME TO event_push_actions;
ALTER TABLE event_push_summary_new RENAME TO event_push_summary;
-- Re-run background updates from 72/02event_push_actions_index.sql and
-- 72/06thread_notifications.sql.
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(7307, 'event_push_summary_unique_index2', '{}')
ON CONFLICT (update_name) DO NOTHING;
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(7307, 'event_push_actions_stream_highlight_index', '{}')
ON CONFLICT (update_name) DO NOTHING;

@ -22,7 +22,10 @@ from synapse.api.room_versions import RoomVersions
from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict
from synapse.handlers.room import RoomEventSource from synapse.handlers.room import RoomEventSource
from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.storage.databases.main.event_push_actions import NotifCounts from synapse.storage.databases.main.event_push_actions import (
NotifCounts,
RoomNotifCounts,
)
from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser
from synapse.types import PersistedEventPosition from synapse.types import PersistedEventPosition
@ -178,7 +181,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check( self.check(
"get_unread_event_push_actions_by_room_for_user", "get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2], [ROOM_ID, USER_ID_2],
NotifCounts(highlight_count=0, unread_count=0, notify_count=0), RoomNotifCounts(
NotifCounts(highlight_count=0, unread_count=0, notify_count=0), {}
),
) )
self.persist( self.persist(
@ -191,7 +196,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check( self.check(
"get_unread_event_push_actions_by_room_for_user", "get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2], [ROOM_ID, USER_ID_2],
NotifCounts(highlight_count=0, unread_count=0, notify_count=1), RoomNotifCounts(
NotifCounts(highlight_count=0, unread_count=0, notify_count=1), {}
),
) )
self.persist( self.persist(
@ -206,7 +213,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check( self.check(
"get_unread_event_push_actions_by_room_for_user", "get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2], [ROOM_ID, USER_ID_2],
NotifCounts(highlight_count=1, unread_count=0, notify_count=2), RoomNotifCounts(
NotifCounts(highlight_count=1, unread_count=0, notify_count=2), {}
),
) )
def test_get_rooms_for_user_with_stream_ordering(self): def test_get_rooms_for_user_with_stream_ordering(self):

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Tuple from typing import Optional, Tuple
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -20,6 +20,7 @@ from synapse.rest import admin
from synapse.rest.client import login, room from synapse.rest.client import login, room
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.databases.main.event_push_actions import NotifCounts from synapse.storage.databases.main.event_push_actions import NotifCounts
from synapse.types import JsonDict
from synapse.util import Clock from synapse.util import Clock
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@ -133,13 +134,14 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
) )
) )
self.assertEqual( self.assertEqual(
counts, counts.main_timeline,
NotifCounts( NotifCounts(
notify_count=noitf_count, notify_count=noitf_count,
unread_count=0, unread_count=0,
highlight_count=highlight_count, highlight_count=highlight_count,
), ),
) )
self.assertEqual(counts.threads, {})
def _create_event(highlight: bool = False) -> str: def _create_event(highlight: bool = False) -> str:
result = self.helper.send_event( result = self.helper.send_event(
@ -186,6 +188,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
_assert_counts(0, 0) _assert_counts(0, 0)
_create_event() _create_event()
_assert_counts(1, 0)
_rotate() _rotate()
_assert_counts(1, 0) _assert_counts(1, 0)
@ -236,6 +239,168 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
_rotate() _rotate()
_assert_counts(0, 0) _assert_counts(0, 0)
def test_count_aggregation_threads(self) -> None:
"""
This is essentially the same test as test_count_aggregation, but adds
events to the main timeline and to a thread.
"""
user_id, token, _, other_token, room_id = self._create_users_and_room()
thread_id: str
last_event_id: str
def _assert_counts(
noitf_count: int,
highlight_count: int,
thread_notif_count: int,
thread_highlight_count: int,
) -> None:
counts = self.get_success(
self.store.db_pool.runInteraction(
"get-unread-counts",
self.store._get_unread_counts_by_receipt_txn,
room_id,
user_id,
)
)
self.assertEqual(
counts.main_timeline,
NotifCounts(
notify_count=noitf_count,
unread_count=0,
highlight_count=highlight_count,
),
)
if thread_notif_count or thread_highlight_count:
self.assertEqual(
counts.threads,
{
thread_id: NotifCounts(
notify_count=thread_notif_count,
unread_count=0,
highlight_count=thread_highlight_count,
),
},
)
else:
self.assertEqual(counts.threads, {})
def _create_event(
highlight: bool = False, thread_id: Optional[str] = None
) -> str:
content: JsonDict = {
"msgtype": "m.text",
"body": user_id if highlight else "msg",
}
if thread_id:
content["m.relates_to"] = {
"rel_type": "m.thread",
"event_id": thread_id,
}
result = self.helper.send_event(
room_id,
type="m.room.message",
content=content,
tok=other_token,
)
nonlocal last_event_id
last_event_id = result["event_id"]
return last_event_id
def _rotate() -> None:
self.get_success(self.store._rotate_notifs())
def _mark_read(event_id: str, thread_id: Optional[str] = None) -> None:
self.get_success(
self.store.insert_receipt(
room_id,
"m.read",
user_id=user_id,
event_ids=[event_id],
thread_id=thread_id,
data={},
)
)
_assert_counts(0, 0, 0, 0)
thread_id = _create_event()
_assert_counts(1, 0, 0, 0)
_rotate()
_assert_counts(1, 0, 0, 0)
_create_event(thread_id=thread_id)
_assert_counts(1, 0, 1, 0)
_rotate()
_assert_counts(1, 0, 1, 0)
_create_event()
_assert_counts(2, 0, 1, 0)
_rotate()
_assert_counts(2, 0, 1, 0)
event_id = _create_event(thread_id=thread_id)
_assert_counts(2, 0, 2, 0)
_rotate()
_assert_counts(2, 0, 2, 0)
_create_event()
_create_event(thread_id=thread_id)
_mark_read(event_id)
_assert_counts(1, 0, 1, 0)
_mark_read(last_event_id)
_assert_counts(0, 0, 0, 0)
_create_event()
_create_event(thread_id=thread_id)
_assert_counts(1, 0, 1, 0)
_rotate()
_assert_counts(1, 0, 1, 0)
# Delete old event push actions, this should not affect the (summarised) count.
self.get_success(self.store._remove_old_push_actions_that_have_rotated())
_assert_counts(1, 0, 1, 0)
_mark_read(last_event_id)
_assert_counts(0, 0, 0, 0)
_create_event(True)
_assert_counts(1, 1, 0, 0)
_rotate()
_assert_counts(1, 1, 0, 0)
event_id = _create_event(True, thread_id)
_assert_counts(1, 1, 1, 1)
_rotate()
_assert_counts(1, 1, 1, 1)
# Check that adding another notification and rotating after highlight
# works.
_create_event()
_rotate()
_assert_counts(2, 1, 1, 1)
_create_event(thread_id=thread_id)
_rotate()
_assert_counts(2, 1, 2, 1)
# Check that sending read receipts at different points results in the
# right counts.
_mark_read(event_id)
_assert_counts(1, 0, 1, 0)
_mark_read(last_event_id)
_assert_counts(0, 0, 0, 0)
_create_event(True)
_create_event(True, thread_id)
_assert_counts(1, 1, 1, 1)
_mark_read(last_event_id)
_assert_counts(0, 0, 0, 0)
_rotate()
_assert_counts(0, 0, 0, 0)
def test_find_first_stream_ordering_after_ts(self) -> None: def test_find_first_stream_ordering_after_ts(self) -> None:
def add_event(so: int, ts: int) -> None: def add_event(so: int, ts: int) -> None:
self.get_success( self.get_success(

Loading…
Cancel
Save