Pass around the reactor explicitly (#3385)

pull/14/head
Amber Brown 7 years ago committed by GitHub
parent c2eff937ac
commit 77ac14b960
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 30
      synapse/handlers/auth.py
  2. 1
      synapse/handlers/message.py
  3. 9
      synapse/handlers/user_directory.py
  4. 6
      synapse/http/client.py
  5. 5
      synapse/http/matrixfederationclient.py
  6. 3
      synapse/notifier.py
  7. 6
      synapse/replication/http/send_event.py
  8. 3
      synapse/rest/media/v1/media_repository.py
  9. 7
      synapse/rest/media/v1/media_storage.py
  10. 19
      synapse/server.py
  11. 3
      synapse/storage/background_updates.py
  12. 6
      synapse/storage/client_ips.py
  13. 3
      synapse/storage/event_push_actions.py
  14. 6
      synapse/storage/events_worker.py
  15. 32
      synapse/util/__init__.py
  16. 25
      synapse/util/async.py
  17. 16
      synapse/util/file_consumer.py
  18. 3
      synapse/util/ratelimitutils.py
  19. 9
      tests/crypto/test_keyring.py
  20. 6
      tests/rest/client/test_transactions.py
  21. 5
      tests/rest/media/v1/test_media_storage.py
  22. 6
      tests/util/test_file_consumer.py
  23. 7
      tests/util/test_linearizer.py
  24. 11
      tests/util/test_logcontext.py
  25. 7
      tests/utils.py

@ -33,6 +33,7 @@ import logging
import bcrypt
import pymacaroons
import simplejson
import attr
import synapse.util.stringutils as stringutils
@ -854,7 +855,11 @@ class AuthHandler(BaseHandler):
return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
bcrypt.gensalt(self.bcrypt_rounds))
return make_deferred_yieldable(threads.deferToThread(_do_hash))
return make_deferred_yieldable(
threads.deferToThreadPool(
self.hs.get_reactor(), self.hs.get_reactor().getThreadPool(), _do_hash
),
)
def validate_hash(self, password, stored_hash):
"""Validates that self.hash(password) == stored_hash.
@ -874,16 +879,21 @@ class AuthHandler(BaseHandler):
)
if stored_hash:
return make_deferred_yieldable(threads.deferToThread(_do_validate_hash))
return make_deferred_yieldable(
threads.deferToThreadPool(
self.hs.get_reactor(),
self.hs.get_reactor().getThreadPool(),
_do_validate_hash,
),
)
else:
return defer.succeed(False)
class MacaroonGeneartor(object):
def __init__(self, hs):
self.clock = hs.get_clock()
self.server_name = hs.config.server_name
self.macaroon_secret_key = hs.config.macaroon_secret_key
@attr.s
class MacaroonGenerator(object):
hs = attr.ib()
def generate_access_token(self, user_id, extra_caveats=None):
extra_caveats = extra_caveats or []
@ -901,7 +911,7 @@ class MacaroonGeneartor(object):
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login")
now = self.clock.time_msec()
now = self.hs.get_clock().time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
@ -913,9 +923,9 @@ class MacaroonGeneartor(object):
def _generate_base_macaroon(self, user_id):
macaroon = pymacaroons.Macaroon(
location=self.server_name,
location=self.hs.config.server_name,
identifier="key",
key=self.macaroon_secret_key)
key=self.hs.config.macaroon_secret_key)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon

@ -806,6 +806,7 @@ class EventCreationHandler(object):
# If we're a worker we need to hit out to the master.
if self.config.worker_app:
yield send_event_to_master(
self.hs.get_clock(),
self.http_client,
host=self.config.worker_replication_host,
port=self.config.worker_replication_http_port,

@ -19,7 +19,6 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.storage.roommember import ProfileInfo
from synapse.util.metrics import Measure
from synapse.util.async import sleep
from synapse.types import get_localpart_from_id
from six import iteritems
@ -174,7 +173,7 @@ class UserDirectoryHandler(object):
logger.info("Handling room %d/%d", num_processed_rooms + 1, len(room_ids))
yield self._handle_initial_room(room_id)
num_processed_rooms += 1
yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
logger.info("Processed all rooms.")
@ -188,7 +187,7 @@ class UserDirectoryHandler(object):
logger.info("Handling user %d/%d", num_processed_users + 1, len(user_ids))
yield self._handle_local_user(user_id)
num_processed_users += 1
yield sleep(self.INITIAL_USER_SLEEP_MS / 1000.)
yield self.clock.sleep(self.INITIAL_USER_SLEEP_MS / 1000.)
logger.info("Processed all users")
@ -236,7 +235,7 @@ class UserDirectoryHandler(object):
count = 0
for user_id in user_ids:
if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
if not self.is_mine_id(user_id):
count += 1
@ -251,7 +250,7 @@ class UserDirectoryHandler(object):
continue
if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
count += 1
user_set = (user_id, other_user_id)

@ -98,8 +98,8 @@ class SimpleHttpClient(object):
method, uri, *args, **kwargs
)
add_timeout_to_deferred(
request_deferred,
60, cancelled_to_request_timed_out_error,
request_deferred, 60, self.hs.get_reactor(),
cancelled_to_request_timed_out_error,
)
response = yield make_deferred_yieldable(request_deferred)
@ -115,7 +115,7 @@ class SimpleHttpClient(object):
"Error sending request to %s %s: %s %s",
method, redact_uri(uri), type(e).__name__, e.message
)
raise e
raise
@defer.inlineCallbacks
def post_urlencoded_get_json(self, uri, args={}, headers=None):

@ -22,7 +22,7 @@ from twisted.web._newclient import ResponseDone
from synapse.http import cancelled_to_request_timed_out_error
from synapse.http.endpoint import matrix_federation_endpoint
import synapse.metrics
from synapse.util.async import sleep, add_timeout_to_deferred
from synapse.util.async import add_timeout_to_deferred
from synapse.util import logcontext
from synapse.util.logcontext import make_deferred_yieldable
import synapse.util.retryutils
@ -193,6 +193,7 @@ class MatrixFederationHttpClient(object):
add_timeout_to_deferred(
request_deferred,
timeout / 1000. if timeout else 60,
self.hs.get_reactor(),
cancelled_to_request_timed_out_error,
)
response = yield make_deferred_yieldable(
@ -234,7 +235,7 @@ class MatrixFederationHttpClient(object):
delay = min(delay, 2)
delay *= random.uniform(0.8, 1.4)
yield sleep(delay)
yield self.clock.sleep(delay)
retries_left -= 1
else:
raise

@ -161,6 +161,7 @@ class Notifier(object):
self.user_to_user_stream = {}
self.room_to_user_streams = {}
self.hs = hs
self.event_sources = hs.get_event_sources()
self.store = hs.get_datastore()
self.pending_new_room_events = []
@ -340,6 +341,7 @@ class Notifier(object):
add_timeout_to_deferred(
listener.deferred,
(end_time - now) / 1000.,
self.hs.get_reactor(),
)
with PreserveLoggingContext():
yield listener.deferred
@ -561,6 +563,7 @@ class Notifier(object):
add_timeout_to_deferred(
listener.deferred.addTimeout,
(end_time - now) / 1000.,
self.hs.get_reactor(),
)
try:
with PreserveLoggingContext():

@ -21,7 +21,6 @@ from synapse.api.errors import (
from synapse.events import FrozenEvent
from synapse.events.snapshot import EventContext
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.util.async import sleep
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.metrics import Measure
from synapse.types import Requester, UserID
@ -33,11 +32,12 @@ logger = logging.getLogger(__name__)
@defer.inlineCallbacks
def send_event_to_master(client, host, port, requester, event, context,
def send_event_to_master(clock, client, host, port, requester, event, context,
ratelimit, extra_users):
"""Send event to be handled on the master
Args:
clock (synapse.util.Clock)
client (SimpleHttpClient)
host (str): host of master
port (int): port on master listening for HTTP replication
@ -77,7 +77,7 @@ def send_event_to_master(client, host, port, requester, event, context,
# If we timed out we probably don't need to worry about backing
# off too much, but lets just wait a little anyway.
yield sleep(1)
yield clock.sleep(1)
except MatrixCodeMessageException as e:
# We convert to SynapseError as we know that it was a SynapseError
# on the master process that we should send to the client. (And

@ -58,6 +58,7 @@ UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
class MediaRepository(object):
def __init__(self, hs):
self.hs = hs
self.auth = hs.get_auth()
self.client = MatrixFederationHttpClient(hs)
self.clock = hs.get_clock()
@ -94,7 +95,7 @@ class MediaRepository(object):
storage_providers.append(provider)
self.media_storage = MediaStorage(
self.primary_base_path, self.filepaths, storage_providers,
self.hs, self.primary_base_path, self.filepaths, storage_providers,
)
self.clock.looping_call(

@ -37,13 +37,15 @@ class MediaStorage(object):
"""Responsible for storing/fetching files from local sources.
Args:
hs (synapse.server.Homeserver)
local_media_directory (str): Base path where we store media on disk
filepaths (MediaFilePaths)
storage_providers ([StorageProvider]): List of StorageProvider that are
used to fetch and store files.
"""
def __init__(self, local_media_directory, filepaths, storage_providers):
def __init__(self, hs, local_media_directory, filepaths, storage_providers):
self.hs = hs
self.local_media_directory = local_media_directory
self.filepaths = filepaths
self.storage_providers = storage_providers
@ -175,7 +177,8 @@ class MediaStorage(object):
res = yield provider.fetch(path, file_info)
if res:
with res:
consumer = BackgroundFileConsumer(open(local_path, "w"))
consumer = BackgroundFileConsumer(
open(local_path, "w"), self.hs.get_reactor())
yield res.write_to_consumer(consumer)
yield consumer.wait()
defer.returnValue(local_path)

@ -40,7 +40,7 @@ from synapse.federation.transport.client import TransportLayerClient
from synapse.federation.transaction_queue import TransactionQueue
from synapse.handlers import Handlers
from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.handlers.auth import AuthHandler, MacaroonGeneartor
from synapse.handlers.auth import AuthHandler, MacaroonGenerator
from synapse.handlers.deactivate_account import DeactivateAccountHandler
from synapse.handlers.devicemessage import DeviceMessageHandler
from synapse.handlers.device import DeviceHandler
@ -165,15 +165,19 @@ class HomeServer(object):
'server_notices_sender',
]
def __init__(self, hostname, **kwargs):
def __init__(self, hostname, reactor=None, **kwargs):
"""
Args:
hostname : The hostname for the server.
"""
if not reactor:
from twisted.internet import reactor
self._reactor = reactor
self.hostname = hostname
self._building = {}
self.clock = Clock()
self.clock = Clock(reactor)
self.distributor = Distributor()
self.ratelimiter = Ratelimiter()
@ -186,6 +190,12 @@ class HomeServer(object):
self.datastore = DataStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def get_reactor(self):
"""
Fetch the Twisted reactor in use by this HomeServer.
"""
return self._reactor
def get_ip_from_request(self, request):
# X-Forwarded-For is handled by our custom request type.
return request.getClientIP()
@ -261,7 +271,7 @@ class HomeServer(object):
return AuthHandler(self)
def build_macaroon_generator(self):
return MacaroonGeneartor(self)
return MacaroonGenerator(self)
def build_device_handler(self):
return DeviceHandler(self)
@ -328,6 +338,7 @@ class HomeServer(object):
return adbapi.ConnectionPool(
name,
cp_reactor=self.get_reactor(),
**self.db_config.get("args", {})
)

@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse.util.async
from ._base import SQLBaseStore
from . import engines
@ -92,7 +91,7 @@ class BackgroundUpdateStore(SQLBaseStore):
logger.info("Starting background schema updates")
while True:
yield synapse.util.async.sleep(
yield self.hs.get_clock().sleep(
self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.)
try:

@ -15,7 +15,7 @@
import logging
from twisted.internet import defer, reactor
from twisted.internet import defer
from ._base import Cache
from . import background_updates
@ -70,7 +70,9 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
self._client_ip_looper = self._clock.looping_call(
self._update_client_ips_batch, 5 * 1000
)
reactor.addSystemEventTrigger("before", "shutdown", self._update_client_ips_batch)
self.hs.get_reactor().addSystemEventTrigger(
"before", "shutdown", self._update_client_ips_batch
)
def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id,
now=None):

@ -16,7 +16,6 @@
from synapse.storage._base import SQLBaseStore, LoggingTransaction
from twisted.internet import defer
from synapse.util.async import sleep
from synapse.util.caches.descriptors import cachedInlineCallbacks
import logging
@ -800,7 +799,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
)
if caught_up:
break
yield sleep(5)
yield self.hs.get_clock().sleep(5)
finally:
self._doing_notif_rotation = False

