Batch up storing state groups when creating new room (#14918)

1.103.0-whithout-watcha
Shay 2 years ago committed by GitHub
parent 335f52d595
commit 1c95ddd09b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      changelog.d/14918.misc
  2. 49
      synapse/events/snapshot.py
  3. 16
      synapse/handlers/message.py
  4. 37
      synapse/handlers/room.py
  5. 4
      synapse/handlers/room_batch.py
  6. 13
      synapse/handlers/room_member.py
  7. 119
      synapse/storage/databases/state/store.py
  8. 25
      tests/handlers/test_message.py
  9. 3
      tests/handlers/test_register.py
  10. 13
      tests/push/test_bulk_push_rule_evaluator.py
  11. 4
      tests/rest/client/test_rooms.py
  12. 6
      tests/storage/test_event_chain.py
  13. 126
      tests/storage/test_state.py
  14. 4
      tests/unittest.py

@ -0,0 +1 @@
Batch up storing state groups when creating a new room.

@ -23,6 +23,7 @@ from synapse.types import JsonDict, StateMap
if TYPE_CHECKING:
from synapse.storage.controllers import StorageControllers
from synapse.storage.databases import StateGroupDataStore
from synapse.storage.databases.main import DataStore
from synapse.types.state import StateFilter
@ -348,6 +349,54 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
partial_state: bool
state_map_before_event: Optional[StateMap[str]] = None
@classmethod
async def batch_persist_unpersisted_contexts(
cls,
events_and_context: List[Tuple[EventBase, "UnpersistedEventContextBase"]],
room_id: str,
last_known_state_group: int,
datastore: "StateGroupDataStore",
) -> List[Tuple[EventBase, EventContext]]:
"""
Takes a list of events and their associated unpersisted contexts and persists
the unpersisted contexts, returning a list of events and persisted contexts.
Note that all the events must be in a linear chain (ie a <- b <- c).
Args:
events_and_context: A list of events and their unpersisted contexts
room_id: the room_id for the events
last_known_state_group: the last persisted state group
datastore: a state datastore
"""
amended_events_and_context = await datastore.store_state_deltas_for_batched(
events_and_context, room_id, last_known_state_group
)
events_and_persisted_context = []
for event, unpersisted_context in amended_events_and_context:
if event.is_state():
context = EventContext(
storage=unpersisted_context._storage,
state_group=unpersisted_context.state_group_after_event,
state_group_before_event=unpersisted_context.state_group_before_event,
state_delta_due_to_event=unpersisted_context.state_delta_due_to_event,
partial_state=unpersisted_context.partial_state,
prev_group=unpersisted_context.state_group_before_event,
delta_ids=unpersisted_context.state_delta_due_to_event,
)
else:
context = EventContext(
storage=unpersisted_context._storage,
state_group=unpersisted_context.state_group_after_event,
state_group_before_event=unpersisted_context.state_group_before_event,
state_delta_due_to_event=unpersisted_context.state_delta_due_to_event,
partial_state=unpersisted_context.partial_state,
prev_group=unpersisted_context.prev_group_for_state_group_before_event,
delta_ids=unpersisted_context.delta_ids_to_state_group_before_event,
)
events_and_persisted_context.append((event, context))
return events_and_persisted_context
async def get_prev_state_ids(
self, state_filter: Optional["StateFilter"] = None
) -> StateMap[str]:

@ -574,7 +574,7 @@ class EventCreationHandler:
state_map: Optional[StateMap[str]] = None,
for_batch: bool = False,
current_state_group: Optional[int] = None,
) -> Tuple[EventBase, EventContext]:
) -> Tuple[EventBase, UnpersistedEventContextBase]:
"""
Given a dict from a client, create a new event. If bool for_batch is true, will
create an event using the prev_event_ids, and will create an event context for
@ -721,8 +721,6 @@ class EventCreationHandler:
current_state_group=current_state_group,
)
context = await unpersisted_context.persist(event)
# In an ideal world we wouldn't need the second part of this condition. However,
# this behaviour isn't spec'd yet, meaning we should be able to deactivate this
# behaviour. Another reason is that this code is also evaluated each time a new
@ -739,7 +737,7 @@ class EventCreationHandler:
assert state_map is not None
prev_event_id = state_map.get((EventTypes.Member, event.sender))
else:
prev_state_ids = await context.get_prev_state_ids(
prev_state_ids = await unpersisted_context.get_prev_state_ids(
StateFilter.from_types([(EventTypes.Member, None)])
)
prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
@ -764,8 +762,7 @@ class EventCreationHandler:
)
self.validator.validate_new(event, self.config)
return event, context
return event, unpersisted_context
async def _is_exempt_from_privacy_policy(
self, builder: EventBuilder, requester: Requester
@ -1005,7 +1002,7 @@ class EventCreationHandler:
max_retries = 5
for i in range(max_retries):
try:
event, context = await self.create_event(
event, unpersisted_context = await self.create_event(
requester,
event_dict,
txn_id=txn_id,
@ -1016,6 +1013,7 @@ class EventCreationHandler:
historical=historical,
depth=depth,
)
context = await unpersisted_context.persist(event)
assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % (
event.sender,
@ -1190,7 +1188,6 @@ class EventCreationHandler:
if for_batch:
assert prev_event_ids is not None
assert state_map is not None
assert current_state_group is not None
auth_ids = self._event_auth_handler.compute_auth_events(builder, state_map)
event = await builder.build(
prev_event_ids=prev_event_ids, auth_event_ids=auth_ids, depth=depth
@ -2046,7 +2043,7 @@ class EventCreationHandler:
max_retries = 5
for i in range(max_retries):
try:
event, context = await self.create_event(
event, unpersisted_context = await self.create_event(
requester,
{
"type": EventTypes.Dummy,
@ -2055,6 +2052,7 @@ class EventCreationHandler:
"sender": user_id,
},
)
context = await unpersisted_context.persist(event)
event.internal_metadata.proactively_send = False

@ -51,6 +51,7 @@ from synapse.api.filtering import Filter
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.event_auth import validate_event_for_room_version
from synapse.events import EventBase
from synapse.events.snapshot import UnpersistedEventContext
from synapse.events.utils import copy_and_fixup_power_levels_contents
from synapse.handlers.relations import BundledAggregations
from synapse.module_api import NOT_SPAM
@ -211,7 +212,7 @@ class RoomCreationHandler:
# the required power level to send the tombstone event.
(
tombstone_event,
tombstone_context,
tombstone_unpersisted_context,
) = await self.event_creation_handler.create_event(
requester,
{
@ -225,6 +226,9 @@ class RoomCreationHandler:
},
},
)
tombstone_context = await tombstone_unpersisted_context.persist(
tombstone_event
)
validate_event_for_room_version(tombstone_event)
await self._event_auth_handler.check_auth_rules_from_context(
tombstone_event
@ -1092,7 +1096,7 @@ class RoomCreationHandler:
content: JsonDict,
for_batch: bool,
**kwargs: Any,
) -> Tuple[EventBase, synapse.events.snapshot.EventContext]:
) -> Tuple[EventBase, synapse.events.snapshot.UnpersistedEventContextBase]:
"""
Creates an event and associated event context.
Args:
@ -1111,20 +1115,23 @@ class RoomCreationHandler:
event_dict = create_event_dict(etype, content, **kwargs)
new_event, new_context = await self.event_creation_handler.create_event(
(
new_event,
new_unpersisted_context,
) = await self.event_creation_handler.create_event(
creator,
event_dict,
prev_event_ids=prev_event,
depth=depth,
state_map=state_map,
for_batch=for_batch,
current_state_group=current_state_group,
)
depth += 1
prev_event = [new_event.event_id]
state_map[(new_event.type, new_event.state_key)] = new_event.event_id
return new_event, new_context
return new_event, new_unpersisted_context
try:
config = self._presets_dict[preset_config]
@ -1134,10 +1141,10 @@ class RoomCreationHandler:
)
creation_content.update({"creator": creator_id})
creation_event, creation_context = await create_event(
creation_event, unpersisted_creation_context = await create_event(
EventTypes.Create, creation_content, False
)
creation_context = await unpersisted_creation_context.persist(creation_event)
logger.debug("Sending %s in new room", EventTypes.Member)
ev = await self.event_creation_handler.handle_new_client_event(
requester=creator,
@ -1181,7 +1188,6 @@ class RoomCreationHandler:
power_event, power_context = await create_event(
EventTypes.PowerLevels, pl_content, True
)
current_state_group = power_context._state_group
events_to_send.append((power_event, power_context))
else:
power_level_content: JsonDict = {
@ -1230,14 +1236,12 @@ class RoomCreationHandler:
power_level_content,
True,
)
current_state_group = pl_context._state_group
events_to_send.append((pl_event, pl_context))
if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state:
room_alias_event, room_alias_context = await create_event(
EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True
)
current_state_group = room_alias_context._state_group
events_to_send.append((room_alias_event, room_alias_context))
if (EventTypes.JoinRules, "") not in initial_state:
@ -1246,7 +1250,6 @@ class RoomCreationHandler:
{"join_rule": config["join_rules"]},
True,
)
current_state_group = join_rules_context._state_group
events_to_send.append((join_rules_event, join_rules_context))
if (EventTypes.RoomHistoryVisibility, "") not in initial_state:
@ -1255,7 +1258,6 @@ class RoomCreationHandler:
{"history_visibility": config["history_visibility"]},
True,
)
current_state_group = visibility_context._state_group
events_to_send.append((visibility_event, visibility_context))
if config["guest_can_join"]:
@ -1265,14 +1267,12 @@ class RoomCreationHandler:
{EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN},
True,
)
current_state_group = guest_access_context._state_group
events_to_send.append((guest_access_event, guest_access_context))
for (etype, state_key), content in initial_state.items():
event, context = await create_event(
etype, content, True, state_key=state_key
)
current_state_group = context._state_group
events_to_send.append((event, context))
if config["encrypted"]:
@ -1284,9 +1284,16 @@ class RoomCreationHandler:
)
events_to_send.append((encryption_event, encryption_context))
datastore = self.hs.get_datastores().state
events_and_context = (
await UnpersistedEventContext.batch_persist_unpersisted_contexts(
events_to_send, room_id, current_state_group, datastore
)
)
last_event = await self.event_creation_handler.handle_new_client_event(
creator,
events_to_send,
events_and_context,
ignore_shadow_ban=True,
ratelimit=False,
)

