|
|
|
@ -15,6 +15,7 @@ |
|
|
|
|
# limitations under the License. |
|
|
|
|
|
|
|
|
|
import logging |
|
|
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union |
|
|
|
|
|
|
|
|
|
import attr |
|
|
|
|
from prometheus_client import Counter |
|
|
|
@ -25,16 +26,16 @@ from synapse.events import EventBase |
|
|
|
|
from synapse.events.snapshot import EventContext |
|
|
|
|
from synapse.state import POWER_KEY |
|
|
|
|
from synapse.util.async_helpers import Linearizer |
|
|
|
|
from synapse.util.caches import register_cache |
|
|
|
|
from synapse.util.caches import CacheMetric, register_cache |
|
|
|
|
from synapse.util.caches.descriptors import lru_cache |
|
|
|
|
from synapse.util.caches.lrucache import LruCache |
|
|
|
|
|
|
|
|
|
from .push_rule_evaluator import PushRuleEvaluatorForEvent |
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
|
|
from synapse.app.homeserver import HomeServer |
|
|
|
|
|
|
|
|
|
rules_by_room = {} |
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
push_rules_invalidation_counter = Counter( |
|
|
|
@ -101,7 +102,7 @@ class BulkPushRuleEvaluator: |
|
|
|
|
room at once. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
def __init__(self, hs): |
|
|
|
|
def __init__(self, hs: "HomeServer"): |
|
|
|
|
self.hs = hs |
|
|
|
|
self.store = hs.get_datastore() |
|
|
|
|
self.auth = hs.get_auth() |
|
|
|
@ -113,7 +114,9 @@ class BulkPushRuleEvaluator: |
|
|
|
|
resizable=False, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
async def _get_rules_for_event(self, event, context): |
|
|
|
|
async def _get_rules_for_event( |
|
|
|
|
self, event: EventBase, context: EventContext |
|
|
|
|
) -> Dict[str, List[Dict[str, Any]]]: |
|
|
|
|
"""This gets the rules for all users in the room at the time of the event, |
|
|
|
|
as well as the push rules for the invitee if the event is an invite. |
|
|
|
|
|
|
|
|
@ -140,11 +143,8 @@ class BulkPushRuleEvaluator: |
|
|
|
|
return rules_by_user |
|
|
|
|
|
|
|
|
|
@lru_cache() |
|
|
|
|
def _get_rules_for_room(self, room_id): |
|
|
|
|
def _get_rules_for_room(self, room_id: str) -> "RulesForRoom": |
|
|
|
|
"""Get the current RulesForRoom object for the given room id |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
RulesForRoom |
|
|
|
|
""" |
|
|
|
|
# It's important that RulesForRoom gets added to self._get_rules_for_room.cache |
|
|
|
|
# before any lookup methods get called on it as otherwise there may be |
|
|
|
@ -156,20 +156,21 @@ class BulkPushRuleEvaluator: |
|
|
|
|
self.room_push_rule_cache_metrics, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
async def _get_power_levels_and_sender_level(self, event, context): |
|
|
|
|
async def _get_power_levels_and_sender_level( |
|
|
|
|
self, event: EventBase, context: EventContext |
|
|
|
|
) -> Tuple[dict, int]: |
|
|
|
|
prev_state_ids = await context.get_prev_state_ids() |
|
|
|
|
pl_event_id = prev_state_ids.get(POWER_KEY) |
|
|
|
|
if pl_event_id: |
|
|
|
|
# fastpath: if there's a power level event, that's all we need, and |
|
|
|
|
# not having a power level event is an extreme edge case |
|
|
|
|
pl_event = await self.store.get_event(pl_event_id) |
|
|
|
|
auth_events = {POWER_KEY: pl_event} |
|
|
|
|
auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)} |
|
|
|
|
else: |
|
|
|
|
auth_events_ids = self.auth.compute_auth_events( |
|
|
|
|
event, prev_state_ids, for_verification=False |
|
|
|
|
) |
|
|
|
|
auth_events = await self.store.get_events(auth_events_ids) |
|
|
|
|
auth_events = {(e.type, e.state_key): e for e in auth_events.values()} |
|
|
|
|
auth_events_dict = await self.store.get_events(auth_events_ids) |
|
|
|
|
auth_events = {(e.type, e.state_key): e for e in auth_events_dict.values()} |
|
|
|
|
|
|
|
|
|
sender_level = get_user_power_level(event.sender, auth_events) |
|
|
|
|
|
|
|
|
@ -177,7 +178,9 @@ class BulkPushRuleEvaluator: |
|
|
|
|
|
|
|
|
|
return pl_event.content if pl_event else {}, sender_level |
|
|
|
|
|
|
|
|
|
async def action_for_event_by_user(self, event, context) -> None: |
|
|
|
|
async def action_for_event_by_user( |
|
|
|
|
self, event: EventBase, context: EventContext |
|
|
|
|
) -> None: |
|
|
|
|
"""Given an event and context, evaluate the push rules, check if the message |
|
|
|
|
should increment the unread count, and insert the results into the |
|
|
|
|
event_push_actions_staging table. |
|
|
|
@ -185,7 +188,7 @@ class BulkPushRuleEvaluator: |
|
|
|
|
count_as_unread = _should_count_as_unread(event, context) |
|
|
|
|
|
|
|
|
|
rules_by_user = await self._get_rules_for_event(event, context) |
|
|
|
|
actions_by_user = {} |
|
|
|
|
actions_by_user = {} # type: Dict[str, List[Union[dict, str]]] |
|
|
|
|
|
|
|
|
|
room_members = await self.store.get_joined_users_from_context(event, context) |
|
|
|
|
|
|
|
|
@ -198,7 +201,7 @@ class BulkPushRuleEvaluator: |
|
|
|
|
event, len(room_members), sender_power_level, power_levels |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
condition_cache = {} |
|
|
|
|
condition_cache = {} # type: Dict[str, bool] |
|
|
|
|
|
|
|
|
|
for uid, rules in rules_by_user.items(): |
|
|
|
|
if event.sender == uid: |
|
|
|
@ -249,7 +252,13 @@ class BulkPushRuleEvaluator: |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _condition_checker(evaluator, conditions, uid, display_name, cache): |
|
|
|
|
def _condition_checker( |
|
|
|
|
evaluator: PushRuleEvaluatorForEvent, |
|
|
|
|
conditions: List[dict], |
|
|
|
|
uid: str, |
|
|
|
|
display_name: str, |
|
|
|
|
cache: Dict[str, bool], |
|
|
|
|
) -> bool: |
|
|
|
|
for cond in conditions: |
|
|
|
|
_id = cond.get("_id", None) |
|
|
|
|
if _id: |
|
|
|
@ -277,15 +286,19 @@ class RulesForRoom: |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
def __init__( |
|
|
|
|
self, hs, room_id, rules_for_room_cache: LruCache, room_push_rule_cache_metrics |
|
|
|
|
self, |
|
|
|
|
hs: "HomeServer", |
|
|
|
|
room_id: str, |
|
|
|
|
rules_for_room_cache: LruCache, |
|
|
|
|
room_push_rule_cache_metrics: CacheMetric, |
|
|
|
|
): |
|
|
|
|
""" |
|
|
|
|
Args: |
|
|
|
|
hs (HomeServer) |
|
|
|
|
room_id (str) |
|
|
|
|
hs: The HomeServer object. |
|
|
|
|
room_id: The room ID. |
|
|
|
|
rules_for_room_cache: The cache object that caches these |
|
|
|
|
RoomsForUser objects. |
|
|
|
|
room_push_rule_cache_metrics (CacheMetric) |
|
|
|
|
room_push_rule_cache_metrics: The metrics object |
|
|
|
|
""" |
|
|
|
|
self.room_id = room_id |
|
|
|
|
self.is_mine_id = hs.is_mine_id |
|
|
|
@ -294,8 +307,10 @@ class RulesForRoom: |
|
|
|
|
|
|
|
|
|
self.linearizer = Linearizer(name="rules_for_room") |
|
|
|
|
|
|
|
|
|
self.member_map = {} # event_id -> (user_id, state) |
|
|
|
|
self.rules_by_user = {} # user_id -> rules |
|
|
|
|
# event_id -> (user_id, state) |
|
|
|
|
self.member_map = {} # type: Dict[str, Tuple[str, str]] |
|
|
|
|
# user_id -> rules |
|
|
|
|
self.rules_by_user = {} # type: Dict[str, List[Dict[str, dict]]] |
|
|
|
|
|
|
|
|
|
# The last state group we updated the caches for. If the state_group of |
|
|
|
|
# a new event comes along, we know that we can just return the cached |
|
|
|
@ -315,7 +330,7 @@ class RulesForRoom: |
|
|
|
|
# calculate push for) |
|
|
|
|
# These never need to be invalidated as we will never set up push for |
|
|
|
|
# them. |
|
|
|
|
self.uninteresting_user_set = set() |
|
|
|
|
self.uninteresting_user_set = set() # type: Set[str] |
|
|
|
|
|
|
|
|
|
# We need to be clever on the invalidating caches callbacks, as |
|
|
|
|
# otherwise the invalidation callback holds a reference to the object, |
|
|
|
@ -325,7 +340,9 @@ class RulesForRoom: |
|
|
|
|
# to self around in the callback. |
|
|
|
|
self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id) |
|
|
|
|
|
|
|
|
|
async def get_rules(self, event, context): |
|
|
|
|
async def get_rules( |
|
|
|
|
self, event: EventBase, context: EventContext |
|
|
|
|
) -> Dict[str, List[Dict[str, dict]]]: |
|
|
|
|
"""Given an event context return the rules for all users who are |
|
|
|
|
currently in the room. |
|
|
|
|
""" |
|
|
|
@ -356,6 +373,8 @@ class RulesForRoom: |
|
|
|
|
else: |
|
|
|
|
current_state_ids = await context.get_current_state_ids() |
|
|
|
|
push_rules_delta_state_cache_metric.inc_misses() |
|
|
|
|
# Ensure the state IDs exist. |
|
|
|
|
assert current_state_ids is not None |
|
|
|
|
|
|
|
|
|
push_rules_state_size_counter.inc(len(current_state_ids)) |
|
|
|
|
|
|
|
|
@ -420,18 +439,23 @@ class RulesForRoom: |
|
|
|
|
return ret_rules_by_user |
|
|
|
|
|
|
|
|
|
async def _update_rules_with_member_event_ids( |
|
|
|
|
self, ret_rules_by_user, member_event_ids, state_group, event |
|
|
|
|
): |
|
|
|
|
self, |
|
|
|
|
ret_rules_by_user: Dict[str, list], |
|
|
|
|
member_event_ids: Dict[str, str], |
|
|
|
|
state_group: Optional[int], |
|
|
|
|
event: EventBase, |
|
|
|
|
) -> None: |
|
|
|
|
"""Update the partially filled rules_by_user dict by fetching rules for |
|
|
|
|
any newly joined users in the `member_event_ids` list. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets |
|
|
|
|
ret_rules_by_user: Partially filled dict of push rules. Gets |
|
|
|
|
updated with any new rules. |
|
|
|
|
member_event_ids (dict): Dict of user id to event id for membership events |
|
|
|
|
member_event_ids: Dict of user id to event id for membership events |
|
|
|
|
that have happened since the last time we filled rules_by_user |
|
|
|
|
state_group: The state group we are currently computing push rules |
|
|
|
|
for. Used when updating the cache. |
|
|
|
|
event: The event we are currently computing push rules for. |
|
|
|
|
""" |
|
|
|
|
sequence = self.sequence |
|
|
|
|
|
|
|
|
@ -449,19 +473,19 @@ class RulesForRoom: |
|
|
|
|
if logger.isEnabledFor(logging.DEBUG): |
|
|
|
|
logger.debug("Found members %r: %r", self.room_id, members.values()) |
|
|
|
|
|
|
|
|
|
user_ids = { |
|
|
|
|
joined_user_ids = { |
|
|
|
|
user_id |
|
|
|
|
for user_id, membership in members.values() |
|
|
|
|
if membership == Membership.JOIN |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
logger.debug("Joined: %r", user_ids) |
|
|
|
|
logger.debug("Joined: %r", joined_user_ids) |
|
|
|
|
|
|
|
|
|
# Previously we only considered users with pushers or read receipts in that |
|
|
|
|
# room. We can't do this anymore because we use push actions to calculate unread |
|
|
|
|
# counts, which don't rely on the user having pushers or sent a read receipt into |
|
|
|
|
# the room. Therefore we just need to filter for local users here. |
|
|
|
|
user_ids = list(filter(self.is_mine_id, user_ids)) |
|
|
|
|
user_ids = list(filter(self.is_mine_id, joined_user_ids)) |
|
|
|
|
|
|
|
|
|
rules_by_user = await self.store.bulk_get_push_rules( |
|
|
|
|
user_ids, on_invalidate=self.invalidate_all_cb |
|
|
|
@ -473,7 +497,7 @@ class RulesForRoom: |
|
|
|
|
|
|
|
|
|
self.update_cache(sequence, members, ret_rules_by_user, state_group) |
|
|
|
|
|
|
|
|
|
def invalidate_all(self): |
|
|
|
|
def invalidate_all(self) -> None: |
|
|
|
|
# Note: Don't hand this function directly to an invalidation callback |
|
|
|
|
# as it keeps a reference to self and will stop this instance from being |
|
|
|
|
# GC'd if it gets dropped from the rules_to_user cache. Instead use |
|
|
|
@ -485,7 +509,7 @@ class RulesForRoom: |
|
|
|
|
self.rules_by_user = {} |
|
|
|
|
push_rules_invalidation_counter.inc() |
|
|
|
|
|
|
|
|
|
def update_cache(self, sequence, members, rules_by_user, state_group): |
|
|
|
|
def update_cache(self, sequence, members, rules_by_user, state_group) -> None: |
|
|
|
|
if sequence == self.sequence: |
|
|
|
|
self.member_map.update(members) |
|
|
|
|
self.rules_by_user = rules_by_user |
|
|
|
@ -506,7 +530,7 @@ class _Invalidation: |
|
|
|
|
cache = attr.ib(type=LruCache) |
|
|
|
|
room_id = attr.ib(type=str) |
|
|
|
|
|
|
|
|
|
def __call__(self): |
|
|
|
|
def __call__(self) -> None: |
|
|
|
|
rules = self.cache.get(self.room_id, None, update_metrics=False) |
|
|
|
|
if rules: |
|
|
|
|
rules.invalidate_all() |
|
|
|
|