|
|
|
@ -755,81 +755,145 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
@trace |
|
|
|
|
def _claim_e2e_one_time_keys(txn): |
|
|
|
|
sql = ( |
|
|
|
|
"SELECT key_id, key_json FROM e2e_one_time_keys_json" |
|
|
|
|
" WHERE user_id = ? AND device_id = ? AND algorithm = ?" |
|
|
|
|
" LIMIT 1" |
|
|
|
|
def _claim_e2e_one_time_key_simple( |
|
|
|
|
txn, user_id: str, device_id: str, algorithm: str |
|
|
|
|
) -> Optional[Tuple[str, str]]: |
|
|
|
|
"""Claim OTK for device for DBs that don't support RETURNING. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
A tuple of key name (algorithm + key ID) and key JSON, if an |
|
|
|
|
OTK was found. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
sql = """ |
|
|
|
|
SELECT key_id, key_json FROM e2e_one_time_keys_json |
|
|
|
|
WHERE user_id = ? AND device_id = ? AND algorithm = ? |
|
|
|
|
LIMIT 1 |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
txn.execute(sql, (user_id, device_id, algorithm)) |
|
|
|
|
otk_row = txn.fetchone() |
|
|
|
|
if otk_row is None: |
|
|
|
|
return None |
|
|
|
|
|
|
|
|
|
key_id, key_json = otk_row |
|
|
|
|
|
|
|
|
|
self.db_pool.simple_delete_one_txn( |
|
|
|
|
txn, |
|
|
|
|
table="e2e_one_time_keys_json", |
|
|
|
|
keyvalues={ |
|
|
|
|
"user_id": user_id, |
|
|
|
|
"device_id": device_id, |
|
|
|
|
"algorithm": algorithm, |
|
|
|
|
"key_id": key_id, |
|
|
|
|
}, |
|
|
|
|
) |
|
|
|
|
fallback_sql = ( |
|
|
|
|
"SELECT key_id, key_json, used FROM e2e_fallback_keys_json" |
|
|
|
|
" WHERE user_id = ? AND device_id = ? AND algorithm = ?" |
|
|
|
|
" LIMIT 1" |
|
|
|
|
self._invalidate_cache_and_stream( |
|
|
|
|
txn, self.count_e2e_one_time_keys, (user_id, device_id) |
|
|
|
|
) |
|
|
|
|
result = {} |
|
|
|
|
delete = [] |
|
|
|
|
used_fallbacks = [] |
|
|
|
|
for user_id, device_id, algorithm in query_list: |
|
|
|
|
user_result = result.setdefault(user_id, {}) |
|
|
|
|
device_result = user_result.setdefault(device_id, {}) |
|
|
|
|
txn.execute(sql, (user_id, device_id, algorithm)) |
|
|
|
|
otk_row = txn.fetchone() |
|
|
|
|
if otk_row is not None: |
|
|
|
|
key_id, key_json = otk_row |
|
|
|
|
device_result[algorithm + ":" + key_id] = key_json |
|
|
|
|
delete.append((user_id, device_id, algorithm, key_id)) |
|
|
|
|
else: |
|
|
|
|
# no one-time key available, so see if there's a fallback |
|
|
|
|
# key |
|
|
|
|
txn.execute(fallback_sql, (user_id, device_id, algorithm)) |
|
|
|
|
fallback_row = txn.fetchone() |
|
|
|
|
if fallback_row is not None: |
|
|
|
|
key_id, key_json, used = fallback_row |
|
|
|
|
device_result[algorithm + ":" + key_id] = key_json |
|
|
|
|
if not used: |
|
|
|
|
used_fallbacks.append( |
|
|
|
|
(user_id, device_id, algorithm, key_id) |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
# drop any one-time keys that were claimed |
|
|
|
|
sql = ( |
|
|
|
|
"DELETE FROM e2e_one_time_keys_json" |
|
|
|
|
" WHERE user_id = ? AND device_id = ? AND algorithm = ?" |
|
|
|
|
" AND key_id = ?" |
|
|
|
|
|
|
|
|
|
return f"{algorithm}:{key_id}", key_json |
|
|
|
|
|
|
|
|
|
@trace |
|
|
|
|
def _claim_e2e_one_time_key_returning( |
|
|
|
|
txn, user_id: str, device_id: str, algorithm: str |
|
|
|
|
) -> Optional[Tuple[str, str]]: |
|
|
|
|
"""Claim OTK for device for DBs that support RETURNING. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
A tuple of key name (algorithm + key ID) and key JSON, if an |
|
|
|
|
OTK was found. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
# We can use RETURNING to do the fetch and DELETE in once step. |
|
|
|
|
sql = """ |
|
|
|
|
DELETE FROM e2e_one_time_keys_json |
|
|
|
|
WHERE user_id = ? AND device_id = ? AND algorithm = ? |
|
|
|
|
AND key_id IN ( |
|
|
|
|
SELECT key_id FROM e2e_one_time_keys_json |
|
|
|
|
WHERE user_id = ? AND device_id = ? AND algorithm = ? |
|
|
|
|
LIMIT 1 |
|
|
|
|
) |
|
|
|
|
RETURNING key_id, key_json |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
txn.execute( |
|
|
|
|
sql, (user_id, device_id, algorithm, user_id, device_id, algorithm) |
|
|
|
|
) |
|
|
|
|
for user_id, device_id, algorithm, key_id in delete: |
|
|
|
|
log_kv( |
|
|
|
|
{ |
|
|
|
|
"message": "Executing claim e2e_one_time_keys transaction on database." |
|
|
|
|
} |
|
|
|
|
) |
|
|
|
|
txn.execute(sql, (user_id, device_id, algorithm, key_id)) |
|
|
|
|
log_kv({"message": "finished executing and invalidating cache"}) |
|
|
|
|
self._invalidate_cache_and_stream( |
|
|
|
|
txn, self.count_e2e_one_time_keys, (user_id, device_id) |
|
|
|
|
otk_row = txn.fetchone() |
|
|
|
|
if otk_row is None: |
|
|
|
|
return None |
|
|
|
|
|
|
|
|
|
key_id, key_json = otk_row |
|
|
|
|
return f"{algorithm}:{key_id}", key_json |
|
|
|
|
|
|
|
|
|
results = {} |
|
|
|
|
for user_id, device_id, algorithm in query_list: |
|
|
|
|
if self.database_engine.supports_returning: |
|
|
|
|
# If we support RETURNING clause we can use a single query that |
|
|
|
|
# allows us to use autocommit mode. |
|
|
|
|
_claim_e2e_one_time_key = _claim_e2e_one_time_key_returning |
|
|
|
|
db_autocommit = True |
|
|
|
|
else: |
|
|
|
|
_claim_e2e_one_time_key = _claim_e2e_one_time_key_simple |
|
|
|
|
db_autocommit = False |
|
|
|
|
|
|
|
|
|
row = await self.db_pool.runInteraction( |
|
|
|
|
"claim_e2e_one_time_keys", |
|
|
|
|
_claim_e2e_one_time_key, |
|
|
|
|
user_id, |
|
|
|
|
device_id, |
|
|
|
|
algorithm, |
|
|
|
|
db_autocommit=db_autocommit, |
|
|
|
|
) |
|
|
|
|
if row: |
|
|
|
|
device_results = results.setdefault(user_id, {}).setdefault( |
|
|
|
|
device_id, {} |
|
|
|
|
) |
|
|
|
|
# mark fallback keys as used |
|
|
|
|
for user_id, device_id, algorithm, key_id in used_fallbacks: |
|
|
|
|
self.db_pool.simple_update_txn( |
|
|
|
|
txn, |
|
|
|
|
"e2e_fallback_keys_json", |
|
|
|
|
{ |
|
|
|
|
device_results[row[0]] = row[1] |
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
# No one-time key available, so see if there's a fallback |
|
|
|
|
# key |
|
|
|
|
row = await self.db_pool.simple_select_one( |
|
|
|
|
table="e2e_fallback_keys_json", |
|
|
|
|
keyvalues={ |
|
|
|
|
"user_id": user_id, |
|
|
|
|
"device_id": device_id, |
|
|
|
|
"algorithm": algorithm, |
|
|
|
|
}, |
|
|
|
|
retcols=("key_id", "key_json", "used"), |
|
|
|
|
desc="_get_fallback_key", |
|
|
|
|
allow_none=True, |
|
|
|
|
) |
|
|
|
|
if row is None: |
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
key_id = row["key_id"] |
|
|
|
|
key_json = row["key_json"] |
|
|
|
|
used = row["used"] |
|
|
|
|
|
|
|
|
|
# Mark fallback key as used if not already. |
|
|
|
|
if not used: |
|
|
|
|
await self.db_pool.simple_update_one( |
|
|
|
|
table="e2e_fallback_keys_json", |
|
|
|
|
keyvalues={ |
|
|
|
|
"user_id": user_id, |
|
|
|
|
"device_id": device_id, |
|
|
|
|
"algorithm": algorithm, |
|
|
|
|
"key_id": key_id, |
|
|
|
|
}, |
|
|
|
|
{"used": True}, |
|
|
|
|
updatevalues={"used": True}, |
|
|
|
|
desc="_get_fallback_key_set_used", |
|
|
|
|
) |
|
|
|
|
self._invalidate_cache_and_stream( |
|
|
|
|
txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id) |
|
|
|
|
await self.invalidate_cache_and_stream( |
|
|
|
|
"get_e2e_unused_fallback_key_types", (user_id, device_id) |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
return result |
|
|
|
|
device_results = results.setdefault(user_id, {}).setdefault(device_id, {}) |
|
|
|
|
device_results[f"{algorithm}:{key_id}"] = key_json |
|
|
|
|
|
|
|
|
|
return await self.db_pool.runInteraction( |
|
|
|
|
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys |
|
|
|
|
) |
|
|
|
|
return results |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): |
|
|
|
|