@ -327,7 +327,7 @@ class RoomBatchHandler:
# Mark all events as historical
event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True
event, context = await self.event_creation_handler.create_event(
event, unpersisted_context = await self.event_creation_handler.create_event(
await self.create_requester_for_user_id_from_app_service(
ev["sender"], app_service_requester.app_service
),
@ -345,7 +345,7 @@ class RoomBatchHandler:
historical=True,
depth=inherited_depth,
)
context = await unpersisted_context.persist(event)
assert context._state_group
# Normally this is done when persisting the event but we have to

@ -414,7 +414,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
max_retries = 5
for i in range(max_retries):
try:
event, context = await self.event_creation_handler.create_event(
(
event,
unpersisted_context,
) = await self.event_creation_handler.create_event(
requester,
{
"type": EventTypes.Member,
@ -435,7 +438,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
outlier=outlier,
historical=historical,
)
context = await unpersisted_context.persist(event)
prev_state_ids = await context.get_prev_state_ids(
StateFilter.from_types([(EventTypes.Member, None)])
)
@ -1944,7 +1947,10 @@ class RoomMemberMasterHandler(RoomMemberHandler):
max_retries = 5
for i in range(max_retries):
try:
event, context = await self.event_creation_handler.create_event(
(
event,
unpersisted_context,
) = await self.event_creation_handler.create_event(
requester,
event_dict,
txn_id=txn_id,
@ -1952,6 +1958,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
auth_event_ids=auth_event_ids,
outlier=True,
)
context = await unpersisted_context.persist(event)
event.internal_metadata.out_of_band_membership = True
result_event = (

@ -18,6 +18,8 @@ from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Se
import attr
from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.events.snapshot import UnpersistedEventContext, UnpersistedEventContextBase
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@ -401,6 +403,123 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
fetched_keys=non_member_types,
)
async def store_state_deltas_for_batched(
self,
events_and_context: List[Tuple[EventBase, UnpersistedEventContextBase]],
room_id: str,
prev_group: int,
) -> List[Tuple[EventBase, UnpersistedEventContext]]:
"""Generate and store state deltas for a group of events and contexts created to be
batch persisted. Note that all the events must be in a linear chain (ie a <- b <- c).
Args:
events_and_context: the events to generate and store a state groups for
and their associated contexts
room_id: the id of the room the events were created for
prev_group: the state group of the last event persisted before the batched events
were created
"""
def insert_deltas_group_txn(
txn: LoggingTransaction,
events_and_context: List[Tuple[EventBase, UnpersistedEventContext]],
prev_group: int,
) -> List[Tuple[EventBase, UnpersistedEventContext]]:
"""Generate and store state groups for the provided events and contexts.
Requires that we have the state as a delta from the last persisted state group.
Returns:
A list of state groups
"""
is_in_db = self.db_pool.simple_select_one_onecol_txn(
txn,
table="state_groups",
keyvalues={"id": prev_group},
retcol="id",
allow_none=True,
)
if not is_in_db:
raise Exception(
"Trying to persist state with unpersisted prev_group: %r"
% (prev_group,)
)
num_state_groups = sum(
1 for event, _ in events_and_context if event.is_state()
)
state_groups = self._state_group_seq_gen.get_next_mult_txn(
txn, num_state_groups
)
sg_before = prev_group
state_group_iter = iter(state_groups)
for event, context in events_and_context:
if not event.is_state():
context.state_group_after_event = sg_before
context.state_group_before_event = sg_before
continue
sg_after = next(state_group_iter)
context.state_group_after_event = sg_after
context.state_group_before_event = sg_before
context.state_delta_due_to_event = {
(event.type, event.state_key): event.event_id
}
sg_before = sg_after
self.db_pool.simple_insert_many_txn(
txn,
table="state_groups",
keys=("id", "room_id", "event_id"),
values=[
(context.state_group_after_event, room_id, event.event_id)
for event, context in events_and_context
if event.is_state()
],
)
self.db_pool.simple_insert_many_txn(
txn,
table="state_group_edges",
keys=("state_group", "prev_state_group"),
values=[
(
context.state_group_after_event,
context.state_group_before_event,
)
for event, context in events_and_context
if event.is_state()
],
)
self.db_pool.simple_insert_many_txn(
txn,
table="state_groups_state",
keys=("state_group", "room_id", "type", "state_key", "event_id"),
values=[
(
context.state_group_after_event,
room_id,
key[0],
key[1],
state_id,
)
for event, context in events_and_context
if context.state_delta_due_to_event is not None
for key, state_id in context.state_delta_due_to_event.items()
],
)
return events_and_context
return await self.db_pool.runInteraction(
"store_state_deltas_for_batched.insert_deltas_group",
insert_deltas_group_txn,
events_and_context,
prev_group,
)
async def store_state_group(
self,
event_id: str,

@ -18,7 +18,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
@ -79,7 +79,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
return memberEvent, memberEventContext
def _create_duplicate_event(self, txn_id: str) -> Tuple[EventBase, EventContext]:
def _create_duplicate_event(
self, txn_id: str
) -> Tuple[EventBase, UnpersistedEventContextBase]:
"""Create a new event with the given transaction ID. All events produced
by this method will be considered duplicates.
"""
@ -107,7 +109,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
txn_id = "something_suitably_random"
event1, context = self._create_duplicate_event(txn_id)
event1, unpersisted_context = self._create_duplicate_event(txn_id)
context = self.get_success(unpersisted_context.persist(event1))
ret_event1 = self.get_success(
self.handler.handle_new_client_event(
@ -119,7 +122,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.assertEqual(event1.event_id, ret_event1.event_id)
event2, context = self._create_duplicate_event(txn_id)
event2, unpersisted_context = self._create_duplicate_event(txn_id)
context = self.get_success(unpersisted_context.persist(event2))
# We want to test that the deduplication at the persit event end works,
# so we want to make sure we test with different events.
@ -140,7 +144,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
# Let's test that calling `persist_event` directly also does the right
# thing.
event3, context = self._create_duplicate_event(txn_id)
event3, unpersisted_context = self._create_duplicate_event(txn_id)
context = self.get_success(unpersisted_context.persist(event3))
self.assertNotEqual(event1.event_id, event3.event_id)
ret_event3, event_pos3, _ = self.get_success(
@ -154,7 +160,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
# Let's test that calling `persist_events` directly also does the right
# thing.
event4, context = self._create_duplicate_event(txn_id)
event4, unpersisted_context = self._create_duplicate_event(txn_id)
context = self.get_success(unpersisted_context.persist(event4))
self.assertNotEqual(event1.event_id, event3.event_id)
events, _ = self.get_success(
@ -174,8 +181,10 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
txn_id = "something_else_suitably_random"
# Create two duplicate events to persist at the same time
event1, context1 = self._create_duplicate_event(txn_id)
event2, context2 = self._create_duplicate_event(txn_id)
event1, unpersisted_context1 = self._create_duplicate_event(txn_id)
context1 = self.get_success(unpersisted_context1.persist(event1))
event2, unpersisted_context2 = self._create_duplicate_event(txn_id)
context2 = self.get_success(unpersisted_context2.persist(event2))
# Ensure their event IDs are different to start with
self.assertNotEqual(event1.event_id, event2.event_id)

@ -507,7 +507,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Lower the permissions of the inviter.
event_creation_handler = self.hs.get_event_creation_handler()
requester = create_requester(inviter)
event, context = self.get_success(
event, unpersisted_context = self.get_success(
event_creation_handler.create_event(
requester,
{
@ -519,6 +519,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
},
)
)
context = self.get_success(unpersisted_context.persist(event))
self.get_success(
event_creation_handler.handle_new_client_event(
requester, events_and_context=[(event, context)]

@ -130,7 +130,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
# Create a new message event, and try to evaluate it under the dodgy
# power level event.
event, context = self.get_success(
event, unpersisted_context = self.get_success(
self.event_creation_handler.create_event(
self.requester,
{
@ -145,6 +145,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
prev_event_ids=[pl_event_id],
)
)
context = self.get_success(unpersisted_context.persist(event))
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
# should not raise
@ -170,7 +171,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
"""Ensure that push rules are not calculated when disabled in the config"""
# Create a new message event which should cause a notification.
event, context = self.get_success(
event, unpersisted_context = self.get_success(
self.event_creation_handler.create_event(
self.requester,
{
@ -184,6 +185,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
},
)
)
context = self.get_success(unpersisted_context.persist(event))
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
# Mock the method which calculates push rules -- we do this instead of
@ -200,7 +202,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
) -> bool:
"""Returns true iff the `mentions` trigger an event push action."""
# Create a new message event which should cause a notification.
event, context = self.get_success(
event, unpersisted_context = self.get_success(
self.event_creation_handler.create_event(
self.requester,
{
@ -211,7 +213,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
},
)
)
context = self.get_success(unpersisted_context.persist(event))
# Execute the push rule machinery.
self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)]))
@ -390,7 +392,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
# Create & persist an event to use as the parent of the relation.
event, context = self.get_success(
event, unpersisted_context = self.get_success(
self.event_creation_handler.create_event(
self.requester,
{
@ -404,6 +406,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
},
)
)
context = self.get_success(unpersisted_context.persist(event))
self.get_success(
self.event_creation_handler.handle_new_client_event(
self.requester, events_and_context=[(event, context)]

@ -713,7 +713,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None
self.assertEqual(33, channel.resource_usage.db_txn_count)
self.assertEqual(30, channel.resource_usage.db_txn_count)
def test_post_room_initial_state(self) -> None:
# POST with initial_state config key, expect new room id
@ -726,7 +726,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None
self.assertEqual(36, channel.resource_usage.db_txn_count)
self.assertEqual(32, channel.resource_usage.db_txn_count)
def test_post_room_visibility_key(self) -> None:
# POST with visibility config key, expect new room id

@ -522,7 +522,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
latest_event_ids = self.get_success(
self.store.get_prev_events_for_room(room_id)
)
event, context = self.get_success(
event, unpersisted_context = self.get_success(
event_handler.create_event(
self.requester,
{
@ -535,6 +535,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
prev_event_ids=latest_event_ids,
)
)
context = self.get_success(unpersisted_context.persist(event))
self.get_success(
event_handler.handle_new_client_event(
self.requester, events_and_context=[(event, context)]
@ -544,7 +545,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
assert state_ids1 is not None
state1 = set(state_ids1.values())
event, context = self.get_success(
event, unpersisted_context = self.get_success(
event_handler.create_event(
self.requester,
{
@ -557,6 +558,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
prev_event_ids=latest_event_ids,
)
)
context = self.get_success(unpersisted_context.persist(event))
self.get_success(
event_handler.handle_new_client_event(
self.requester, events_and_context=[(event, context)]

@ -496,3 +496,129 @@ class StateStoreTestCase(HomeserverTestCase):
self.assertEqual(is_all, True)
self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
def test_batched_state_group_storing(self) -> None:
creation_event = self.inject_state_event(
self.room, self.u_alice, EventTypes.Create, "", {}
)
state_to_event = self.get_success(
self.storage.state.get_state_groups(
self.room.to_string(), [creation_event.event_id]
)
)
current_state_group = list(state_to_event.keys())[0]
# create some unpersisted events and event contexts to store against room
events_and_context = []
builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
"type": EventTypes.Name,
"sender": self.u_alice.to_string(),
"state_key": "",
"room_id": self.room.to_string(),
"content": {"name": "first rename of room"},
},
)
event1, unpersisted_context1 = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
events_and_context.append((event1, unpersisted_context1))
builder2 = self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
"type": EventTypes.JoinRules,
"sender": self.u_alice.to_string(),
"state_key": "",
"room_id": self.room.to_string(),
"content": {"join_rule": "private"},
},
)
event2, unpersisted_context2 = self.get_success(
self.event_creation_handler.create_new_client_event(builder2)
)
events_and_context.append((event2, unpersisted_context2))
builder3 = self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
"type": EventTypes.Message,
"sender": self.u_alice.to_string(),
"room_id": self.room.to_string(),
"content": {"body": "hello from event 3", "msgtype": "m.text"},
},
)
event3, unpersisted_context3 = self.get_success(
self.event_creation_handler.create_new_client_event(builder3)
)
events_and_context.append((event3, unpersisted_context3))
builder4 = self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
"type": EventTypes.JoinRules,
"sender": self.u_alice.to_string(),
"state_key": "",
"room_id": self.room.to_string(),
"content": {"join_rule": "public"},
},
)
event4, unpersisted_context4 = self.get_success(
self.event_creation_handler.create_new_client_event(builder4)
)
events_and_context.append((event4, unpersisted_context4))
processed_events_and_context = self.get_success(
self.hs.get_datastores().state.store_state_deltas_for_batched(
events_and_context, self.room.to_string(), current_state_group
)
)
# check that only state events are in state_groups, and all state events are in state_groups
res = self.get_success(
self.store.db_pool.simple_select_list(
table="state_groups",
keyvalues=None,
retcols=("event_id",),
)
)
events = []
for result in res:
self.assertNotIn(event3.event_id, result)
events.append(result.get("event_id"))
for event, _ in processed_events_and_context:
if event.is_state():
self.assertIn(event.event_id, events)
# check that each unique state has state group in state_groups_state and that the
# type/state key is correct, and check that each state event's state group
# has an entry and prev event in state_group_edges
for event, context in processed_events_and_context:
if event.is_state():
state = self.get_success(
self.store.db_pool.simple_select_list(
table="state_groups_state",
keyvalues={"state_group": context.state_group_after_event},
retcols=("type", "state_key"),
)
)
self.assertEqual(event.type, state[0].get("type"))
self.assertEqual(event.state_key, state[0].get("state_key"))
groups = self.get_success(
self.store.db_pool.simple_select_list(
table="state_group_edges",
keyvalues={"state_group": str(context.state_group_after_event)},
retcols=("*",),
)
)
self.assertEqual(
context.state_group_before_event, groups[0].get("prev_state_group")
)

@ -723,7 +723,7 @@ class HomeserverTestCase(TestCase):
event_creator = self.hs.get_event_creation_handler()
requester = create_requester(user)
event, context = self.get_success(
event, unpersisted_context = self.get_success(
event_creator.create_event(
requester,
{
@ -735,7 +735,7 @@ class HomeserverTestCase(TestCase):
prev_event_ids=prev_event_ids,
)
)
context = self.get_success(unpersisted_context.persist(event))
if soft_failed:
event.internal_metadata.soft_failed = True

Loading…
Cancel
Save