diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 1a76b22e4..a8f8989e3 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -15,9 +15,7 @@ import logging from synapse.api.errors import StoreError -from synapse.events import FrozenEvent -from synapse.events.utils import prune_event -from synapse.util import unwrap_deferred + from synapse.util.logutils import log_function from synapse.util.logcontext import preserve_context_over_fn, LoggingContext from synapse.util.lrucache import LruCache @@ -29,7 +27,6 @@ from twisted.internet import defer from collections import namedtuple, OrderedDict import functools -import simplejson as json import sys import time import threading @@ -868,234 +865,6 @@ class SQLBaseStore(object): return self.runInteraction("_simple_max_id", func) - @defer.inlineCallbacks - def _get_events(self, event_ids, check_redacted=True, - get_prev_content=False, allow_rejected=False, txn=None): - if not event_ids: - defer.returnValue([]) - - event_map = self._get_events_from_cache( - event_ids, - check_redacted=check_redacted, - get_prev_content=get_prev_content, - allow_rejected=allow_rejected, - ) - - missing_events_ids = [e for e in event_ids if e not in event_map] - - if not missing_events_ids: - defer.returnValue([ - event_map[e_id] for e_id in event_ids - if e_id in event_map and event_map[e_id] - ]) - - def get_missing(txn): - missing_events = unwrap_deferred(self._fetch_events( - txn, - missing_events_ids, - check_redacted=check_redacted, - get_prev_content=get_prev_content, - allow_rejected=allow_rejected, - )) - - event_map.update(missing_events) - - return [ - event_map[e_id] for e_id in event_ids - if e_id in event_map and event_map[e_id] - ] - - if not txn: - res = yield self.runInteraction("_get_events", get_missing) - defer.returnValue(res) - else: - defer.returnValue(get_missing(txn)) - - def _get_events_txn(self, txn, event_ids, check_redacted=True, - get_prev_content=False, allow_rejected=False): - return unwrap_deferred(self._get_events( - event_ids, - check_redacted=check_redacted, - get_prev_content=get_prev_content, - allow_rejected=allow_rejected, - txn=txn, - )) - - def _invalidate_get_event_cache(self, event_id): - for check_redacted in (False, True): - for get_prev_content in (False, True): - self._get_event_cache.invalidate(event_id, check_redacted, - get_prev_content) - - def _get_event_txn(self, txn, event_id, check_redacted=True, - get_prev_content=False, allow_rejected=False): - - events = self._get_events_txn( - txn, [event_id], - check_redacted=check_redacted, - get_prev_content=get_prev_content, - allow_rejected=allow_rejected, - ) - - return events[0] if events else None - - def _get_events_from_cache(self, events, check_redacted, get_prev_content, - allow_rejected): - event_map = {} - - for event_id in events: - try: - ret = self._get_event_cache.get( - event_id, check_redacted, get_prev_content - ) - - if allow_rejected or not ret.rejected_reason: - event_map[event_id] = ret - else: - event_map[event_id] = None - except KeyError: - pass - - return event_map - - @defer.inlineCallbacks - def _fetch_events(self, txn, events, check_redacted=True, - get_prev_content=False, allow_rejected=False): - if not events: - defer.returnValue({}) - - rows = [] - N = 200 - for i in range(1 + len(events) / N): - evs = events[i*N:(i + 1)*N] - if not evs: - break - - sql = ( - "SELECT e.internal_metadata, e.json, r.redacts, rej.event_id " - " FROM event_json as e" - " LEFT JOIN rejections as rej USING (event_id)" - " LEFT JOIN redactions as r ON e.event_id = r.redacts" - " WHERE e.event_id IN (%s)" - ) % (",".join(["?"]*len(evs)),) - - if txn: - txn.execute(sql, evs) - rows.extend(txn.fetchall()) - else: - res = yield self._execute("_fetch_events", None, sql, *evs) - rows.extend(res) - - res = yield defer.gatherResults( - [ - defer.maybeDeferred( - self._get_event_from_row, - txn, - row[0], row[1], row[2], - check_redacted=check_redacted, - get_prev_content=get_prev_content, - rejected_reason=row[3], - ) - for row in rows - ], - consumeErrors=True, - ) - - for e in res: - self._get_event_cache.prefill( - e.event_id, check_redacted, get_prev_content, e - ) - - defer.returnValue({ - e.event_id: e - for e in res if e - }) - - @defer.inlineCallbacks - def _get_event_from_row(self, txn, internal_metadata, js, redacted, - check_redacted=True, get_prev_content=False, - rejected_reason=None): - d = json.loads(js) - internal_metadata = json.loads(internal_metadata) - - def select(txn, *args, **kwargs): - if txn: - return self._simple_select_one_onecol_txn(txn, *args, **kwargs) - else: - return self._simple_select_one_onecol( - *args, - desc="_get_event_from_row", **kwargs - ) - - def get_event(txn, *args, **kwargs): - if txn: - return self._get_event_txn(txn, *args, **kwargs) - else: - return self.get_event(*args, **kwargs) - - if rejected_reason: - rejected_reason = yield select( - txn, - table="rejections", - keyvalues={"event_id": rejected_reason}, - retcol="reason", - ) - - ev = FrozenEvent( - d, - internal_metadata_dict=internal_metadata, - rejected_reason=rejected_reason, - ) - - if check_redacted and redacted: - ev = prune_event(ev) - - redaction_id = yield select( - txn, - table="redactions", - keyvalues={"redacts": ev.event_id}, - retcol="event_id", - ) - - ev.unsigned["redacted_by"] = redaction_id - # Get the redaction event. - - because = yield get_event( - txn, - redaction_id, - check_redacted=False - ) - - if because: - ev.unsigned["redacted_because"] = because - - if get_prev_content and "replaces_state" in ev.unsigned: - prev = yield get_event( - txn, - ev.unsigned["replaces_state"], - get_prev_content=False, - ) - if prev: - ev.unsigned["prev_content"] = prev.get_dict()["content"] - - defer.returnValue(ev) - - def _parse_events(self, rows): - return self.runInteraction( - "_parse_events", self._parse_events_txn, rows - ) - - def _parse_events_txn(self, txn, rows): - event_ids = [r["event_id"] for r in rows] - - return self._get_events_txn(txn, event_ids) - - def _has_been_redacted_txn(self, txn, event): - sql = "SELECT event_id FROM redactions WHERE redacts = ?" - txn.execute(sql, (event.event_id,)) - result = txn.fetchone() - return result[0] if result else None - def get_next_stream_id(self): with self._next_stream_id_lock: i = self._next_stream_id diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 9242b0a84..f960ef835 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -17,6 +17,10 @@ from _base import SQLBaseStore, _RollbackButIsFineException from twisted.internet import defer +from synapse.events import FrozenEvent +from synapse.events.utils import prune_event +from synapse.util import unwrap_deferred + from synapse.util.logutils import log_function from synapse.api.constants import EventTypes from synapse.crypto.event_signing import compute_event_reference_hash @@ -26,6 +30,7 @@ from syutil.jsonutil import encode_canonical_json from contextlib import contextmanager import logging +import simplejson as json logger = logging.getLogger(__name__) @@ -393,3 +398,244 @@ class EventsStore(SQLBaseStore): return self.runInteraction( "have_events", f, ) + + @defer.inlineCallbacks + def _get_events(self, event_ids, check_redacted=True, + get_prev_content=False, allow_rejected=False, txn=None): + if not event_ids: + defer.returnValue([]) + + event_map = self._get_events_from_cache( + event_ids, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + missing_events_ids = [e for e in event_ids if e not in event_map] + + if not missing_events_ids: + defer.returnValue([ + event_map[e_id] for e_id in event_ids + if e_id in event_map and event_map[e_id] + ]) + + if not txn: + missing_events = yield self.runInteraction( + "_get_events", + self._fetch_events_txn, + missing_events_ids, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + else: + missing_events = yield self._fetch_events( + txn, + missing_events_ids, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + event_map.update(missing_events) + + defer.returnValue([ + event_map[e_id] for e_id in event_ids + if e_id in event_map and event_map[e_id] + ]) + + + def _get_events_txn(self, txn, event_ids, check_redacted=True, + get_prev_content=False, allow_rejected=False): + return unwrap_deferred(self._get_events( + event_ids, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + txn=txn, + )) + + def _invalidate_get_event_cache(self, event_id): + for check_redacted in (False, True): + for get_prev_content in (False, True): + self._get_event_cache.invalidate(event_id, check_redacted, + get_prev_content) + + def _get_event_txn(self, txn, event_id, check_redacted=True, + get_prev_content=False, allow_rejected=False): + + events = self._get_events_txn( + txn, [event_id], + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + return events[0] if events else None + + def _get_events_from_cache(self, events, check_redacted, get_prev_content, + allow_rejected): + event_map = {} + + for event_id in events: + try: + ret = self._get_event_cache.get( + event_id, check_redacted, get_prev_content + ) + + if allow_rejected or not ret.rejected_reason: + event_map[event_id] = ret + else: + event_map[event_id] = None + except KeyError: + pass + + return event_map + + def _fetch_events_txn(self, txn, events, check_redacted=True, + get_prev_content=False, allow_rejected=False): + return unwrap_deferred(self._fetch_events( + txn, events, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + )) + + @defer.inlineCallbacks + def _fetch_events(self, txn, events, check_redacted=True, + get_prev_content=False, allow_rejected=False): + if not events: + defer.returnValue({}) + + rows = [] + N = 200 + for i in range(1 + len(events) / N): + evs = events[i*N:(i + 1)*N] + if not evs: + break + + sql = ( + "SELECT e.internal_metadata, e.json, r.redacts, rej.event_id " + " FROM event_json as e" + " LEFT JOIN rejections as rej USING (event_id)" + " LEFT JOIN redactions as r ON e.event_id = r.redacts" + " WHERE e.event_id IN (%s)" + ) % (",".join(["?"]*len(evs)),) + + if txn: + txn.execute(sql, evs) + rows.extend(txn.fetchall()) + else: + res = yield self._execute("_fetch_events", None, sql, *evs) + rows.extend(res) + + res = yield defer.gatherResults( + [ + defer.maybeDeferred( + self._get_event_from_row, + txn, + row[0], row[1], row[2], + check_redacted=check_redacted, + get_prev_content=get_prev_content, + rejected_reason=row[3], + ) + for row in rows + ], + consumeErrors=True, + ) + + for e in res: + self._get_event_cache.prefill( + e.event_id, check_redacted, get_prev_content, e + ) + + defer.returnValue({ + e.event_id: e + for e in res if e + }) + + @defer.inlineCallbacks + def _get_event_from_row(self, txn, internal_metadata, js, redacted, + check_redacted=True, get_prev_content=False, + rejected_reason=None): + d = json.loads(js) + internal_metadata = json.loads(internal_metadata) + + def select(txn, *args, **kwargs): + if txn: + return self._simple_select_one_onecol_txn(txn, *args, **kwargs) + else: + return self._simple_select_one_onecol( + *args, + desc="_get_event_from_row", **kwargs + ) + + def get_event(txn, *args, **kwargs): + if txn: + return self._get_event_txn(txn, *args, **kwargs) + else: + return self.get_event(*args, **kwargs) + + if rejected_reason: + rejected_reason = yield select( + txn, + table="rejections", + keyvalues={"event_id": rejected_reason}, + retcol="reason", + ) + + ev = FrozenEvent( + d, + internal_metadata_dict=internal_metadata, + rejected_reason=rejected_reason, + ) + + if check_redacted and redacted: + ev = prune_event(ev) + + redaction_id = yield select( + txn, + table="redactions", + keyvalues={"redacts": ev.event_id}, + retcol="event_id", + ) + + ev.unsigned["redacted_by"] = redaction_id + # Get the redaction event. + + because = yield get_event( + txn, + redaction_id, + check_redacted=False + ) + + if because: + ev.unsigned["redacted_because"] = because + + if get_prev_content and "replaces_state" in ev.unsigned: + prev = yield get_event( + txn, + ev.unsigned["replaces_state"], + get_prev_content=False, + ) + if prev: + ev.unsigned["prev_content"] = prev.get_dict()["content"] + + defer.returnValue(ev) + + def _parse_events(self, rows): + return self.runInteraction( + "_parse_events", self._parse_events_txn, rows + ) + + def _parse_events_txn(self, txn, rows): + event_ids = [r["event_id"] for r in rows] + + return self._get_events_txn(txn, event_ids) + + def _has_been_redacted_txn(self, txn, event): + sql = "SELECT event_id FROM redactions WHERE redacts = ?" + txn.execute(sql, (event.event_id,)) + result = txn.fetchone() + return result[0] if result else None