|
|
|
@ -79,9 +79,15 @@ class E2eKeysHandler: |
|
|
|
|
"client_keys", self.on_federation_query_client_keys |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
# Limit the number of in-flight requests from a single device. |
|
|
|
|
self._query_devices_linearizer = Linearizer( |
|
|
|
|
name="query_devices", |
|
|
|
|
max_count=10, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
@trace |
|
|
|
|
async def query_devices( |
|
|
|
|
self, query_body: JsonDict, timeout: int, from_user_id: str |
|
|
|
|
self, query_body: JsonDict, timeout: int, from_user_id: str, from_device_id: str |
|
|
|
|
) -> JsonDict: |
|
|
|
|
"""Handle a device key query from a client |
|
|
|
|
|
|
|
|
@ -105,191 +111,197 @@ class E2eKeysHandler: |
|
|
|
|
from_user_id: the user making the query. This is used when |
|
|
|
|
adding cross-signing signatures to limit what signatures users |
|
|
|
|
can see. |
|
|
|
|
from_device_id: the device making the query. This is used to limit |
|
|
|
|
the number of in-flight queries at a time. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
device_keys_query = query_body.get( |
|
|
|
|
"device_keys", {} |
|
|
|
|
) # type: Dict[str, Iterable[str]] |
|
|
|
|
|
|
|
|
|
# separate users by domain. |
|
|
|
|
# make a map from domain to user_id to device_ids |
|
|
|
|
local_query = {} |
|
|
|
|
remote_queries = {} |
|
|
|
|
|
|
|
|
|
for user_id, device_ids in device_keys_query.items(): |
|
|
|
|
# we use UserID.from_string to catch invalid user ids |
|
|
|
|
if self.is_mine(UserID.from_string(user_id)): |
|
|
|
|
local_query[user_id] = device_ids |
|
|
|
|
else: |
|
|
|
|
remote_queries[user_id] = device_ids |
|
|
|
|
|
|
|
|
|
set_tag("local_key_query", local_query) |
|
|
|
|
set_tag("remote_key_query", remote_queries) |
|
|
|
|
|
|
|
|
|
# First get local devices. |
|
|
|
|
# A map of destination -> failure response. |
|
|
|
|
failures = {} # type: Dict[str, JsonDict] |
|
|
|
|
results = {} |
|
|
|
|
if local_query: |
|
|
|
|
local_result = await self.query_local_devices(local_query) |
|
|
|
|
for user_id, keys in local_result.items(): |
|
|
|
|
if user_id in local_query: |
|
|
|
|
results[user_id] = keys |
|
|
|
|
|
|
|
|
|
# Get cached cross-signing keys |
|
|
|
|
cross_signing_keys = await self.get_cross_signing_keys_from_cache( |
|
|
|
|
device_keys_query, from_user_id |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
# Now attempt to get any remote devices from our local cache. |
|
|
|
|
# A map of destination -> user ID -> device IDs. |
|
|
|
|
remote_queries_not_in_cache = {} # type: Dict[str, Dict[str, Iterable[str]]] |
|
|
|
|
if remote_queries: |
|
|
|
|
query_list = [] # type: List[Tuple[str, Optional[str]]] |
|
|
|
|
for user_id, device_ids in remote_queries.items(): |
|
|
|
|
if device_ids: |
|
|
|
|
query_list.extend((user_id, device_id) for device_id in device_ids) |
|
|
|
|
with await self._query_devices_linearizer.queue((from_user_id, from_device_id)): |
|
|
|
|
device_keys_query = query_body.get( |
|
|
|
|
"device_keys", {} |
|
|
|
|
) # type: Dict[str, Iterable[str]] |
|
|
|
|
|
|
|
|
|
# separate users by domain. |
|
|
|
|
# make a map from domain to user_id to device_ids |
|
|
|
|
local_query = {} |
|
|
|
|
remote_queries = {} |
|
|
|
|
|
|
|
|
|
for user_id, device_ids in device_keys_query.items(): |
|
|
|
|
# we use UserID.from_string to catch invalid user ids |
|
|
|
|
if self.is_mine(UserID.from_string(user_id)): |
|
|
|
|
local_query[user_id] = device_ids |
|
|
|
|
else: |
|
|
|
|
query_list.append((user_id, None)) |
|
|
|
|
|
|
|
|
|
( |
|
|
|
|
user_ids_not_in_cache, |
|
|
|
|
remote_results, |
|
|
|
|
) = await self.store.get_user_devices_from_cache(query_list) |
|
|
|
|
for user_id, devices in remote_results.items(): |
|
|
|
|
user_devices = results.setdefault(user_id, {}) |
|
|
|
|
for device_id, device in devices.items(): |
|
|
|
|
keys = device.get("keys", None) |
|
|
|
|
device_display_name = device.get("device_display_name", None) |
|
|
|
|
if keys: |
|
|
|
|
result = dict(keys) |
|
|
|
|
unsigned = result.setdefault("unsigned", {}) |
|
|
|
|
if device_display_name: |
|
|
|
|
unsigned["device_display_name"] = device_display_name |
|
|
|
|
user_devices[device_id] = result |
|
|
|
|
|
|
|
|
|
# check for missing cross-signing keys. |
|
|
|
|
for user_id in remote_queries.keys(): |
|
|
|
|
cached_cross_master = user_id in cross_signing_keys["master_keys"] |
|
|
|
|
cached_cross_selfsigning = ( |
|
|
|
|
user_id in cross_signing_keys["self_signing_keys"] |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
# check if we are missing only one of cross-signing master or |
|
|
|
|
# self-signing key, but the other one is cached. |
|
|
|
|
# as we need both, this will issue a federation request. |
|
|
|
|
# if we don't have any of the keys, either the user doesn't have |
|
|
|
|
# cross-signing set up, or the cached device list |
|
|
|
|
# is not (yet) updated. |
|
|
|
|
if cached_cross_master ^ cached_cross_selfsigning: |
|
|
|
|
user_ids_not_in_cache.add(user_id) |
|
|
|
|
|
|
|
|
|
# add those users to the list to fetch over federation. |
|
|
|
|
for user_id in user_ids_not_in_cache: |
|
|
|
|
domain = get_domain_from_id(user_id) |
|
|
|
|
r = remote_queries_not_in_cache.setdefault(domain, {}) |
|
|
|
|
r[user_id] = remote_queries[user_id] |
|
|
|
|
|
|
|
|
|
# Now fetch any devices that we don't have in our cache |
|
|
|
|
@trace |
|
|
|
|
async def do_remote_query(destination): |
|
|
|
|
"""This is called when we are querying the device list of a user on |
|
|
|
|
a remote homeserver and their device list is not in the device list |
|
|
|
|
cache. If we share a room with this user and we're not querying for |
|
|
|
|
specific user we will update the cache with their device list. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
destination_query = remote_queries_not_in_cache[destination] |
|
|
|
|
|
|
|
|
|
# We first consider whether we wish to update the device list cache with |
|
|
|
|
# the users device list. We want to track a user's devices when the |
|
|
|
|
# authenticated user shares a room with the queried user and the query |
|
|
|
|
# has not specified a particular device. |
|
|
|
|
# If we update the cache for the queried user we remove them from further |
|
|
|
|
# queries. We use the more efficient batched query_client_keys for all |
|
|
|
|
# remaining users |
|
|
|
|
user_ids_updated = [] |
|
|
|
|
for (user_id, device_list) in destination_query.items(): |
|
|
|
|
if user_id in user_ids_updated: |
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
if device_list: |
|
|
|
|
continue |
|
|
|
|
remote_queries[user_id] = device_ids |
|
|
|
|
|
|
|
|
|
set_tag("local_key_query", local_query) |
|
|
|
|
set_tag("remote_key_query", remote_queries) |
|
|
|
|
|
|
|
|
|
# First get local devices. |
|
|
|
|
# A map of destination -> failure response. |
|
|
|
|
failures = {} # type: Dict[str, JsonDict] |
|
|
|
|
results = {} |
|
|
|
|
if local_query: |
|
|
|
|
local_result = await self.query_local_devices(local_query) |
|
|
|
|
for user_id, keys in local_result.items(): |
|
|
|
|
if user_id in local_query: |
|
|
|
|
results[user_id] = keys |
|
|
|
|
|
|
|
|
|
room_ids = await self.store.get_rooms_for_user(user_id) |
|
|
|
|
if not room_ids: |
|
|
|
|
continue |
|
|
|
|
# Get cached cross-signing keys |
|
|
|
|
cross_signing_keys = await self.get_cross_signing_keys_from_cache( |
|
|
|
|
device_keys_query, from_user_id |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
# We've decided we're sharing a room with this user and should |
|
|
|
|
# probably be tracking their device lists. However, we haven't |
|
|
|
|
# done an initial sync on the device list so we do it now. |
|
|
|
|
try: |
|
|
|
|
if self._is_master: |
|
|
|
|
user_devices = await self.device_handler.device_list_updater.user_device_resync( |
|
|
|
|
user_id |
|
|
|
|
# Now attempt to get any remote devices from our local cache. |
|
|
|
|
# A map of destination -> user ID -> device IDs. |
|
|
|
|
remote_queries_not_in_cache = ( |
|
|
|
|
{} |
|
|
|
|
) # type: Dict[str, Dict[str, Iterable[str]]] |
|
|
|
|
if remote_queries: |
|
|
|
|
query_list = [] # type: List[Tuple[str, Optional[str]]] |
|
|
|
|
for user_id, device_ids in remote_queries.items(): |
|
|
|
|
if device_ids: |
|
|
|
|
query_list.extend( |
|
|
|
|
(user_id, device_id) for device_id in device_ids |
|
|
|
|
) |
|
|
|
|
else: |
|
|
|
|
user_devices = await self._user_device_resync_client( |
|
|
|
|
user_id=user_id |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
user_devices = user_devices["devices"] |
|
|
|
|
user_results = results.setdefault(user_id, {}) |
|
|
|
|
for device in user_devices: |
|
|
|
|
user_results[device["device_id"]] = device["keys"] |
|
|
|
|
user_ids_updated.append(user_id) |
|
|
|
|
except Exception as e: |
|
|
|
|
failures[destination] = _exception_to_failure(e) |
|
|
|
|
|
|
|
|
|
if len(destination_query) == len(user_ids_updated): |
|
|
|
|
# We've updated all the users in the query and we do not need to |
|
|
|
|
# make any further remote calls. |
|
|
|
|
return |
|
|
|
|
query_list.append((user_id, None)) |
|
|
|
|
|
|
|
|
|
# Remove all the users from the query which we have updated |
|
|
|
|
for user_id in user_ids_updated: |
|
|
|
|
destination_query.pop(user_id) |
|
|
|
|
( |
|
|
|
|
user_ids_not_in_cache, |
|
|
|
|
remote_results, |
|
|
|
|
) = await self.store.get_user_devices_from_cache(query_list) |
|
|
|
|
for user_id, devices in remote_results.items(): |
|
|
|
|
user_devices = results.setdefault(user_id, {}) |
|
|
|
|
for device_id, device in devices.items(): |
|
|
|
|
keys = device.get("keys", None) |
|
|
|
|
device_display_name = device.get("device_display_name", None) |
|
|
|
|
if keys: |
|
|
|
|
result = dict(keys) |
|
|
|
|
unsigned = result.setdefault("unsigned", {}) |
|
|
|
|
if device_display_name: |
|
|
|
|
unsigned["device_display_name"] = device_display_name |
|
|
|
|
user_devices[device_id] = result |
|
|
|
|
|
|
|
|
|
# check for missing cross-signing keys. |
|
|
|
|
for user_id in remote_queries.keys(): |
|
|
|
|
cached_cross_master = user_id in cross_signing_keys["master_keys"] |
|
|
|
|
cached_cross_selfsigning = ( |
|
|
|
|
user_id in cross_signing_keys["self_signing_keys"] |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
remote_result = await self.federation.query_client_keys( |
|
|
|
|
destination, {"device_keys": destination_query}, timeout=timeout |
|
|
|
|
) |
|
|
|
|
# check if we are missing only one of cross-signing master or |
|
|
|
|
# self-signing key, but the other one is cached. |
|
|
|
|
# as we need both, this will issue a federation request. |
|
|
|
|
# if we don't have any of the keys, either the user doesn't have |
|
|
|
|
# cross-signing set up, or the cached device list |
|
|
|
|
# is not (yet) updated. |
|
|
|
|
if cached_cross_master ^ cached_cross_selfsigning: |
|
|
|
|
user_ids_not_in_cache.add(user_id) |
|
|
|
|
|
|
|
|
|
# add those users to the list to fetch over federation. |
|
|
|
|
for user_id in user_ids_not_in_cache: |
|
|
|
|
domain = get_domain_from_id(user_id) |
|
|
|
|
r = remote_queries_not_in_cache.setdefault(domain, {}) |
|
|
|
|
r[user_id] = remote_queries[user_id] |
|
|
|
|
|
|
|
|
|
# Now fetch any devices that we don't have in our cache |
|
|
|
|
@trace |
|
|
|
|
async def do_remote_query(destination): |
|
|
|
|
"""This is called when we are querying the device list of a user on |
|
|
|
|
a remote homeserver and their device list is not in the device list |
|
|
|
|
cache. If we share a room with this user and we're not querying for |
|
|
|
|
specific user we will update the cache with their device list. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
destination_query = remote_queries_not_in_cache[destination] |
|
|
|
|
|
|
|
|
|
# We first consider whether we wish to update the device list cache with |
|
|
|
|
# the users device list. We want to track a user's devices when the |
|
|
|
|
# authenticated user shares a room with the queried user and the query |
|
|
|
|
# has not specified a particular device. |
|
|
|
|
# If we update the cache for the queried user we remove them from further |
|
|
|
|
# queries. We use the more efficient batched query_client_keys for all |
|
|
|
|
# remaining users |
|
|
|
|
user_ids_updated = [] |
|
|
|
|
for (user_id, device_list) in destination_query.items(): |
|
|
|
|
if user_id in user_ids_updated: |
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
if device_list: |
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
room_ids = await self.store.get_rooms_for_user(user_id) |
|
|
|
|
if not room_ids: |
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
# We've decided we're sharing a room with this user and should |
|
|
|
|
# probably be tracking their device lists. However, we haven't |
|
|
|
|
# done an initial sync on the device list so we do it now. |
|
|
|
|
try: |
|
|
|
|
if self._is_master: |
|
|
|
|
user_devices = await self.device_handler.device_list_updater.user_device_resync( |
|
|
|
|
user_id |
|
|
|
|
) |
|
|
|
|
else: |
|
|
|
|
user_devices = await self._user_device_resync_client( |
|
|
|
|
user_id=user_id |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
user_devices = user_devices["devices"] |
|
|
|
|
user_results = results.setdefault(user_id, {}) |
|
|
|
|
for device in user_devices: |
|
|
|
|
user_results[device["device_id"]] = device["keys"] |
|
|
|
|
user_ids_updated.append(user_id) |
|
|
|
|
except Exception as e: |
|
|
|
|
failures[destination] = _exception_to_failure(e) |
|
|
|
|
|
|
|
|
|
if len(destination_query) == len(user_ids_updated): |
|
|
|
|
# We've updated all the users in the query and we do not need to |
|
|
|
|
# make any further remote calls. |
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
# Remove all the users from the query which we have updated |
|
|
|
|
for user_id in user_ids_updated: |
|
|
|
|
destination_query.pop(user_id) |
|
|
|
|
|
|
|
|
|
for user_id, keys in remote_result["device_keys"].items(): |
|
|
|
|
if user_id in destination_query: |
|
|
|
|
results[user_id] = keys |
|
|
|
|
try: |
|
|
|
|
remote_result = await self.federation.query_client_keys( |
|
|
|
|
destination, {"device_keys": destination_query}, timeout=timeout |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
if "master_keys" in remote_result: |
|
|
|
|
for user_id, key in remote_result["master_keys"].items(): |
|
|
|
|
for user_id, keys in remote_result["device_keys"].items(): |
|
|
|
|
if user_id in destination_query: |
|
|
|
|
cross_signing_keys["master_keys"][user_id] = key |
|
|
|
|
results[user_id] = keys |
|
|
|
|
|
|
|
|
|
if "self_signing_keys" in remote_result: |
|
|
|
|
for user_id, key in remote_result["self_signing_keys"].items(): |
|
|
|
|
if user_id in destination_query: |
|
|
|
|
cross_signing_keys["self_signing_keys"][user_id] = key |
|
|
|
|
if "master_keys" in remote_result: |
|
|
|
|
for user_id, key in remote_result["master_keys"].items(): |
|
|
|
|
if user_id in destination_query: |
|
|
|
|
cross_signing_keys["master_keys"][user_id] = key |
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
|
failure = _exception_to_failure(e) |
|
|
|
|
failures[destination] = failure |
|
|
|
|
set_tag("error", True) |
|
|
|
|
set_tag("reason", failure) |
|
|
|
|
if "self_signing_keys" in remote_result: |
|
|
|
|
for user_id, key in remote_result["self_signing_keys"].items(): |
|
|
|
|
if user_id in destination_query: |
|
|
|
|
cross_signing_keys["self_signing_keys"][user_id] = key |
|
|
|
|
|
|
|
|
|
await make_deferred_yieldable( |
|
|
|
|
defer.gatherResults( |
|
|
|
|
[ |
|
|
|
|
run_in_background(do_remote_query, destination) |
|
|
|
|
for destination in remote_queries_not_in_cache |
|
|
|
|
], |
|
|
|
|
consumeErrors=True, |
|
|
|
|
).addErrback(unwrapFirstError) |
|
|
|
|
) |
|
|
|
|
except Exception as e: |
|
|
|
|
failure = _exception_to_failure(e) |
|
|
|
|
failures[destination] = failure |
|
|
|
|
set_tag("error", True) |
|
|
|
|
set_tag("reason", failure) |
|
|
|
|
|
|
|
|
|
await make_deferred_yieldable( |
|
|
|
|
defer.gatherResults( |
|
|
|
|
[ |
|
|
|
|
run_in_background(do_remote_query, destination) |
|
|
|
|
for destination in remote_queries_not_in_cache |
|
|
|
|
], |
|
|
|
|
consumeErrors=True, |
|
|
|
|
).addErrback(unwrapFirstError) |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
ret = {"device_keys": results, "failures": failures} |
|
|
|
|
ret = {"device_keys": results, "failures": failures} |
|
|
|
|
|
|
|
|
|
ret.update(cross_signing_keys) |
|
|
|
|
ret.update(cross_signing_keys) |
|
|
|
|
|
|
|
|
|
return ret |
|
|
|
|
return ret |
|
|
|
|
|
|
|
|
|
async def get_cross_signing_keys_from_cache( |
|
|
|
|
self, query: Iterable[str], from_user_id: Optional[str] |
|
|
|
|