@ -11,9 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import traceback
from twisted . internet import defer
from twisted . internet . defer import CancelledError , Deferred
from twisted . internet . defer import CancelledError , Deferred , ensureDeferred
from twisted . internet . task import Clock
from twisted . python . failure import Failure
from synapse . logging . context import (
SENTINEL_CONTEXT ,
@ -21,7 +24,11 @@ from synapse.logging.context import (
PreserveLoggingContext ,
current_context ,
)
from synapse . util . async_helpers import ObservableDeferred , timeout_deferred
from synapse . util . async_helpers import (
ObservableDeferred ,
concurrently_execute ,
timeout_deferred ,
)
from tests . unittest import TestCase
@ -171,3 +178,107 @@ class TimeoutDeferredTest(TestCase):
)
self . failureResultOf ( timing_out_d , defer . TimeoutError )
self . assertIs ( current_context ( ) , context_one )
class _TestException ( Exception ) :
pass
class ConcurrentlyExecuteTest ( TestCase ) :
def test_limits_runners ( self ) :
""" If we have more tasks than runners, we should get the limit of runners """
started = 0
waiters = [ ]
processed = [ ]
async def callback ( v ) :
# when we first enter, bump the start count
nonlocal started
started + = 1
# record the fact we got an item
processed . append ( v )
# wait for the goahead before returning
d2 = Deferred ( )
waiters . append ( d2 )
await d2
# set it going
d2 = ensureDeferred ( concurrently_execute ( callback , [ 1 , 2 , 3 , 4 , 5 ] , 3 ) )
# check we got exactly 3 processes
self . assertEqual ( started , 3 )
self . assertEqual ( len ( waiters ) , 3 )
# let one finish
waiters . pop ( ) . callback ( 0 )
# ... which should start another
self . assertEqual ( started , 4 )
self . assertEqual ( len ( waiters ) , 3 )
# we still shouldn't be done
self . assertNoResult ( d2 )
# finish the job
while waiters :
waiters . pop ( ) . callback ( 0 )
# check everything got done
self . assertEqual ( started , 5 )
self . assertCountEqual ( processed , [ 1 , 2 , 3 , 4 , 5 ] )
self . successResultOf ( d2 )
def test_preserves_stacktraces ( self ) :
""" Test that the stacktrace from an exception thrown in the callback is preserved """
d1 = Deferred ( )
async def callback ( v ) :
# alas, this doesn't work at all without an await here
await d1
raise _TestException ( " bah " )
async def caller ( ) :
try :
await concurrently_execute ( callback , [ 1 ] , 2 )
except _TestException as e :
tb = traceback . extract_tb ( e . __traceback__ )
# we expect to see "caller", "concurrently_execute" and "callback".
self . assertEqual ( tb [ 0 ] . name , " caller " )
self . assertEqual ( tb [ 1 ] . name , " concurrently_execute " )
self . assertEqual ( tb [ - 1 ] . name , " callback " )
else :
self . fail ( " No exception thrown " )
d2 = ensureDeferred ( caller ( ) )
d1 . callback ( 0 )
self . successResultOf ( d2 )
def test_preserves_stacktraces_on_preformed_failure ( self ) :
""" Test that the stacktrace on a Failure returned by the callback is preserved """
d1 = Deferred ( )
f = Failure ( _TestException ( " bah " ) )
async def callback ( v ) :
# alas, this doesn't work at all without an await here
await d1
await defer . fail ( f )
async def caller ( ) :
try :
await concurrently_execute ( callback , [ 1 ] , 2 )
except _TestException as e :
tb = traceback . extract_tb ( e . __traceback__ )
# we expect to see "caller", "concurrently_execute", "callback",
# and some magic from inside ensureDeferred that happens when .fail
# is called.
self . assertEqual ( tb [ 0 ] . name , " caller " )
self . assertEqual ( tb [ 1 ] . name , " concurrently_execute " )
self . assertEqual ( tb [ - 2 ] . name , " callback " )
else :
self . fail ( " No exception thrown " )
d2 = ensureDeferred ( caller ( ) )
d1 . callback ( 0 )
self . successResultOf ( d2 )