|
|
|
@ -12,16 +12,19 @@ |
|
|
|
|
# See the License for the specific language governing permissions and |
|
|
|
|
# limitations under the License. |
|
|
|
|
|
|
|
|
|
from typing import Dict, Iterable, List, Tuple |
|
|
|
|
from typing import Collection, Dict, List, Tuple |
|
|
|
|
|
|
|
|
|
from unpaddedbase64 import encode_base64 |
|
|
|
|
|
|
|
|
|
from synapse.storage._base import SQLBaseStore |
|
|
|
|
from synapse.storage.types import Cursor |
|
|
|
|
from synapse.crypto.event_signing import compute_event_reference_hash |
|
|
|
|
from synapse.storage.databases.main.events_worker import ( |
|
|
|
|
EventRedactBehaviour, |
|
|
|
|
EventsWorkerStore, |
|
|
|
|
) |
|
|
|
|
from synapse.util.caches.descriptors import cached, cachedList |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SignatureWorkerStore(SQLBaseStore): |
|
|
|
|
class SignatureWorkerStore(EventsWorkerStore): |
|
|
|
|
@cached() |
|
|
|
|
def get_event_reference_hash(self, event_id): |
|
|
|
|
# This is a dummy function to allow get_event_reference_hashes |
|
|
|
@ -32,7 +35,7 @@ class SignatureWorkerStore(SQLBaseStore): |
|
|
|
|
cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1 |
|
|
|
|
) |
|
|
|
|
async def get_event_reference_hashes( |
|
|
|
|
self, event_ids: Iterable[str] |
|
|
|
|
self, event_ids: Collection[str] |
|
|
|
|
) -> Dict[str, Dict[str, bytes]]: |
|
|
|
|
"""Get all hashes for given events. |
|
|
|
|
|
|
|
|
@ -41,18 +44,27 @@ class SignatureWorkerStore(SQLBaseStore): |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
A mapping of event ID to a mapping of algorithm to hash. |
|
|
|
|
Returns an empty dict for a given event id if that event is unknown. |
|
|
|
|
""" |
|
|
|
|
events = await self.get_events( |
|
|
|
|
event_ids, |
|
|
|
|
redact_behaviour=EventRedactBehaviour.AS_IS, |
|
|
|
|
allow_rejected=True, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
def f(txn): |
|
|
|
|
return { |
|
|
|
|
event_id: self._get_event_reference_hashes_txn(txn, event_id) |
|
|
|
|
for event_id in event_ids |
|
|
|
|
} |
|
|
|
|
hashes: Dict[str, Dict[str, bytes]] = {} |
|
|
|
|
for event_id in event_ids: |
|
|
|
|
event = events.get(event_id) |
|
|
|
|
if event is None: |
|
|
|
|
hashes[event_id] = {} |
|
|
|
|
else: |
|
|
|
|
ref_alg, ref_hash_bytes = compute_event_reference_hash(event) |
|
|
|
|
hashes[event_id] = {ref_alg: ref_hash_bytes} |
|
|
|
|
|
|
|
|
|
return await self.db_pool.runInteraction("get_event_reference_hashes", f) |
|
|
|
|
return hashes |
|
|
|
|
|
|
|
|
|
async def add_event_hashes( |
|
|
|
|
self, event_ids: Iterable[str] |
|
|
|
|
self, event_ids: Collection[str] |
|
|
|
|
) -> List[Tuple[str, Dict[str, str]]]: |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
@ -70,24 +82,6 @@ class SignatureWorkerStore(SQLBaseStore): |
|
|
|
|
|
|
|
|
|
return list(encoded_hashes.items()) |
|
|
|
|
|
|
|
|
|
def _get_event_reference_hashes_txn( |
|
|
|
|
self, txn: Cursor, event_id: str |
|
|
|
|
) -> Dict[str, bytes]: |
|
|
|
|
"""Get all the hashes for a given PDU. |
|
|
|
|
Args: |
|
|
|
|
txn: |
|
|
|
|
event_id: Id for the Event. |
|
|
|
|
Returns: |
|
|
|
|
A mapping of algorithm -> hash. |
|
|
|
|
""" |
|
|
|
|
query = ( |
|
|
|
|
"SELECT algorithm, hash" |
|
|
|
|
" FROM event_reference_hashes" |
|
|
|
|
" WHERE event_id = ?" |
|
|
|
|
) |
|
|
|
|
txn.execute(query, (event_id,)) |
|
|
|
|
return {k: v for k, v in txn} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SignatureStore(SignatureWorkerStore): |
|
|
|
|
"""Persistence for event signatures and hashes""" |
|
|
|
|