From acbfdc3442f706150c6312e534ca4ecec8548582 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 23 Jul 2018 12:17:16 +0100 Subject: [PATCH 01/16] Refcator EventContext to accept state during init --- synapse/events/snapshot.py | 48 ++++++++++--------- synapse/state.py | 97 +++++++++++++++++++++----------------- 2 files changed, 82 insertions(+), 63 deletions(-) diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index f83a1581a..fbbe8dd49 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -60,22 +60,22 @@ class EventContext(object): "app_service", ] - def __init__(self): + def __init__(self, state_group, current_state_ids, prev_state_ids, + prev_group=None, delta_ids=None): # The current state including the current event - self.current_state_ids = None + self.current_state_ids = current_state_ids # The current state excluding the current event - self.prev_state_ids = None - self.state_group = None - - self.rejected = False + self.prev_state_ids = prev_state_ids + self.state_group = state_group # A previously persisted state group and a delta between that # and this state. - self.prev_group = None - self.delta_ids = None + self.prev_group = prev_group + self.delta_ids = delta_ids - self.prev_state_events = None + self.prev_state_events = [] + self.rejected = False self.app_service = None def serialize(self, event): @@ -123,27 +123,33 @@ class EventContext(object): Returns: EventContext """ - context = EventContext() - context.state_group = input["state_group"] - context.rejected = input["rejected"] - context.prev_group = input["prev_group"] - context.delta_ids = _decode_state_dict(input["delta_ids"]) - context.prev_state_events = input["prev_state_events"] - # We use the state_group and prev_state_id stuff to pull the # current_state_ids out of the DB and construct prev_state_ids. prev_state_id = input["prev_state_id"] event_type = input["event_type"] event_state_key = input["event_state_key"] - context.current_state_ids = yield store.get_state_ids_for_group( - context.state_group, + state_group = input["state_group"] + + current_state_ids = yield store.get_state_ids_for_group( + state_group, ) if prev_state_id and event_state_key: - context.prev_state_ids = dict(context.current_state_ids) - context.prev_state_ids[(event_type, event_state_key)] = prev_state_id + prev_state_ids = dict(current_state_ids) + prev_state_ids[(event_type, event_state_key)] = prev_state_id else: - context.prev_state_ids = context.current_state_ids + prev_state_ids = current_state_ids + + context = EventContext( + state_group=state_group, + current_state_ids=current_state_ids, + prev_state_ids=prev_state_ids, + prev_group=input["prev_group"], + delta_ids = _decode_state_dict(input["delta_ids"]), + ) + + context.rejected = input["rejected"] + context.prev_state_events = input["prev_state_events"] app_service_id = input["app_service_id"] if app_service_id: diff --git a/synapse/state.py b/synapse/state.py index 504caae2f..a70869500 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -203,25 +203,27 @@ class StateHandler(object): # If this is an outlier, then we know it shouldn't have any current # state. Certainly store.get_current_state won't return any, and # persisting the event won't store the state group. - context = EventContext() if old_state: - context.prev_state_ids = { + prev_state_ids = { (s.type, s.state_key): s.event_id for s in old_state } if event.is_state(): - context.current_state_ids = dict(context.prev_state_ids) + current_state_ids = dict(prev_state_ids) key = (event.type, event.state_key) - context.current_state_ids[key] = event.event_id + current_state_ids[key] = event.event_id else: - context.current_state_ids = context.prev_state_ids + current_state_ids = prev_state_ids else: - context.current_state_ids = {} - context.prev_state_ids = {} - context.prev_state_events = [] + current_state_ids = {} + prev_state_ids = {} # We don't store state for outliers, so we don't generate a state - # froup for it. - context.state_group = None + # group for it. + context = EventContext( + state_group=None, + current_state_ids=current_state_ids, + prev_state_ids=prev_state_ids, + ) defer.returnValue(context) @@ -230,31 +232,35 @@ class StateHandler(object): # Let's just correctly fill out the context and create a # new state group for it. - context = EventContext() - context.prev_state_ids = { + prev_state_ids = { (s.type, s.state_key): s.event_id for s in old_state } if event.is_state(): key = (event.type, event.state_key) - if key in context.prev_state_ids: - replaces = context.prev_state_ids[key] + if key in prev_state_ids: + replaces = prev_state_ids[key] if replaces != event.event_id: # Paranoia check event.unsigned["replaces_state"] = replaces - context.current_state_ids = dict(context.prev_state_ids) - context.current_state_ids[key] = event.event_id + current_state_ids = dict(prev_state_ids) + current_state_ids[key] = event.event_id else: - context.current_state_ids = context.prev_state_ids + current_state_ids = prev_state_ids - context.state_group = yield self.store.store_state_group( + state_group = yield self.store.store_state_group( event.event_id, event.room_id, prev_group=None, delta_ids=None, - current_state_ids=context.current_state_ids, + current_state_ids=current_state_ids, + ) + + context = EventContext( + state_group=state_group, + current_state_ids=current_state_ids, + prev_state_ids=prev_state_ids, ) - context.prev_state_events = [] defer.returnValue(context) logger.debug("calling resolve_state_groups from compute_event_context") @@ -262,47 +268,47 @@ class StateHandler(object): event.room_id, [e for e, _ in event.prev_events], ) - curr_state = entry.state + prev_state_ids = entry.state + prev_group = None + delta_ids = None - context = EventContext() - context.prev_state_ids = curr_state if event.is_state(): # If this is a state event then we need to create a new state # group for the state after this event. key = (event.type, event.state_key) - if key in context.prev_state_ids: - replaces = context.prev_state_ids[key] + if key in prev_state_ids: + replaces = prev_state_ids[key] event.unsigned["replaces_state"] = replaces - context.current_state_ids = dict(context.prev_state_ids) - context.current_state_ids[key] = event.event_id + current_state_ids = dict(prev_state_ids) + current_state_ids[key] = event.event_id if entry.state_group: # If the state at the event has a state group assigned then # we can use that as the prev group - context.prev_group = entry.state_group - context.delta_ids = { + prev_group = entry.state_group + delta_ids = { key: event.event_id } elif entry.prev_group: # If the state at the event only has a prev group, then we can # use that as a prev group too. - context.prev_group = entry.prev_group - context.delta_ids = dict(entry.delta_ids) - context.delta_ids[key] = event.event_id + prev_group = entry.prev_group + delta_ids = dict(entry.delta_ids) + delta_ids[key] = event.event_id - context.state_group = yield self.store.store_state_group( + state_group = yield self.store.store_state_group( event.event_id, event.room_id, - prev_group=context.prev_group, - delta_ids=context.delta_ids, - current_state_ids=context.current_state_ids, + prev_group=prev_group, + delta_ids=delta_ids, + current_state_ids=current_state_ids, ) else: - context.current_state_ids = context.prev_state_ids - context.prev_group = entry.prev_group - context.delta_ids = entry.delta_ids + current_state_ids = prev_state_ids + prev_group = entry.prev_group + delta_ids = entry.delta_ids if entry.state_group is None: entry.state_group = yield self.store.store_state_group( @@ -310,13 +316,20 @@ class StateHandler(object): event.room_id, prev_group=entry.prev_group, delta_ids=entry.delta_ids, - current_state_ids=context.current_state_ids, + current_state_ids=current_state_ids, ) entry.state_id = entry.state_group - context.state_group = entry.state_group + state_group = entry.state_group + + context = EventContext( + state_group=state_group, + current_state_ids=current_state_ids, + prev_state_ids=prev_state_ids, + prev_group=prev_group, + delta_ids=delta_ids, + ) - context.prev_state_events = [] defer.returnValue(context) @defer.inlineCallbacks From 959f4b9074f6e893dc3c8e622e8c17fd229ad319 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 23 Jul 2018 12:22:59 +0100 Subject: [PATCH 02/16] Newsfile --- changelog.d/3577.misc | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 changelog.d/3577.misc diff --git a/changelog.d/3577.misc b/changelog.d/3577.misc new file mode 100644 index 000000000..e69de29bb From 842cdece42e59a4181a496761447f0cb00053c05 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 23 Jul 2018 12:31:35 +0100 Subject: [PATCH 03/16] pep8 --- synapse/events/snapshot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index fbbe8dd49..5e02ef1a5 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -145,7 +145,7 @@ class EventContext(object): current_state_ids=current_state_ids, prev_state_ids=prev_state_ids, prev_group=input["prev_group"], - delta_ids = _decode_state_dict(input["delta_ids"]), + delta_ids=_decode_state_dict(input["delta_ids"]), ) context.rejected = input["rejected"] From 440b8845b531db05e7c4f646e48dee7635cf1f0a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 23 Jul 2018 12:38:46 +0100 Subject: [PATCH 04/16] Make EventContext lazy load state --- synapse/events/snapshot.py | 153 +++++++++++++++++++++++++++---------- synapse/state.py | 6 +- 2 files changed, 115 insertions(+), 44 deletions(-) diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 5e02ef1a5..f9568638a 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -19,18 +19,12 @@ from frozendict import frozendict from twisted.internet import defer +from synapse.util.logcontext import make_deferred_yieldable, run_in_background + class EventContext(object): """ Attributes: - current_state_ids (dict[(str, str), str]): - The current state map including the current event. - (type, state_key) -> event_id - - prev_state_ids (dict[(str, str), str]): - The current state map excluding the current event. - (type, state_key) -> event_id - state_group (int|None): state group id, if the state has been stored as a state group. This is usually only None if e.g. the event is an outlier. @@ -47,36 +41,71 @@ class EventContext(object): prev_state_events (?): XXX: is this ever set to anything other than the empty list? + + _current_state_ids (dict[(str, str), str]|None): + The current state map including the current event. None if outlier + or we haven't fetched the state from DB yet. + (type, state_key) -> event_id + + _prev_state_ids (dict[(str, str), str]|None): + The current state map excluding the current event. None if outlier + or we haven't fetched the state from DB yet. + (type, state_key) -> event_id + + _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have + been calculated. None if we haven't started calculating yet + + _prev_state_id (str|None): If set then the event associated with the + context overrode the _prev_state_id + + _event_type (str): The type of the event the context is associated with + + _event_state_key (str|None): The state_key of the event the context is + associated with """ __slots__ = [ - "current_state_ids", - "prev_state_ids", "state_group", "rejected", "prev_group", "delta_ids", "prev_state_events", "app_service", + "_current_state_ids", + "_prev_state_ids", + "_prev_state_id", + "_event_type", + "_event_state_key", + "_fetching_state_deferred", ] - def __init__(self, state_group, current_state_ids, prev_state_ids, - prev_group=None, delta_ids=None): + @staticmethod + def with_state(state_group, current_state_ids, prev_state_ids, + prev_group=None, delta_ids=None): + context = EventContext() + # The current state including the current event - self.current_state_ids = current_state_ids + context._current_state_ids = current_state_ids # The current state excluding the current event - self.prev_state_ids = prev_state_ids - self.state_group = state_group + context._prev_state_ids = prev_state_ids + context.state_group = state_group + + context._prev_state_id = None + context._event_type = None + context._event_state_key = None + context._fetching_state_deferred = defer.succeed(None) # A previously persisted state group and a delta between that # and this state. - self.prev_group = prev_group - self.delta_ids = delta_ids + context.prev_group = prev_group + context.delta_ids = delta_ids + + context.prev_state_events = [] - self.prev_state_events = [] + context.rejected = False + context.app_service = None - self.rejected = False - self.app_service = None + return context def serialize(self, event): """Converts self to a type that can be serialized as JSON, and then @@ -123,30 +152,17 @@ class EventContext(object): Returns: EventContext """ + context = EventContext() + # We use the state_group and prev_state_id stuff to pull the # current_state_ids out of the DB and construct prev_state_ids. - prev_state_id = input["prev_state_id"] - event_type = input["event_type"] - event_state_key = input["event_state_key"] + context._prev_state_id = input["prev_state_id"] + context._event_type = input["event_type"] + context._event_state_key = input["event_state_key"] - state_group = input["state_group"] - - current_state_ids = yield store.get_state_ids_for_group( - state_group, - ) - if prev_state_id and event_state_key: - prev_state_ids = dict(current_state_ids) - prev_state_ids[(event_type, event_state_key)] = prev_state_id - else: - prev_state_ids = current_state_ids - - context = EventContext( - state_group=state_group, - current_state_ids=current_state_ids, - prev_state_ids=prev_state_ids, - prev_group=input["prev_group"], - delta_ids=_decode_state_dict(input["delta_ids"]), - ) + context.state_group = input["state_group"] + context.prev_group = input["prev_group"] + context.delta_ids = _decode_state_dict(input["delta_ids"]) context.rejected = input["rejected"] context.prev_state_events = input["prev_state_events"] @@ -157,6 +173,61 @@ class EventContext(object): defer.returnValue(context) + @defer.inlineCallbacks + def get_current_state_ids(self, store): + """Gets the current state IDs + + Returns: + Deferred[dict[(str, str), str]|None]: Returns None if state_group + is None, which happens when the associated event is an outlier. + """ + + if not self._fetching_state_deferred: + self._fetching_state_deferred = run_in_background( + self._fill_out_state, store, + ) + + yield make_deferred_yieldable(self._fetching_state_deferred) + + defer.returnValue(self._current_state_ids) + + @defer.inlineCallbacks + def get_prev_state_ids(self, store): + """Gets the prev state IDs + + Returns: + Deferred[dict[(str, str), str]|None]: Returns None if state_group + is None, which happens when the associated event is an outlier. + """ + + if not self._fetching_state_deferred: + self._fetching_state_deferred = run_in_background( + self._fill_out_state, store, + ) + + yield make_deferred_yieldable(self._fetching_state_deferred) + + defer.returnValue(self._prev_state_ids) + + @defer.inlineCallbacks + def _fill_out_state(self, store): + """Called to populate the _current_state_ids and _prev_state_ids + attributes by loading from the database. + """ + if self.state_group is None: + return + + self._current_state_ids = yield store.get_state_ids_for_group( + self.state_group, + ) + if self._prev_state_id and self._event_state_key is not None: + self._prev_state_ids = dict(self._current_state_ids) + + key = (self._event_type, self._event_state_key) + self._prev_state_ids[key] = self._prev_state_id + else: + self._prev_state_ids = self._current_state_ids + def _encode_state_dict(state_dict): """Since dicts of (type, state_key) -> event_id cannot be serialized in diff --git a/synapse/state.py b/synapse/state.py index a70869500..32125c95d 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -219,7 +219,7 @@ class StateHandler(object): # We don't store state for outliers, so we don't generate a state # group for it. - context = EventContext( + context = EventContext.with_state( state_group=None, current_state_ids=current_state_ids, prev_state_ids=prev_state_ids, @@ -255,7 +255,7 @@ class StateHandler(object): current_state_ids=current_state_ids, ) - context = EventContext( + context = EventContext.with_state( state_group=state_group, current_state_ids=current_state_ids, prev_state_ids=prev_state_ids, @@ -322,7 +322,7 @@ class StateHandler(object): state_group = entry.state_group - context = EventContext( + context = EventContext.with_state( state_group=state_group, current_state_ids=current_state_ids, prev_state_ids=prev_state_ids, From e42510ba635b3e4d83215e4f5634ca51411996e0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 23 Jul 2018 13:00:22 +0100 Subject: [PATCH 05/16] Use new getters --- synapse/api/auth.py | 6 ++++-- synapse/handlers/_base.py | 3 ++- synapse/handlers/federation.py | 23 ++++++++++++++------- synapse/handlers/message.py | 26 +++++++++++++++--------- synapse/handlers/room_member.py | 9 +++++--- synapse/push/bulk_push_rule_evaluator.py | 7 ++++--- synapse/storage/events.py | 2 +- synapse/storage/push_rule.py | 7 +++++-- synapse/storage/roommember.py | 7 +++++-- 9 files changed, 59 insertions(+), 31 deletions(-) diff --git a/synapse/api/auth.py b/synapse/api/auth.py index bc629832d..535bdb449 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -65,8 +65,9 @@ class Auth(object): @defer.inlineCallbacks def check_from_context(self, event, context, do_sig_check=True): + prev_state_ids = yield context.get_prev_state_ids(self.store) auth_events_ids = yield self.compute_auth_events( - event, context.prev_state_ids, for_verification=True, + event, prev_state_ids, for_verification=True, ) auth_events = yield self.store.get_events(auth_events_ids) auth_events = { @@ -544,7 +545,8 @@ class Auth(object): @defer.inlineCallbacks def add_auth_events(self, builder, context): - auth_ids = yield self.compute_auth_events(builder, context.prev_state_ids) + prev_state_ids = yield context.get_prev_state_ids(self.store) + auth_ids = yield self.compute_auth_events(builder, prev_state_ids) auth_events_entries = yield self.store.add_event_hashes( auth_ids diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index b6a8b3aa3..704181d2d 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -112,8 +112,9 @@ class BaseHandler(object): guest_access = event.content.get("guest_access", "forbidden") if guest_access != "can_join": if context: + current_state_ids = yield context.get_current_state_ids(self.store) current_state = yield self.store.get_events( - list(context.current_state_ids.values()) + list(current_state_ids.values()) ) else: current_state = yield self.state_handler.get_current_state( diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index a6d391c4e..98dd4a7fd 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -486,7 +486,10 @@ class FederationHandler(BaseHandler): # joined the room. Don't bother if the user is just # changing their profile info. newly_joined = True - prev_state_id = context.prev_state_ids.get( + + prev_state_ids = yield context.get_prev_state_ids(self.store) + + prev_state_id = prev_state_ids.get( (event.type, event.state_key) ) if prev_state_id: @@ -1106,10 +1109,12 @@ class FederationHandler(BaseHandler): user = UserID.from_string(event.state_key) yield user_joined_room(self.distributor, user, event.room_id) - state_ids = list(context.prev_state_ids.values()) + prev_state_ids = yield context.get_prev_state_ids(self.store) + + state_ids = list(prev_state_ids.values()) auth_chain = yield self.store.get_auth_chain(state_ids) - state = yield self.store.get_events(list(context.prev_state_ids.values())) + state = yield self.store.get_events(list(prev_state_ids.values())) defer.returnValue({ "state": list(state.values()), @@ -1635,8 +1640,9 @@ class FederationHandler(BaseHandler): ) if not auth_events: + prev_state_ids = yield context.get_prev_state_ids(self.store) auth_events_ids = yield self.auth.compute_auth_events( - event, context.prev_state_ids, for_verification=True, + event, prev_state_ids, for_verification=True, ) auth_events = yield self.store.get_events(auth_events_ids) auth_events = { @@ -1876,9 +1882,10 @@ class FederationHandler(BaseHandler): break if do_resolution: + prev_state_ids = yield context.get_prev_state_ids(self.store) # 1. Get what we think is the auth chain. auth_ids = yield self.auth.compute_auth_events( - event, context.prev_state_ids + event, prev_state_ids ) local_auth_chain = yield self.store.get_auth_chain( auth_ids, include_given=True @@ -2222,7 +2229,8 @@ class FederationHandler(BaseHandler): event.content["third_party_invite"]["signed"]["token"] ) original_invite = None - original_invite_id = context.prev_state_ids.get(key) + prev_state_ids = yield context.get_prev_state_ids(self.store) + original_invite_id = prev_state_ids.get(key) if original_invite_id: original_invite = yield self.store.get_event( original_invite_id, allow_none=True @@ -2264,7 +2272,8 @@ class FederationHandler(BaseHandler): signed = event.content["third_party_invite"]["signed"] token = signed["token"] - invite_event_id = context.prev_state_ids.get( + prev_state_ids = yield context.get_prev_state_ids(self.store) + invite_event_id = prev_state_ids.get( (EventTypes.ThirdPartyInvite, token,) ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index abc07ea87..c4bcd9018 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -630,7 +630,8 @@ class EventCreationHandler(object): If so, returns the version of the event in context. Otherwise, returns None. """ - prev_event_id = context.prev_state_ids.get((event.type, event.state_key)) + prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_event_id = prev_state_ids.get((event.type, event.state_key)) prev_event = yield self.store.get_event(prev_event_id, allow_none=True) if not prev_event: return @@ -752,8 +753,8 @@ class EventCreationHandler(object): event = builder.build() logger.debug( - "Created event %s with state: %s", - event.event_id, context.prev_state_ids, + "Created event %s", + event.event_id, ) defer.returnValue( @@ -884,9 +885,11 @@ class EventCreationHandler(object): e.sender == event.sender ) + current_state_ids = yield context.get_current_state_ids(self.store) + state_to_include_ids = [ e_id - for k, e_id in iteritems(context.current_state_ids) + for k, e_id in iteritems(current_state_ids) if k[0] in self.hs.config.room_invite_state_types or k == (EventTypes.Member, event.sender) ] @@ -922,8 +925,9 @@ class EventCreationHandler(object): ) if event.type == EventTypes.Redaction: + prev_state_ids = yield context.get_prev_state_ids(self.store) auth_events_ids = yield self.auth.compute_auth_events( - event, context.prev_state_ids, for_verification=True, + event, prev_state_ids, for_verification=True, ) auth_events = yield self.store.get_events(auth_events_ids) auth_events = { @@ -943,11 +947,13 @@ class EventCreationHandler(object): "You don't have permission to redact events" ) - if event.type == EventTypes.Create and context.prev_state_ids: - raise AuthError( - 403, - "Changing the room create event is forbidden", - ) + if event.type == EventTypes.Create: + prev_state_ids = yield context.get_prev_state_ids(self.store) + if prev_state_ids: + raise AuthError( + 403, + "Changing the room create event is forbidden", + ) (event_stream_id, max_stream_id) = yield self.store.persist_event( event, context=context diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 00f2e279b..a832d9180 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -201,7 +201,9 @@ class RoomMemberHandler(object): ratelimit=ratelimit, ) - prev_member_event_id = context.prev_state_ids.get( + prev_state_ids = yield context.get_prev_state_ids(self.store) + + prev_member_event_id = prev_state_ids.get( (EventTypes.Member, target.to_string()), None ) @@ -496,9 +498,10 @@ class RoomMemberHandler(object): if prev_event is not None: return + prev_state_ids = yield context.get_prev_state_ids(self.store) if event.membership == Membership.JOIN: if requester.is_guest: - guest_can_join = yield self._can_guest_join(context.prev_state_ids) + guest_can_join = yield self._can_guest_join(prev_state_ids) if not guest_can_join: # This should be an auth check, but guests are a local concept, # so don't really fit into the general auth process. @@ -517,7 +520,7 @@ class RoomMemberHandler(object): ratelimit=ratelimit, ) - prev_member_event_id = context.prev_state_ids.get( + prev_member_event_id = prev_state_ids.get( (EventTypes.Member, event.state_key), None ) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index bb181d94e..1d14d3639 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -112,7 +112,8 @@ class BulkPushRuleEvaluator(object): @defer.inlineCallbacks def _get_power_levels_and_sender_level(self, event, context): - pl_event_id = context.prev_state_ids.get(POWER_KEY) + prev_state_ids = yield context.get_prev_state_ids(self.store) + pl_event_id = prev_state_ids.get(POWER_KEY) if pl_event_id: # fastpath: if there's a power level event, that's all we need, and # not having a power level event is an extreme edge case @@ -120,7 +121,7 @@ class BulkPushRuleEvaluator(object): auth_events = {POWER_KEY: pl_event} else: auth_events_ids = yield self.auth.compute_auth_events( - event, context.prev_state_ids, for_verification=False, + event, prev_state_ids, for_verification=False, ) auth_events = yield self.store.get_events(auth_events_ids) auth_events = { @@ -304,7 +305,7 @@ class RulesForRoom(object): push_rules_delta_state_cache_metric.inc_hits() else: - current_state_ids = context.current_state_ids + current_state_ids = yield context.get_current_state_ids(self.store) push_rules_delta_state_cache_metric.inc_misses() push_rules_state_size_counter.inc(len(current_state_ids)) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 4ff0fdc4a..bf4f3ee92 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -549,7 +549,7 @@ class EventsStore(EventsWorkerStore): if ctx.state_group in state_groups_map: continue - state_groups_map[ctx.state_group] = ctx.current_state_ids + state_groups_map[ctx.state_group] = yield ctx.get_current_state_ids(self) # We need to map the event_ids to their state groups. First, let's # check if the event is one we're persisting, in which case we can diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index be655d287..af564b1b4 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -186,6 +186,7 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore, defer.returnValue(results) + @defer.inlineCallbacks def bulk_get_push_rules_for_room(self, event, context): state_group = context.state_group if not state_group: @@ -195,9 +196,11 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore, # To do this we set the state_group to a new object as object() != object() state_group = object() - return self._bulk_get_push_rules_for_room( - event.room_id, state_group, context.current_state_ids, event=event + current_state_ids = yield context.get_current_state_ids(self) + result = yield self._bulk_get_push_rules_for_room( + event.room_id, state_group, current_state_ids, event=event ) + defer.returnValue(result) @cachedInlineCallbacks(num_args=2, cache_context=True) def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state_ids, diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 02a802bed..a27702a7a 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -232,6 +232,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): defer.returnValue(user_who_share_room) + @defer.inlineCallbacks def get_joined_users_from_context(self, event, context): state_group = context.state_group if not state_group: @@ -241,11 +242,13 @@ class RoomMemberWorkerStore(EventsWorkerStore): # To do this we set the state_group to a new object as object() != object() state_group = object() - return self._get_joined_users_from_context( - event.room_id, state_group, context.current_state_ids, + current_state_ids = yield context.get_current_state_ids(self) + result = yield self._get_joined_users_from_context( + event.room_id, state_group, current_state_ids, event=event, context=context, ) + defer.returnValue(result) def get_joined_users_from_state(self, room_id, state_entry): state_group = state_entry.state_group From 027bc01a1bc254fe08140c6e91a9fb945b08486f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 23 Jul 2018 13:02:09 +0100 Subject: [PATCH 06/16] Add support for updating state --- synapse/events/snapshot.py | 19 +++++++++++++++++++ synapse/handlers/federation.py | 32 +++++++++++++++++++++++--------- 2 files changed, 42 insertions(+), 9 deletions(-) diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index f9568638a..b090751bf 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -228,6 +228,25 @@ class EventContext(object): else: self._prev_state_ids = self._current_state_ids + @defer.inlineCallbacks + def update_state(self, state_group, prev_state_ids, current_state_ids, + delta_ids): + """Replace the state in the context + """ + + # We need to make sure we wait for any ongoing fetching of state + # to complete so that the updated state doesn't get clobbered + if self._fetching_state_deferred: + yield make_deferred_yieldable(self._fetching_state_deferred) + + self.state_group = state_group + self._prev_state_ids = prev_state_ids + self._current_state_ids = current_state_ids + self.delta_ids = delta_ids + + # We need to ensure that that we've marked as having fetched the state + self._fetching_state_deferred = defer.succeed(None) + def _encode_state_dict(state_dict): """Since dicts of (type, state_key) -> event_id cannot be serialized in diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 98dd4a7fd..14654d59f 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1975,21 +1975,35 @@ class FederationHandler(BaseHandler): k: a.event_id for k, a in iteritems(auth_events) if k != event_key } - context.current_state_ids = dict(context.current_state_ids) - context.current_state_ids.update(state_updates) + current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = dict(current_state_ids) + + current_state_ids.update(state_updates) + if context.delta_ids is not None: - context.delta_ids = dict(context.delta_ids) - context.delta_ids.update(state_updates) - context.prev_state_ids = dict(context.prev_state_ids) - context.prev_state_ids.update({ + delta_ids = dict(context.delta_ids) + delta_ids.update(state_updates) + + prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = dict(prev_state_ids) + + prev_state_ids.update({ k: a.event_id for k, a in iteritems(auth_events) }) - context.state_group = yield self.store.store_state_group( + + state_group = yield self.store.store_state_group( event.event_id, event.room_id, prev_group=context.prev_group, - delta_ids=context.delta_ids, - current_state_ids=context.current_state_ids, + delta_ids=delta_ids, + current_state_ids=current_state_ids, + ) + + yield context.update_state( + state_group=state_group, + current_state_ids=current_state_ids, + prev_state_ids=prev_state_ids, + delta_ids=delta_ids, ) @defer.inlineCallbacks From f3182bb1d0ba496bef710a529c4cd7a99da72061 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 23 Jul 2018 13:19:24 +0100 Subject: [PATCH 07/16] Newsfile --- changelog.d/3579.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/3579.misc diff --git a/changelog.d/3579.misc b/changelog.d/3579.misc new file mode 100644 index 000000000..2374dc0c4 --- /dev/null +++ b/changelog.d/3579.misc @@ -0,0 +1 @@ +Lazily load state on master process when using workers to reduce DB consumption From 8fbe418777a62ea6dd5fd811d63e30c42589650b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 23 Jul 2018 13:33:49 +0100 Subject: [PATCH 08/16] Fix unit tests --- .../replication/slave/storage/test_events.py | 8 ++-- tests/test_state.py | 47 ++++++++++++++----- 2 files changed, 40 insertions(+), 15 deletions(-) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index cea01d93e..f5b47f5ec 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -222,9 +222,11 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): state_ids = { key: e.event_id for key, e in state.items() } - context = EventContext() - context.current_state_ids = state_ids - context.prev_state_ids = state_ids + context = EventContext.with_state( + state_group=None, + current_state_ids=state_ids, + prev_state_ids=state_ids + ) else: state_handler = self.hs.get_state_handler() context = yield state_handler.compute_event_context(event) diff --git a/tests/test_state.py b/tests/test_state.py index c0f2d1152..429a18cbf 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -204,7 +204,8 @@ class StateTestCase(unittest.TestCase): self.store.register_event_context(event, context) context_store[event.event_id] = context - self.assertEqual(2, len(context_store["D"].prev_state_ids)) + prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store) + self.assertEqual(2, len(prev_state_ids)) @defer.inlineCallbacks def test_branch_basic_conflict(self): @@ -255,9 +256,11 @@ class StateTestCase(unittest.TestCase): self.store.register_event_context(event, context) context_store[event.event_id] = context + prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store) + self.assertSetEqual( {"START", "A", "C"}, - {e_id for e_id in context_store["D"].prev_state_ids.values()} + {e_id for e_id in prev_state_ids.values()} ) @defer.inlineCallbacks @@ -318,9 +321,11 @@ class StateTestCase(unittest.TestCase): self.store.register_event_context(event, context) context_store[event.event_id] = context + prev_state_ids = yield context_store["E"].get_prev_state_ids(self.store) + self.assertSetEqual( {"START", "A", "B", "C"}, - {e for e in context_store["E"].prev_state_ids.values()} + {e for e in prev_state_ids.values()} ) @defer.inlineCallbacks @@ -398,9 +403,11 @@ class StateTestCase(unittest.TestCase): self.store.register_event_context(event, context) context_store[event.event_id] = context + prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store) + self.assertSetEqual( {"A1", "A2", "A3", "A5", "B"}, - {e for e in context_store["D"].prev_state_ids.values()} + {e for e in prev_state_ids.values()} ) def _add_depths(self, nodes, edges): @@ -429,8 +436,10 @@ class StateTestCase(unittest.TestCase): event, old_state=old_state ) + current_state_ids = yield context.get_current_state_ids(self.store) + self.assertEqual( - set(e.event_id for e in old_state), set(context.current_state_ids.values()) + set(e.event_id for e in old_state), set(current_state_ids.values()) ) self.assertIsNotNone(context.state_group) @@ -449,8 +458,10 @@ class StateTestCase(unittest.TestCase): event, old_state=old_state ) + prev_state_ids = yield context.get_prev_state_ids(self.store) + self.assertEqual( - set(e.event_id for e in old_state), set(context.prev_state_ids.values()) + set(e.event_id for e in old_state), set(prev_state_ids.values()) ) @defer.inlineCallbacks @@ -475,9 +486,11 @@ class StateTestCase(unittest.TestCase): context = yield self.state.compute_event_context(event) + current_state_ids = yield context.get_current_state_ids(self.store) + self.assertEqual( set([e.event_id for e in old_state]), - set(context.current_state_ids.values()) + set(current_state_ids.values()) ) self.assertEqual(group_name, context.state_group) @@ -504,9 +517,11 @@ class StateTestCase(unittest.TestCase): context = yield self.state.compute_event_context(event) + prev_state_ids = yield context.get_prev_state_ids(self.store) + self.assertEqual( set([e.event_id for e in old_state]), - set(context.prev_state_ids.values()) + set(prev_state_ids.values()) ) self.assertIsNotNone(context.state_group) @@ -545,7 +560,9 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2, ) - self.assertEqual(len(context.current_state_ids), 6) + current_state_ids = yield context.get_current_state_ids(self.store) + + self.assertEqual(len(current_state_ids), 6) self.assertIsNotNone(context.state_group) @@ -585,7 +602,9 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2, ) - self.assertEqual(len(context.current_state_ids), 6) + current_state_ids = yield context.get_current_state_ids(self.store) + + self.assertEqual(len(current_state_ids), 6) self.assertIsNotNone(context.state_group) @@ -642,8 +661,10 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2, ) + current_state_ids = yield context.get_current_state_ids(self.store) + self.assertEqual( - old_state_2[3].event_id, context.current_state_ids[("test1", "1")] + old_state_2[3].event_id, current_state_ids[("test1", "1")] ) # Reverse the depth to make sure we are actually using the depths @@ -670,8 +691,10 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2, ) + current_state_ids = yield context.get_current_state_ids(self.store) + self.assertEqual( - old_state_1[3].event_id, context.current_state_ids[("test1", "1")] + old_state_1[3].event_id, current_state_ids[("test1", "1")] ) def _get_context(self, event, prev_event_id_1, old_state_1, prev_event_id_2, From 4797ed000e612b68c418ce8c342bdd7ecc16b198 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 23 Jul 2018 15:05:40 +0100 Subject: [PATCH 09/16] Update docstrings to make sense --- synapse/events/snapshot.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index b090751bf..a6d7bf570 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -55,13 +55,16 @@ class EventContext(object): _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have been calculated. None if we haven't started calculating yet - _prev_state_id (str|None): If set then the event associated with the - context overrode the _prev_state_id - - _event_type (str): The type of the event the context is associated with + _event_type (str): The type of the event the context is associated with. + Only set when state has not been fetched yet. _event_state_key (str|None): The state_key of the event the context is - associated with + associated with. Only set when state has not been fetched yet. + + _prev_state_id (str|None): If the event associated with the context is + a state event, then `_prev_state_id` is the event_id of the state + that was replaced. + Only set when state has not been fetched yet. """ __slots__ = [ From 999bcf9d016fb7fd9ad5a9daf4f0ec6d25a10717 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 23 Jul 2018 15:24:21 +0100 Subject: [PATCH 10/16] Fix EventContext when using workers We were: 1. Not correctly setting all attributes 2. Using defer.inlineCallbacks in a non-generator --- synapse/events/snapshot.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index a6d7bf570..e31eceb92 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -82,6 +82,11 @@ class EventContext(object): "_fetching_state_deferred", ] + def __init__(self): + self.prev_state_events = [] + self.rejected = False + self.app_service = None + @staticmethod def with_state(state_group, current_state_ids, prev_state_ids, prev_group=None, delta_ids=None): @@ -103,11 +108,6 @@ class EventContext(object): context.prev_group = prev_group context.delta_ids = delta_ids - context.prev_state_events = [] - - context.rejected = False - context.app_service = None - return context def serialize(self, event): @@ -143,7 +143,6 @@ class EventContext(object): } @staticmethod - @defer.inlineCallbacks def deserialize(store, input): """Converts a dict that was produced by `serialize` back into a EventContext. @@ -162,6 +161,7 @@ class EventContext(object): context._prev_state_id = input["prev_state_id"] context._event_type = input["event_type"] context._event_state_key = input["event_state_key"] + context._fetching_state_deferred = None context.state_group = input["state_group"] context.prev_group = input["prev_group"] @@ -174,7 +174,7 @@ class EventContext(object): if app_service_id: context.app_service = store.get_app_service_by_id(app_service_id) - defer.returnValue(context) + return context @defer.inlineCallbacks def get_current_state_ids(self, store): From a4d24781bf2290311af8f901862d83134bbd6229 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 23 Jul 2018 15:28:51 +0100 Subject: [PATCH 11/16] Newsfile --- changelog.d/3581.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/3581.misc diff --git a/changelog.d/3581.misc b/changelog.d/3581.misc new file mode 100644 index 000000000..2374dc0c4 --- /dev/null +++ b/changelog.d/3581.misc @@ -0,0 +1 @@ +Lazily load state on master process when using workers to reduce DB consumption From 0faa3223cdf996aa18376a7420a43061a6691638 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 23 Jul 2018 16:28:00 +0100 Subject: [PATCH 12/16] Fix missing attributes on workers. This was missed during the transition from attribute to getter for getting state from context. --- synapse/events/snapshot.py | 10 ++++++---- synapse/handlers/message.py | 5 +++-- synapse/replication/http/send_event.py | 7 +++++-- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index e31eceb92..a59064b41 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -110,7 +110,8 @@ class EventContext(object): return context - def serialize(self, event): + @defer.inlineCallbacks + def serialize(self, event, store): """Converts self to a type that can be serialized as JSON, and then deserialized by `deserialize` @@ -126,11 +127,12 @@ class EventContext(object): # the prev_state_ids, so if we're a state event we include the event # id that we replaced in the state. if event.is_state(): - prev_state_id = self.prev_state_ids.get((event.type, event.state_key)) + prev_state_ids = yield self.get_prev_state_ids(store) + prev_state_id = prev_state_ids.get((event.type, event.state_key)) else: prev_state_id = None - return { + defer.returnValue({ "prev_state_id": prev_state_id, "event_type": event.type, "event_state_key": event.state_key if event.is_state() else None, @@ -140,7 +142,7 @@ class EventContext(object): "delta_ids": _encode_state_dict(self.delta_ids), "prev_state_events": self.prev_state_events, "app_service_id": self.app_service.id if self.app_service else None - } + }) @staticmethod def deserialize(store, input): diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index c4bcd9018..7571975c2 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -807,8 +807,9 @@ class EventCreationHandler(object): # If we're a worker we need to hit out to the master. if self.config.worker_app: yield send_event_to_master( - self.hs.get_clock(), - self.http_client, + clock=self.hs.get_clock(), + store=self.store, + client=self.http_client, host=self.config.worker_replication_host, port=self.config.worker_replication_http_port, requester=requester, diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index 2eede5479..5227bc333 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -34,12 +34,13 @@ logger = logging.getLogger(__name__) @defer.inlineCallbacks -def send_event_to_master(clock, client, host, port, requester, event, context, +def send_event_to_master(clock, store, client, host, port, requester, event, context, ratelimit, extra_users): """Send event to be handled on the master Args: clock (synapse.util.Clock) + store (DataStore) client (SimpleHttpClient) host (str): host of master port (int): port on master listening for HTTP replication @@ -53,11 +54,13 @@ def send_event_to_master(clock, client, host, port, requester, event, context, host, port, event.event_id, ) + serialized_context = yield context.serialize(event, store) + payload = { "event": event.get_pdu_json(), "internal_metadata": event.internal_metadata.get_dict(), "rejected_reason": event.rejected_reason, - "context": context.serialize(event), + "context": serialized_context, "requester": requester.serialize(), "ratelimit": ratelimit, "extra_users": [u.to_string() for u in extra_users], From 9f41ad491dd950058562e25045bc6ae3539ddc79 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 23 Jul 2018 16:31:46 +0100 Subject: [PATCH 13/16] Newsfile --- changelog.d/3582.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/3582.misc diff --git a/changelog.d/3582.misc b/changelog.d/3582.misc new file mode 100644 index 000000000..2374dc0c4 --- /dev/null +++ b/changelog.d/3582.misc @@ -0,0 +1 @@ +Lazily load state on master process when using workers to reduce DB consumption From 50c60e5fadbefff6785c17dda9eecf88286dba30 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 23 Jul 2018 17:21:40 +0100 Subject: [PATCH 14/16] Only get cached state from context in persist_event We don't want to bother pulling out the current state from the DB since until we know we have to. Checking the context for state is just an optimisation. --- synapse/events/snapshot.py | 13 +++++++++++++ synapse/storage/events.py | 4 +++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index a59064b41..c439b5380 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -163,6 +163,9 @@ class EventContext(object): context._prev_state_id = input["prev_state_id"] context._event_type = input["event_type"] context._event_state_key = input["event_state_key"] + + context._current_state_ids = None + context._prev_state_ids = None context._fetching_state_deferred = None context.state_group = input["state_group"] @@ -214,6 +217,16 @@ class EventContext(object): defer.returnValue(self._prev_state_ids) + def get_cached_current_state_ids(self): + """Gets the current state IDs if we have them already cached. + + Returns: + dict[(str, str), str]|None: Returns None if state_group + is None, which happens when the associated event is an outlier. + """ + + return self._current_state_ids + @defer.inlineCallbacks def _fill_out_state(self, store): """Called to populate the _current_state_ids and _prev_state_ids diff --git a/synapse/storage/events.py b/synapse/storage/events.py index bf4f3ee92..dc0b3c2eb 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -549,7 +549,9 @@ class EventsStore(EventsWorkerStore): if ctx.state_group in state_groups_map: continue - state_groups_map[ctx.state_group] = yield ctx.get_current_state_ids(self) + current_state_ids = ctx.get_cached_current_state_ids() + if current_state_ids is not None: + state_groups_map[ctx.state_group] = current_state_ids # We need to map the event_ids to their state groups. First, let's # check if the event is one we're persisting, in which case we can From 2d5bba151bec414ea6f16417c5eae335205c0df0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 23 Jul 2018 17:29:32 +0100 Subject: [PATCH 15/16] Newsfile --- changelog.d/3584.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/3584.misc diff --git a/changelog.d/3584.misc b/changelog.d/3584.misc new file mode 100644 index 000000000..2374dc0c4 --- /dev/null +++ b/changelog.d/3584.misc @@ -0,0 +1 @@ +Lazily load state on master process when using workers to reduce DB consumption From 8b9f164fff6cf821ff5bc702f3660c0f0eb320e7 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 23 Jul 2018 17:43:01 +0100 Subject: [PATCH 16/16] Comments --- synapse/events/snapshot.py | 5 +++-- synapse/storage/events.py | 3 +++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index c439b5380..189212b0f 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -221,8 +221,9 @@ class EventContext(object): """Gets the current state IDs if we have them already cached. Returns: - dict[(str, str), str]|None: Returns None if state_group - is None, which happens when the associated event is an outlier. + dict[(str, str), str]|None: Returns None if we haven't cached the + state or if state_group is None, which happens when the associated + event is an outlier. """ return self._current_state_ids diff --git a/synapse/storage/events.py b/synapse/storage/events.py index dc0b3c2eb..c2910094d 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -549,6 +549,9 @@ class EventsStore(EventsWorkerStore): if ctx.state_group in state_groups_map: continue + # We're only interested in pulling out state that has already + # been cached in the context. We'll pull stuff out of the DB later + # if necessary. current_state_ids = ctx.get_cached_current_state_ids() if current_state_ids is not None: state_groups_map[ctx.state_group] = current_state_ids