|
|
|
@ -11,9 +11,12 @@ |
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
|
# See the License for the specific language governing permissions and |
|
|
|
|
# limitations under the License. |
|
|
|
|
|
|
|
|
|
import hashlib |
|
|
|
|
import json |
|
|
|
|
import logging |
|
|
|
|
import time |
|
|
|
|
import uuid |
|
|
|
|
import warnings |
|
|
|
|
from collections import deque |
|
|
|
|
from io import SEEK_END, BytesIO |
|
|
|
|
from typing import ( |
|
|
|
@ -27,6 +30,7 @@ from typing import ( |
|
|
|
|
Type, |
|
|
|
|
Union, |
|
|
|
|
) |
|
|
|
|
from unittest.mock import Mock |
|
|
|
|
|
|
|
|
|
import attr |
|
|
|
|
from typing_extensions import Deque |
|
|
|
@ -53,11 +57,24 @@ from twisted.web.http_headers import Headers |
|
|
|
|
from twisted.web.resource import IResource |
|
|
|
|
from twisted.web.server import Request, Site |
|
|
|
|
|
|
|
|
|
from synapse.config.database import DatabaseConnectionConfig |
|
|
|
|
from synapse.http.site import SynapseRequest |
|
|
|
|
from synapse.server import HomeServer |
|
|
|
|
from synapse.storage import DataStore |
|
|
|
|
from synapse.storage.engines import PostgresEngine, create_engine |
|
|
|
|
from synapse.types import JsonDict |
|
|
|
|
from synapse.util import Clock |
|
|
|
|
|
|
|
|
|
from tests.utils import setup_test_homeserver as _sth |
|
|
|
|
from tests.utils import ( |
|
|
|
|
LEAVE_DB, |
|
|
|
|
POSTGRES_BASE_DB, |
|
|
|
|
POSTGRES_HOST, |
|
|
|
|
POSTGRES_PASSWORD, |
|
|
|
|
POSTGRES_USER, |
|
|
|
|
USE_POSTGRES_FOR_TESTS, |
|
|
|
|
MockClock, |
|
|
|
|
default_config, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@ -450,14 +467,11 @@ class ThreadPool: |
|
|
|
|
return d |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def setup_test_homeserver(cleanup_func, *args, **kwargs): |
|
|
|
|
def _make_test_homeserver_synchronous(server: HomeServer) -> None: |
|
|
|
|
""" |
|
|
|
|
Set up a synchronous test server, driven by the reactor used by |
|
|
|
|
the homeserver. |
|
|
|
|
Make the given test homeserver's database interactions synchronous. |
|
|
|
|
""" |
|
|
|
|
server = _sth(cleanup_func, *args, **kwargs) |
|
|
|
|
|
|
|
|
|
# Make the thread pool synchronous. |
|
|
|
|
clock = server.get_clock() |
|
|
|
|
|
|
|
|
|
for database in server.get_datastores().databases: |
|
|
|
@ -485,6 +499,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): |
|
|
|
|
|
|
|
|
|
pool.runWithConnection = runWithConnection |
|
|
|
|
pool.runInteraction = runInteraction |
|
|
|
|
# Replace the thread pool with a threadless 'thread' pool |
|
|
|
|
pool.threadpool = ThreadPool(clock._reactor) |
|
|
|
|
pool.running = True |
|
|
|
|
|
|
|
|
@ -492,8 +507,6 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): |
|
|
|
|
# thread, so we need to disable the dedicated thread behaviour. |
|
|
|
|
server.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False |
|
|
|
|
|
|
|
|
|
return server |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]: |
|
|
|
|
clock = ThreadedMemoryReactorClock() |
|
|
|
@ -673,3 +686,171 @@ def connect_client( |
|
|
|
|
client.makeConnection(FakeTransport(server, reactor)) |
|
|
|
|
|
|
|
|
|
return client, server |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestHomeServer(HomeServer): |
|
|
|
|
DATASTORE_CLASS = DataStore |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def setup_test_homeserver( |
|
|
|
|
cleanup_func, |
|
|
|
|
name="test", |
|
|
|
|
config=None, |
|
|
|
|
reactor=None, |
|
|
|
|
homeserver_to_use: Type[HomeServer] = TestHomeServer, |
|
|
|
|
**kwargs, |
|
|
|
|
): |
|
|
|
|
""" |
|
|
|
|
Setup a homeserver suitable for running tests against. Keyword arguments |
|
|
|
|
are passed to the Homeserver constructor. |
|
|
|
|
|
|
|
|
|
If no datastore is supplied, one is created and given to the homeserver. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
cleanup_func : The function used to register a cleanup routine for |
|
|
|
|
after the test. |
|
|
|
|
|
|
|
|
|
Calling this method directly is deprecated: you should instead derive from |
|
|
|
|
HomeserverTestCase. |
|
|
|
|
""" |
|
|
|
|
if reactor is None: |
|
|
|
|
from twisted.internet import reactor |
|
|
|
|
|
|
|
|
|
if config is None: |
|
|
|
|
config = default_config(name, parse=True) |
|
|
|
|
|
|
|
|
|
config.ldap_enabled = False |
|
|
|
|
|
|
|
|
|
if "clock" not in kwargs: |
|
|
|
|
kwargs["clock"] = MockClock() |
|
|
|
|
|
|
|
|
|
if USE_POSTGRES_FOR_TESTS: |
|
|
|
|
test_db = "synapse_test_%s" % uuid.uuid4().hex |
|
|
|
|
|
|
|
|
|
database_config = { |
|
|
|
|
"name": "psycopg2", |
|
|
|
|
"args": { |
|
|
|
|
"database": test_db, |
|
|
|
|
"host": POSTGRES_HOST, |
|
|
|
|
"password": POSTGRES_PASSWORD, |
|
|
|
|
"user": POSTGRES_USER, |
|
|
|
|
"cp_min": 1, |
|
|
|
|
"cp_max": 5, |
|
|
|
|
}, |
|
|
|
|
} |
|
|
|
|
else: |
|
|
|
|
database_config = { |
|
|
|
|
"name": "sqlite3", |
|
|
|
|
"args": {"database": ":memory:", "cp_min": 1, "cp_max": 1}, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if "db_txn_limit" in kwargs: |
|
|
|
|
database_config["txn_limit"] = kwargs["db_txn_limit"] |
|
|
|
|
|
|
|
|
|
database = DatabaseConnectionConfig("master", database_config) |
|
|
|
|
config.database.databases = [database] |
|
|
|
|
|
|
|
|
|
db_engine = create_engine(database.config) |
|
|
|
|
|
|
|
|
|
# Create the database before we actually try and connect to it, based off |
|
|
|
|
# the template database we generate in setupdb() |
|
|
|
|
if isinstance(db_engine, PostgresEngine): |
|
|
|
|
db_conn = db_engine.module.connect( |
|
|
|
|
database=POSTGRES_BASE_DB, |
|
|
|
|
user=POSTGRES_USER, |
|
|
|
|
host=POSTGRES_HOST, |
|
|
|
|
password=POSTGRES_PASSWORD, |
|
|
|
|
) |
|
|
|
|
db_conn.autocommit = True |
|
|
|
|
cur = db_conn.cursor() |
|
|
|
|
cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) |
|
|
|
|
cur.execute( |
|
|
|
|
"CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB) |
|
|
|
|
) |
|
|
|
|
cur.close() |
|
|
|
|
db_conn.close() |
|
|
|
|
|
|
|
|
|
hs = homeserver_to_use( |
|
|
|
|
name, |
|
|
|
|
config=config, |
|
|
|
|
version_string="Synapse/tests", |
|
|
|
|
reactor=reactor, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
# Install @cache_in_self attributes |
|
|
|
|
for key, val in kwargs.items(): |
|
|
|
|
setattr(hs, "_" + key, val) |
|
|
|
|
|
|
|
|
|
# Mock TLS |
|
|
|
|
hs.tls_server_context_factory = Mock() |
|
|
|
|
hs.tls_client_options_factory = Mock() |
|
|
|
|
|
|
|
|
|
hs.setup() |
|
|
|
|
if homeserver_to_use == TestHomeServer: |
|
|
|
|
hs.setup_background_tasks() |
|
|
|
|
|
|
|
|
|
if isinstance(db_engine, PostgresEngine): |
|
|
|
|
database = hs.get_datastores().databases[0] |
|
|
|
|
|
|
|
|
|
# We need to do cleanup on PostgreSQL |
|
|
|
|
def cleanup(): |
|
|
|
|
import psycopg2 |
|
|
|
|
|
|
|
|
|
# Close all the db pools |
|
|
|
|
database._db_pool.close() |
|
|
|
|
|
|
|
|
|
dropped = False |
|
|
|
|
|
|
|
|
|
# Drop the test database |
|
|
|
|
db_conn = db_engine.module.connect( |
|
|
|
|
database=POSTGRES_BASE_DB, |
|
|
|
|
user=POSTGRES_USER, |
|
|
|
|
host=POSTGRES_HOST, |
|
|
|
|
password=POSTGRES_PASSWORD, |
|
|
|
|
) |
|
|
|
|
db_conn.autocommit = True |
|
|
|
|
cur = db_conn.cursor() |
|
|
|
|
|
|
|
|
|
# Try a few times to drop the DB. Some things may hold on to the |
|
|
|
|
# database for a few more seconds due to flakiness, preventing |
|
|
|
|
# us from dropping it when the test is over. If we can't drop |
|
|
|
|
# it, warn and move on. |
|
|
|
|
for _ in range(5): |
|
|
|
|
try: |
|
|
|
|
cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) |
|
|
|
|
db_conn.commit() |
|
|
|
|
dropped = True |
|
|
|
|
except psycopg2.OperationalError as e: |
|
|
|
|
warnings.warn( |
|
|
|
|
"Couldn't drop old db: " + str(e), category=UserWarning |
|
|
|
|
) |
|
|
|
|
time.sleep(0.5) |
|
|
|
|
|
|
|
|
|
cur.close() |
|
|
|
|
db_conn.close() |
|
|
|
|
|
|
|
|
|
if not dropped: |
|
|
|
|
warnings.warn("Failed to drop old DB.", category=UserWarning) |
|
|
|
|
|
|
|
|
|
if not LEAVE_DB: |
|
|
|
|
# Register the cleanup hook |
|
|
|
|
cleanup_func(cleanup) |
|
|
|
|
|
|
|
|
|
# bcrypt is far too slow to be doing in unit tests |
|
|
|
|
# Need to let the HS build an auth handler and then mess with it |
|
|
|
|
# because AuthHandler's constructor requires the HS, so we can't make one |
|
|
|
|
# beforehand and pass it in to the HS's constructor (chicken / egg) |
|
|
|
|
async def hash(p): |
|
|
|
|
return hashlib.md5(p.encode("utf8")).hexdigest() |
|
|
|
|
|
|
|
|
|
hs.get_auth_handler().hash = hash |
|
|
|
|
|
|
|
|
|
async def validate_hash(p, h): |
|
|
|
|
return hashlib.md5(p.encode("utf8")).hexdigest() == h |
|
|
|
|
|
|
|
|
|
hs.get_auth_handler().validate_hash = validate_hash |
|
|
|
|
|
|
|
|
|
# Make the threadpool and database transactions synchronous for testing. |
|
|
|
|
_make_test_homeserver_synchronous(hs) |
|
|
|
|
|
|
|
|
|
return hs |
|
|
|
|