mirror of https://github.com/watcha-fr/synapse
Clean up caching/locking of OIDC metadata load (#9362)
Ensure that we lock correctly to prevent multiple concurrent metadata load requests, and generally clean up the way we construct the metadata cache.code_spécifique_watcha
parent
0ad087273c
commit
3b754aea27
@ -0,0 +1 @@ |
||||
Clean up the code to load the metadata for OpenID Connect identity providers. |
@ -0,0 +1,129 @@ |
||||
# -*- coding: utf-8 -*- |
||||
# Copyright 2021 The Matrix.org Foundation C.I.C. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
from typing import Awaitable, Callable, Generic, Optional, TypeVar, Union |
||||
|
||||
from twisted.internet.defer import Deferred |
||||
from twisted.python.failure import Failure |
||||
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background |
||||
|
||||
TV = TypeVar("TV") |
||||
|
||||
|
||||
class CachedCall(Generic[TV]): |
||||
"""A wrapper for asynchronous calls whose results should be shared |
||||
|
||||
This is useful for wrapping asynchronous functions, where there might be multiple |
||||
callers, but we only want to call the underlying function once (and have the result |
||||
returned to all callers). |
||||
|
||||
Similar results can be achieved via a lock of some form, but that typically requires |
||||
more boilerplate (and ends up being less efficient). |
||||
|
||||
Correctly handles Synapse logcontexts (logs and resource usage for the underlying |
||||
function are logged against the logcontext which is active when get() is first |
||||
called). |
||||
|
||||
Example usage: |
||||
|
||||
_cached_val = CachedCall(_load_prop) |
||||
|
||||
async def handle_request() -> X: |
||||
# We can call this multiple times, but it will result in a single call to |
||||
# _load_prop(). |
||||
return await _cached_val.get() |
||||
|
||||
async def _load_prop() -> X: |
||||
await difficult_operation() |
||||
|
||||
|
||||
The implementation is deliberately single-shot (ie, once the call is initiated, |
||||
there is no way to ask for it to be run). This keeps the implementation and |
||||
semantics simple. If you want to make a new call, simply replace the whole |
||||
CachedCall object. |
||||
""" |
||||
|
||||
__slots__ = ["_callable", "_deferred", "_result"] |
||||
|
||||
def __init__(self, f: Callable[[], Awaitable[TV]]): |
||||
""" |
||||
Args: |
||||
f: The underlying function. Only one call to this function will be alive |
||||
at once (per instance of CachedCall) |
||||
""" |
||||
self._callable = f # type: Optional[Callable[[], Awaitable[TV]]] |
||||
self._deferred = None # type: Optional[Deferred] |
||||
self._result = None # type: Union[None, Failure, TV] |
||||
|
||||
async def get(self) -> TV: |
||||
"""Kick off the call if necessary, and return the result""" |
||||
|
||||
# Fire off the callable now if this is our first time |
||||
if not self._deferred: |
||||
self._deferred = run_in_background(self._callable) |
||||
|
||||
# we will never need the callable again, so make sure it can be GCed |
||||
self._callable = None |
||||
|
||||
# once the deferred completes, store the result. We cannot simply leave the |
||||
# result in the deferred, since if it's a Failure, GCing the deferred |
||||
# would then log a critical error about unhandled Failures. |
||||
def got_result(r): |
||||
self._result = r |
||||
|
||||
self._deferred.addBoth(got_result) |
||||
|
||||
# TODO: consider cancellation semantics. Currently, if the call to get() |
||||
# is cancelled, the underlying call will continue (and any future calls |
||||
# will get the result/exception), which I think is *probably* ok, modulo |
||||
# the fact the underlying call may be logged to a cancelled logcontext, |
||||
# and any eventual exception may not be reported. |
||||
|
||||
# we can now await the deferred, and once it completes, return the result. |
||||
await make_deferred_yieldable(self._deferred) |
||||
|
||||
# I *think* this is the easiest way to correctly raise a Failure without having |
||||
# to gut-wrench into the implementation of Deferred. |
||||
d = Deferred() |
||||
d.callback(self._result) |
||||
return await d |
||||
|
||||
|
||||
class RetryOnExceptionCachedCall(Generic[TV]): |
||||
"""A wrapper around CachedCall which will retry the call if an exception is thrown |
||||
|
||||
This is used in much the same way as CachedCall, but adds some extra functionality |
||||
so that if the underlying function throws an exception, then the next call to get() |
||||
will initiate another call to the underlying function. (Any calls to get() which |
||||
are already pending will raise the exception.) |
||||
""" |
||||
|
||||
slots = ["_cachedcall"] |
||||
|
||||
def __init__(self, f: Callable[[], Awaitable[TV]]): |
||||
async def _wrapper() -> TV: |
||||
try: |
||||
return await f() |
||||
except Exception: |
||||
# the call raised an exception: replace the underlying CachedCall to |
||||
# trigger another call next time get() is called |
||||
self._cachedcall = CachedCall(_wrapper) |
||||
raise |
||||
|
||||
self._cachedcall = CachedCall(_wrapper) |
||||
|
||||
async def get(self) -> TV: |
||||
return await self._cachedcall.get() |
@ -0,0 +1,161 @@ |
||||
# -*- coding: utf-8 -*- |
||||
# Copyright 2021 The Matrix.org Foundation C.I.C. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
from unittest.mock import Mock |
||||
|
||||
from twisted.internet import defer |
||||
from twisted.internet.defer import Deferred |
||||
|
||||
from synapse.util.caches.cached_call import CachedCall, RetryOnExceptionCachedCall |
||||
|
||||
from tests.test_utils import get_awaitable_result |
||||
from tests.unittest import TestCase |
||||
|
||||
|
||||
class CachedCallTestCase(TestCase): |
||||
def test_get(self): |
||||
""" |
||||
Happy-path test case: makes a couple of calls and makes sure they behave |
||||
correctly |
||||
""" |
||||
d = Deferred() |
||||
|
||||
async def f(): |
||||
return await d |
||||
|
||||
slow_call = Mock(side_effect=f) |
||||
|
||||
cached_call = CachedCall(slow_call) |
||||
|
||||
# the mock should not yet have been called |
||||
slow_call.assert_not_called() |
||||
|
||||
# now fire off a couple of calls |
||||
completed_results = [] |
||||
|
||||
async def r(): |
||||
res = await cached_call.get() |
||||
completed_results.append(res) |
||||
|
||||
r1 = defer.ensureDeferred(r()) |
||||
r2 = defer.ensureDeferred(r()) |
||||
|
||||
# neither result should be complete yet |
||||
self.assertNoResult(r1) |
||||
self.assertNoResult(r2) |
||||
|
||||
# and the mock should have been called *once*, with no params |
||||
slow_call.assert_called_once_with() |
||||
|
||||
# allow the deferred to complete, which should complete both the pending results |
||||
d.callback(123) |
||||
self.assertEqual(completed_results, [123, 123]) |
||||
self.successResultOf(r1) |
||||
self.successResultOf(r2) |
||||
|
||||
# another call to the getter should complete immediately |
||||
slow_call.reset_mock() |
||||
r3 = get_awaitable_result(cached_call.get()) |
||||
self.assertEqual(r3, 123) |
||||
slow_call.assert_not_called() |
||||
|
||||
def test_fast_call(self): |
||||
""" |
||||
Test the behaviour when the underlying function completes immediately |
||||
""" |
||||
|
||||
async def f(): |
||||
return 12 |
||||
|
||||
fast_call = Mock(side_effect=f) |
||||
cached_call = CachedCall(fast_call) |
||||
|
||||
# the mock should not yet have been called |
||||
fast_call.assert_not_called() |
||||
|
||||
# run the call a couple of times, which should complete immediately |
||||
self.assertEqual(get_awaitable_result(cached_call.get()), 12) |
||||
self.assertEqual(get_awaitable_result(cached_call.get()), 12) |
||||
|
||||
# the mock should have been called once |
||||
fast_call.assert_called_once_with() |
||||
|
||||
|
||||
class RetryOnExceptionCachedCallTestCase(TestCase): |
||||
def test_get(self): |
||||
# set up the RetryOnExceptionCachedCall around a function which will fail |
||||
# (after a while) |
||||
d = Deferred() |
||||
|
||||
async def f1(): |
||||
await d |
||||
raise ValueError("moo") |
||||
|
||||
slow_call = Mock(side_effect=f1) |
||||
cached_call = RetryOnExceptionCachedCall(slow_call) |
||||
|
||||
# the mock should not yet have been called |
||||
slow_call.assert_not_called() |
||||
|
||||
# now fire off a couple of calls |
||||
completed_results = [] |
||||
|
||||
async def r(): |
||||
try: |
||||
await cached_call.get() |
||||
except Exception as e1: |
||||
completed_results.append(e1) |
||||
|
||||
r1 = defer.ensureDeferred(r()) |
||||
r2 = defer.ensureDeferred(r()) |
||||
|
||||
# neither result should be complete yet |
||||
self.assertNoResult(r1) |
||||
self.assertNoResult(r2) |
||||
|
||||
# and the mock should have been called *once*, with no params |
||||
slow_call.assert_called_once_with() |
||||
|
||||
# complete the deferred, which should make the pending calls fail |
||||
d.callback(0) |
||||
self.assertEqual(len(completed_results), 2) |
||||
for e in completed_results: |
||||
self.assertIsInstance(e, ValueError) |
||||
self.assertEqual(e.args, ("moo",)) |
||||
|
||||
# reset the mock to return a successful result, and make another pair of calls |
||||
# to the getter |
||||
d = Deferred() |
||||
|
||||
async def f2(): |
||||
return await d |
||||
|
||||
slow_call.reset_mock() |
||||
slow_call.side_effect = f2 |
||||
r3 = defer.ensureDeferred(cached_call.get()) |
||||
r4 = defer.ensureDeferred(cached_call.get()) |
||||
|
||||
self.assertNoResult(r3) |
||||
self.assertNoResult(r4) |
||||
slow_call.assert_called_once_with() |
||||
|
||||
# let that call complete, and check the results |
||||
d.callback(123) |
||||
self.assertEqual(self.successResultOf(r3), 123) |
||||
self.assertEqual(self.successResultOf(r4), 123) |
||||
|
||||
# and now more calls to the getter should complete immediately |
||||
slow_call.reset_mock() |
||||
self.assertEqual(get_awaitable_result(cached_call.get()), 123) |
||||
slow_call.assert_not_called() |
Loading…
Reference in new issue