@ -13,7 +13,7 @@
# limitations under the License.
import datetime
from typing import Dict , List , Tuple , Union
from typing import Dict , List , Tuple , Union , cast
import attr
from parameterized import parameterized
@ -26,11 +26,12 @@ from synapse.api.room_versions import (
EventFormatVersions ,
RoomVersion ,
)
from synapse . events import _EventInternalMetadata
from synapse . events import EventBase , _EventInternalMetadata
from synapse . rest import admin
from synapse . rest . client import login , room
from synapse . server import HomeServer
from synapse . storage . database import LoggingTransaction
from synapse . storage . types import Cursor
from synapse . types import JsonDict
from synapse . util import Clock , json_encoder
@ -54,11 +55,11 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def prepare ( self , reactor : MemoryReactor , clock : Clock , hs : HomeServer ) - > None :
self . store = hs . get_datastores ( ) . main
def test_get_prev_events_for_room ( self ) :
def test_get_prev_events_for_room ( self ) - > None :
room_id = " @ROOM:local "
# add a bunch of events and hashes to act as forward extremities
def insert_event ( txn , i ) :
def insert_event ( txn : Cursor , i : int ) - > None :
event_id = " $event_ %i :local " % i
txn . execute (
@ -90,12 +91,12 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
for i in range ( 0 , 10 ) :
self . assertEqual ( " $event_ %i :local " % ( 19 - i ) , r [ i ] )
def test_get_rooms_with_many_extremities ( self ) :
def test_get_rooms_with_many_extremities ( self ) - > None :
room1 = " #room1 "
room2 = " #room2 "
room3 = " #room3 "
def insert_event ( txn , i , room_id ) :
def insert_event ( txn : Cursor , i : int , room_id : str ) - > None :
event_id = " $event_ %i :local " % i
txn . execute (
(
@ -155,7 +156,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# | |
# K J
auth_graph = {
auth_graph : Dict [ str , List [ str ] ] = {
" a " : [ " e " ] ,
" b " : [ " e " ] ,
" c " : [ " g " , " i " ] ,
@ -185,7 +186,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# Mark the room as maybe having a cover index.
def store_room ( txn ) :
def store_room ( txn : LoggingTransaction ) - > None :
self . store . db_pool . simple_insert_txn (
txn ,
" rooms " ,
@ -203,7 +204,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# We rudely fiddle with the appropriate tables directly, as that's much
# easier than constructing events properly.
def insert_event ( txn ) :
def insert_event ( txn : LoggingTransaction ) - > None :
stream_ordering = 0
for event_id in auth_graph :
@ -228,7 +229,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self . hs . datastores . persist_events . _persist_event_auth_chain_txn (
txn ,
[
FakeEvent ( event_id , room_id , auth_graph [ event_id ] )
cast ( EventBase , FakeEvent ( event_id , room_id , auth_graph [ event_id ] ) )
for event_id in auth_graph
] ,
)
@ -243,7 +244,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
return room_id
@parameterized . expand ( [ ( True , ) , ( False , ) ] )
def test_auth_chain_ids ( self , use_chain_cover_index : bool ) :
def test_auth_chain_ids ( self , use_chain_cover_index : bool ) - > None :
room_id = self . _setup_auth_chain ( use_chain_cover_index )
# a and b have the same auth chain.
@ -308,7 +309,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self . assertCountEqual ( auth_chain_ids , [ " i " , " j " ] )
@parameterized . expand ( [ ( True , ) , ( False , ) ] )
def test_auth_difference ( self , use_chain_cover_index : bool ) :
def test_auth_difference ( self , use_chain_cover_index : bool ) - > None :
room_id = self . _setup_auth_chain ( use_chain_cover_index )
# Now actually test that various combinations give the right result:
@ -353,7 +354,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
)
self . assertSetEqual ( difference , set ( ) )
def test_auth_difference_partial_cover ( self ) :
def test_auth_difference_partial_cover ( self ) - > None :
""" Test that we correctly handle rooms where not all events have a chain
cover calculated . This can happen in some obscure edge cases , including
during the background update that calculates the chain cover for old
@ -377,7 +378,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# | |
# K J
auth_graph = {
auth_graph : Dict [ str , List [ str ] ] = {
" a " : [ " e " ] ,
" b " : [ " e " ] ,
" c " : [ " g " , " i " ] ,
@ -408,7 +409,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# We rudely fiddle with the appropriate tables directly, as that's much
# easier than constructing events properly.
def insert_event ( txn ) :
def insert_event ( txn : LoggingTransaction ) - > None :
# First insert the room and mark it as having a chain cover.
self . store . db_pool . simple_insert_txn (
txn ,
@ -447,7 +448,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self . hs . datastores . persist_events . _persist_event_auth_chain_txn (
txn ,
[
FakeEvent ( event_id , room_id , auth_graph [ event_id ] )
cast ( EventBase , FakeEvent ( event_id , room_id , auth_graph [ event_id ] ) )
for event_id in auth_graph
if event_id != " b "
] ,
@ -465,7 +466,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self . hs . datastores . persist_events . _persist_event_auth_chain_txn (
txn ,
[ FakeEvent ( " b " , room_id , auth_graph [ " b " ] ) ] ,
[ cast ( EventBase , FakeEvent ( " b " , room_id , auth_graph [ " b " ] ) ) ] ,
)
self . store . db_pool . simple_update_txn (
@ -527,7 +528,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
@parameterized . expand (
[ ( room_version , ) for room_version in KNOWN_ROOM_VERSIONS . values ( ) ]
)
def test_prune_inbound_federation_queue ( self , room_version : RoomVersion ) :
def test_prune_inbound_federation_queue ( self , room_version : RoomVersion ) - > None :
""" Test that pruning of inbound federation queues work """
room_id = " some_room_id "
@ -686,7 +687,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
stream_ordering + = 1
def populate_db ( txn : LoggingTransaction ) :
def populate_db ( txn : LoggingTransaction ) - > None :
# Insert the room to satisfy the foreign key constraint of
# `event_failed_pull_attempts`
self . store . db_pool . simple_insert_txn (
@ -760,7 +761,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
return _BackfillSetupInfo ( room_id = room_id , depth_map = depth_map )
def test_get_backfill_points_in_room ( self ) :
def test_get_backfill_points_in_room ( self ) - > None :
"""
Test to make sure only backfill points that are older and come before
the ` current_depth ` are returned .
@ -787,7 +788,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def test_get_backfill_points_in_room_excludes_events_we_have_attempted (
self ,
) :
) - > None :
"""
Test to make sure that events we have attempted to backfill ( and within
backoff timeout duration ) do not show up as an event to backfill again .
@ -824,7 +825,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def test_get_backfill_points_in_room_attempted_event_retry_after_backoff_duration (
self ,
) :
) - > None :
"""
Test to make sure after we fake attempt to backfill event " b3 " many times ,
we can see retry and see the " b3 " again after the backoff timeout duration
@ -941,7 +942,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
" 5 " : 7 ,
}
def populate_db ( txn : LoggingTransaction ) :
def populate_db ( txn : LoggingTransaction ) - > None :
# Insert the room to satisfy the foreign key constraint of
# `event_failed_pull_attempts`
self . store . db_pool . simple_insert_txn (
@ -996,7 +997,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
return _BackfillSetupInfo ( room_id = room_id , depth_map = depth_map )
def test_get_insertion_event_backward_extremities_in_room ( self ) :
def test_get_insertion_event_backward_extremities_in_room ( self ) - > None :
"""
Test to make sure only insertion event backward extremities that are
older and come before the ` current_depth ` are returned .
@ -1027,7 +1028,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def test_get_insertion_event_backward_extremities_in_room_excludes_events_we_have_attempted (
self ,
) :
) - > None :
"""
Test to make sure that insertion events we have attempted to backfill
( and within backoff timeout duration ) do not show up as an event to
@ -1060,7 +1061,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def test_get_insertion_event_backward_extremities_in_room_attempted_event_retry_after_backoff_duration (
self ,
) :
) - > None :
"""
Test to make sure after we fake attempt to backfill event
" insertion_eventA " many times , we can see retry and see the
@ -1130,9 +1131,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
backfill_event_ids = [ backfill_point [ 0 ] for backfill_point in backfill_points ]
self . assertEqual ( backfill_event_ids , [ " insertion_eventA " ] )
def test_get_event_ids_to_not_pull_from_backoff (
self ,
) :
def test_get_event_ids_to_not_pull_from_backoff ( self ) - > None :
"""
Test to make sure only event IDs we should backoff from are returned .
"""
@ -1157,7 +1156,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def test_get_event_ids_to_not_pull_from_backoff_retry_after_backoff_duration (
self ,
) :
) - > None :
"""
Test to make sure no event IDs are returned after the backoff duration has
elapsed .
@ -1187,19 +1186,19 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self . assertEqual ( event_ids_to_backoff , [ ] )
@attr . s
@attr . s ( auto_attribs = True )
class FakeEvent :
event_id = attr . ib ( )
room_id = attr . ib ( )
auth_events = attr . ib ( )
event_id : str
room_id : str
auth_events : List [ str ]
type = " foo "
state_key = " foo "
internal_metadata = _EventInternalMetadata ( { } )
def auth_event_ids ( self ) :
def auth_event_ids ( self ) - > List [ str ] :
return self . auth_events
def is_state ( self ) :
def is_state ( self ) - > bool :
return True