Do not try to serialize raw aggregations dict. (#11791)

code_spécifique_watcha
Patrick Cloke 3 years ago committed by GitHub
parent 9f2016e96e
commit b784299cbc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      changelog.d/11612.bugfix
  2. 1
      changelog.d/11612.misc
  3. 1
      changelog.d/11791.bugfix
  4. 4
      synapse/events/utils.py
  5. 13
      synapse/rest/admin/rooms.py
  6. 11
      synapse/rest/client/room.py
  7. 108
      tests/rest/client/test_relations.py

@ -0,0 +1 @@
Include the bundled aggregations in the `/sync` response, per [MSC2675](https://github.com/matrix-org/matrix-doc/pull/2675).

@ -1 +0,0 @@
Avoid database access in the JSON serialization process.

@ -0,0 +1 @@
Include the bundled aggregations in the `/sync` response, per [MSC2675](https://github.com/matrix-org/matrix-doc/pull/2675).

@ -402,7 +402,7 @@ class EventClientSerializer:
if bundle_aggregations: if bundle_aggregations:
event_aggregations = bundle_aggregations.get(event.event_id) event_aggregations = bundle_aggregations.get(event.event_id)
if event_aggregations: if event_aggregations:
self._injected_bundled_aggregations( self._inject_bundled_aggregations(
event, event,
time_now, time_now,
bundle_aggregations[event.event_id], bundle_aggregations[event.event_id],
@ -411,7 +411,7 @@ class EventClientSerializer:
return serialized_event return serialized_event
def _injected_bundled_aggregations( def _inject_bundled_aggregations(
self, self,
event: EventBase, event: EventBase,
time_now: int, time_now: int,

@ -744,20 +744,15 @@ class RoomEventContextServlet(RestServlet):
) )
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
aggregations = results.pop("aggregations", None)
results["events_before"] = self._event_serializer.serialize_events( results["events_before"] = self._event_serializer.serialize_events(
results["events_before"], results["events_before"], time_now, bundle_aggregations=aggregations
time_now,
bundle_aggregations=results["aggregations"],
) )
results["event"] = self._event_serializer.serialize_event( results["event"] = self._event_serializer.serialize_event(
results["event"], results["event"], time_now, bundle_aggregations=aggregations
time_now,
bundle_aggregations=results["aggregations"],
) )
results["events_after"] = self._event_serializer.serialize_events( results["events_after"] = self._event_serializer.serialize_events(
results["events_after"], results["events_after"], time_now, bundle_aggregations=aggregations
time_now,
bundle_aggregations=results["aggregations"],
) )
results["state"] = self._event_serializer.serialize_events( results["state"] = self._event_serializer.serialize_events(
results["state"], time_now results["state"], time_now

@ -714,18 +714,15 @@ class RoomEventContextServlet(RestServlet):
raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
aggregations = results.pop("aggregations", None)
results["events_before"] = self._event_serializer.serialize_events( results["events_before"] = self._event_serializer.serialize_events(
results["events_before"], results["events_before"], time_now, bundle_aggregations=aggregations
time_now,
bundle_aggregations=results["aggregations"],
) )
results["event"] = self._event_serializer.serialize_event( results["event"] = self._event_serializer.serialize_event(
results["event"], time_now, bundle_aggregations=results["aggregations"] results["event"], time_now, bundle_aggregations=aggregations
) )
results["events_after"] = self._event_serializer.serialize_events( results["events_after"] = self._event_serializer.serialize_events(
results["events_after"], results["events_after"], time_now, bundle_aggregations=aggregations
time_now,
bundle_aggregations=results["aggregations"],
) )
results["state"] = self._event_serializer.serialize_events( results["state"] = self._event_serializer.serialize_events(
results["state"], time_now results["state"], time_now

@ -21,6 +21,7 @@ from unittest.mock import patch
from synapse.api.constants import EventTypes, RelationTypes from synapse.api.constants import EventTypes, RelationTypes
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login, register, relations, room, sync from synapse.rest.client import login, register, relations, room, sync
from synapse.types import JsonDict
from tests import unittest from tests import unittest
from tests.server import FakeChannel from tests.server import FakeChannel
@ -454,7 +455,14 @@ class RelationsTestCase(unittest.HomeserverTestCase):
@unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
def test_bundled_aggregations(self): def test_bundled_aggregations(self):
"""Test that annotations, references, and threads get correctly bundled.""" """
Test that annotations, references, and threads get correctly bundled.
Note that this doesn't test against /relations since only thread relations
get bundled via that API. See test_aggregation_get_event_for_thread.
See test_edit for a similar test for edits.
"""
# Setup by sending a variety of relations. # Setup by sending a variety of relations.
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(200, channel.code, channel.json_body)
@ -482,12 +490,13 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(200, channel.code, channel.json_body)
thread_2 = channel.json_body["event_id"] thread_2 = channel.json_body["event_id"]
def assert_bundle(actual): def assert_bundle(event_json: JsonDict) -> None:
"""Assert the expected values of the bundled aggregations.""" """Assert the expected values of the bundled aggregations."""
relations_dict = event_json["unsigned"].get("m.relations")
# Ensure the fields are as expected. # Ensure the fields are as expected.
self.assertCountEqual( self.assertCountEqual(
actual.keys(), relations_dict.keys(),
( (
RelationTypes.ANNOTATION, RelationTypes.ANNOTATION,
RelationTypes.REFERENCE, RelationTypes.REFERENCE,
@ -503,20 +512,20 @@ class RelationsTestCase(unittest.HomeserverTestCase):
{"type": "m.reaction", "key": "b", "count": 1}, {"type": "m.reaction", "key": "b", "count": 1},
] ]
}, },
actual[RelationTypes.ANNOTATION], relations_dict[RelationTypes.ANNOTATION],
) )
self.assertEquals( self.assertEquals(
{"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]}, {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]},
actual[RelationTypes.REFERENCE], relations_dict[RelationTypes.REFERENCE],
) )
self.assertEquals( self.assertEquals(
2, 2,
actual[RelationTypes.THREAD].get("count"), relations_dict[RelationTypes.THREAD].get("count"),
) )
self.assertTrue( self.assertTrue(
actual[RelationTypes.THREAD].get("current_user_participated") relations_dict[RelationTypes.THREAD].get("current_user_participated")
) )
# The latest thread event has some fields that don't matter. # The latest thread event has some fields that don't matter.
self.assert_dict( self.assert_dict(
@ -533,20 +542,9 @@ class RelationsTestCase(unittest.HomeserverTestCase):
"type": "m.room.test", "type": "m.room.test",
"user_id": self.user_id, "user_id": self.user_id,
}, },
actual[RelationTypes.THREAD].get("latest_event"), relations_dict[RelationTypes.THREAD].get("latest_event"),
) )
def _find_and_assert_event(events):
"""
Find the parent event in a chunk of events and assert that it has the proper bundled aggregations.
"""
for event in events:
if event["event_id"] == self.parent_id:
break
else:
raise AssertionError(f"Event {self.parent_id} not found in chunk")
assert_bundle(event["unsigned"].get("m.relations"))
# Request the event directly. # Request the event directly.
channel = self.make_request( channel = self.make_request(
"GET", "GET",
@ -554,7 +552,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
access_token=self.user_token, access_token=self.user_token,
) )
self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(200, channel.code, channel.json_body)
assert_bundle(channel.json_body["unsigned"].get("m.relations")) assert_bundle(channel.json_body)
# Request the room messages. # Request the room messages.
channel = self.make_request( channel = self.make_request(
@ -563,7 +561,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
access_token=self.user_token, access_token=self.user_token,
) )
self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(200, channel.code, channel.json_body)
_find_and_assert_event(channel.json_body["chunk"]) assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
# Request the room context. # Request the room context.
channel = self.make_request( channel = self.make_request(
@ -572,17 +570,14 @@ class RelationsTestCase(unittest.HomeserverTestCase):
access_token=self.user_token, access_token=self.user_token,
) )
self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(200, channel.code, channel.json_body)
assert_bundle(channel.json_body["event"]["unsigned"].get("m.relations")) assert_bundle(channel.json_body["event"])
# Request sync. # Request sync.
channel = self.make_request("GET", "/sync", access_token=self.user_token) channel = self.make_request("GET", "/sync", access_token=self.user_token)
self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(200, channel.code, channel.json_body)
room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
self.assertTrue(room_timeline["limited"]) self.assertTrue(room_timeline["limited"])
_find_and_assert_event(room_timeline["events"]) self._find_event_in_chunk(room_timeline["events"])
# Note that /relations is tested separately in test_aggregation_get_event_for_thread
# since it needs different data configured.
def test_aggregation_get_event_for_annotation(self): def test_aggregation_get_event_for_annotation(self):
"""Test that annotations do not get bundled aggregations included """Test that annotations do not get bundled aggregations included
@ -777,25 +772,58 @@ class RelationsTestCase(unittest.HomeserverTestCase):
edit_event_id = channel.json_body["event_id"] edit_event_id = channel.json_body["event_id"]
def assert_bundle(event_json: JsonDict) -> None:
"""Assert the expected values of the bundled aggregations."""
relations_dict = event_json["unsigned"].get("m.relations")
self.assertIn(RelationTypes.REPLACE, relations_dict)
m_replace_dict = relations_dict[RelationTypes.REPLACE]
for key in ["event_id", "sender", "origin_server_ts"]:
self.assertIn(key, m_replace_dict)
self.assert_dict(
{"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
)
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/rooms/%s/event/%s" % (self.room, self.parent_id), f"/rooms/{self.room}/event/{self.parent_id}",
access_token=self.user_token, access_token=self.user_token,
) )
self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(200, channel.code, channel.json_body)
self.assertEquals(channel.json_body["content"], new_body) self.assertEquals(channel.json_body["content"], new_body)
assert_bundle(channel.json_body)
relations_dict = channel.json_body["unsigned"].get("m.relations") # Request the room messages.
self.assertIn(RelationTypes.REPLACE, relations_dict) channel = self.make_request(
"GET",
f"/rooms/{self.room}/messages?dir=b",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
m_replace_dict = relations_dict[RelationTypes.REPLACE] # Request the room context.
for key in ["event_id", "sender", "origin_server_ts"]: channel = self.make_request(
self.assertIn(key, m_replace_dict) "GET",
f"/rooms/{self.room}/context/{self.parent_id}",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
assert_bundle(channel.json_body["event"])
self.assert_dict( # Request sync, but limit the timeline so it becomes limited (and includes
{"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict # bundled aggregations).
filter = urllib.parse.quote_plus(
'{"room": {"timeline": {"limit": 2}}}'.encode()
)
channel = self.make_request(
"GET", f"/sync?filter={filter}", access_token=self.user_token
) )
self.assertEquals(200, channel.code, channel.json_body)
room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
self.assertTrue(room_timeline["limited"])
assert_bundle(self._find_event_in_chunk(room_timeline["events"]))
def test_multi_edit(self): def test_multi_edit(self):
"""Test that multiple edits, including attempts by people who """Test that multiple edits, including attempts by people who
@ -1102,6 +1130,16 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(200, channel.code, channel.json_body)
self.assertEquals(channel.json_body["chunk"], []) self.assertEquals(channel.json_body["chunk"], [])
def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict:
"""
Find the parent event in a chunk of events and assert that it has the proper bundled aggregations.
"""
for event in events:
if event["event_id"] == self.parent_id:
return event
raise AssertionError(f"Event {self.parent_id} not found in chunk")
def _send_relation( def _send_relation(
self, self,
relation_type: str, relation_type: str,

Loading…
Cancel
Save