diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index fe01281c9..91bed4746 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -176,7 +176,7 @@ class FederationClient(FederationBase): ) @log_function - def query_client_keys(self, destination, content): + def query_client_keys(self, destination, content, timeout): """Query device keys for a device hosted on a remote server. Args: @@ -188,10 +188,12 @@ class FederationClient(FederationBase): response """ sent_queries_counter.inc("client_device_keys") - return self.transport_layer.query_client_keys(destination, content) + return self.transport_layer.query_client_keys( + destination, content, timeout + ) @log_function - def claim_client_keys(self, destination, content): + def claim_client_keys(self, destination, content, timeout): """Claims one-time keys for a device hosted on a remote server. Args: @@ -203,7 +205,9 @@ class FederationClient(FederationBase): response """ sent_queries_counter.inc("client_one_time_keys") - return self.transport_layer.claim_client_keys(destination, content) + return self.transport_layer.claim_client_keys( + destination, content, timeout + ) @defer.inlineCallbacks @log_function diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 3d088e43c..2b138526b 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -298,7 +298,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function - def query_client_keys(self, destination, query_content): + def query_client_keys(self, destination, query_content, timeout): """Query the device keys for a list of user ids hosted on a remote server. @@ -327,12 +327,13 @@ class TransportLayerClient(object): destination=destination, path=path, data=query_content, + timeout=timeout, ) defer.returnValue(content) @defer.inlineCallbacks @log_function - def claim_client_keys(self, destination, query_content): + def claim_client_keys(self, destination, query_content, timeout): """Claim one-time keys for a list of devices hosted on a remote server. Request: @@ -363,6 +364,7 @@ class TransportLayerClient(object): destination=destination, path=path, data=query_content, + timeout=timeout, ) defer.returnValue(content) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 2c7bfd91e..5bfd70093 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -13,14 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import collections import json import logging from twisted.internet import defer -from synapse.api import errors -import synapse.types +from synapse.api.errors import SynapseError, CodeMessageException +from synapse.types import get_domain_from_id +from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred logger = logging.getLogger(__name__) @@ -30,7 +30,6 @@ class E2eKeysHandler(object): self.store = hs.get_datastore() self.federation = hs.get_replication_layer() self.is_mine_id = hs.is_mine_id - self.server_name = hs.hostname # doesn't really work as part of the generic query API, because the # query request requires an object POST, but we abuse the @@ -40,7 +39,7 @@ class E2eKeysHandler(object): ) @defer.inlineCallbacks - def query_devices(self, query_body): + def query_devices(self, query_body, timeout): """ Handle a device key query from a client { @@ -63,27 +62,50 @@ class E2eKeysHandler(object): # separate users by domain. # make a map from domain to user_id to device_ids - queries_by_domain = collections.defaultdict(dict) + local_query = {} + remote_queries = {} + for user_id, device_ids in device_keys_query.items(): - user = synapse.types.UserID.from_string(user_id) - queries_by_domain[user.domain][user_id] = device_ids + if self.is_mine_id(user_id): + local_query[user_id] = device_ids + else: + domain = get_domain_from_id(user_id) + remote_queries.setdefault(domain, {})[user_id] = device_ids # do the queries - # TODO: do these in parallel + failures = {} results = {} - for destination, destination_query in queries_by_domain.items(): - if destination == self.server_name: - res = yield self.query_local_devices(destination_query) - else: - res = yield self.federation.query_client_keys( - destination, {"device_keys": destination_query} - ) - res = res["device_keys"] - for user_id, keys in res.items(): - if user_id in destination_query: + if local_query: + local_result = yield self.query_local_devices(local_query) + for user_id, keys in local_result.items(): + if user_id in local_query: results[user_id] = keys - defer.returnValue((200, {"device_keys": results})) + @defer.inlineCallbacks + def do_remote_query(destination): + destination_query = remote_queries[destination] + try: + remote_result = yield self.federation.query_client_keys( + destination, + {"device_keys": destination_query}, + timeout=timeout + ) + for user_id, keys in remote_result["device_keys"].items(): + if user_id in destination_query: + results[user_id] = keys + except CodeMessageException as e: + failures[destination] = { + "status": e.code, "message": e.message + } + + yield preserve_context_over_deferred(defer.gatherResults([ + preserve_fn(do_remote_query)(destination) + for destination in remote_queries + ])) + + defer.returnValue((200, { + "device_keys": results, "failures": failures, + })) @defer.inlineCallbacks def query_local_devices(self, query): @@ -104,7 +126,7 @@ class E2eKeysHandler(object): if not self.is_mine_id(user_id): logger.warning("Request for keys for non-local user %s", user_id) - raise errors.SynapseError(400, "Not a user here") + raise SynapseError(400, "Not a user here") if not device_ids: local_query.append((user_id, None)) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index f93093dd8..d0556ae34 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -246,7 +246,7 @@ class MatrixFederationHttpClient(object): @defer.inlineCallbacks def put_json(self, destination, path, data={}, json_data_callback=None, - long_retries=False): + long_retries=False, timeout=None): """ Sends the specifed json data using PUT Args: @@ -259,6 +259,8 @@ class MatrixFederationHttpClient(object): use as the request body. long_retries (bool): A boolean that indicates whether we should retry for a short or long time. + timeout(int): How long to try (in ms) the destination for before + giving up. None indicates no timeout. Returns: Deferred: Succeeds when we get a 2xx HTTP response. The result @@ -285,6 +287,7 @@ class MatrixFederationHttpClient(object): body_callback=body_callback, headers_dict={"Content-Type": ["application/json"]}, long_retries=long_retries, + timeout=timeout, ) if 200 <= response.code < 300: @@ -300,7 +303,8 @@ class MatrixFederationHttpClient(object): defer.returnValue(json.loads(body)) @defer.inlineCallbacks - def post_json(self, destination, path, data={}, long_retries=True): + def post_json(self, destination, path, data={}, long_retries=True, + timeout=None): """ Sends the specifed json data using POST Args: @@ -311,6 +315,8 @@ class MatrixFederationHttpClient(object): the request body. This will be encoded as JSON. long_retries (bool): A boolean that indicates whether we should retry for a short or long time. + timeout(int): How long to try (in ms) the destination for before + giving up. None indicates no timeout. Returns: Deferred: Succeeds when we get a 2xx HTTP response. The result @@ -331,6 +337,7 @@ class MatrixFederationHttpClient(object): body_callback=body_callback, headers_dict={"Content-Type": ["application/json"]}, long_retries=True, + timeout=timeout, ) if 200 <= response.code < 300: diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index c5ff16adf..8f0572765 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -19,11 +19,12 @@ import simplejson as json from canonicaljson import encode_canonical_json from twisted.internet import defer -import synapse.api.errors -import synapse.server -import synapse.types -from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.types import UserID +from synapse.api.errors import SynapseError, CodeMessageException +from synapse.http.servlet import ( + RestServlet, parse_json_object_from_request, parse_integer +) +from synapse.types import get_domain_from_id +from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from ._base import client_v2_patterns logger = logging.getLogger(__name__) @@ -88,7 +89,7 @@ class KeyUploadServlet(RestServlet): device_id = requester.device_id if device_id is None: - raise synapse.api.errors.SynapseError( + raise SynapseError( 400, "To upload keys, you must pass device_id when authenticating" ) @@ -195,18 +196,21 @@ class KeyQueryServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request, user_id, device_id): yield self.auth.get_user_by_req(request) + timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) - result = yield self.e2e_keys_handler.query_devices(body) + result = yield self.e2e_keys_handler.query_devices(body, timeout) defer.returnValue(result) @defer.inlineCallbacks def on_GET(self, request, user_id, device_id): requester = yield self.auth.get_user_by_req(request) + timeout = parse_integer(request, "timeout", 10 * 1000) auth_user_id = requester.user.to_string() user_id = user_id if user_id else auth_user_id device_ids = [device_id] if device_id else [] result = yield self.e2e_keys_handler.query_devices( - {"device_keys": {user_id: device_ids}} + {"device_keys": {user_id: device_ids}}, + timeout, ) defer.returnValue(result) @@ -244,39 +248,43 @@ class OneTimeKeyServlet(RestServlet): self.auth = hs.get_auth() self.clock = hs.get_clock() self.federation = hs.get_replication_layer() - self.is_mine = hs.is_mine + self.is_mine_id = hs.is_mine_id @defer.inlineCallbacks def on_GET(self, request, user_id, device_id, algorithm): yield self.auth.get_user_by_req(request) + timeout = parse_integer(request, "timeout", 10 * 1000) result = yield self.handle_request( - {"one_time_keys": {user_id: {device_id: algorithm}}} + {"one_time_keys": {user_id: {device_id: algorithm}}}, + timeout, ) defer.returnValue(result) @defer.inlineCallbacks def on_POST(self, request, user_id, device_id, algorithm): yield self.auth.get_user_by_req(request) + timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) - result = yield self.handle_request(body) + result = yield self.handle_request(body, timeout) defer.returnValue(result) @defer.inlineCallbacks - def handle_request(self, body): + def handle_request(self, body, timeout): local_query = [] remote_queries = {} + for user_id, device_keys in body.get("one_time_keys", {}).items(): - user = UserID.from_string(user_id) - if self.is_mine(user): + if self.is_mine_id(user_id): for device_id, algorithm in device_keys.items(): local_query.append((user_id, device_id, algorithm)) else: - remote_queries.setdefault(user.domain, {})[user_id] = ( - device_keys - ) + domain = get_domain_from_id(user_id) + remote_queries.setdefault(domain, {})[user_id] = device_keys + results = yield self.store.claim_e2e_one_time_keys(local_query) json_result = {} + failures = {} for user_id, device_keys in results.items(): for device_id, keys in device_keys.items(): for key_id, json_bytes in keys.items(): @@ -284,15 +292,32 @@ class OneTimeKeyServlet(RestServlet): key_id: json.loads(json_bytes) } - for destination, device_keys in remote_queries.items(): - remote_result = yield self.federation.claim_client_keys( - destination, {"one_time_keys": device_keys} - ) - for user_id, keys in remote_result["one_time_keys"].items(): - if user_id in device_keys: - json_result[user_id] = keys - - defer.returnValue((200, {"one_time_keys": json_result})) + @defer.inlineCallbacks + def claim_client_keys(destination): + device_keys = remote_queries[destination] + try: + remote_result = yield self.federation.claim_client_keys( + destination, + {"one_time_keys": device_keys}, + timeout=timeout + ) + for user_id, keys in remote_result["one_time_keys"].items(): + if user_id in device_keys: + json_result[user_id] = keys + except CodeMessageException as e: + failures[destination] = { + "status": e.code, "message": e.message + } + + yield preserve_context_over_deferred(defer.gatherResults([ + preserve_fn(claim_client_keys)(destination) + for destination in remote_queries + ])) + + defer.returnValue((200, { + "one_time_keys": json_result, + "failures": failures + })) def register_servlets(hs, http_server):