|
|
|
@ -15,6 +15,7 @@ |
|
|
|
|
|
|
|
|
|
import logging |
|
|
|
|
import time |
|
|
|
|
from typing import Any, Callable, Dict, Generic, Tuple, TypeVar, Union |
|
|
|
|
|
|
|
|
|
import attr |
|
|
|
|
from sortedcontainers import SortedList |
|
|
|
@ -23,15 +24,19 @@ from synapse.util.caches import register_cache |
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
SENTINEL = object() |
|
|
|
|
SENTINEL = object() # type: Any |
|
|
|
|
|
|
|
|
|
T = TypeVar("T") |
|
|
|
|
KT = TypeVar("KT") |
|
|
|
|
VT = TypeVar("VT") |
|
|
|
|
|
|
|
|
|
class TTLCache: |
|
|
|
|
|
|
|
|
|
class TTLCache(Generic[KT, VT]): |
|
|
|
|
"""A key/value cache implementation where each entry has its own TTL""" |
|
|
|
|
|
|
|
|
|
def __init__(self, cache_name, timer=time.time): |
|
|
|
|
def __init__(self, cache_name: str, timer: Callable[[], float] = time.time): |
|
|
|
|
# map from key to _CacheEntry |
|
|
|
|
self._data = {} |
|
|
|
|
self._data = {} # type: Dict[KT, _CacheEntry] |
|
|
|
|
|
|
|
|
|
# the _CacheEntries, sorted by expiry time |
|
|
|
|
self._expiry_list = SortedList() # type: SortedList[_CacheEntry] |
|
|
|
@ -40,26 +45,27 @@ class TTLCache: |
|
|
|
|
|
|
|
|
|
self._metrics = register_cache("ttl", cache_name, self, resizable=False) |
|
|
|
|
|
|
|
|
|
def set(self, key, value, ttl): |
|
|
|
|
def set(self, key: KT, value: VT, ttl: float) -> None: |
|
|
|
|
"""Add/update an entry in the cache |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
key: key for this entry |
|
|
|
|
value: value for this entry |
|
|
|
|
ttl (float): TTL for this entry, in seconds |
|
|
|
|
ttl: TTL for this entry, in seconds |
|
|
|
|
""" |
|
|
|
|
expiry = self._timer() + ttl |
|
|
|
|
|
|
|
|
|
self.expire() |
|
|
|
|
e = self._data.pop(key, SENTINEL) |
|
|
|
|
if e != SENTINEL: |
|
|
|
|
if e is not SENTINEL: |
|
|
|
|
assert isinstance(e, _CacheEntry) |
|
|
|
|
self._expiry_list.remove(e) |
|
|
|
|
|
|
|
|
|
entry = _CacheEntry(expiry_time=expiry, ttl=ttl, key=key, value=value) |
|
|
|
|
self._data[key] = entry |
|
|
|
|
self._expiry_list.add(entry) |
|
|
|
|
|
|
|
|
|
def get(self, key, default=SENTINEL): |
|
|
|
|
def get(self, key: KT, default: T = SENTINEL) -> Union[VT, T]: |
|
|
|
|
"""Get a value from the cache |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
@ -72,23 +78,23 @@ class TTLCache: |
|
|
|
|
""" |
|
|
|
|
self.expire() |
|
|
|
|
e = self._data.get(key, SENTINEL) |
|
|
|
|
if e == SENTINEL: |
|
|
|
|
if e is SENTINEL: |
|
|
|
|
self._metrics.inc_misses() |
|
|
|
|
if default == SENTINEL: |
|
|
|
|
if default is SENTINEL: |
|
|
|
|
raise KeyError(key) |
|
|
|
|
return default |
|
|
|
|
assert isinstance(e, _CacheEntry) |
|
|
|
|
self._metrics.inc_hits() |
|
|
|
|
return e.value |
|
|
|
|
|
|
|
|
|
def get_with_expiry(self, key): |
|
|
|
|
def get_with_expiry(self, key: KT) -> Tuple[VT, float, float]: |
|
|
|
|
"""Get a value, and its expiry time, from the cache |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
key: key to look up |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
Tuple[Any, float, float]: the value from the cache, the expiry time |
|
|
|
|
and the TTL |
|
|
|
|
A tuple of the value from the cache, the expiry time and the TTL |
|
|
|
|
|
|
|
|
|
Raises: |
|
|
|
|
KeyError if the entry is not found |
|
|
|
@ -102,7 +108,7 @@ class TTLCache: |
|
|
|
|
self._metrics.inc_hits() |
|
|
|
|
return e.value, e.expiry_time, e.ttl |
|
|
|
|
|
|
|
|
|
def pop(self, key, default=SENTINEL): |
|
|
|
|
def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]: # type: ignore |
|
|
|
|
"""Remove a value from the cache |
|
|
|
|
|
|
|
|
|
If key is in the cache, remove it and return its value, else return default. |
|
|
|
@ -118,29 +124,30 @@ class TTLCache: |
|
|
|
|
""" |
|
|
|
|
self.expire() |
|
|
|
|
e = self._data.pop(key, SENTINEL) |
|
|
|
|
if e == SENTINEL: |
|
|
|
|
if e is SENTINEL: |
|
|
|
|
self._metrics.inc_misses() |
|
|
|
|
if default == SENTINEL: |
|
|
|
|
if default is SENTINEL: |
|
|
|
|
raise KeyError(key) |
|
|
|
|
return default |
|
|
|
|
assert isinstance(e, _CacheEntry) |
|
|
|
|
self._expiry_list.remove(e) |
|
|
|
|
self._metrics.inc_hits() |
|
|
|
|
return e.value |
|
|
|
|
|
|
|
|
|
def __getitem__(self, key): |
|
|
|
|
def __getitem__(self, key: KT) -> VT: |
|
|
|
|
return self.get(key) |
|
|
|
|
|
|
|
|
|
def __delitem__(self, key): |
|
|
|
|
def __delitem__(self, key: KT) -> None: |
|
|
|
|
self.pop(key) |
|
|
|
|
|
|
|
|
|
def __contains__(self, key): |
|
|
|
|
def __contains__(self, key: KT) -> bool: |
|
|
|
|
return key in self._data |
|
|
|
|
|
|
|
|
|
def __len__(self): |
|
|
|
|
def __len__(self) -> int: |
|
|
|
|
self.expire() |
|
|
|
|
return len(self._data) |
|
|
|
|
|
|
|
|
|
def expire(self): |
|
|
|
|
def expire(self) -> None: |
|
|
|
|
"""Run the expiry on the cache. Any entries whose expiry times are due will |
|
|
|
|
be removed |
|
|
|
|
""" |
|
|
|
@ -158,7 +165,7 @@ class _CacheEntry: |
|
|
|
|
"""TTLCache entry""" |
|
|
|
|
|
|
|
|
|
# expiry_time is the first attribute, so that entries are sorted by expiry. |
|
|
|
|
expiry_time = attr.ib() |
|
|
|
|
ttl = attr.ib() |
|
|
|
|
expiry_time = attr.ib(type=float) |
|
|
|
|
ttl = attr.ib(type=float) |
|
|
|
|
key = attr.ib() |
|
|
|
|
value = attr.ib() |
|
|
|
|