|
|
|
@ -1,5 +1,6 @@ |
|
|
|
|
# -*- coding: utf-8 -*- |
|
|
|
|
# Copyright 2014-2016 OpenMarket Ltd |
|
|
|
|
# Copyright 2019 New Vector Ltd. |
|
|
|
|
# |
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
|
# you may not use this file except in compliance with the License. |
|
|
|
@ -13,15 +14,15 @@ |
|
|
|
|
# See the License for the specific language governing permissions and |
|
|
|
|
# limitations under the License. |
|
|
|
|
|
|
|
|
|
import itertools |
|
|
|
|
import logging |
|
|
|
|
|
|
|
|
|
import six |
|
|
|
|
|
|
|
|
|
from signedjson.key import decode_verify_key_bytes |
|
|
|
|
|
|
|
|
|
from twisted.internet import defer |
|
|
|
|
|
|
|
|
|
from synapse.util.caches.descriptors import cachedInlineCallbacks |
|
|
|
|
from synapse.util import batch_iter |
|
|
|
|
from synapse.util.caches.descriptors import cached, cachedList |
|
|
|
|
|
|
|
|
|
from ._base import SQLBaseStore |
|
|
|
|
|
|
|
|
@ -38,36 +39,50 @@ else: |
|
|
|
|
class KeyStore(SQLBaseStore): |
|
|
|
|
"""Persistence for signature verification keys |
|
|
|
|
""" |
|
|
|
|
@cachedInlineCallbacks() |
|
|
|
|
def _get_server_verify_key(self, server_name, key_id): |
|
|
|
|
verify_key_bytes = yield self._simple_select_one_onecol( |
|
|
|
|
table="server_signature_keys", |
|
|
|
|
keyvalues={"server_name": server_name, "key_id": key_id}, |
|
|
|
|
retcol="verify_key", |
|
|
|
|
desc="_get_server_verify_key", |
|
|
|
|
allow_none=True, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
if verify_key_bytes: |
|
|
|
|
defer.returnValue(decode_verify_key_bytes(key_id, bytes(verify_key_bytes))) |
|
|
|
|
@cached() |
|
|
|
|
def _get_server_verify_key(self, server_name_and_key_id): |
|
|
|
|
raise NotImplementedError() |
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def get_server_verify_keys(self, server_name, key_ids): |
|
|
|
|
"""Retrieve the NACL verification key for a given server for the given |
|
|
|
|
key_ids |
|
|
|
|
@cachedList( |
|
|
|
|
cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids" |
|
|
|
|
) |
|
|
|
|
def get_server_verify_keys(self, server_name_and_key_ids): |
|
|
|
|
""" |
|
|
|
|
Args: |
|
|
|
|
server_name (str): The name of the server. |
|
|
|
|
key_ids (iterable[str]): key_ids to try and look up. |
|
|
|
|
server_name_and_key_ids (iterable[Tuple[str, str]]): |
|
|
|
|
iterable of (server_name, key-id) tuples to fetch keys for |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
Deferred: resolves to dict[str, VerifyKey]: map from |
|
|
|
|
key_id to verification key. |
|
|
|
|
Deferred: resolves to dict[Tuple[str, str], VerifyKey|None]: |
|
|
|
|
map from (server_name, key_id) -> VerifyKey, or None if the key is |
|
|
|
|
unknown |
|
|
|
|
""" |
|
|
|
|
keys = {} |
|
|
|
|
for key_id in key_ids: |
|
|
|
|
key = yield self._get_server_verify_key(server_name, key_id) |
|
|
|
|
if key: |
|
|
|
|
keys[key_id] = key |
|
|
|
|
defer.returnValue(keys) |
|
|
|
|
|
|
|
|
|
def _get_keys(txn, batch): |
|
|
|
|
"""Processes a batch of keys to fetch, and adds the result to `keys`.""" |
|
|
|
|
|
|
|
|
|
# batch_iter always returns tuples so it's safe to do len(batch) |
|
|
|
|
sql = ( |
|
|
|
|
"SELECT server_name, key_id, verify_key FROM server_signature_keys " |
|
|
|
|
"WHERE 1=0" |
|
|
|
|
) + " OR (server_name=? AND key_id=?)" * len(batch) |
|
|
|
|
|
|
|
|
|
txn.execute(sql, tuple(itertools.chain.from_iterable(batch))) |
|
|
|
|
|
|
|
|
|
for row in txn: |
|
|
|
|
server_name, key_id, key_bytes = row |
|
|
|
|
keys[(server_name, key_id)] = decode_verify_key_bytes( |
|
|
|
|
key_id, bytes(key_bytes) |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
def _txn(txn): |
|
|
|
|
for batch in batch_iter(server_name_and_key_ids, 50): |
|
|
|
|
_get_keys(txn, batch) |
|
|
|
|
return keys |
|
|
|
|
|
|
|
|
|
return self.runInteraction("get_server_verify_keys", _txn) |
|
|
|
|
|
|
|
|
|
def store_server_verify_key( |
|
|
|
|
self, server_name, from_server, time_now_ms, verify_key |
|
|
|
@ -93,8 +108,11 @@ class KeyStore(SQLBaseStore): |
|
|
|
|
"verify_key": db_binary_type(verify_key.encode()), |
|
|
|
|
}, |
|
|
|
|
) |
|
|
|
|
# invalidate takes a tuple corresponding to the params of |
|
|
|
|
# _get_server_verify_key. _get_server_verify_key only takes one |
|
|
|
|
# param, which is itself the 2-tuple (server_name, key_id). |
|
|
|
|
txn.call_after( |
|
|
|
|
self._get_server_verify_key.invalidate, (server_name, key_id) |
|
|
|
|
self._get_server_verify_key.invalidate, ((server_name, key_id),) |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
return self.runInteraction("store_server_verify_key", _txn) |
|
|
|
|