@ -14,7 +14,7 @@
# limitations under the License.
from ._base import SQLBaseStore
from twisted.internet import defer, reactor
from twisted.internet import defer
from synapse.events import FrozenEvent
from synapse.events.utils import prune_event
@ -265,7 +265,7 @@ class EventsWorkerStore(SQLBaseStore):
except Exception:
logger.exception("Failed to callback")
with PreserveLoggingContext():
reactor.callFromThread(fire, event_list, row_dict)
self.hs.get_reactor().callFromThread(fire, event_list, row_dict)
except Exception as e:
logger.exception("do_fetch")
@ -278,7 +278,7 @@ class EventsWorkerStore(SQLBaseStore):
if event_list:
with PreserveLoggingContext():
reactor.callFromThread(fire, event_list)
self.hs.get_reactor().callFromThread(fire, event_list)
@defer.inlineCallbacks
def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):

@ -13,15 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.util.logcontext import PreserveLoggingContext
from twisted.internet import defer, reactor, task
import time
import logging
from itertools import islice
import attr
from twisted.internet import defer, task
from synapse.util.logcontext import PreserveLoggingContext
logger = logging.getLogger(__name__)
@ -31,16 +30,24 @@ def unwrapFirstError(failure):
return failure.value.subFailure
@attr.s
class Clock(object):
"""A small utility that obtains current time-of-day so that time may be
mocked during unit-tests.
TODO(paul): Also move the sleep() functionality into it
"""
A Clock wraps a Twisted reactor and provides utilities on top of it.
"""
_reactor = attr.ib()
@defer.inlineCallbacks
def sleep(self, seconds):
d = defer.Deferred()
with PreserveLoggingContext():
self._reactor.callLater(seconds, d.callback, seconds)
res = yield d
defer.returnValue(res)
def time(self):
"""Returns the current system time in seconds since epoch."""
return time.time()
return self._reactor.seconds()
def time_msec(self):
"""Returns the current system time in miliseconds since epoch."""
@ -56,6 +63,7 @@ class Clock(object):
msec(float): How long to wait between calls in milliseconds.
"""
call = task.LoopingCall(f)
call.clock = self._reactor
call.start(msec / 1000.0, now=False)
return call
@ -73,7 +81,7 @@ class Clock(object):
callback(*args, **kwargs)
with PreserveLoggingContext():
return reactor.callLater(delay, wrapped_callback, *args, **kwargs)
return self._reactor.callLater(delay, wrapped_callback, *args, **kwargs)
def cancel_call_later(self, timer, ignore_errs=False):
try:

