incorporate more review

pull/14/head
Matthew Hodgson 6 years ago
parent efcdacad7d
commit cd241d6bda
  1. 12
      synapse/handlers/sync.py
  2. 36
      synapse/storage/state.py
  3. 9
      tests/storage/test_state.py

@ -1526,6 +1526,9 @@ def _calculate_state(
previous (dict): state at the end of the previous sync (or empty dict previous (dict): state at the end of the previous sync (or empty dict
if this is an initial sync) if this is an initial sync)
current (dict): state at the end of the timeline current (dict): state at the end of the timeline
lazy_load_members (bool): whether to return members from timeline_start
or not. assumes that timeline_start has already been filtered to
include only the members the client needs to know about.
Returns: Returns:
dict dict
@ -1545,9 +1548,12 @@ def _calculate_state(
p_ids = set(e for e in previous.values()) p_ids = set(e for e in previous.values())
tc_ids = set(e for e in timeline_contains.values()) tc_ids = set(e for e in timeline_contains.values())
# track the membership events in the state as of the start of the timeline # If we are lazyloading room members, we explicitly add the membership events
# so we can add them back in to the state if we're lazyloading. We don't # for the senders in the timeline into the state block returned by /sync,
# add them into state if they're already contained in the timeline. # as we may not have sent them to the client before. We find these membership
# events by filtering them out of timeline_start, which has already been filtered
# to only include membership events for the senders in the timeline.
if lazy_load_members: if lazy_load_members:
ll_ids = set( ll_ids = set(
e for t, e in timeline_start.iteritems() e for t, e in timeline_start.iteritems()

@ -185,7 +185,7 @@ class StateGroupWorkerStore(SQLBaseStore):
}) })
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_state_groups_from_groups(self, groups, types, filtered_types=None): def _get_state_groups_from_groups(self, groups, types):
"""Returns the state groups for a given set of groups, filtering on """Returns the state groups for a given set of groups, filtering on
types of state events. types of state events.
@ -194,9 +194,6 @@ class StateGroupWorkerStore(SQLBaseStore):
types (Iterable[str, str|None]|None): list of 2-tuples of the form types (Iterable[str, str|None]|None): list of 2-tuples of the form
(`type`, `state_key`), where a `state_key` of `None` matches all (`type`, `state_key`), where a `state_key` of `None` matches all
state_keys for the `type`. If None, all types are returned. state_keys for the `type`. If None, all types are returned.
filtered_types(Iterable[str]|None): Only apply filtering via `types` to this
list of event types. Other types of events are returned unfiltered.
If None, `types` filtering is applied to all events.
Returns: Returns:
dictionary state_group -> (dict of (type, state_key) -> event id) dictionary state_group -> (dict of (type, state_key) -> event id)
@ -207,14 +204,14 @@ class StateGroupWorkerStore(SQLBaseStore):
for chunk in chunks: for chunk in chunks:
res = yield self.runInteraction( res = yield self.runInteraction(
"_get_state_groups_from_groups", "_get_state_groups_from_groups",
self._get_state_groups_from_groups_txn, chunk, types, filtered_types, self._get_state_groups_from_groups_txn, chunk, types,
) )
results.update(res) results.update(res)
defer.returnValue(results) defer.returnValue(results)
def _get_state_groups_from_groups_txn( def _get_state_groups_from_groups_txn(
self, txn, groups, types=None, filtered_types=None, self, txn, groups, types=None,
): ):
results = {group: {} for group in groups} results = {group: {} for group in groups}
@ -266,17 +263,6 @@ class StateGroupWorkerStore(SQLBaseStore):
) )
for etype, state_key in types for etype, state_key in types
] ]
if filtered_types is not None:
# XXX: check whether this slows postgres down like a list of
# ORs does too?
unique_types = set(filtered_types)
clause_to_args.append(
(
"AND type <> ? " * len(unique_types),
list(unique_types)
)
)
else: else:
# If types is None we fetch all the state, and so just use an # If types is None we fetch all the state, and so just use an
# empty where clause with no extra args. # empty where clause with no extra args.
@ -306,13 +292,6 @@ class StateGroupWorkerStore(SQLBaseStore):
where_clauses.append("(type = ? AND state_key = ?)") where_clauses.append("(type = ? AND state_key = ?)")
where_args.extend([typ[0], typ[1]]) where_args.extend([typ[0], typ[1]])
if filtered_types is not None:
unique_types = set(filtered_types)
where_clauses.append(
"(" + " AND ".join(["type <> ?"] * len(unique_types)) + ")"
)
where_args.extend(list(unique_types))
where_clause = "AND (%s)" % (" OR ".join(where_clauses)) where_clause = "AND (%s)" % (" OR ".join(where_clauses))
else: else:
where_clause = "" where_clause = ""
@ -643,13 +622,13 @@ class StateGroupWorkerStore(SQLBaseStore):
# cache. Hence, if we are doing a wildcard lookup, populate the # cache. Hence, if we are doing a wildcard lookup, populate the
# cache fully so that we can do an efficient lookup next time. # cache fully so that we can do an efficient lookup next time.
if types and any(k is None for (t, k) in types): if filtered_types or (types and any(k is None for (t, k) in types)):
types_to_fetch = None types_to_fetch = None
else: else:
types_to_fetch = types types_to_fetch = types
group_to_state_dict = yield self._get_state_groups_from_groups( group_to_state_dict = yield self._get_state_groups_from_groups(
missing_groups, types_to_fetch, filtered_types missing_groups, types_to_fetch
) )
for group, group_state_dict in iteritems(group_to_state_dict): for group, group_state_dict in iteritems(group_to_state_dict):
@ -659,7 +638,10 @@ class StateGroupWorkerStore(SQLBaseStore):
if types: if types:
for k, v in iteritems(group_state_dict): for k, v in iteritems(group_state_dict):
(typ, _) = k (typ, _) = k
if k in types or (typ, None) in types: if (
(k in types or (typ, None) in types) or
(filtered_types and typ not in filtered_types)
):
state_dict[k] = v state_dict[k] = v
else: else:
state_dict.update(group_state_dict) state_dict.update(group_state_dict)

@ -158,3 +158,12 @@ class StateStoreTestCase(tests.unittest.TestCase):
(e2.type, e2.state_key): e2, (e2.type, e2.state_key): e2,
(e3.type, e3.state_key): e3, (e3.type, e3.state_key): e3,
}, state) }, state)
state = yield self.store.get_state_for_event(
e5.event_id, [], filtered_types=[EventTypes.Member],
)
self.assertStateMapEqual({
(e1.type, e1.state_key): e1,
(e2.type, e2.state_key): e2,
}, state)

Loading…
Cancel
Save