|
|
|
@ -27,6 +27,8 @@ from synapse.api.errors import SynapseError, Codes |
|
|
|
|
from synapse.util.retryutils import get_retry_limiter |
|
|
|
|
from synapse.util import unwrapFirstError |
|
|
|
|
|
|
|
|
|
from synapse.util.async import ObservableDeferred |
|
|
|
|
|
|
|
|
|
from OpenSSL import crypto |
|
|
|
|
|
|
|
|
|
from collections import namedtuple |
|
|
|
@ -88,6 +90,8 @@ class Keyring(object): |
|
|
|
|
"Not signed with a supported algorithm", |
|
|
|
|
Codes.UNAUTHORIZED, |
|
|
|
|
)) |
|
|
|
|
else: |
|
|
|
|
deferreds[group_id] = defer.Deferred() |
|
|
|
|
|
|
|
|
|
group = KeyGroup(server_name, group_id, key_ids) |
|
|
|
|
|
|
|
|
@ -133,10 +137,41 @@ class Keyring(object): |
|
|
|
|
Codes.UNAUTHORIZED, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
deferreds.update(self.get_server_verify_keys( |
|
|
|
|
group_id_to_group |
|
|
|
|
)) |
|
|
|
|
server_to_deferred = { |
|
|
|
|
server_name: defer.Deferred() |
|
|
|
|
for server_name, _ in server_and_json |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
# We want to wait for any previous lookups to complete before |
|
|
|
|
# proceeding. |
|
|
|
|
wait_on_deferred = self.wait_for_previous_lookups( |
|
|
|
|
[server_name for server_name, _ in server_and_json], |
|
|
|
|
server_to_deferred, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
# Actually start fetching keys. |
|
|
|
|
wait_on_deferred.addBoth( |
|
|
|
|
lambda _: self.get_server_verify_keys(group_id_to_group, deferreds) |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
# When we've finished fetching all the keys for a given server_name, |
|
|
|
|
# resolve the deferred passed to `wait_for_previous_lookups` so that |
|
|
|
|
# any lookups waiting will proceed. |
|
|
|
|
server_to_gids = {} |
|
|
|
|
|
|
|
|
|
def remove_deferreds(res, server_name, group_id): |
|
|
|
|
server_to_gids[server_name].discard(group_id) |
|
|
|
|
if not server_to_gids[server_name]: |
|
|
|
|
server_to_deferred.pop(server_name).callback(None) |
|
|
|
|
return res |
|
|
|
|
|
|
|
|
|
for g_id, deferred in deferreds.items(): |
|
|
|
|
server_name = group_id_to_group[g_id].server_name |
|
|
|
|
server_to_gids.setdefault(server_name, set()).add(g_id) |
|
|
|
|
deferred.addBoth(remove_deferreds, server_name, g_id) |
|
|
|
|
|
|
|
|
|
# Pass those keys to handle_key_deferred so that the json object |
|
|
|
|
# signatures can be verified |
|
|
|
|
return [ |
|
|
|
|
handle_key_deferred( |
|
|
|
|
group_id_to_group[g_id], |
|
|
|
@ -145,7 +180,30 @@ class Keyring(object): |
|
|
|
|
for g_id in group_ids |
|
|
|
|
] |
|
|
|
|
|
|
|
|
|
def get_server_verify_keys(self, group_id_to_group): |
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def wait_for_previous_lookups(self, server_names, server_to_deferred): |
|
|
|
|
"""Waits for any previous key lookups for the given servers to finish. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
server_names (list): list of server_names we want to lookup |
|
|
|
|
server_to_deferred (dict): server_name to deferred which gets |
|
|
|
|
resolved once we've finished looking up keys for that server |
|
|
|
|
""" |
|
|
|
|
while True: |
|
|
|
|
wait_on = [ |
|
|
|
|
self.key_downloads[server_name] |
|
|
|
|
for server_name in server_names |
|
|
|
|
if server_name in self.key_downloads |
|
|
|
|
] |
|
|
|
|
if wait_on: |
|
|
|
|
yield defer.DeferredList(wait_on) |
|
|
|
|
else: |
|
|
|
|
break |
|
|
|
|
|
|
|
|
|
for server_name, deferred in server_to_deferred: |
|
|
|
|
self.key_downloads[server_name] = ObservableDeferred(deferred) |
|
|
|
|
|
|
|
|
|
def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred): |
|
|
|
|
"""Takes a dict of KeyGroups and tries to find at least one key for |
|
|
|
|
each group. |
|
|
|
|
""" |
|
|
|
@ -157,11 +215,6 @@ class Keyring(object): |
|
|
|
|
self.get_keys_from_server, # Then try directly |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
group_deferreds = { |
|
|
|
|
group_id: defer.Deferred() |
|
|
|
|
for group_id in group_id_to_group |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def do_iterations(): |
|
|
|
|
merged_results = {} |
|
|
|
@ -182,7 +235,7 @@ class Keyring(object): |
|
|
|
|
for group in group_id_to_group.values(): |
|
|
|
|
for key_id in group.key_ids: |
|
|
|
|
if key_id in merged_results[group.server_name]: |
|
|
|
|
group_deferreds.pop(group.group_id).callback(( |
|
|
|
|
group_id_to_deferred[group.group_id].callback(( |
|
|
|
|
group.group_id, |
|
|
|
|
group.server_name, |
|
|
|
|
key_id, |
|
|
|
@ -205,7 +258,7 @@ class Keyring(object): |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
for group in missing_groups.values(): |
|
|
|
|
group_deferreds.pop(group.group_id).errback(SynapseError( |
|
|
|
|
group_id_to_deferred[group.group_id].errback(SynapseError( |
|
|
|
|
401, |
|
|
|
|
"No key for %s with id %s" % ( |
|
|
|
|
group.server_name, group.key_ids, |
|
|
|
@ -214,13 +267,13 @@ class Keyring(object): |
|
|
|
|
)) |
|
|
|
|
|
|
|
|
|
def on_err(err): |
|
|
|
|
for deferred in group_deferreds.values(): |
|
|
|
|
deferred.errback(err) |
|
|
|
|
group_deferreds.clear() |
|
|
|
|
for deferred in group_id_to_deferred.values(): |
|
|
|
|
if not deferred.called: |
|
|
|
|
deferred.errback(err) |
|
|
|
|
|
|
|
|
|
do_iterations().addErrback(on_err) |
|
|
|
|
|
|
|
|
|
return group_deferreds |
|
|
|
|
return group_id_to_deferred |
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def get_keys_from_store(self, server_name_and_key_ids): |
|
|
|
|