@ -15,12 +15,17 @@
import atexit
import os
from typing import Any , Callable , Dict , List , Tuple , Union , overload
import attr
from typing_extensions import Literal , ParamSpec
from synapse . api . constants import EventTypes
from synapse . api . room_versions import RoomVersions
from synapse . config . homeserver import HomeServerConfig
from synapse . config . server import DEFAULT_ROOM_VERSION
from synapse . logging . context import current_context , set_current_context
from synapse . server import HomeServer
from synapse . storage . database import LoggingDatabaseConnection
from synapse . storage . engines import create_engine
from synapse . storage . prepare_database import prepare_database
@ -50,12 +55,11 @@ SQLITE_PERSIST_DB = os.environ.get("SYNAPSE_TEST_PERSIST_SQLITE_DB") is not None
POSTGRES_DBNAME_FOR_INITIAL_CREATE = " postgres "
def setupdb ( ) :
def setupdb ( ) - > None :
# If we're using PostgreSQL, set up the db once
if USE_POSTGRES_FOR_TESTS :
# create a PostgresEngine
db_engine = create_engine ( { " name " : " psycopg2 " , " args " : { } } )
# connect to postgres to create the base database.
db_conn = db_engine . module . connect (
user = POSTGRES_USER ,
@ -82,11 +86,11 @@ def setupdb():
port = POSTGRES_PORT ,
password = POSTGRES_PASSWORD ,
)
db _conn = LoggingDatabaseConnection ( db_conn , db_engine , " tests " )
prepare_database ( db _conn, db_engine , None )
db _conn. close ( )
logging _conn = LoggingDatabaseConnection ( db_conn , db_engine , " tests " )
prepare_database ( logging _conn, db_engine , None )
logging _conn. close ( )
def _cleanup ( ) :
def _cleanup ( ) - > None :
db_conn = db_engine . module . connect (
user = POSTGRES_USER ,
host = POSTGRES_HOST ,
@ -103,7 +107,19 @@ def setupdb():
atexit . register ( _cleanup )
def default_config ( name , parse = False ) :
@overload
def default_config ( name : str , parse : Literal [ False ] = . . . ) - > Dict [ str , object ] :
. . .
@overload
def default_config ( name : str , parse : Literal [ True ] ) - > HomeServerConfig :
. . .
def default_config (
name : str , parse : bool = False
) - > Union [ Dict [ str , object ] , HomeServerConfig ] :
"""
Create a reasonable test config .
"""
@ -181,90 +197,122 @@ def default_config(name, parse=False):
return config_dict
def mock_getRawHeaders ( headers = None ) :
def mock_getRawHeaders ( headers = None ) : # type: ignore[no-untyped-def]
headers = headers if headers is not None else { }
def getRawHeaders ( name , default = None ) :
def getRawHeaders ( name , default = None ) : # type: ignore[no-untyped-def]
# If the requested header is present, the real twisted function returns
# List[str] if name is a str and List[bytes] if name is a bytes.
# This mock doesn't support that behaviour.
# Fortunately, none of the current callers of mock_getRawHeaders() provide a
# headers dict, so we don't encounter this discrepancy in practice.
return headers . get ( name , default )
return getRawHeaders
P = ParamSpec ( " P " )
@attr . s ( slots = True , auto_attribs = True )
class Timer :
absolute_time : float
callback : Callable [ [ ] , None ]
expired : bool
# TODO: Make this generic over a ParamSpec?
@attr . s ( slots = True , auto_attribs = True )
class Looper :
func : Callable [ . . . , Any ]
interval : float # seconds
last : float
args : Tuple [ object , . . . ]
kwargs : Dict [ str , object ]
class MockClock :
now = 1000
now = 1000. 0
def __init__ ( self ) :
# list of lists of [absolute_time, callback, expired] in no particular
# order
self . timers = [ ]
self . loopers = [ ]
def __init__ ( self ) - > None :
# Timers in no particular order
self . timers : List [ Timer ] = [ ]
self . loopers : List [ Looper ] = [ ]
def time ( self ) :
def time ( self ) - > float :
return self . now
def time_msec ( self ) :
return self . time ( ) * 1000
def time_msec ( self ) - > int :
return int ( self . time ( ) * 1000 )
def call_later ( self , delay , callback , * args , * * kwargs ) :
def call_later (
self ,
delay : float ,
callback : Callable [ P , object ] ,
* args : P . args ,
* * kwargs : P . kwargs ,
) - > Timer :
ctx = current_context ( )
def wrapped_callback ( ) :
def wrapped_callback ( ) - > None :
set_current_context ( ctx )
callback ( * args , * * kwargs )
t = [ self . now + delay , wrapped_callback , False ]
t = Timer ( self . now + delay , wrapped_callback , False )
self . timers . append ( t )
return t
def looping_call ( self , function , interval , * args , * * kwargs ) :
self . loopers . append ( [ function , interval / 1000.0 , self . now , args , kwargs ] )
def cancel_call_later ( self , timer , ignore_errs = False ) :
if timer [ 2 ] :
def looping_call (
self ,
function : Callable [ P , object ] ,
interval : float ,
* args : P . args ,
* * kwargs : P . kwargs ,
) - > None :
# This type-ignore should be redundant once we use a mypy release with
# https://github.com/python/mypy/pull/12668.
self . loopers . append ( Looper ( function , interval / 1000.0 , self . now , args , kwargs ) ) # type: ignore[arg-type]
def cancel_call_later ( self , timer : Timer , ignore_errs : bool = False ) - > None :
if timer . expired :
if not ignore_errs :
raise Exception ( " Cannot cancel an expired timer " )
timer [ 2 ] = True
timer . expired = True
self . timers = [ t for t in self . timers if t != timer ]
# For unit testing
def advance_time ( self , secs ) :
def advance_time ( self , secs : float ) - > None :
self . now + = secs
timers = self . timers
self . timers = [ ]
for t in timers :
time , callback , expired = t
if expired :
if t . expired :
raise Exception ( " Timer already expired " )
if self . now > = time :
t [ 2 ] = True
callback ( )
if self . now > = t . absolute_t ime:
t . expired = True
t . callback ( )
else :
self . timers . append ( t )
for looped in self . loopers :
func , interval , last , args , kwargs = looped
if last + interval < self . now :
func ( * args , * * kwargs )
looped [ 2 ] = self . now
if looped . last + looped . interval < self . now :
looped . func ( * looped . args , * * looped . kwargs )
looped . last = self . now
def advance_time_msec ( self , ms ) :
def advance_time_msec ( self , ms : float ) - > None :
self . advance_time ( ms / 1000.0 )
def time_bound_deferred ( self , d , * args , * * kwargs ) :
# We don't bother timing things out for now.
return d
async def create_room ( hs , room_id : str , creator_id : str ) :
async def create_room ( hs : HomeServer , room_id : str , creator_id : str ) - > None :
""" Creates and persist a creation event for the given room """
persistence_store = hs . get_storage_controllers ( ) . persistence
assert persistence_store is not None
store = hs . get_datastores ( ) . main
event_builder_factory = hs . get_event_builder_factory ( )
event_creation_handler = hs . get_event_creation_handler ( )