|
|
|
@ -14,13 +14,12 @@ |
|
|
|
|
# limitations under the License. |
|
|
|
|
|
|
|
|
|
import logging |
|
|
|
|
from typing import Iterable, List, TypeVar |
|
|
|
|
from typing import Dict, Iterable, List, Optional, Set, Tuple, TypeVar |
|
|
|
|
|
|
|
|
|
import attr |
|
|
|
|
|
|
|
|
|
from twisted.internet import defer |
|
|
|
|
|
|
|
|
|
from synapse.api.constants import EventTypes |
|
|
|
|
from synapse.events import EventBase |
|
|
|
|
from synapse.types import StateMap |
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
@ -34,16 +33,16 @@ class StateFilter(object): |
|
|
|
|
"""A filter used when querying for state. |
|
|
|
|
|
|
|
|
|
Attributes: |
|
|
|
|
types (dict[str, set[str]|None]): Map from type to set of state keys (or |
|
|
|
|
None). This specifies which state_keys for the given type to fetch |
|
|
|
|
from the DB. If None then all events with that type are fetched. If |
|
|
|
|
the set is empty then no events with that type are fetched. |
|
|
|
|
include_others (bool): Whether to fetch events with types that do not |
|
|
|
|
types: Map from type to set of state keys (or None). This specifies |
|
|
|
|
which state_keys for the given type to fetch from the DB. If None |
|
|
|
|
then all events with that type are fetched. If the set is empty |
|
|
|
|
then no events with that type are fetched. |
|
|
|
|
include_others: Whether to fetch events with types that do not |
|
|
|
|
appear in `types`. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
types = attr.ib() |
|
|
|
|
include_others = attr.ib(default=False) |
|
|
|
|
types = attr.ib(type=Dict[str, Optional[Set[str]]]) |
|
|
|
|
include_others = attr.ib(default=False, type=bool) |
|
|
|
|
|
|
|
|
|
def __attrs_post_init__(self): |
|
|
|
|
# If `include_others` is set we canonicalise the filter by removing |
|
|
|
@ -52,36 +51,35 @@ class StateFilter(object): |
|
|
|
|
self.types = {k: v for k, v in self.types.items() if v is not None} |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def all(): |
|
|
|
|
def all() -> "StateFilter": |
|
|
|
|
"""Creates a filter that fetches everything. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
StateFilter |
|
|
|
|
The new state filter. |
|
|
|
|
""" |
|
|
|
|
return StateFilter(types={}, include_others=True) |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def none(): |
|
|
|
|
def none() -> "StateFilter": |
|
|
|
|
"""Creates a filter that fetches nothing. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
StateFilter |
|
|
|
|
The new state filter. |
|
|
|
|
""" |
|
|
|
|
return StateFilter(types={}, include_others=False) |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def from_types(types): |
|
|
|
|
def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter": |
|
|
|
|
"""Creates a filter that only fetches the given types |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
types (Iterable[tuple[str, str|None]]): A list of type and state |
|
|
|
|
keys to fetch. A state_key of None fetches everything for |
|
|
|
|
that type |
|
|
|
|
types: A list of type and state keys to fetch. A state_key of None |
|
|
|
|
fetches everything for that type |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
StateFilter |
|
|
|
|
The new state filter. |
|
|
|
|
""" |
|
|
|
|
type_dict = {} |
|
|
|
|
type_dict = {} # type: Dict[str, Optional[Set[str]]] |
|
|
|
|
for typ, s in types: |
|
|
|
|
if typ in type_dict: |
|
|
|
|
if type_dict[typ] is None: |
|
|
|
@ -91,24 +89,24 @@ class StateFilter(object): |
|
|
|
|
type_dict[typ] = None |
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
type_dict.setdefault(typ, set()).add(s) |
|
|
|
|
type_dict.setdefault(typ, set()).add(s) # type: ignore |
|
|
|
|
|
|
|
|
|
return StateFilter(types=type_dict) |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def from_lazy_load_member_list(members): |
|
|
|
|
def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter": |
|
|
|
|
"""Creates a filter that returns all non-member events, plus the member |
|
|
|
|
events for the given users |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
members (iterable[str]): Set of user IDs |
|
|
|
|
members: Set of user IDs |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
StateFilter |
|
|
|
|
The new state filter |
|
|
|
|
""" |
|
|
|
|
return StateFilter(types={EventTypes.Member: set(members)}, include_others=True) |
|
|
|
|
|
|
|
|
|
def return_expanded(self): |
|
|
|
|
def return_expanded(self) -> "StateFilter": |
|
|
|
|
"""Creates a new StateFilter where type wild cards have been removed |
|
|
|
|
(except for memberships). The returned filter is a superset of the |
|
|
|
|
current one, i.e. anything that passes the current filter will pass |
|
|
|
@ -130,7 +128,7 @@ class StateFilter(object): |
|
|
|
|
return all non-member events |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
StateFilter |
|
|
|
|
The new state filter. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
if self.is_full(): |
|
|
|
@ -167,7 +165,7 @@ class StateFilter(object): |
|
|
|
|
include_others=True, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
def make_sql_filter_clause(self): |
|
|
|
|
def make_sql_filter_clause(self) -> Tuple[str, List[str]]: |
|
|
|
|
"""Converts the filter to an SQL clause. |
|
|
|
|
|
|
|
|
|
For example: |
|
|
|
@ -179,13 +177,12 @@ class StateFilter(object): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
tuple[str, list]: The SQL string (may be empty) and arguments. An |
|
|
|
|
empty SQL string is returned when the filter matches everything |
|
|
|
|
(i.e. is "full"). |
|
|
|
|
The SQL string (may be empty) and arguments. An empty SQL string is |
|
|
|
|
returned when the filter matches everything (i.e. is "full"). |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
where_clause = "" |
|
|
|
|
where_args = [] |
|
|
|
|
where_args = [] # type: List[str] |
|
|
|
|
|
|
|
|
|
if self.is_full(): |
|
|
|
|
return where_clause, where_args |
|
|
|
@ -221,7 +218,7 @@ class StateFilter(object): |
|
|
|
|
|
|
|
|
|
return where_clause, where_args |
|
|
|
|
|
|
|
|
|
def max_entries_returned(self): |
|
|
|
|
def max_entries_returned(self) -> Optional[int]: |
|
|
|
|
"""Returns the maximum number of entries this filter will return if |
|
|
|
|
known, otherwise returns None. |
|
|
|
|
|
|
|
|
@ -260,33 +257,33 @@ class StateFilter(object): |
|
|
|
|
|
|
|
|
|
return filtered_state |
|
|
|
|
|
|
|
|
|
def is_full(self): |
|
|
|
|
def is_full(self) -> bool: |
|
|
|
|
"""Whether this filter fetches everything or not |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
bool |
|
|
|
|
True if the filter fetches everything. |
|
|
|
|
""" |
|
|
|
|
return self.include_others and not self.types |
|
|
|
|
|
|
|
|
|
def has_wildcards(self): |
|
|
|
|
def has_wildcards(self) -> bool: |
|
|
|
|
"""Whether the filter includes wildcards or is attempting to fetch |
|
|
|
|
specific state. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
bool |
|
|
|
|
True if the filter includes wildcards. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
return self.include_others or any( |
|
|
|
|
state_keys is None for state_keys in self.types.values() |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
def concrete_types(self): |
|
|
|
|
def concrete_types(self) -> List[Tuple[str, str]]: |
|
|
|
|
"""Returns a list of concrete type/state_keys (i.e. not None) that |
|
|
|
|
will be fetched. This will be a complete list if `has_wildcards` |
|
|
|
|
returns False, but otherwise will be a subset (or even empty). |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
list[tuple[str,str]] |
|
|
|
|
A list of type/state_keys tuples. |
|
|
|
|
""" |
|
|
|
|
return [ |
|
|
|
|
(t, s) |
|
|
|
@ -295,7 +292,7 @@ class StateFilter(object): |
|
|
|
|
for s in state_keys |
|
|
|
|
] |
|
|
|
|
|
|
|
|
|
def get_member_split(self): |
|
|
|
|
def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]: |
|
|
|
|
"""Return the filter split into two: one which assumes it's exclusively |
|
|
|
|
matching against member state, and one which assumes it's matching |
|
|
|
|
against non member state. |
|
|
|
@ -307,7 +304,7 @@ class StateFilter(object): |
|
|
|
|
state caches). |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
tuple[StateFilter, StateFilter]: The member and non member filters |
|
|
|
|
The member and non member filters |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
if EventTypes.Member in self.types: |
|
|
|
@ -340,6 +337,9 @@ class StateGroupStorage(object): |
|
|
|
|
"""Given a state group try to return a previous group and a delta between |
|
|
|
|
the old and the new. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
state_group: The state group used to retrieve state deltas. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]: |
|
|
|
|
(prev_group, delta_ids) |
|
|
|
@ -347,55 +347,59 @@ class StateGroupStorage(object): |
|
|
|
|
|
|
|
|
|
return self.stores.state.get_state_group_delta(state_group) |
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def get_state_groups_ids(self, _room_id, event_ids): |
|
|
|
|
async def get_state_groups_ids( |
|
|
|
|
self, _room_id: str, event_ids: Iterable[str] |
|
|
|
|
) -> Dict[int, StateMap[str]]: |
|
|
|
|
"""Get the event IDs of all the state for the state groups for the given events |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
_room_id (str): id of the room for these events |
|
|
|
|
event_ids (iterable[str]): ids of the events |
|
|
|
|
_room_id: id of the room for these events |
|
|
|
|
event_ids: ids of the events |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
Deferred[dict[int, StateMap[str]]]: |
|
|
|
|
dict of state_group_id -> (dict of (type, state_key) -> event id) |
|
|
|
|
dict of state_group_id -> (dict of (type, state_key) -> event id) |
|
|
|
|
""" |
|
|
|
|
if not event_ids: |
|
|
|
|
return {} |
|
|
|
|
|
|
|
|
|
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) |
|
|
|
|
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) |
|
|
|
|
|
|
|
|
|
groups = set(event_to_groups.values()) |
|
|
|
|
group_to_state = yield self.stores.state._get_state_for_groups(groups) |
|
|
|
|
group_to_state = await self.stores.state._get_state_for_groups(groups) |
|
|
|
|
|
|
|
|
|
return group_to_state |
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def get_state_ids_for_group(self, state_group): |
|
|
|
|
async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]: |
|
|
|
|
"""Get the event IDs of all the state in the given state group |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
state_group (int) |
|
|
|
|
state_group: A state group for which we want to get the state IDs. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
Deferred[dict]: Resolves to a map of (type, state_key) -> event_id |
|
|
|
|
Resolves to a map of (type, state_key) -> event_id |
|
|
|
|
""" |
|
|
|
|
group_to_state = yield self._get_state_for_groups((state_group,)) |
|
|
|
|
group_to_state = await self._get_state_for_groups((state_group,)) |
|
|
|
|
|
|
|
|
|
return group_to_state[state_group] |
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def get_state_groups(self, room_id, event_ids): |
|
|
|
|
async def get_state_groups( |
|
|
|
|
self, room_id: str, event_ids: Iterable[str] |
|
|
|
|
) -> Dict[int, List[EventBase]]: |
|
|
|
|
""" Get the state groups for the given list of event_ids |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
room_id: ID of the room for these events. |
|
|
|
|
event_ids: The event IDs to retrieve state for. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
Deferred[dict[int, list[EventBase]]]: |
|
|
|
|
dict of state_group_id -> list of state events. |
|
|
|
|
dict of state_group_id -> list of state events. |
|
|
|
|
""" |
|
|
|
|
if not event_ids: |
|
|
|
|
return {} |
|
|
|
|
|
|
|
|
|
group_to_ids = yield self.get_state_groups_ids(room_id, event_ids) |
|
|
|
|
group_to_ids = await self.get_state_groups_ids(room_id, event_ids) |
|
|
|
|
|
|
|
|
|
state_event_map = yield self.stores.main.get_events( |
|
|
|
|
state_event_map = await self.stores.main.get_events( |
|
|
|
|
[ |
|
|
|
|
ev_id |
|
|
|
|
for group_ids in group_to_ids.values() |
|
|
|
@ -423,31 +427,34 @@ class StateGroupStorage(object): |
|
|
|
|
groups: list of state group IDs to query |
|
|
|
|
state_filter: The state filter used to fetch state |
|
|
|
|
from the database. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
return self.stores.state._get_state_groups_from_groups(groups, state_filter) |
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def get_state_for_events(self, event_ids, state_filter=StateFilter.all()): |
|
|
|
|
async def get_state_for_events( |
|
|
|
|
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all() |
|
|
|
|
): |
|
|
|
|
"""Given a list of event_ids and type tuples, return a list of state |
|
|
|
|
dicts for each event. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
event_ids (list[string]) |
|
|
|
|
state_filter (StateFilter): The state filter used to fetch state |
|
|
|
|
from the database. |
|
|
|
|
event_ids: The events to fetch the state of. |
|
|
|
|
state_filter: The state filter used to fetch state. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
deferred: A dict of (event_id) -> (type, state_key) -> [state_events] |
|
|
|
|
A dict of (event_id) -> (type, state_key) -> [state_events] |
|
|
|
|
""" |
|
|
|
|
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) |
|
|
|
|
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) |
|
|
|
|
|
|
|
|
|
groups = set(event_to_groups.values()) |
|
|
|
|
group_to_state = yield self.stores.state._get_state_for_groups( |
|
|
|
|
group_to_state = await self.stores.state._get_state_for_groups( |
|
|
|
|
groups, state_filter |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
state_event_map = yield self.stores.main.get_events( |
|
|
|
|
state_event_map = await self.stores.main.get_events( |
|
|
|
|
[ev_id for sd in group_to_state.values() for ev_id in sd.values()], |
|
|
|
|
get_prev_content=False, |
|
|
|
|
) |
|
|
|
@ -463,24 +470,24 @@ class StateGroupStorage(object): |
|
|
|
|
|
|
|
|
|
return {event: event_to_state[event] for event in event_ids} |
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()): |
|
|
|
|
async def get_state_ids_for_events( |
|
|
|
|
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all() |
|
|
|
|
): |
|
|
|
|
""" |
|
|
|
|
Get the state dicts corresponding to a list of events, containing the event_ids |
|
|
|
|
of the state events (as opposed to the events themselves) |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
event_ids(list(str)): events whose state should be returned |
|
|
|
|
state_filter (StateFilter): The state filter used to fetch state |
|
|
|
|
from the database. |
|
|
|
|
event_ids: events whose state should be returned |
|
|
|
|
state_filter: The state filter used to fetch state from the database. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
A deferred dict from event_id -> (type, state_key) -> event_id |
|
|
|
|
A dict from event_id -> (type, state_key) -> event_id |
|
|
|
|
""" |
|
|
|
|
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) |
|
|
|
|
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) |
|
|
|
|
|
|
|
|
|
groups = set(event_to_groups.values()) |
|
|
|
|
group_to_state = yield self.stores.state._get_state_for_groups( |
|
|
|
|
group_to_state = await self.stores.state._get_state_for_groups( |
|
|
|
|
groups, state_filter |
|
|
|
|
) |
|
|
|
|
|
|
|
|
@ -491,36 +498,36 @@ class StateGroupStorage(object): |
|
|
|
|
|
|
|
|
|
return {event: event_to_state[event] for event in event_ids} |
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def get_state_for_event(self, event_id, state_filter=StateFilter.all()): |
|
|
|
|
async def get_state_for_event( |
|
|
|
|
self, event_id: str, state_filter: StateFilter = StateFilter.all() |
|
|
|
|
): |
|
|
|
|
""" |
|
|
|
|
Get the state dict corresponding to a particular event |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
event_id(str): event whose state should be returned |
|
|
|
|
state_filter (StateFilter): The state filter used to fetch state |
|
|
|
|
from the database. |
|
|
|
|
event_id: event whose state should be returned |
|
|
|
|
state_filter: The state filter used to fetch state from the database. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
A deferred dict from (type, state_key) -> state_event |
|
|
|
|
A dict from (type, state_key) -> state_event |
|
|
|
|
""" |
|
|
|
|
state_map = yield self.get_state_for_events([event_id], state_filter) |
|
|
|
|
state_map = await self.get_state_for_events([event_id], state_filter) |
|
|
|
|
return state_map[event_id] |
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()): |
|
|
|
|
async def get_state_ids_for_event( |
|
|
|
|
self, event_id: str, state_filter: StateFilter = StateFilter.all() |
|
|
|
|
): |
|
|
|
|
""" |
|
|
|
|
Get the state dict corresponding to a particular event |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
event_id(str): event whose state should be returned |
|
|
|
|
state_filter (StateFilter): The state filter used to fetch state |
|
|
|
|
from the database. |
|
|
|
|
event_id: event whose state should be returned |
|
|
|
|
state_filter: The state filter used to fetch state from the database. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
A deferred dict from (type, state_key) -> state_event |
|
|
|
|
""" |
|
|
|
|
state_map = yield self.get_state_ids_for_events([event_id], state_filter) |
|
|
|
|
state_map = await self.get_state_ids_for_events([event_id], state_filter) |
|
|
|
|
return state_map[event_id] |
|
|
|
|
|
|
|
|
|
def _get_state_for_groups( |
|
|
|
@ -530,9 +537,8 @@ class StateGroupStorage(object): |
|
|
|
|
filtering by type/state_key |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
groups (iterable[int]): list of state groups for which we want |
|
|
|
|
to get the state. |
|
|
|
|
state_filter (StateFilter): The state filter used to fetch state |
|
|
|
|
groups: list of state groups for which we want to get the state. |
|
|
|
|
state_filter: The state filter used to fetch state. |
|
|
|
|
from the database. |
|
|
|
|
Returns: |
|
|
|
|
Deferred[dict[int, StateMap[str]]]: Dict of state group to state map. |
|
|
|
@ -540,18 +546,23 @@ class StateGroupStorage(object): |
|
|
|
|
return self.stores.state._get_state_for_groups(groups, state_filter) |
|
|
|
|
|
|
|
|
|
def store_state_group( |
|
|
|
|
self, event_id, room_id, prev_group, delta_ids, current_state_ids |
|
|
|
|
self, |
|
|
|
|
event_id: str, |
|
|
|
|
room_id: str, |
|
|
|
|
prev_group: Optional[int], |
|
|
|
|
delta_ids: Optional[dict], |
|
|
|
|
current_state_ids: dict, |
|
|
|
|
): |
|
|
|
|
"""Store a new set of state, returning a newly assigned state group. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
event_id (str): The event ID for which the state was calculated |
|
|
|
|
room_id (str) |
|
|
|
|
prev_group (int|None): A previous state group for the room, optional. |
|
|
|
|
delta_ids (dict|None): The delta between state at `prev_group` and |
|
|
|
|
event_id: The event ID for which the state was calculated. |
|
|
|
|
room_id: ID of the room for which the state was calculated. |
|
|
|
|
prev_group: A previous state group for the room, optional. |
|
|
|
|
delta_ids: The delta between state at `prev_group` and |
|
|
|
|
`current_state_ids`, if `prev_group` was given. Same format as |
|
|
|
|
`current_state_ids`. |
|
|
|
|
current_state_ids (dict): The state to store. Map of (type, state_key) |
|
|
|
|
current_state_ids: The state to store. Map of (type, state_key) |
|
|
|
|
to event_id. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|