@ -26,6 +26,7 @@ from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse . storage . databases . main . signatures import SignatureWorkerStore
from synapse . types import Collection
from synapse . util . caches . descriptors import cached
from synapse . util . caches . lrucache import LruCache
from synapse . util . iterutils import batch_iter
logger = logging . getLogger ( __name__ )
@ -40,6 +41,11 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
self . _delete_old_forward_extrem_cache , 60 * 60 * 1000
)
# Cache of event ID to list of auth event IDs and their depths.
self . _event_auth_cache = LruCache (
500000 , " _event_auth_cache " , size_callback = len
) # type: LruCache[str, List[Tuple[str, int]]]
async def get_auth_chain (
self , event_ids : Collection [ str ] , include_given : bool = False
) - > List [ EventBase ] :
@ -84,17 +90,45 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
else :
results = set ( )
base_sql = " SELECT DISTINCT auth_id FROM event_auth WHERE "
# We pull out the depth simply so that we can populate the
# `_event_auth_cache` cache.
base_sql = """
SELECT a . event_id , auth_id , depth
FROM event_auth AS a
INNER JOIN events AS e ON ( e . event_id = a . auth_id )
WHERE
"""
front = set ( event_ids )
while front :
new_front = set ( )
for chunk in batch_iter ( front , 100 ) :
clause , args = make_in_list_sql_clause (
txn . database_engine , " event_id " , chunk
)
txn . execute ( base_sql + clause , args )
new_front . update ( r [ 0 ] for r in txn )
# Pull the auth events either from the cache or DB.
to_fetch = [ ] # Event IDs to fetch from DB # type: List[str]
for event_id in chunk :
res = self . _event_auth_cache . get ( event_id )
if res is None :
to_fetch . append ( event_id )
else :
new_front . update ( auth_id for auth_id , depth in res )
if to_fetch :
clause , args = make_in_list_sql_clause (
txn . database_engine , " a.event_id " , to_fetch
)
txn . execute ( base_sql + clause , args )
# Note we need to batch up the results by event ID before
# adding to the cache.
to_cache = { }
for event_id , auth_event_id , auth_event_depth in txn :
to_cache . setdefault ( event_id , [ ] ) . append (
( auth_event_id , auth_event_depth )
)
new_front . add ( auth_event_id )
for event_id , auth_events in to_cache . items ( ) :
self . _event_auth_cache . set ( event_id , auth_events )
new_front - = results
@ -213,14 +247,38 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
break
# Fetch the auth events and their depths of the N last events we're
# currently walking
# currently walking, either from cache or DB.
search , chunk = search [ : - 100 ] , search [ - 100 : ]
clause , args = make_in_list_sql_clause (
txn . database_engine , " a.event_id " , [ e_id for _ , e_id in chunk ]
)
txn . execute ( base_sql + clause , args )
for event_id , auth_event_id , auth_event_depth in txn :
found = [ ] # Results found # type: List[Tuple[str, str, int]]
to_fetch = [ ] # Event IDs to fetch from DB # type: List[str]
for _ , event_id in chunk :
res = self . _event_auth_cache . get ( event_id )
if res is None :
to_fetch . append ( event_id )
else :
found . extend ( ( event_id , auth_id , depth ) for auth_id , depth in res )
if to_fetch :
clause , args = make_in_list_sql_clause (
txn . database_engine , " a.event_id " , to_fetch
)
txn . execute ( base_sql + clause , args )
# We parse the results and add the to the `found` set and the
# cache (note we need to batch up the results by event ID before
# adding to the cache).
to_cache = { }
for event_id , auth_event_id , auth_event_depth in txn :
to_cache . setdefault ( event_id , [ ] ) . append (
( auth_event_id , auth_event_depth )
)
found . append ( ( event_id , auth_event_id , auth_event_depth ) )
for event_id , auth_events in to_cache . items ( ) :
self . _event_auth_cache . set ( event_id , auth_events )
for event_id , auth_event_id , auth_event_depth in found :
event_to_auth_events . setdefault ( event_id , set ( ) ) . add ( auth_event_id )
sets = event_to_missing_sets . get ( auth_event_id )