|
|
|
@ -337,7 +337,7 @@ class StateHandler(object): |
|
|
|
|
for st in state_groups_ids.values() |
|
|
|
|
] |
|
|
|
|
with Measure(self.clock, "state._resolve_events"): |
|
|
|
|
new_state, _ = Resolver.resolve_events( |
|
|
|
|
new_state, _ = resolve_events( |
|
|
|
|
state_sets, event_type, state_key |
|
|
|
|
) |
|
|
|
|
new_state = { |
|
|
|
@ -392,11 +392,11 @@ class StateHandler(object): |
|
|
|
|
) |
|
|
|
|
with Measure(self.clock, "state._resolve_events"): |
|
|
|
|
if event.is_state(): |
|
|
|
|
return Resolver.resolve_events( |
|
|
|
|
return resolve_events( |
|
|
|
|
state_sets, event.type, event.state_key |
|
|
|
|
) |
|
|
|
|
else: |
|
|
|
|
return Resolver.resolve_events(state_sets) |
|
|
|
|
return resolve_events(state_sets) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _ordered_events(events): |
|
|
|
@ -406,138 +406,136 @@ def _ordered_events(events): |
|
|
|
|
return sorted(events, key=key_func) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Resolver(object): |
|
|
|
|
@staticmethod |
|
|
|
|
def resolve_events(state_sets, event_type=None, state_key=""): |
|
|
|
|
""" |
|
|
|
|
Returns |
|
|
|
|
(dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple |
|
|
|
|
(new_state, prev_states). new_state is a map from (type, state_key) |
|
|
|
|
to event. prev_states is a list of event_ids. |
|
|
|
|
""" |
|
|
|
|
state = {} |
|
|
|
|
for st in state_sets: |
|
|
|
|
for e in st: |
|
|
|
|
state.setdefault( |
|
|
|
|
(e.type, e.state_key), |
|
|
|
|
{} |
|
|
|
|
)[e.event_id] = e |
|
|
|
|
|
|
|
|
|
unconflicted_state = { |
|
|
|
|
k: v.values()[0] for k, v in state.items() |
|
|
|
|
if len(v.values()) == 1 |
|
|
|
|
} |
|
|
|
|
def resolve_events(state_sets, event_type=None, state_key=""): |
|
|
|
|
""" |
|
|
|
|
Returns |
|
|
|
|
(dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple |
|
|
|
|
(new_state, prev_states). new_state is a map from (type, state_key) |
|
|
|
|
to event. prev_states is a list of event_ids. |
|
|
|
|
""" |
|
|
|
|
state = {} |
|
|
|
|
for st in state_sets: |
|
|
|
|
for e in st: |
|
|
|
|
state.setdefault( |
|
|
|
|
(e.type, e.state_key), |
|
|
|
|
{} |
|
|
|
|
)[e.event_id] = e |
|
|
|
|
|
|
|
|
|
unconflicted_state = { |
|
|
|
|
k: v.values()[0] for k, v in state.items() |
|
|
|
|
if len(v.values()) == 1 |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
conflicted_state = { |
|
|
|
|
k: v.values() |
|
|
|
|
for k, v in state.items() |
|
|
|
|
if len(v.values()) > 1 |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if event_type: |
|
|
|
|
prev_states_events = conflicted_state.get( |
|
|
|
|
(event_type, state_key), [] |
|
|
|
|
) |
|
|
|
|
prev_states = [s.event_id for s in prev_states_events] |
|
|
|
|
else: |
|
|
|
|
prev_states = [] |
|
|
|
|
|
|
|
|
|
auth_events = { |
|
|
|
|
k: e for k, e in unconflicted_state.items() |
|
|
|
|
if k[0] in AuthEventTypes |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
resolved_state = _resolve_state_events( |
|
|
|
|
conflicted_state, auth_events |
|
|
|
|
) |
|
|
|
|
except: |
|
|
|
|
logger.exception("Failed to resolve state") |
|
|
|
|
raise |
|
|
|
|
|
|
|
|
|
conflicted_state = { |
|
|
|
|
k: v.values() |
|
|
|
|
for k, v in state.items() |
|
|
|
|
if len(v.values()) > 1 |
|
|
|
|
} |
|
|
|
|
new_state = unconflicted_state |
|
|
|
|
new_state.update(resolved_state) |
|
|
|
|
|
|
|
|
|
if event_type: |
|
|
|
|
prev_states_events = conflicted_state.get( |
|
|
|
|
(event_type, state_key), [] |
|
|
|
|
return new_state, prev_states |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _resolve_state_events(conflicted_state, auth_events): |
|
|
|
|
""" This is where we actually decide which of the conflicted state to |
|
|
|
|
use. |
|
|
|
|
|
|
|
|
|
We resolve conflicts in the following order: |
|
|
|
|
1. power levels |
|
|
|
|
2. join rules |
|
|
|
|
3. memberships |
|
|
|
|
4. other events. |
|
|
|
|
""" |
|
|
|
|
resolved_state = {} |
|
|
|
|
power_key = (EventTypes.PowerLevels, "") |
|
|
|
|
if power_key in conflicted_state: |
|
|
|
|
events = conflicted_state[power_key] |
|
|
|
|
logger.debug("Resolving conflicted power levels %r", events) |
|
|
|
|
resolved_state[power_key] = _resolve_auth_events( |
|
|
|
|
events, auth_events) |
|
|
|
|
|
|
|
|
|
auth_events.update(resolved_state) |
|
|
|
|
|
|
|
|
|
for key, events in conflicted_state.items(): |
|
|
|
|
if key[0] == EventTypes.JoinRules: |
|
|
|
|
logger.debug("Resolving conflicted join rules %r", events) |
|
|
|
|
resolved_state[key] = _resolve_auth_events( |
|
|
|
|
events, |
|
|
|
|
auth_events |
|
|
|
|
) |
|
|
|
|
prev_states = [s.event_id for s in prev_states_events] |
|
|
|
|
else: |
|
|
|
|
prev_states = [] |
|
|
|
|
|
|
|
|
|
auth_events = { |
|
|
|
|
k: e for k, e in unconflicted_state.items() |
|
|
|
|
if k[0] in AuthEventTypes |
|
|
|
|
} |
|
|
|
|
auth_events.update(resolved_state) |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
resolved_state = Resolver._resolve_state_events( |
|
|
|
|
conflicted_state, auth_events |
|
|
|
|
for key, events in conflicted_state.items(): |
|
|
|
|
if key[0] == EventTypes.Member: |
|
|
|
|
logger.debug("Resolving conflicted member lists %r", events) |
|
|
|
|
resolved_state[key] = _resolve_auth_events( |
|
|
|
|
events, |
|
|
|
|
auth_events |
|
|
|
|
) |
|
|
|
|
except: |
|
|
|
|
logger.exception("Failed to resolve state") |
|
|
|
|
raise |
|
|
|
|
|
|
|
|
|
new_state = unconflicted_state |
|
|
|
|
new_state.update(resolved_state) |
|
|
|
|
auth_events.update(resolved_state) |
|
|
|
|
|
|
|
|
|
return new_state, prev_states |
|
|
|
|
for key, events in conflicted_state.items(): |
|
|
|
|
if key not in resolved_state: |
|
|
|
|
logger.debug("Resolving conflicted state %r:%r", key, events) |
|
|
|
|
resolved_state[key] = _resolve_normal_events( |
|
|
|
|
events, auth_events |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def _resolve_state_events(conflicted_state, auth_events): |
|
|
|
|
""" This is where we actually decide which of the conflicted state to |
|
|
|
|
use. |
|
|
|
|
return resolved_state |
|
|
|
|
|
|
|
|
|
We resolve conflicts in the following order: |
|
|
|
|
1. power levels |
|
|
|
|
2. join rules |
|
|
|
|
3. memberships |
|
|
|
|
4. other events. |
|
|
|
|
""" |
|
|
|
|
resolved_state = {} |
|
|
|
|
power_key = (EventTypes.PowerLevels, "") |
|
|
|
|
if power_key in conflicted_state: |
|
|
|
|
events = conflicted_state[power_key] |
|
|
|
|
logger.debug("Resolving conflicted power levels %r", events) |
|
|
|
|
resolved_state[power_key] = Resolver._resolve_auth_events( |
|
|
|
|
events, auth_events) |
|
|
|
|
|
|
|
|
|
auth_events.update(resolved_state) |
|
|
|
|
|
|
|
|
|
for key, events in conflicted_state.items(): |
|
|
|
|
if key[0] == EventTypes.JoinRules: |
|
|
|
|
logger.debug("Resolving conflicted join rules %r", events) |
|
|
|
|
resolved_state[key] = Resolver._resolve_auth_events( |
|
|
|
|
events, |
|
|
|
|
auth_events |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
auth_events.update(resolved_state) |
|
|
|
|
def _resolve_auth_events(events, auth_events): |
|
|
|
|
reverse = [i for i in reversed(_ordered_events(events))] |
|
|
|
|
|
|
|
|
|
for key, events in conflicted_state.items(): |
|
|
|
|
if key[0] == EventTypes.Member: |
|
|
|
|
logger.debug("Resolving conflicted member lists %r", events) |
|
|
|
|
resolved_state[key] = Resolver._resolve_auth_events( |
|
|
|
|
events, |
|
|
|
|
auth_events |
|
|
|
|
) |
|
|
|
|
auth_events = dict(auth_events) |
|
|
|
|
|
|
|
|
|
prev_event = reverse[0] |
|
|
|
|
for event in reverse[1:]: |
|
|
|
|
auth_events[(prev_event.type, prev_event.state_key)] = prev_event |
|
|
|
|
try: |
|
|
|
|
# The signatures have already been checked at this point |
|
|
|
|
event_auth.check(event, auth_events, do_sig_check=False) |
|
|
|
|
prev_event = event |
|
|
|
|
except AuthError: |
|
|
|
|
return prev_event |
|
|
|
|
|
|
|
|
|
auth_events.update(resolved_state) |
|
|
|
|
return event |
|
|
|
|
|
|
|
|
|
for key, events in conflicted_state.items(): |
|
|
|
|
if key not in resolved_state: |
|
|
|
|
logger.debug("Resolving conflicted state %r:%r", key, events) |
|
|
|
|
resolved_state[key] = Resolver._resolve_normal_events( |
|
|
|
|
events, auth_events |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
return resolved_state |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def _resolve_auth_events(events, auth_events): |
|
|
|
|
reverse = [i for i in reversed(_ordered_events(events))] |
|
|
|
|
|
|
|
|
|
auth_events = dict(auth_events) |
|
|
|
|
|
|
|
|
|
prev_event = reverse[0] |
|
|
|
|
for event in reverse[1:]: |
|
|
|
|
auth_events[(prev_event.type, prev_event.state_key)] = prev_event |
|
|
|
|
try: |
|
|
|
|
# The signatures have already been checked at this point |
|
|
|
|
event_auth.check(event, auth_events, do_sig_check=False) |
|
|
|
|
prev_event = event |
|
|
|
|
except AuthError: |
|
|
|
|
return prev_event |
|
|
|
|
|
|
|
|
|
return event |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def _resolve_normal_events(events, auth_events): |
|
|
|
|
for event in _ordered_events(events): |
|
|
|
|
try: |
|
|
|
|
# The signatures have already been checked at this point |
|
|
|
|
event_auth.check(event, auth_events, do_sig_check=False) |
|
|
|
|
return event |
|
|
|
|
except AuthError: |
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
# Use the last event (the one with the least depth) if they all fail |
|
|
|
|
# the auth check. |
|
|
|
|
return event |
|
|
|
|
def _resolve_normal_events(events, auth_events): |
|
|
|
|
for event in _ordered_events(events): |
|
|
|
|
try: |
|
|
|
|
# The signatures have already been checked at this point |
|
|
|
|
event_auth.check(event, auth_events, do_sig_check=False) |
|
|
|
|
return event |
|
|
|
|
except AuthError: |
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
# Use the last event (the one with the least depth) if they all fail |
|
|
|
|
# the auth check. |
|
|
|
|
return event |
|
|
|
|