|
|
@ -12,7 +12,7 @@ |
|
|
|
# See the License for the specific language governing permissions and |
|
|
|
# See the License for the specific language governing permissions and |
|
|
|
# limitations under the License. |
|
|
|
# limitations under the License. |
|
|
|
import logging |
|
|
|
import logging |
|
|
|
from typing import Any, Callable, Iterable, List, Optional, Tuple |
|
|
|
from typing import Any, Iterable, List, Optional, Tuple |
|
|
|
|
|
|
|
|
|
|
|
from canonicaljson import encode_canonical_json |
|
|
|
from canonicaljson import encode_canonical_json |
|
|
|
from parameterized import parameterized |
|
|
|
from parameterized import parameterized |
|
|
@ -21,7 +21,7 @@ from twisted.test.proto_helpers import MemoryReactor |
|
|
|
|
|
|
|
|
|
|
|
from synapse.api.constants import ReceiptTypes |
|
|
|
from synapse.api.constants import ReceiptTypes |
|
|
|
from synapse.api.room_versions import RoomVersions |
|
|
|
from synapse.api.room_versions import RoomVersions |
|
|
|
from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict |
|
|
|
from synapse.events import EventBase, make_event_from_dict |
|
|
|
from synapse.events.snapshot import EventContext |
|
|
|
from synapse.events.snapshot import EventContext |
|
|
|
from synapse.handlers.room import RoomEventSource |
|
|
|
from synapse.handlers.room import RoomEventSource |
|
|
|
from synapse.server import HomeServer |
|
|
|
from synapse.server import HomeServer |
|
|
@ -46,32 +46,9 @@ ROOM_ID = "!room:test" |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def dict_equals(self: EventBase, other: EventBase) -> bool: |
|
|
|
|
|
|
|
me = encode_canonical_json(self.get_pdu_json()) |
|
|
|
|
|
|
|
them = encode_canonical_json(other.get_pdu_json()) |
|
|
|
|
|
|
|
return me == them |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def patch__eq__(cls: object) -> Callable[[], None]: |
|
|
|
|
|
|
|
eq = getattr(cls, "__eq__", None) |
|
|
|
|
|
|
|
cls.__eq__ = dict_equals # type: ignore[assignment] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def unpatch() -> None: |
|
|
|
|
|
|
|
if eq is not None: |
|
|
|
|
|
|
|
cls.__eq__ = eq # type: ignore[method-assign] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return unpatch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase): |
|
|
|
class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase): |
|
|
|
STORE_TYPE = EventsWorkerStore |
|
|
|
STORE_TYPE = EventsWorkerStore |
|
|
|
|
|
|
|
|
|
|
|
def setUp(self) -> None: |
|
|
|
|
|
|
|
# Patch up the equality operator for events so that we can check |
|
|
|
|
|
|
|
# whether lists of events match using assertEqual |
|
|
|
|
|
|
|
self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(EventBase)] |
|
|
|
|
|
|
|
super().setUp() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: |
|
|
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: |
|
|
|
super().prepare(reactor, clock, hs) |
|
|
|
super().prepare(reactor, clock, hs) |
|
|
|
|
|
|
|
|
|
|
@ -84,8 +61,14 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase): |
|
|
|
) |
|
|
|
) |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def tearDown(self) -> None: |
|
|
|
def assertEventsEqual( |
|
|
|
[unpatch() for unpatch in self.unpatches] |
|
|
|
self, first: EventBase, second: EventBase, msg: Optional[Any] = None |
|
|
|
|
|
|
|
) -> None: |
|
|
|
|
|
|
|
self.assertEqual( |
|
|
|
|
|
|
|
encode_canonical_json(first.get_pdu_json()), |
|
|
|
|
|
|
|
encode_canonical_json(second.get_pdu_json()), |
|
|
|
|
|
|
|
msg, |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def test_get_latest_event_ids_in_room(self) -> None: |
|
|
|
def test_get_latest_event_ids_in_room(self) -> None: |
|
|
|
create = self.persist(type="m.room.create", key="", creator=USER_ID) |
|
|
|
create = self.persist(type="m.room.create", key="", creator=USER_ID) |
|
|
@ -107,7 +90,7 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase): |
|
|
|
|
|
|
|
|
|
|
|
msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello") |
|
|
|
msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello") |
|
|
|
self.replicate() |
|
|
|
self.replicate() |
|
|
|
self.check("get_event", [msg.event_id], msg) |
|
|
|
self.check("get_event", [msg.event_id], msg, asserter=self.assertEventsEqual) |
|
|
|
|
|
|
|
|
|
|
|
redaction = self.persist(type="m.room.redaction", redacts=msg.event_id) |
|
|
|
redaction = self.persist(type="m.room.redaction", redacts=msg.event_id) |
|
|
|
self.replicate() |
|
|
|
self.replicate() |
|
|
@ -119,7 +102,9 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase): |
|
|
|
redacted = make_event_from_dict( |
|
|
|
redacted = make_event_from_dict( |
|
|
|
msg_dict, internal_metadata_dict=msg.internal_metadata.get_dict() |
|
|
|
msg_dict, internal_metadata_dict=msg.internal_metadata.get_dict() |
|
|
|
) |
|
|
|
) |
|
|
|
self.check("get_event", [msg.event_id], redacted) |
|
|
|
self.check( |
|
|
|
|
|
|
|
"get_event", [msg.event_id], redacted, asserter=self.assertEventsEqual |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def test_backfilled_redactions(self) -> None: |
|
|
|
def test_backfilled_redactions(self) -> None: |
|
|
|
self.persist(type="m.room.create", key="", creator=USER_ID) |
|
|
|
self.persist(type="m.room.create", key="", creator=USER_ID) |
|
|
@ -127,7 +112,7 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase): |
|
|
|
|
|
|
|
|
|
|
|
msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello") |
|
|
|
msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello") |
|
|
|
self.replicate() |
|
|
|
self.replicate() |
|
|
|
self.check("get_event", [msg.event_id], msg) |
|
|
|
self.check("get_event", [msg.event_id], msg, asserter=self.assertEventsEqual) |
|
|
|
|
|
|
|
|
|
|
|
redaction = self.persist( |
|
|
|
redaction = self.persist( |
|
|
|
type="m.room.redaction", redacts=msg.event_id, backfill=True |
|
|
|
type="m.room.redaction", redacts=msg.event_id, backfill=True |
|
|
@ -141,7 +126,9 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase): |
|
|
|
redacted = make_event_from_dict( |
|
|
|
redacted = make_event_from_dict( |
|
|
|
msg_dict, internal_metadata_dict=msg.internal_metadata.get_dict() |
|
|
|
msg_dict, internal_metadata_dict=msg.internal_metadata.get_dict() |
|
|
|
) |
|
|
|
) |
|
|
|
self.check("get_event", [msg.event_id], redacted) |
|
|
|
self.check( |
|
|
|
|
|
|
|
"get_event", [msg.event_id], redacted, asserter=self.assertEventsEqual |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def test_invites(self) -> None: |
|
|
|
def test_invites(self) -> None: |
|
|
|
self.persist(type="m.room.create", key="", creator=USER_ID) |
|
|
|
self.persist(type="m.room.create", key="", creator=USER_ID) |
|
|
|