|
|
|
@ -14,15 +14,18 @@ |
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
|
# See the License for the specific language governing permissions and |
|
|
|
|
# limitations under the License. |
|
|
|
|
from typing import Dict, List |
|
|
|
|
|
|
|
|
|
from six import iteritems |
|
|
|
|
|
|
|
|
|
from canonicaljson import encode_canonical_json, json |
|
|
|
|
|
|
|
|
|
from twisted.enterprise.adbapi import Connection |
|
|
|
|
from twisted.internet import defer |
|
|
|
|
|
|
|
|
|
from synapse.logging.opentracing import log_kv, set_tag, trace |
|
|
|
|
from synapse.storage._base import SQLBaseStore, db_to_json |
|
|
|
|
from synapse.util.caches.descriptors import cached |
|
|
|
|
from synapse.util.caches.descriptors import cached, cachedList |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EndToEndKeyWorkerStore(SQLBaseStore): |
|
|
|
@ -271,7 +274,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): |
|
|
|
|
Args: |
|
|
|
|
txn (twisted.enterprise.adbapi.Connection): db connection |
|
|
|
|
user_id (str): the user whose key is being requested |
|
|
|
|
key_type (str): the type of key that is being set: either 'master' |
|
|
|
|
key_type (str): the type of key that is being requested: either 'master' |
|
|
|
|
for a master key, 'self_signing' for a self-signing key, or |
|
|
|
|
'user_signing' for a user-signing key |
|
|
|
|
from_user_id (str): if specified, signatures made by this user on |
|
|
|
@ -316,8 +319,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore): |
|
|
|
|
"""Returns a user's cross-signing key. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
user_id (str): the user whose self-signing key is being requested |
|
|
|
|
key_type (str): the type of cross-signing key to get |
|
|
|
|
user_id (str): the user whose key is being requested |
|
|
|
|
key_type (str): the type of key that is being requested: either 'master' |
|
|
|
|
for a master key, 'self_signing' for a self-signing key, or |
|
|
|
|
'user_signing' for a user-signing key |
|
|
|
|
from_user_id (str): if specified, signatures made by this user on |
|
|
|
|
the self-signing key will be included in the result |
|
|
|
|
|
|
|
|
@ -332,6 +337,206 @@ class EndToEndKeyWorkerStore(SQLBaseStore): |
|
|
|
|
from_user_id, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
@cached(num_args=1) |
|
|
|
|
def _get_bare_e2e_cross_signing_keys(self, user_id): |
|
|
|
|
"""Dummy function. Only used to make a cache for |
|
|
|
|
_get_bare_e2e_cross_signing_keys_bulk. |
|
|
|
|
""" |
|
|
|
|
raise NotImplementedError() |
|
|
|
|
|
|
|
|
|
@cachedList( |
|
|
|
|
cached_method_name="_get_bare_e2e_cross_signing_keys", |
|
|
|
|
list_name="user_ids", |
|
|
|
|
num_args=1, |
|
|
|
|
) |
|
|
|
|
def _get_bare_e2e_cross_signing_keys_bulk( |
|
|
|
|
self, user_ids: List[str] |
|
|
|
|
) -> Dict[str, Dict[str, dict]]: |
|
|
|
|
"""Returns the cross-signing keys for a set of users. The output of this |
|
|
|
|
function should be passed to _get_e2e_cross_signing_signatures_txn if |
|
|
|
|
the signatures for the calling user need to be fetched. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
user_ids (list[str]): the users whose keys are being requested |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
dict[str, dict[str, dict]]: mapping from user ID to key type to key |
|
|
|
|
data. If a user's cross-signing keys were not found, either |
|
|
|
|
their user ID will not be in the dict, or their user ID will map |
|
|
|
|
to None. |
|
|
|
|
|
|
|
|
|
""" |
|
|
|
|
return self.db.runInteraction( |
|
|
|
|
"get_bare_e2e_cross_signing_keys_bulk", |
|
|
|
|
self._get_bare_e2e_cross_signing_keys_bulk_txn, |
|
|
|
|
user_ids, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
def _get_bare_e2e_cross_signing_keys_bulk_txn( |
|
|
|
|
self, txn: Connection, user_ids: List[str], |
|
|
|
|
) -> Dict[str, Dict[str, dict]]: |
|
|
|
|
"""Returns the cross-signing keys for a set of users. The output of this |
|
|
|
|
function should be passed to _get_e2e_cross_signing_signatures_txn if |
|
|
|
|
the signatures for the calling user need to be fetched. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
txn (twisted.enterprise.adbapi.Connection): db connection |
|
|
|
|
user_ids (list[str]): the users whose keys are being requested |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
dict[str, dict[str, dict]]: mapping from user ID to key type to key |
|
|
|
|
data. If a user's cross-signing keys were not found, their user |
|
|
|
|
ID will not be in the dict. |
|
|
|
|
|
|
|
|
|
""" |
|
|
|
|
result = {} |
|
|
|
|
|
|
|
|
|
batch_size = 100 |
|
|
|
|
chunks = [ |
|
|
|
|
user_ids[i : i + batch_size] for i in range(0, len(user_ids), batch_size) |
|
|
|
|
] |
|
|
|
|
for user_chunk in chunks: |
|
|
|
|
sql = """ |
|
|
|
|
SELECT k.user_id, k.keytype, k.keydata, k.stream_id |
|
|
|
|
FROM e2e_cross_signing_keys k |
|
|
|
|
INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id |
|
|
|
|
FROM e2e_cross_signing_keys |
|
|
|
|
GROUP BY user_id, keytype) s |
|
|
|
|
USING (user_id, stream_id, keytype) |
|
|
|
|
WHERE k.user_id IN (%s) |
|
|
|
|
""" % ( |
|
|
|
|
",".join("?" for u in user_chunk), |
|
|
|
|
) |
|
|
|
|
query_params = [] |
|
|
|
|
query_params.extend(user_chunk) |
|
|
|
|
|
|
|
|
|
txn.execute(sql, query_params) |
|
|
|
|
rows = self.db.cursor_to_dict(txn) |
|
|
|
|
|
|
|
|
|
for row in rows: |
|
|
|
|
user_id = row["user_id"] |
|
|
|
|
key_type = row["keytype"] |
|
|
|
|
key = json.loads(row["keydata"]) |
|
|
|
|
user_info = result.setdefault(user_id, {}) |
|
|
|
|
user_info[key_type] = key |
|
|
|
|
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
def _get_e2e_cross_signing_signatures_txn( |
|
|
|
|
self, txn: Connection, keys: Dict[str, Dict[str, dict]], from_user_id: str, |
|
|
|
|
) -> Dict[str, Dict[str, dict]]: |
|
|
|
|
"""Returns the cross-signing signatures made by a user on a set of keys. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
txn (twisted.enterprise.adbapi.Connection): db connection |
|
|
|
|
keys (dict[str, dict[str, dict]]): a map of user ID to key type to |
|
|
|
|
key data. This dict will be modified to add signatures. |
|
|
|
|
from_user_id (str): fetch the signatures made by this user |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
dict[str, dict[str, dict]]: mapping from user ID to key type to key |
|
|
|
|
data. The return value will be the same as the keys argument, |
|
|
|
|
with the modifications included. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
# find out what cross-signing keys (a.k.a. devices) we need to get |
|
|
|
|
# signatures for. This is a map of (user_id, device_id) to key type |
|
|
|
|
# (device_id is the key's public part). |
|
|
|
|
devices = {} |
|
|
|
|
|
|
|
|
|
for user_id, user_info in keys.items(): |
|
|
|
|
if user_info is None: |
|
|
|
|
continue |
|
|
|
|
for key_type, key in user_info.items(): |
|
|
|
|
device_id = None |
|
|
|
|
for k in key["keys"].values(): |
|
|
|
|
device_id = k |
|
|
|
|
devices[(user_id, device_id)] = key_type |
|
|
|
|
|
|
|
|
|
device_list = list(devices) |
|
|
|
|
|
|
|
|
|
# split into batches |
|
|
|
|
batch_size = 100 |
|
|
|
|
chunks = [ |
|
|
|
|
device_list[i : i + batch_size] |
|
|
|
|
for i in range(0, len(device_list), batch_size) |
|
|
|
|
] |
|
|
|
|
for user_chunk in chunks: |
|
|
|
|
sql = """ |
|
|
|
|
SELECT target_user_id, target_device_id, key_id, signature |
|
|
|
|
FROM e2e_cross_signing_signatures |
|
|
|
|
WHERE user_id = ? |
|
|
|
|
AND (%s) |
|
|
|
|
""" % ( |
|
|
|
|
" OR ".join( |
|
|
|
|
"(target_user_id = ? AND target_device_id = ?)" for d in devices |
|
|
|
|
) |
|
|
|
|
) |
|
|
|
|
query_params = [from_user_id] |
|
|
|
|
for item in devices: |
|
|
|
|
# item is a (user_id, device_id) tuple |
|
|
|
|
query_params.extend(item) |
|
|
|
|
|
|
|
|
|
txn.execute(sql, query_params) |
|
|
|
|
rows = self.db.cursor_to_dict(txn) |
|
|
|
|
|
|
|
|
|
# and add the signatures to the appropriate keys |
|
|
|
|
for row in rows: |
|
|
|
|
key_id = row["key_id"] |
|
|
|
|
target_user_id = row["target_user_id"] |
|
|
|
|
target_device_id = row["target_device_id"] |
|
|
|
|
key_type = devices[(target_user_id, target_device_id)] |
|
|
|
|
# We need to copy everything, because the result may have come |
|
|
|
|
# from the cache. dict.copy only does a shallow copy, so we |
|
|
|
|
# need to recursively copy the dicts that will be modified. |
|
|
|
|
user_info = keys[target_user_id] = keys[target_user_id].copy() |
|
|
|
|
target_user_key = user_info[key_type] = user_info[key_type].copy() |
|
|
|
|
if "signatures" in target_user_key: |
|
|
|
|
signatures = target_user_key["signatures"] = target_user_key[ |
|
|
|
|
"signatures" |
|
|
|
|
].copy() |
|
|
|
|
if from_user_id in signatures: |
|
|
|
|
user_sigs = signatures[from_user_id] = signatures[from_user_id] |
|
|
|
|
user_sigs[key_id] = row["signature"] |
|
|
|
|
else: |
|
|
|
|
signatures[from_user_id] = {key_id: row["signature"]} |
|
|
|
|
else: |
|
|
|
|
target_user_key["signatures"] = { |
|
|
|
|
from_user_id: {key_id: row["signature"]} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
return keys |
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def get_e2e_cross_signing_keys_bulk( |
|
|
|
|
self, user_ids: List[str], from_user_id: str = None |
|
|
|
|
) -> defer.Deferred: |
|
|
|
|
"""Returns the cross-signing keys for a set of users. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
user_ids (list[str]): the users whose keys are being requested |
|
|
|
|
from_user_id (str): if specified, signatures made by this user on |
|
|
|
|
the self-signing keys will be included in the result |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
Deferred[dict[str, dict[str, dict]]]: map of user ID to key type to |
|
|
|
|
key data. If a user's cross-signing keys were not found, either |
|
|
|
|
their user ID will not be in the dict, or their user ID will map |
|
|
|
|
to None. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids) |
|
|
|
|
|
|
|
|
|
if from_user_id: |
|
|
|
|
result = yield self.db.runInteraction( |
|
|
|
|
"get_e2e_cross_signing_signatures", |
|
|
|
|
self._get_e2e_cross_signing_signatures_txn, |
|
|
|
|
result, |
|
|
|
|
from_user_id, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
def get_all_user_signature_changes_for_remotes(self, from_key, to_key): |
|
|
|
|
"""Return a list of changes from the user signature stream to notify remotes. |
|
|
|
|
Note that the user signature stream represents when a user signs their |
|
|
|
@ -520,6 +725,10 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): |
|
|
|
|
}, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
self._invalidate_cache_and_stream( |
|
|
|
|
txn, self._get_bare_e2e_cross_signing_keys, (user_id,) |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
def set_e2e_cross_signing_key(self, user_id, key_type, key): |
|
|
|
|
"""Set a user's cross-signing key. |
|
|
|
|
|
|
|
|
|