|
|
|
@ -24,6 +24,7 @@ from typing import ( |
|
|
|
|
Mapping, |
|
|
|
|
Optional, |
|
|
|
|
Sequence, |
|
|
|
|
Set, |
|
|
|
|
Tuple, |
|
|
|
|
Union, |
|
|
|
|
cast, |
|
|
|
@ -1260,6 +1261,65 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker |
|
|
|
|
Returns: |
|
|
|
|
A map of user ID -> a map device ID -> a map of key ID -> JSON. |
|
|
|
|
""" |
|
|
|
|
if isinstance(self.database_engine, PostgresEngine): |
|
|
|
|
return await self.db_pool.runInteraction( |
|
|
|
|
"_claim_e2e_fallback_keys_bulk", |
|
|
|
|
self._claim_e2e_fallback_keys_bulk_txn, |
|
|
|
|
query_list, |
|
|
|
|
db_autocommit=True, |
|
|
|
|
) |
|
|
|
|
# Use an UPDATE FROM... RETURNING combined with a VALUES block to do |
|
|
|
|
# everything in one query. Note: this is also supported in SQLite 3.33.0, |
|
|
|
|
# (see https://www.sqlite.org/lang_update.html#update_from), but we do not |
|
|
|
|
# have an equivalent of psycopg2's execute_values to do this in one query. |
|
|
|
|
else: |
|
|
|
|
return await self._claim_e2e_fallback_keys_simple(query_list) |
|
|
|
|
|
|
|
|
|
def _claim_e2e_fallback_keys_bulk_txn( |
|
|
|
|
self, |
|
|
|
|
txn: LoggingTransaction, |
|
|
|
|
query_list: Iterable[Tuple[str, str, str, bool]], |
|
|
|
|
) -> Dict[str, Dict[str, Dict[str, JsonDict]]]: |
|
|
|
|
"""Efficient implementation of claim_e2e_fallback_keys for Postgres. |
|
|
|
|
|
|
|
|
|
Safe to autocommit: this is a single query. |
|
|
|
|
""" |
|
|
|
|
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} |
|
|
|
|
|
|
|
|
|
sql = """ |
|
|
|
|
WITH claims(user_id, device_id, algorithm, mark_as_used) AS ( |
|
|
|
|
VALUES ? |
|
|
|
|
) |
|
|
|
|
UPDATE e2e_fallback_keys_json k |
|
|
|
|
SET used = used OR mark_as_used |
|
|
|
|
FROM claims |
|
|
|
|
WHERE (k.user_id, k.device_id, k.algorithm) = (claims.user_id, claims.device_id, claims.algorithm) |
|
|
|
|
RETURNING k.user_id, k.device_id, k.algorithm, k.key_id, k.key_json; |
|
|
|
|
""" |
|
|
|
|
claimed_keys = cast( |
|
|
|
|
List[Tuple[str, str, str, str, str]], |
|
|
|
|
txn.execute_values(sql, query_list), |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
seen_user_device: Set[Tuple[str, str]] = set() |
|
|
|
|
for user_id, device_id, algorithm, key_id, key_json in claimed_keys: |
|
|
|
|
device_results = results.setdefault(user_id, {}).setdefault(device_id, {}) |
|
|
|
|
device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json) |
|
|
|
|
|
|
|
|
|
if (user_id, device_id) in seen_user_device: |
|
|
|
|
continue |
|
|
|
|
seen_user_device.add((user_id, device_id)) |
|
|
|
|
self._invalidate_cache_and_stream( |
|
|
|
|
txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id) |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
|
async def _claim_e2e_fallback_keys_simple( |
|
|
|
|
self, |
|
|
|
|
query_list: Iterable[Tuple[str, str, str, bool]], |
|
|
|
|
) -> Dict[str, Dict[str, Dict[str, JsonDict]]]: |
|
|
|
|
"""Naive, inefficient implementation of claim_e2e_fallback_keys for SQLite.""" |
|
|
|
|
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} |
|
|
|
|
for user_id, device_id, algorithm, mark_as_used in query_list: |
|
|
|
|
row = await self.db_pool.simple_select_one( |
|
|
|
|