@ -14,15 +14,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import enum
import threading
from typing import (
Callable ,
Collection ,
Dict ,
Generic ,
Iterable ,
MutableMapping ,
Optional ,
Set ,
Sized ,
Tuple ,
TypeVar ,
Union ,
cast ,
@ -31,7 +35,6 @@ from typing import (
from prometheus_client import Gauge
from twisted . internet import defer
from twisted . python import failure
from twisted . python . failure import Failure
from synapse . util . async_helpers import ObservableDeferred
@ -94,7 +97,7 @@ class DeferredCache(Generic[KT, VT]):
# _pending_deferred_cache maps from the key value to a `CacheEntry` object.
self . _pending_deferred_cache : Union [
TreeCache , " MutableMapping[KT, CacheEntry] "
TreeCache , " MutableMapping[KT, CacheEntry[KT, VT] ] "
] = cache_type ( )
def metrics_cb ( ) - > None :
@ -159,15 +162,16 @@ class DeferredCache(Generic[KT, VT]):
Raises :
KeyError if the key is not found in the cache
"""
callbacks = [ callback ] if callback else [ ]
val = self . _pending_deferred_cache . get ( key , _Sentinel . sentinel )
if val is not _Sentinel . sentinel :
val . callbacks . update ( callbacks )
val . add_invalidation_callback ( key , callback )
if update_metrics :
m = self . cache . metrics
assert m # we always have a name, so should always have metrics
m . inc_hits ( )
return val . deferred . observe ( )
return val . deferred ( key )
callbacks = ( callback , ) if callback else ( )
val2 = self . cache . get (
key , _Sentinel . sentinel , callbacks = callbacks , update_metrics = update_metrics
@ -177,6 +181,73 @@ class DeferredCache(Generic[KT, VT]):
else :
return defer . succeed ( val2 )
def get_bulk (
self ,
keys : Collection [ KT ] ,
callback : Optional [ Callable [ [ ] , None ] ] = None ,
) - > Tuple [ Dict [ KT , VT ] , Optional [ " defer.Deferred[Dict[KT, VT]] " ] , Collection [ KT ] ] :
""" Bulk lookup of items in the cache.
Returns :
A 3 - tuple of :
1. a dict of key / value of items already cached ;
2. a deferred that resolves to a dict of key / value of items
we ' re already fetching; and
3. a collection of keys that don ' t appear in the previous two.
"""
# The cached results
cached = { }
# List of pending deferreds
pending = [ ]
# Dict that gets filled out when the pending deferreds complete
pending_results = { }
# List of keys that aren't in either cache
missing = [ ]
callbacks = ( callback , ) if callback else ( )
for key in keys :
# Check if its in the main cache.
immediate_value = self . cache . get (
key ,
_Sentinel . sentinel ,
callbacks = callbacks ,
)
if immediate_value is not _Sentinel . sentinel :
cached [ key ] = immediate_value
continue
# Check if its in the pending cache
pending_value = self . _pending_deferred_cache . get ( key , _Sentinel . sentinel )
if pending_value is not _Sentinel . sentinel :
pending_value . add_invalidation_callback ( key , callback )
def completed_cb ( value : VT , key : KT ) - > VT :
pending_results [ key ] = value
return value
# Add a callback to fill out `pending_results` when that completes
d = pending_value . deferred ( key ) . addCallback ( completed_cb , key )
pending . append ( d )
continue
# Not in either cache
missing . append ( key )
# If we've got pending deferreds, squash them into a single one that
# returns `pending_results`.
pending_deferred = None
if pending :
pending_deferred = defer . gatherResults (
pending , consumeErrors = True
) . addCallback ( lambda _ : pending_results )
return ( cached , pending_deferred , missing )
def get_immediate (
self , key : KT , default : T , update_metrics : bool = True
) - > Union [ VT , T ] :
@ -218,84 +289,89 @@ class DeferredCache(Generic[KT, VT]):
value : a deferred which will complete with a result to add to the cache
callback : An optional callback to be called when the entry is invalidated
"""
if not isinstance ( value , defer . Deferred ) :
raise TypeError ( " not a Deferred " )
callbacks = [ callback ] if callback else [ ]
self . check_thread ( )
existing_entry = self . _pending_deferred_cache . pop ( key , None )
if existing_entry :
existing_entry . invalidate ( )
self . _pending_deferred_cache . pop ( key , None )
# XXX: why don't we invalidate the entry in `self.cache` yet?
# we can save a whole load of effort if the deferred is ready.
if value . called :
result = value . result
if not isinstance ( result , failure . Failure ) :
self . cache . set ( key , cast ( VT , result ) , callbacks )
return value
# otherwise, we'll add an entry to the _pending_deferred_cache for now,
# and add callbacks to add it to the cache properly later.
entry = CacheEntrySingle [ KT , VT ] ( value )
entry . add_invalidation_callback ( key , callback )
self . _pending_deferred_cache [ key ] = entry
deferred = entry . deferred ( key ) . addCallbacks (
self . _completed_callback ,
self . _error_callback ,
callbackArgs = ( entry , key ) ,
errbackArgs = ( entry , key ) ,
)
observable = ObservableDeferred ( value , consumeErrors = True )
observer = observable . observe ( )
entry = CacheEntry ( deferred = observable , callbacks = callbacks )
# we return a new Deferred which will be called before any subsequent observers.
return deferred
self . _pending_deferred_cache [ key ] = entry
def start_bulk_input (
self ,
keys : Collection [ KT ] ,
callback : Optional [ Callable [ [ ] , None ] ] = None ,
) - > " CacheMultipleEntries[KT, VT] " :
""" Bulk set API for use when fetching multiple keys at once from the DB.
def compare_and_pop ( ) - > bool :
""" Check if our entry is still the one in _pending_deferred_cache, and
if so , pop it .
Returns true if the entries matched .
"""
existing_entry = self . _pending_deferred_cache . pop ( key , None )
if existing_entry is entry :
return True
# oops, the _pending_deferred_cache has been updated since
# we started our query, so we are out of date.
#
# Better put back whatever we took out. (We do it this way
# round, rather than peeking into the _pending_deferred_cache
# and then removing on a match, to make the common case faster)
if existing_entry is not None :
self . _pending_deferred_cache [ key ] = existing_entry
return False
def cb ( result : VT ) - > None :
if compare_and_pop ( ) :
self . cache . set ( key , result , entry . callbacks )
else :
# we're not going to put this entry into the cache, so need
# to make sure that the invalidation callbacks are called.
# That was probably done when _pending_deferred_cache was
# updated, but it's possible that `set` was called without
# `invalidate` being previously called, in which case it may
# not have been. Either way, let's double-check now.
entry . invalidate ( )
def eb ( _fail : Failure ) - > None :
compare_and_pop ( )
entry . invalidate ( )
# once the deferred completes, we can move the entry from the
# _pending_deferred_cache to the real cache.
#
observer . addCallbacks ( cb , eb )
Called * before * starting the fetch from the DB , and the caller * must *
call either ` complete_bulk ( . . ) ` or ` error_bulk ( . . ) ` on the return value .
"""
# we return a new Deferred which will be called before any subsequent observers.
return observable . observe ( )
entry = CacheMultipleEntries [ KT , VT ] ( )
entry . add_global_invalidation_callback ( callback )
for key in keys :
self . _pending_deferred_cache [ key ] = entry
return entry
def _completed_callback (
self , value : VT , entry : " CacheEntry[KT, VT] " , key : KT
) - > VT :
""" Called when a deferred is completed. """
# We check if the current entry matches the entry associated with the
# deferred. If they don't match then it got invalidated.
current_entry = self . _pending_deferred_cache . pop ( key , None )
if current_entry is not entry :
if current_entry :
self . _pending_deferred_cache [ key ] = current_entry
return value
self . cache . set ( key , value , entry . get_invalidation_callbacks ( key ) )
return value
def _error_callback (
self ,
failure : Failure ,
entry : " CacheEntry[KT, VT] " ,
key : KT ,
) - > Failure :
""" Called when a deferred errors. """
# We check if the current entry matches the entry associated with the
# deferred. If they don't match then it got invalidated.
current_entry = self . _pending_deferred_cache . pop ( key , None )
if current_entry is not entry :
if current_entry :
self . _pending_deferred_cache [ key ] = current_entry
return failure
for cb in entry . get_invalidation_callbacks ( key ) :
cb ( )
return failure
def prefill (
self , key : KT , value : VT , callback : Optional [ Callable [ [ ] , None ] ] = None
) - > None :
callbacks = [ callback ] if callback else [ ]
callbacks = ( callback , ) if callback else ( )
self . cache . set ( key , value , callbacks = callbacks )
self . _pending_deferred_cache . pop ( key , None )
def invalidate ( self , key : KT ) - > None :
""" Delete a key, or tree of entries
@ -311,41 +387,129 @@ class DeferredCache(Generic[KT, VT]):
self . cache . del_multi ( key )
# if we have a pending lookup for this key, remove it from the
# _pending_deferred_cache, which will (a) stop it being returned
# for f uture queries and (b) stop it being persisted as a proper entry
# _pending_deferred_cache, which will (a) stop it being returned for
# future queries and (b) stop it being persisted as a proper entry
# in self.cache.
entry = self . _pending_deferred_cache . pop ( key , None )
# run the invalidation callbacks now, rather than waiting for the
# deferred to resolve.
if entry :
# _pending_deferred_cache.pop should either return a CacheEntry, or, in the
# case of a TreeCache, a dict of keys to cache entries. Either way calling
# iterate_tree_cache_entry on it will do the right thing.
for entry in iterate_tree_cache_entry ( entry ) :
entry . invalidate ( )
for cb in entry . get_invalidation_callbacks ( key ) :
cb ( )
def invalidate_all ( self ) - > None :
self . check_thread ( )
self . cache . clear ( )
for entry in self . _pending_deferred_cache . values ( ) :
entry . invalidate ( )
for key , entry in self . _pending_deferred_cache . items ( ) :
for cb in entry . get_invalidation_callbacks ( key ) :
cb ( )
self . _pending_deferred_cache . clear ( )
class CacheEntry :
__slots__ = [ " deferred " , " callbacks " , " invalidated " ]
class CacheEntry ( Generic [ KT , VT ] , metaclass = abc . ABCMeta ) :
""" Abstract class for entries in `DeferredCache[KT, VT]` """
def __init__ (
self , deferred : ObservableDeferred , callbacks : Iterable [ Callable [ [ ] , None ] ]
) :
self . deferred = deferred
self . callbacks = set ( callbacks )
self . invalidated = False
def invalidate ( self ) - > None :
if not self . invalidated :
self . invalidated = True
for callback in self . callbacks :
callback ( )
self . callbacks . clear ( )
@abc . abstractmethod
def deferred ( self , key : KT ) - > " defer.Deferred[VT] " :
""" Get a deferred that a caller can wait on to get the value at the
given key """
. . .
@abc . abstractmethod
def add_invalidation_callback (
self , key : KT , callback : Optional [ Callable [ [ ] , None ] ]
) - > None :
""" Add an invalidation callback """
. . .
@abc . abstractmethod
def get_invalidation_callbacks ( self , key : KT ) - > Collection [ Callable [ [ ] , None ] ] :
""" Get all invalidation callbacks """
. . .
class CacheEntrySingle ( CacheEntry [ KT , VT ] ) :
""" An implementation of `CacheEntry` wrapping a deferred that results in a
single cache entry .
"""
__slots__ = [ " _deferred " , " _callbacks " ]
def __init__ ( self , deferred : " defer.Deferred[VT] " ) - > None :
self . _deferred = ObservableDeferred ( deferred , consumeErrors = True )
self . _callbacks : Set [ Callable [ [ ] , None ] ] = set ( )
def deferred ( self , key : KT ) - > " defer.Deferred[VT] " :
return self . _deferred . observe ( )
def add_invalidation_callback (
self , key : KT , callback : Optional [ Callable [ [ ] , None ] ]
) - > None :
if callback is None :
return
self . _callbacks . add ( callback )
def get_invalidation_callbacks ( self , key : KT ) - > Collection [ Callable [ [ ] , None ] ] :
return self . _callbacks
class CacheMultipleEntries ( CacheEntry [ KT , VT ] ) :
""" Cache entry that is used for bulk lookups and insertions. """
__slots__ = [ " _deferred " , " _callbacks " , " _global_callbacks " ]
def __init__ ( self ) - > None :
self . _deferred : Optional [ ObservableDeferred [ Dict [ KT , VT ] ] ] = None
self . _callbacks : Dict [ KT , Set [ Callable [ [ ] , None ] ] ] = { }
self . _global_callbacks : Set [ Callable [ [ ] , None ] ] = set ( )
def deferred ( self , key : KT ) - > " defer.Deferred[VT] " :
if not self . _deferred :
self . _deferred = ObservableDeferred ( defer . Deferred ( ) , consumeErrors = True )
return self . _deferred . observe ( ) . addCallback ( lambda res : res . get ( key ) )
def add_invalidation_callback (
self , key : KT , callback : Optional [ Callable [ [ ] , None ] ]
) - > None :
if callback is None :
return
self . _callbacks . setdefault ( key , set ( ) ) . add ( callback )
def get_invalidation_callbacks ( self , key : KT ) - > Collection [ Callable [ [ ] , None ] ] :
return self . _callbacks . get ( key , set ( ) ) | self . _global_callbacks
def add_global_invalidation_callback (
self , callback : Optional [ Callable [ [ ] , None ] ]
) - > None :
""" Add a callback for when any keys get invalidated. """
if callback is None :
return
self . _global_callbacks . add ( callback )
def complete_bulk (
self ,
cache : DeferredCache [ KT , VT ] ,
result : Dict [ KT , VT ] ,
) - > None :
""" Called when there is a result """
for key , value in result . items ( ) :
cache . _completed_callback ( value , self , key )
if self . _deferred :
self . _deferred . callback ( result )
def error_bulk (
self , cache : DeferredCache [ KT , VT ] , keys : Collection [ KT ] , failure : Failure
) - > None :
""" Called when bulk lookup failed. """
for key in keys :
cache . _error_callback ( failure , self , key )
if self . _deferred :
self . _deferred . errback ( failure )