Speed up `@cachedList` (#13591)

This speeds things up by ~2x.

The vast majority of the time is now spent in `LruCache` moving things around the linked lists.

We do this via two things:
1. Don't create a deferred per-key during bulk set operations in `DeferredCache`. Instead, only create them if a subsequent caller asks for the key.
2. Add a bulk lookup API to `DeferredCache` rather than use a loop.
1.103.0-whithout-watcha
Erik Johnston 2 years ago committed by GitHub
parent 05c9c7363b
commit f7ddfe17a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      changelog.d/13591.misc
  2. 346
      synapse/util/caches/deferred_cache.py
  3. 89
      synapse/util/caches/descriptors.py
  4. 3
      synapse/util/caches/treecache.py

@ -0,0 +1 @@
Improve performance of `@cachedList`.

@ -14,15 +14,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import enum
import threading
from typing import (
Callable,
Collection,
Dict,
Generic,
Iterable,
MutableMapping,
Optional,
Set,
Sized,
Tuple,
TypeVar,
Union,
cast,
@ -31,7 +35,6 @@ from typing import (
from prometheus_client import Gauge
from twisted.internet import defer
from twisted.python import failure
from twisted.python.failure import Failure
from synapse.util.async_helpers import ObservableDeferred
@ -94,7 +97,7 @@ class DeferredCache(Generic[KT, VT]):
# _pending_deferred_cache maps from the key value to a `CacheEntry` object.
self._pending_deferred_cache: Union[
TreeCache, "MutableMapping[KT, CacheEntry]"
TreeCache, "MutableMapping[KT, CacheEntry[KT, VT]]"
] = cache_type()
def metrics_cb() -> None:
@ -159,15 +162,16 @@ class DeferredCache(Generic[KT, VT]):
Raises:
KeyError if the key is not found in the cache
"""
callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
if val is not _Sentinel.sentinel:
val.callbacks.update(callbacks)
val.add_invalidation_callback(key, callback)
if update_metrics:
m = self.cache.metrics
assert m # we always have a name, so should always have metrics
m.inc_hits()
return val.deferred.observe()
return val.deferred(key)
callbacks = (callback,) if callback else ()
val2 = self.cache.get(
key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics
@ -177,6 +181,73 @@ class DeferredCache(Generic[KT, VT]):
else:
return defer.succeed(val2)
def get_bulk(
self,
keys: Collection[KT],
callback: Optional[Callable[[], None]] = None,
) -> Tuple[Dict[KT, VT], Optional["defer.Deferred[Dict[KT, VT]]"], Collection[KT]]:
"""Bulk lookup of items in the cache.
Returns:
A 3-tuple of:
1. a dict of key/value of items already cached;
2. a deferred that resolves to a dict of key/value of items
we're already fetching; and
3. a collection of keys that don't appear in the previous two.
"""
# The cached results
cached = {}
# List of pending deferreds
pending = []
# Dict that gets filled out when the pending deferreds complete
pending_results = {}
# List of keys that aren't in either cache
missing = []
callbacks = (callback,) if callback else ()
for key in keys:
# Check if its in the main cache.
immediate_value = self.cache.get(
key,
_Sentinel.sentinel,
callbacks=callbacks,
)
if immediate_value is not _Sentinel.sentinel:
cached[key] = immediate_value
continue
# Check if its in the pending cache
pending_value = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
if pending_value is not _Sentinel.sentinel:
pending_value.add_invalidation_callback(key, callback)
def completed_cb(value: VT, key: KT) -> VT:
pending_results[key] = value
return value
# Add a callback to fill out `pending_results` when that completes
d = pending_value.deferred(key).addCallback(completed_cb, key)
pending.append(d)
continue
# Not in either cache
missing.append(key)
# If we've got pending deferreds, squash them into a single one that
# returns `pending_results`.
pending_deferred = None
if pending:
pending_deferred = defer.gatherResults(
pending, consumeErrors=True
).addCallback(lambda _: pending_results)
return (cached, pending_deferred, missing)
def get_immediate(
self, key: KT, default: T, update_metrics: bool = True
) -> Union[VT, T]:
@ -218,84 +289,89 @@ class DeferredCache(Generic[KT, VT]):
value: a deferred which will complete with a result to add to the cache
callback: An optional callback to be called when the entry is invalidated
"""
if not isinstance(value, defer.Deferred):
raise TypeError("not a Deferred")
callbacks = [callback] if callback else []
self.check_thread()
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry:
existing_entry.invalidate()
self._pending_deferred_cache.pop(key, None)
# XXX: why don't we invalidate the entry in `self.cache` yet?
# we can save a whole load of effort if the deferred is ready.
if value.called:
result = value.result
if not isinstance(result, failure.Failure):
self.cache.set(key, cast(VT, result), callbacks)
return value
# otherwise, we'll add an entry to the _pending_deferred_cache for now,
# and add callbacks to add it to the cache properly later.
entry = CacheEntrySingle[KT, VT](value)
entry.add_invalidation_callback(key, callback)
self._pending_deferred_cache[key] = entry
deferred = entry.deferred(key).addCallbacks(
self._completed_callback,
self._error_callback,
callbackArgs=(entry, key),
errbackArgs=(entry, key),
)
observable = ObservableDeferred(value, consumeErrors=True)
observer = observable.observe()
entry = CacheEntry(deferred=observable, callbacks=callbacks)
# we return a new Deferred which will be called before any subsequent observers.
return deferred
self._pending_deferred_cache[key] = entry
def start_bulk_input(
self,
keys: Collection[KT],
callback: Optional[Callable[[], None]] = None,
) -> "CacheMultipleEntries[KT, VT]":
"""Bulk set API for use when fetching multiple keys at once from the DB.
def compare_and_pop() -> bool:
"""Check if our entry is still the one in _pending_deferred_cache, and
if so, pop it.
Returns true if the entries matched.
"""
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry is entry:
return True
# oops, the _pending_deferred_cache has been updated since
# we started our query, so we are out of date.
#
# Better put back whatever we took out. (We do it this way
# round, rather than peeking into the _pending_deferred_cache
# and then removing on a match, to make the common case faster)
if existing_entry is not None:
self._pending_deferred_cache[key] = existing_entry
return False
def cb(result: VT) -> None:
if compare_and_pop():
self.cache.set(key, result, entry.callbacks)
else:
# we're not going to put this entry into the cache, so need
# to make sure that the invalidation callbacks are called.
# That was probably done when _pending_deferred_cache was
# updated, but it's possible that `set` was called without
# `invalidate` being previously called, in which case it may
# not have been. Either way, let's double-check now.
entry.invalidate()
def eb(_fail: Failure) -> None:
compare_and_pop()
entry.invalidate()
# once the deferred completes, we can move the entry from the
# _pending_deferred_cache to the real cache.
#
observer.addCallbacks(cb, eb)
Called *before* starting the fetch from the DB, and the caller *must*
call either `complete_bulk(..)` or `error_bulk(..)` on the return value.
"""
# we return a new Deferred which will be called before any subsequent observers.
return observable.observe()
entry = CacheMultipleEntries[KT, VT]()
entry.add_global_invalidation_callback(callback)
for key in keys:
self._pending_deferred_cache[key] = entry
return entry
def _completed_callback(
self, value: VT, entry: "CacheEntry[KT, VT]", key: KT
) -> VT:
"""Called when a deferred is completed."""
# We check if the current entry matches the entry associated with the
# deferred. If they don't match then it got invalidated.
current_entry = self._pending_deferred_cache.pop(key, None)
if current_entry is not entry:
if current_entry:
self._pending_deferred_cache[key] = current_entry
return value
self.cache.set(key, value, entry.get_invalidation_callbacks(key))
return value
def _error_callback(
self,
failure: Failure,
entry: "CacheEntry[KT, VT]",
key: KT,
) -> Failure:
"""Called when a deferred errors."""
# We check if the current entry matches the entry associated with the
# deferred. If they don't match then it got invalidated.
current_entry = self._pending_deferred_cache.pop(key, None)
if current_entry is not entry:
if current_entry:
self._pending_deferred_cache[key] = current_entry
return failure
for cb in entry.get_invalidation_callbacks(key):
cb()
return failure
def prefill(
self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None
) -> None:
callbacks = [callback] if callback else []
callbacks = (callback,) if callback else ()
self.cache.set(key, value, callbacks=callbacks)
self._pending_deferred_cache.pop(key, None)
def invalidate(self, key: KT) -> None:
"""Delete a key, or tree of entries
@ -311,41 +387,129 @@ class DeferredCache(Generic[KT, VT]):
self.cache.del_multi(key)
# if we have a pending lookup for this key, remove it from the
# _pending_deferred_cache, which will (a) stop it being returned
# for future queries and (b) stop it being persisted as a proper entry
# _pending_deferred_cache, which will (a) stop it being returned for
# future queries and (b) stop it being persisted as a proper entry
# in self.cache.
entry = self._pending_deferred_cache.pop(key, None)
# run the invalidation callbacks now, rather than waiting for the
# deferred to resolve.
if entry:
# _pending_deferred_cache.pop should either return a CacheEntry, or, in the
# case of a TreeCache, a dict of keys to cache entries. Either way calling
# iterate_tree_cache_entry on it will do the right thing.
for entry in iterate_tree_cache_entry(entry):
entry.invalidate()
for cb in entry.get_invalidation_callbacks(key):
cb()
def invalidate_all(self) -> None:
self.check_thread()
self.cache.clear()
for entry in self._pending_deferred_cache.values():
entry.invalidate()
for key, entry in self._pending_deferred_cache.items():
for cb in entry.get_invalidation_callbacks(key):
cb()
self._pending_deferred_cache.clear()
class CacheEntry:
__slots__ = ["deferred", "callbacks", "invalidated"]
class CacheEntry(Generic[KT, VT], metaclass=abc.ABCMeta):
"""Abstract class for entries in `DeferredCache[KT, VT]`"""
def __init__(
self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]]
):
self.deferred = deferred
self.callbacks = set(callbacks)
self.invalidated = False
def invalidate(self) -> None:
if not self.invalidated:
self.invalidated = True
for callback in self.callbacks:
callback()
self.callbacks.clear()
@abc.abstractmethod
def deferred(self, key: KT) -> "defer.Deferred[VT]":
"""Get a deferred that a caller can wait on to get the value at the
given key"""
...
@abc.abstractmethod
def add_invalidation_callback(
self, key: KT, callback: Optional[Callable[[], None]]
) -> None:
"""Add an invalidation callback"""
...
@abc.abstractmethod
def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
"""Get all invalidation callbacks"""
...
class CacheEntrySingle(CacheEntry[KT, VT]):
"""An implementation of `CacheEntry` wrapping a deferred that results in a
single cache entry.
"""
__slots__ = ["_deferred", "_callbacks"]
def __init__(self, deferred: "defer.Deferred[VT]") -> None:
self._deferred = ObservableDeferred(deferred, consumeErrors=True)
self._callbacks: Set[Callable[[], None]] = set()
def deferred(self, key: KT) -> "defer.Deferred[VT]":
return self._deferred.observe()
def add_invalidation_callback(
self, key: KT, callback: Optional[Callable[[], None]]
) -> None:
if callback is None:
return
self._callbacks.add(callback)
def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
return self._callbacks
class CacheMultipleEntries(CacheEntry[KT, VT]):
"""Cache entry that is used for bulk lookups and insertions."""
__slots__ = ["_deferred", "_callbacks", "_global_callbacks"]
def __init__(self) -> None:
self._deferred: Optional[ObservableDeferred[Dict[KT, VT]]] = None
self._callbacks: Dict[KT, Set[Callable[[], None]]] = {}
self._global_callbacks: Set[Callable[[], None]] = set()
def deferred(self, key: KT) -> "defer.Deferred[VT]":
if not self._deferred:
self._deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
return self._deferred.observe().addCallback(lambda res: res.get(key))
def add_invalidation_callback(
self, key: KT, callback: Optional[Callable[[], None]]
) -> None:
if callback is None:
return
self._callbacks.setdefault(key, set()).add(callback)
def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
return self._callbacks.get(key, set()) | self._global_callbacks
def add_global_invalidation_callback(
self, callback: Optional[Callable[[], None]]
) -> None:
"""Add a callback for when any keys get invalidated."""
if callback is None:
return
self._global_callbacks.add(callback)
def complete_bulk(
self,
cache: DeferredCache[KT, VT],
result: Dict[KT, VT],
) -> None:
"""Called when there is a result"""
for key, value in result.items():
cache._completed_callback(value, self, key)
if self._deferred:
self._deferred.callback(result)
def error_bulk(
self, cache: DeferredCache[KT, VT], keys: Collection[KT], failure: Failure
) -> None:
"""Called when bulk lookup failed."""
for key in keys:
cache._error_callback(failure, self, key)
if self._deferred:
self._deferred.errback(failure)

@ -25,6 +25,7 @@ from typing import (
Generic,
Hashable,
Iterable,
List,
Mapping,
Optional,
Sequence,
@ -440,16 +441,6 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
list_args = arg_dict[self.list_name]
results = {}
def update_results_dict(res: Any, arg: Hashable) -> None:
results[arg] = res
# list of deferreds to wait for
cached_defers = []
missing = set()
# If the cache takes a single arg then that is used as the key,
# otherwise a tuple is used.
if num_args == 1:
@ -457,6 +448,9 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
def arg_to_cache_key(arg: Hashable) -> Hashable:
return arg
def cache_key_to_arg(key: tuple) -> Hashable:
return key
else:
keylist = list(keyargs)
@ -464,58 +458,53 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
keylist[self.list_pos] = arg
return tuple(keylist)
for arg in list_args:
try:
res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback)
if not res.called:
res.addCallback(update_results_dict, arg)
cached_defers.append(res)
else:
results[arg] = res.result
except KeyError:
missing.add(arg)
def cache_key_to_arg(key: tuple) -> Hashable:
return key[self.list_pos]
cache_keys = [arg_to_cache_key(arg) for arg in list_args]
immediate_results, pending_deferred, missing = cache.get_bulk(
cache_keys, callback=invalidate_callback
)
results = {cache_key_to_arg(key): v for key, v in immediate_results.items()}
cached_defers: List["defer.Deferred[Any]"] = []
if pending_deferred:
def update_results(r: Dict) -> None:
for k, v in r.items():
results[cache_key_to_arg(k)] = v
pending_deferred.addCallback(update_results)
cached_defers.append(pending_deferred)
if missing:
# we need a deferred for each entry in the list,
# which we put in the cache. Each deferred resolves with the
# relevant result for that key.
deferreds_map = {}
for arg in missing:
deferred: "defer.Deferred[Any]" = defer.Deferred()
deferreds_map[arg] = deferred
key = arg_to_cache_key(arg)
cached_defers.append(
cache.set(key, deferred, callback=invalidate_callback)
)
cache_entry = cache.start_bulk_input(missing, invalidate_callback)
def complete_all(res: Dict[Hashable, Any]) -> None:
# the wrapped function has completed. It returns a dict.
# We can now update our own result map, and then resolve the
# observable deferreds in the cache.
for e, d1 in deferreds_map.items():
val = res.get(e, None)
# make sure we update the results map before running the
# deferreds, because as soon as we run the last deferred, the
# gatherResults() below will complete and return the result
# dict to our caller.
results[e] = val
d1.callback(val)
missing_results = {}
for key in missing:
arg = cache_key_to_arg(key)
val = res.get(arg, None)
results[arg] = val
missing_results[key] = val
cache_entry.complete_bulk(cache, missing_results)
def errback_all(f: Failure) -> None:
# the wrapped function has failed. Propagate the failure into
# the cache, which will invalidate the entry, and cause the
# relevant cached_deferreds to fail, which will propagate the
# failure to our caller.
for d1 in deferreds_map.values():
d1.errback(f)
cache_entry.error_bulk(cache, missing, f)
args_to_call = dict(arg_dict)
args_to_call[self.list_name] = missing
args_to_call[self.list_name] = {
cache_key_to_arg(key) for key in missing
}
# dispatch the call, and attach the two handlers
defer.maybeDeferred(
missing_d = defer.maybeDeferred(
preserve_fn(self.orig), **args_to_call
).addCallbacks(complete_all, errback_all)
cached_defers.append(missing_d)
if cached_defers:
d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks(

@ -135,6 +135,9 @@ class TreeCache:
def values(self):
return iterate_tree_cache_entry(self.root)
def items(self):
return iterate_tree_cache_items((), self.root)
def __len__(self) -> int:
return self.size

Loading…
Cancel
Save