|
|
|
@ -26,6 +26,7 @@ class StateTestCase(unittest.TestCase): |
|
|
|
|
self.store = Mock( |
|
|
|
|
spec_set=[ |
|
|
|
|
"get_state_groups", |
|
|
|
|
"add_event_hashes", |
|
|
|
|
] |
|
|
|
|
) |
|
|
|
|
hs = Mock(spec=["get_datastore"]) |
|
|
|
@ -44,17 +45,20 @@ class StateTestCase(unittest.TestCase): |
|
|
|
|
self.create_event(type="test2", state_key=""), |
|
|
|
|
] |
|
|
|
|
|
|
|
|
|
yield self.state.annotate_event_with_state(event, old_state=old_state) |
|
|
|
|
context = yield self.state.compute_event_context( |
|
|
|
|
event, old_state=old_state |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
for k, v in event.old_state_events.items(): |
|
|
|
|
for k, v in context.current_state.items(): |
|
|
|
|
type, state_key = k |
|
|
|
|
self.assertEqual(type, v.type) |
|
|
|
|
self.assertEqual(state_key, v.state_key) |
|
|
|
|
|
|
|
|
|
self.assertEqual(set(old_state), set(event.old_state_events.values())) |
|
|
|
|
self.assertDictEqual(event.old_state_events, event.state_events) |
|
|
|
|
self.assertEqual( |
|
|
|
|
set(old_state), set(context.current_state.values()) |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
self.assertIsNone(event.state_group) |
|
|
|
|
self.assertIsNone(context.state_group) |
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def test_annotate_with_old_state(self): |
|
|
|
@ -66,21 +70,21 @@ class StateTestCase(unittest.TestCase): |
|
|
|
|
self.create_event(type="test2", state_key=""), |
|
|
|
|
] |
|
|
|
|
|
|
|
|
|
yield self.state.annotate_event_with_state(event, old_state=old_state) |
|
|
|
|
context = yield self.state.compute_event_context( |
|
|
|
|
event, old_state=old_state |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
for k, v in event.old_state_events.items(): |
|
|
|
|
for k, v in context.current_state.items(): |
|
|
|
|
type, state_key = k |
|
|
|
|
self.assertEqual(type, v.type) |
|
|
|
|
self.assertEqual(state_key, v.state_key) |
|
|
|
|
|
|
|
|
|
self.assertEqual( |
|
|
|
|
set(old_state + [event]), |
|
|
|
|
set(event.old_state_events.values()) |
|
|
|
|
set(old_state), |
|
|
|
|
set(context.current_state.values()) |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
self.assertDictEqual(event.old_state_events, event.state_events) |
|
|
|
|
|
|
|
|
|
self.assertIsNone(event.state_group) |
|
|
|
|
self.assertIsNone(context.state_group) |
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def test_trivial_annotate_message(self): |
|
|
|
@ -99,30 +103,19 @@ class StateTestCase(unittest.TestCase): |
|
|
|
|
group_name: old_state, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
yield self.state.annotate_event_with_state(event) |
|
|
|
|
context = yield self.state.compute_event_context(event) |
|
|
|
|
|
|
|
|
|
for k, v in event.old_state_events.items(): |
|
|
|
|
for k, v in context.current_state.items(): |
|
|
|
|
type, state_key = k |
|
|
|
|
self.assertEqual(type, v.type) |
|
|
|
|
self.assertEqual(state_key, v.state_key) |
|
|
|
|
|
|
|
|
|
self.assertEqual( |
|
|
|
|
set([e.event_id for e in old_state]), |
|
|
|
|
set([e.event_id for e in event.old_state_events.values()]) |
|
|
|
|
set([e.event_id for e in context.current_state.values()]) |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
self.assertDictEqual( |
|
|
|
|
{ |
|
|
|
|
k: v.event_id |
|
|
|
|
for k, v in event.old_state_events.items() |
|
|
|
|
}, |
|
|
|
|
{ |
|
|
|
|
k: v.event_id |
|
|
|
|
for k, v in event.state_events.items() |
|
|
|
|
} |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
self.assertEqual(group_name, event.state_group) |
|
|
|
|
self.assertEqual(group_name, context.state_group) |
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def test_trivial_annotate_state(self): |
|
|
|
@ -141,38 +134,19 @@ class StateTestCase(unittest.TestCase): |
|
|
|
|
group_name: old_state, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
yield self.state.annotate_event_with_state(event) |
|
|
|
|
context = yield self.state.compute_event_context(event) |
|
|
|
|
|
|
|
|
|
for k, v in event.old_state_events.items(): |
|
|
|
|
for k, v in context.current_state.items(): |
|
|
|
|
type, state_key = k |
|
|
|
|
self.assertEqual(type, v.type) |
|
|
|
|
self.assertEqual(state_key, v.state_key) |
|
|
|
|
|
|
|
|
|
self.assertEqual( |
|
|
|
|
set([e.event_id for e in old_state]), |
|
|
|
|
set([e.event_id for e in event.old_state_events.values()]) |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
self.assertEqual( |
|
|
|
|
set([e.event_id for e in old_state] + [event.event_id]), |
|
|
|
|
set([e.event_id for e in event.state_events.values()]) |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
new_state = { |
|
|
|
|
k: v.event_id |
|
|
|
|
for k, v in event.state_events.items() |
|
|
|
|
} |
|
|
|
|
old_state = { |
|
|
|
|
k: v.event_id |
|
|
|
|
for k, v in event.old_state_events.items() |
|
|
|
|
} |
|
|
|
|
old_state[(event.type, event.state_key)] = event.event_id |
|
|
|
|
self.assertDictEqual( |
|
|
|
|
old_state, |
|
|
|
|
new_state |
|
|
|
|
set([e.event_id for e in context.current_state.values()]) |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
self.assertIsNone(event.state_group) |
|
|
|
|
self.assertIsNone(context.state_group) |
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def test_resolve_message_conflict(self): |
|
|
|
@ -199,16 +173,11 @@ class StateTestCase(unittest.TestCase): |
|
|
|
|
group_name_2: old_state_2, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
yield self.state.annotate_event_with_state(event) |
|
|
|
|
context = yield self.state.compute_event_context(event) |
|
|
|
|
|
|
|
|
|
self.assertEqual(len(event.old_state_events), 5) |
|
|
|
|
|
|
|
|
|
self.assertEqual( |
|
|
|
|
set([e.event_id for e in event.state_events.values()]), |
|
|
|
|
set([e.event_id for e in event.old_state_events.values()]) |
|
|
|
|
) |
|
|
|
|
self.assertEqual(len(context.current_state), 5) |
|
|
|
|
|
|
|
|
|
self.assertIsNone(event.state_group) |
|
|
|
|
self.assertIsNone(context.state_group) |
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def test_resolve_state_conflict(self): |
|
|
|
@ -235,19 +204,11 @@ class StateTestCase(unittest.TestCase): |
|
|
|
|
group_name_2: old_state_2, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
yield self.state.annotate_event_with_state(event) |
|
|
|
|
context = yield self.state.compute_event_context(event) |
|
|
|
|
|
|
|
|
|
self.assertEqual(len(event.old_state_events), 5) |
|
|
|
|
self.assertEqual(len(context.current_state), 5) |
|
|
|
|
|
|
|
|
|
expected_new = event.old_state_events |
|
|
|
|
expected_new[(event.type, event.state_key)] = event |
|
|
|
|
|
|
|
|
|
self.assertEqual( |
|
|
|
|
set([e.event_id for e in expected_new.values()]), |
|
|
|
|
set([e.event_id for e in event.state_events.values()]), |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
self.assertIsNone(event.state_group) |
|
|
|
|
self.assertIsNone(context.state_group) |
|
|
|
|
|
|
|
|
|
def create_event(self, name=None, type=None, state_key=None): |
|
|
|
|
self.event_id += 1 |
|
|
|
@ -266,6 +227,9 @@ class StateTestCase(unittest.TestCase): |
|
|
|
|
event.state_key = state_key |
|
|
|
|
event.event_id = event_id |
|
|
|
|
|
|
|
|
|
event.is_state = lambda: (state_key is not None) |
|
|
|
|
event.unsigned = {} |
|
|
|
|
|
|
|
|
|
event.user_id = "@user_id:example.com" |
|
|
|
|
event.room_id = "!room_id:example.com" |
|
|
|
|
|
|
|
|
|