|
|
|
@ -19,11 +19,12 @@ import simplejson as json |
|
|
|
|
from canonicaljson import encode_canonical_json |
|
|
|
|
from twisted.internet import defer |
|
|
|
|
|
|
|
|
|
import synapse.api.errors |
|
|
|
|
import synapse.server |
|
|
|
|
import synapse.types |
|
|
|
|
from synapse.http.servlet import RestServlet, parse_json_object_from_request |
|
|
|
|
from synapse.types import UserID |
|
|
|
|
from synapse.api.errors import SynapseError, CodeMessageException |
|
|
|
|
from synapse.http.servlet import ( |
|
|
|
|
RestServlet, parse_json_object_from_request, parse_integer |
|
|
|
|
) |
|
|
|
|
from synapse.types import get_domain_from_id |
|
|
|
|
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred |
|
|
|
|
from ._base import client_v2_patterns |
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
@ -88,7 +89,7 @@ class KeyUploadServlet(RestServlet): |
|
|
|
|
device_id = requester.device_id |
|
|
|
|
|
|
|
|
|
if device_id is None: |
|
|
|
|
raise synapse.api.errors.SynapseError( |
|
|
|
|
raise SynapseError( |
|
|
|
|
400, |
|
|
|
|
"To upload keys, you must pass device_id when authenticating" |
|
|
|
|
) |
|
|
|
@ -195,18 +196,21 @@ class KeyQueryServlet(RestServlet): |
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def on_POST(self, request, user_id, device_id): |
|
|
|
|
yield self.auth.get_user_by_req(request) |
|
|
|
|
timeout = parse_integer(request, "timeout", 10 * 1000) |
|
|
|
|
body = parse_json_object_from_request(request) |
|
|
|
|
result = yield self.e2e_keys_handler.query_devices(body) |
|
|
|
|
result = yield self.e2e_keys_handler.query_devices(body, timeout) |
|
|
|
|
defer.returnValue(result) |
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def on_GET(self, request, user_id, device_id): |
|
|
|
|
requester = yield self.auth.get_user_by_req(request) |
|
|
|
|
timeout = parse_integer(request, "timeout", 10 * 1000) |
|
|
|
|
auth_user_id = requester.user.to_string() |
|
|
|
|
user_id = user_id if user_id else auth_user_id |
|
|
|
|
device_ids = [device_id] if device_id else [] |
|
|
|
|
result = yield self.e2e_keys_handler.query_devices( |
|
|
|
|
{"device_keys": {user_id: device_ids}} |
|
|
|
|
{"device_keys": {user_id: device_ids}}, |
|
|
|
|
timeout, |
|
|
|
|
) |
|
|
|
|
defer.returnValue(result) |
|
|
|
|
|
|
|
|
@ -244,39 +248,43 @@ class OneTimeKeyServlet(RestServlet): |
|
|
|
|
self.auth = hs.get_auth() |
|
|
|
|
self.clock = hs.get_clock() |
|
|
|
|
self.federation = hs.get_replication_layer() |
|
|
|
|
self.is_mine = hs.is_mine |
|
|
|
|
self.is_mine_id = hs.is_mine_id |
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def on_GET(self, request, user_id, device_id, algorithm): |
|
|
|
|
yield self.auth.get_user_by_req(request) |
|
|
|
|
timeout = parse_integer(request, "timeout", 10 * 1000) |
|
|
|
|
result = yield self.handle_request( |
|
|
|
|
{"one_time_keys": {user_id: {device_id: algorithm}}} |
|
|
|
|
{"one_time_keys": {user_id: {device_id: algorithm}}}, |
|
|
|
|
timeout, |
|
|
|
|
) |
|
|
|
|
defer.returnValue(result) |
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def on_POST(self, request, user_id, device_id, algorithm): |
|
|
|
|
yield self.auth.get_user_by_req(request) |
|
|
|
|
timeout = parse_integer(request, "timeout", 10 * 1000) |
|
|
|
|
body = parse_json_object_from_request(request) |
|
|
|
|
result = yield self.handle_request(body) |
|
|
|
|
result = yield self.handle_request(body, timeout) |
|
|
|
|
defer.returnValue(result) |
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def handle_request(self, body): |
|
|
|
|
def handle_request(self, body, timeout): |
|
|
|
|
local_query = [] |
|
|
|
|
remote_queries = {} |
|
|
|
|
|
|
|
|
|
for user_id, device_keys in body.get("one_time_keys", {}).items(): |
|
|
|
|
user = UserID.from_string(user_id) |
|
|
|
|
if self.is_mine(user): |
|
|
|
|
if self.is_mine_id(user_id): |
|
|
|
|
for device_id, algorithm in device_keys.items(): |
|
|
|
|
local_query.append((user_id, device_id, algorithm)) |
|
|
|
|
else: |
|
|
|
|
remote_queries.setdefault(user.domain, {})[user_id] = ( |
|
|
|
|
device_keys |
|
|
|
|
) |
|
|
|
|
domain = get_domain_from_id(user_id) |
|
|
|
|
remote_queries.setdefault(domain, {})[user_id] = device_keys |
|
|
|
|
|
|
|
|
|
results = yield self.store.claim_e2e_one_time_keys(local_query) |
|
|
|
|
|
|
|
|
|
json_result = {} |
|
|
|
|
failures = {} |
|
|
|
|
for user_id, device_keys in results.items(): |
|
|
|
|
for device_id, keys in device_keys.items(): |
|
|
|
|
for key_id, json_bytes in keys.items(): |
|
|
|
@ -284,15 +292,32 @@ class OneTimeKeyServlet(RestServlet): |
|
|
|
|
key_id: json.loads(json_bytes) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
for destination, device_keys in remote_queries.items(): |
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def claim_client_keys(destination): |
|
|
|
|
device_keys = remote_queries[destination] |
|
|
|
|
try: |
|
|
|
|
remote_result = yield self.federation.claim_client_keys( |
|
|
|
|
destination, {"one_time_keys": device_keys} |
|
|
|
|
destination, |
|
|
|
|
{"one_time_keys": device_keys}, |
|
|
|
|
timeout=timeout |
|
|
|
|
) |
|
|
|
|
for user_id, keys in remote_result["one_time_keys"].items(): |
|
|
|
|
if user_id in device_keys: |
|
|
|
|
json_result[user_id] = keys |
|
|
|
|
except CodeMessageException as e: |
|
|
|
|
failures[destination] = { |
|
|
|
|
"status": e.code, "message": e.message |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
yield preserve_context_over_deferred(defer.gatherResults([ |
|
|
|
|
preserve_fn(claim_client_keys)(destination) |
|
|
|
|
for destination in remote_queries |
|
|
|
|
])) |
|
|
|
|
|
|
|
|
|
defer.returnValue((200, {"one_time_keys": json_result})) |
|
|
|
|
defer.returnValue((200, { |
|
|
|
|
"one_time_keys": json_result, |
|
|
|
|
"failures": failures |
|
|
|
|
})) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def register_servlets(hs, http_server): |
|
|
|
|