@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer, reactor
from twisted.internet import defer
from twisted.internet.defer import CancelledError
from twisted.python import failure
from .logcontext import (
PreserveLoggingContext, make_deferred_yieldable, run_in_background
)
from synapse.util import logcontext, unwrapFirstError
from synapse.util import logcontext, unwrapFirstError, Clock
from contextlib import contextmanager
@ -31,15 +31,6 @@ from six.moves import range
logger = logging.getLogger(__name__)
@defer.inlineCallbacks
def sleep(seconds):
d = defer.Deferred()
with PreserveLoggingContext():
reactor.callLater(seconds, d.callback, seconds)
res = yield d
defer.returnValue(res)
class ObservableDeferred(object):
"""Wraps a deferred object so that we can add observer deferreds. These
observer deferreds do not affect the callback chain of the original
@ -172,13 +163,18 @@ class Linearizer(object):
# do some work.
"""
def __init__(self, name=None):
def __init__(self, name=None, clock=None):
if name is None:
self.name = id(self)
else:
self.name = name
self.key_to_defer = {}
if not clock:
from twisted.internet import reactor
clock = Clock(reactor)
self._clock = clock
@defer.inlineCallbacks
def queue(self, key):
# If there is already a deferred in the queue, we pull it out so that
@ -219,7 +215,7 @@ class Linearizer(object):
# the context manager, but it needs to happen while we hold the
# lock, and the context manager's exit code must be synchronous,
# so actually this is the only sensible place.
yield sleep(0)
yield self._clock.sleep(0)
else:
logger.info("Acquired uncontended linearizer lock %r for key %r",
@ -396,7 +392,7 @@ class DeferredTimeoutError(Exception):
"""
def add_timeout_to_deferred(deferred, timeout, on_timeout_cancel=None):
def add_timeout_to_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
"""
Add a timeout to a deferred by scheduling it to be cancelled after
timeout seconds.
@ -411,6 +407,7 @@ def add_timeout_to_deferred(deferred, timeout, on_timeout_cancel=None):
Args:
deferred (defer.Deferred): deferred to be timed out
timeout (Number): seconds to time out after
reactor (twisted.internet.reactor): the Twisted reactor to use
on_timeout_cancel (callable): A callable which is called immediately
after the deferred times out, and not if this deferred is

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import threads, reactor
from twisted.internet import threads
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
@ -27,6 +27,7 @@ class BackgroundFileConsumer(object):
Args:
file_obj (file): The file like object to write to. Closed when
finished.
reactor (twisted.internet.reactor): the Twisted reactor to use
"""
# For PushProducers pause if we have this many unwritten slices
@ -34,9 +35,11 @@ class BackgroundFileConsumer(object):
# And resume once the size of the queue is less than this
_RESUME_ON_QUEUE_SIZE = 2
def __init__(self, file_obj):
def __init__(self, file_obj, reactor):
self._file_obj = file_obj
self._reactor = reactor
# Producer we're registered with
self._producer = None
@ -71,7 +74,10 @@ class BackgroundFileConsumer(object):
self._producer = producer
self.streaming = streaming
self._finished_deferred = run_in_background(
threads.deferToThread, self._writer
threads.deferToThreadPool,
self._reactor,
self._reactor.getThreadPool(),
self._writer,
)
if not streaming:
self._producer.resumeProducing()
@ -109,7 +115,7 @@ class BackgroundFileConsumer(object):
# producer.
if self._producer and self._paused_producer:
if self._bytes_queue.qsize() <= self._RESUME_ON_QUEUE_SIZE:
reactor.callFromThread(self._resume_paused_producer)
self._reactor.callFromThread(self._resume_paused_producer)
bytes = self._bytes_queue.get()
@ -121,7 +127,7 @@ class BackgroundFileConsumer(object):
# If its a pull producer then we need to explicitly ask for
# more stuff.
if not self.streaming and self._producer:
reactor.callFromThread(self._producer.resumeProducing)
self._reactor.callFromThread(self._producer.resumeProducing)
except Exception as e:
self._write_exception = e
raise

@ -17,7 +17,6 @@ from twisted.internet import defer
from synapse.api.errors import LimitExceededError
from synapse.util.async import sleep
from synapse.util.logcontext import (
run_in_background, make_deferred_yieldable,
PreserveLoggingContext,
@ -153,7 +152,7 @@ class _PerHostRatelimiter(object):
"Ratelimit [%s]: sleeping req",
id(request_id),
)
ret_defer = run_in_background(sleep, self.sleep_msec / 1000.0)
ret_defer = run_in_background(self.clock.sleep, self.sleep_msec / 1000.0)
self.sleeping_requests.add(request_id)

@ -19,10 +19,10 @@ import signedjson.sign
from mock import Mock
from synapse.api.errors import SynapseError
from synapse.crypto import keyring
from synapse.util import async, logcontext
from synapse.util import logcontext, Clock
from synapse.util.logcontext import LoggingContext
from tests import unittest, utils
from twisted.internet import defer
from twisted.internet import defer, reactor
class MockPerspectiveServer(object):
@ -118,6 +118,7 @@ class KeyringTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_verify_json_objects_for_server_awaits_previous_requests(self):
clock = Clock(reactor)
key1 = signedjson.key.generate_signing_key(1)
kr = keyring.Keyring(self.hs)
@ -167,7 +168,7 @@ class KeyringTestCase(unittest.TestCase):
# wait a tick for it to send the request to the perspectives server
# (it first tries the datastore)
yield async.sleep(1) # XXX find out why this takes so long!
yield clock.sleep(1) # XXX find out why this takes so long!
self.http_client.post_json.assert_called_once()
self.assertIs(LoggingContext.current_context(), context_11)
@ -183,7 +184,7 @@ class KeyringTestCase(unittest.TestCase):
res_deferreds_2 = kr.verify_json_objects_for_server(
[("server10", json1)],
)
yield async.sleep(1)
yield clock.sleep(1)
self.http_client.post_json.assert_not_called()
res_deferreds_2[0].addBoth(self.check_context, None)

@ -1,9 +1,9 @@
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.rest.client.transactions import CLEANUP_PERIOD_MS
from twisted.internet import defer
from twisted.internet import defer, reactor
from mock import Mock, call
from synapse.util import async
from synapse.util import Clock
from synapse.util.logcontext import LoggingContext
from tests import unittest
from tests.utils import MockClock
@ -46,7 +46,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
def test_logcontexts_with_async_result(self):
@defer.inlineCallbacks
def cb():
yield async.sleep(0)
yield Clock(reactor).sleep(0)
defer.returnValue("yay")
@defer.inlineCallbacks

@ -14,7 +14,7 @@
# limitations under the License.
from twisted.internet import defer
from twisted.internet import defer, reactor
from synapse.rest.media.v1._base import FileInfo
from synapse.rest.media.v1.media_storage import MediaStorage
@ -38,6 +38,7 @@ class MediaStorageTests(unittest.TestCase):
self.secondary_base_path = os.path.join(self.test_dir, "secondary")
hs = Mock()
hs.get_reactor = Mock(return_value=reactor)
hs.config.media_store_path = self.primary_base_path
storage_providers = [FileStorageProviderBackend(
@ -46,7 +47,7 @@ class MediaStorageTests(unittest.TestCase):
self.filepaths = MediaFilePaths(self.primary_base_path)
self.media_storage = MediaStorage(
self.primary_base_path, self.filepaths, storage_providers,
hs, self.primary_base_path, self.filepaths, storage_providers,
)
def tearDown(self):

@ -30,7 +30,7 @@ class FileConsumerTests(unittest.TestCase):
@defer.inlineCallbacks
def test_pull_consumer(self):
string_file = StringIO()
consumer = BackgroundFileConsumer(string_file)
consumer = BackgroundFileConsumer(string_file, reactor=reactor)
try:
producer = DummyPullProducer()
@ -54,7 +54,7 @@ class FileConsumerTests(unittest.TestCase):
@defer.inlineCallbacks
def test_push_consumer(self):
string_file = BlockingStringWrite()
consumer = BackgroundFileConsumer(string_file)
consumer = BackgroundFileConsumer(string_file, reactor=reactor)
try:
producer = NonCallableMock(spec_set=[])
@ -80,7 +80,7 @@ class FileConsumerTests(unittest.TestCase):
@defer.inlineCallbacks
def test_push_producer_feedback(self):
string_file = BlockingStringWrite()
consumer = BackgroundFileConsumer(string_file)
consumer = BackgroundFileConsumer(string_file, reactor=reactor)
try:
producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"])

@ -12,10 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.util import async, logcontext
from synapse.util import logcontext, Clock
from tests import unittest
from twisted.internet import defer
from twisted.internet import defer, reactor
from synapse.util.async import Linearizer
from six.moves import range
@ -53,7 +54,7 @@ class LinearizerTestCase(unittest.TestCase):
self.assertEqual(
logcontext.LoggingContext.current_context(), lc)
if sleep:
yield async.sleep(0)
yield Clock(reactor).sleep(0)
self.assertEqual(
logcontext.LoggingContext.current_context(), lc)

@ -3,8 +3,7 @@ from twisted.internet import defer
from twisted.internet import reactor
from .. import unittest
from synapse.util.async import sleep
from synapse.util import logcontext
from synapse.util import logcontext, Clock
from synapse.util.logcontext import LoggingContext
@ -22,18 +21,20 @@ class LoggingContextTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_sleep(self):
clock = Clock(reactor)
@defer.inlineCallbacks
def competing_callback():
with LoggingContext() as competing_context:
competing_context.request = "competing"
yield sleep(0)
yield clock.sleep(0)
self._check_test_key("competing")
reactor.callLater(0, competing_callback)
with LoggingContext() as context_one:
context_one.request = "one"
yield sleep(0)
yield clock.sleep(0)
self._check_test_key("one")
def _test_run_in_background(self, function):
@ -87,7 +88,7 @@ class LoggingContextTestCase(unittest.TestCase):
def test_run_in_background_with_blocking_fn(self):
@defer.inlineCallbacks
def blocking_function():
yield sleep(0)
yield Clock(reactor).sleep(0)
return self._test_run_in_background(blocking_function)

@ -37,11 +37,15 @@ USE_POSTGRES_FOR_TESTS = False
@defer.inlineCallbacks
def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None,
**kargs):
"""Setup a homeserver suitable for running tests against. Keyword arguments
are passed to the Homeserver constructor. If no datastore is supplied a
datastore backed by an in-memory sqlite db will be given to the HS.
"""
if reactor is None:
from twisted.internet import reactor
if config is None:
config = Mock()
config.signing_key = [MockKey()]
@ -110,6 +114,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
database_engine=db_engine,
room_list_handler=object(),
tls_server_context_factory=Mock(),
reactor=reactor,
**kargs
)
db_conn = hs.get_db_conn()

Loading…
Cancel
Save