@ -19,7 +19,6 @@ import urllib
import urlparse
from mock import Mock , patch
from twisted . enterprise . adbapi import ConnectionPool
from twisted . internet import defer , reactor
from synapse . api . errors import CodeMessageException , cs_error
@ -60,30 +59,37 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
config . update_user_directory = False
config . use_frozen_dicts = True
config . database_config = { " name " : " sqlite3 " }
config . ldap_enabled = False
if " clock " not in kargs :
kargs [ " clock " ] = MockClock ( )
config . database_config = {
" name " : " sqlite3 " ,
" args " : {
" database " : " :memory: " ,
" cp_min " : 1 ,
" cp_max " : 1 ,
} ,
}
db_engine = create_engine ( config . database_config )
# we need to configure the connection pool to run the on_new_connection
# function, so that we can test code that uses custom sqlite functions
# (like rank).
config . database_config [ " args " ] [ " cp_openfun " ] = db_engine . on_new_connection
if datastore is None :
# we need to configure the connection pool to run the on_new_connection
# function, so that we can test code that uses custom sqlite functions
# (like rank).
db_pool = SQLiteMemoryDbPool (
cp_openfun = db_engine . on_new_connection ,
)
yield db_pool . prepare ( )
hs = HomeServer (
name , db_pool = db_pool , config = config ,
name , config = config ,
db_config = config . database_config ,
version_string = " Synapse/tests " ,
database_engine = db_engine ,
get_db_conn = db_pool . get_db_conn ,
room_list_handler = object ( ) ,
tls_server_context_factory = Mock ( ) ,
* * kargs
)
yield prepare_database ( hs . get_db_conn ( ) , db_engine , config )
hs . setup ( )
else :
hs = HomeServer (
@ -308,38 +314,6 @@ class MockClock(object):
return d
class SQLiteMemoryDbPool ( ConnectionPool , object ) :
def __init__ ( self , * * kwargs ) :
connkw = {
" cp_min " : 1 ,
" cp_max " : 1 ,
}
connkw . update ( kwargs )
super ( SQLiteMemoryDbPool , self ) . __init__ (
" sqlite3 " , " :memory: " , * * connkw
)
self . config = Mock ( )
self . config . password_providers = [ ]
self . config . database_config = { " name " : " sqlite3 " }
def prepare ( self ) :
engine = self . create_engine ( )
return self . runWithConnection (
lambda conn : prepare_database ( conn , engine , self . config )
)
def get_db_conn ( self ) :
conn = self . connect ( )
engine = self . create_engine ( )
prepare_database ( conn , engine , self . config )
return conn
def create_engine ( self ) :
return create_engine ( self . config . database_config )
def _format_call ( args , kwargs ) :
return " , " . join (
[ " %r " % ( a ) for a in args ] +