|
|
|
@ -35,28 +35,22 @@ logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
RoomsForUser = namedtuple( |
|
|
|
|
"RoomsForUser", |
|
|
|
|
("room_id", "sender", "membership", "event_id", "stream_ordering") |
|
|
|
|
"RoomsForUser", ("room_id", "sender", "membership", "event_id", "stream_ordering") |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
GetRoomsForUserWithStreamOrdering = namedtuple( |
|
|
|
|
"_GetRoomsForUserWithStreamOrdering", |
|
|
|
|
("room_id", "stream_ordering",) |
|
|
|
|
"_GetRoomsForUserWithStreamOrdering", ("room_id", "stream_ordering") |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# We store this using a namedtuple so that we save about 3x space over using a |
|
|
|
|
# dict. |
|
|
|
|
ProfileInfo = namedtuple( |
|
|
|
|
"ProfileInfo", ("avatar_url", "display_name") |
|
|
|
|
) |
|
|
|
|
ProfileInfo = namedtuple("ProfileInfo", ("avatar_url", "display_name")) |
|
|
|
|
|
|
|
|
|
# "members" points to a truncated list of (user_id, event_id) tuples for users of |
|
|
|
|
# a given membership type, suitable for use in calculating heroes for a room. |
|
|
|
|
# "count" points to the total numberr of users of a given membership type. |
|
|
|
|
MemberSummary = namedtuple( |
|
|
|
|
"MemberSummary", ("members", "count") |
|
|
|
|
) |
|
|
|
|
MemberSummary = namedtuple("MemberSummary", ("members", "count")) |
|
|
|
|
|
|
|
|
|
_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update" |
|
|
|
|
|
|
|
|
@ -67,7 +61,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): |
|
|
|
|
"""Returns the set of all hosts currently in the room |
|
|
|
|
""" |
|
|
|
|
user_ids = yield self.get_users_in_room( |
|
|
|
|
room_id, on_invalidate=cache_context.invalidate, |
|
|
|
|
room_id, on_invalidate=cache_context.invalidate |
|
|
|
|
) |
|
|
|
|
hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids) |
|
|
|
|
defer.returnValue(hosts) |
|
|
|
@ -84,8 +78,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): |
|
|
|
|
" WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?" |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
txn.execute(sql, (room_id, Membership.JOIN,)) |
|
|
|
|
txn.execute(sql, (room_id, Membership.JOIN)) |
|
|
|
|
return [to_ascii(r[0]) for r in txn] |
|
|
|
|
|
|
|
|
|
return self.runInteraction("get_users_in_room", f) |
|
|
|
|
|
|
|
|
|
@cached(max_entries=100000) |
|
|
|
@ -156,9 +151,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): |
|
|
|
|
A deferred list of RoomsForUser. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
return self.get_rooms_for_user_where_membership_is( |
|
|
|
|
user_id, [Membership.INVITE] |
|
|
|
|
) |
|
|
|
|
return self.get_rooms_for_user_where_membership_is(user_id, [Membership.INVITE]) |
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def get_invite_for_user_in_room(self, user_id, room_id): |
|
|
|
@ -196,11 +189,13 @@ class RoomMemberWorkerStore(EventsWorkerStore): |
|
|
|
|
return self.runInteraction( |
|
|
|
|
"get_rooms_for_user_where_membership_is", |
|
|
|
|
self._get_rooms_for_user_where_membership_is_txn, |
|
|
|
|
user_id, membership_list |
|
|
|
|
user_id, |
|
|
|
|
membership_list, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
def _get_rooms_for_user_where_membership_is_txn(self, txn, user_id, |
|
|
|
|
membership_list): |
|
|
|
|
def _get_rooms_for_user_where_membership_is_txn( |
|
|
|
|
self, txn, user_id, membership_list |
|
|
|
|
): |
|
|
|
|
|
|
|
|
|
do_invite = Membership.INVITE in membership_list |
|
|
|
|
membership_list = [m for m in membership_list if m != Membership.INVITE] |
|
|
|
@ -227,9 +222,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): |
|
|
|
|
) % (where_clause,) |
|
|
|
|
|
|
|
|
|
txn.execute(sql, args) |
|
|
|
|
results = [ |
|
|
|
|
RoomsForUser(**r) for r in self.cursor_to_dict(txn) |
|
|
|
|
] |
|
|
|
|
results = [RoomsForUser(**r) for r in self.cursor_to_dict(txn)] |
|
|
|
|
|
|
|
|
|
if do_invite: |
|
|
|
|
sql = ( |
|
|
|
@ -241,13 +234,16 @@ class RoomMemberWorkerStore(EventsWorkerStore): |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
txn.execute(sql, (user_id,)) |
|
|
|
|
results.extend(RoomsForUser( |
|
|
|
|
room_id=r["room_id"], |
|
|
|
|
sender=r["inviter"], |
|
|
|
|
event_id=r["event_id"], |
|
|
|
|
stream_ordering=r["stream_ordering"], |
|
|
|
|
membership=Membership.INVITE, |
|
|
|
|
) for r in self.cursor_to_dict(txn)) |
|
|
|
|
results.extend( |
|
|
|
|
RoomsForUser( |
|
|
|
|
room_id=r["room_id"], |
|
|
|
|
sender=r["inviter"], |
|
|
|
|
event_id=r["event_id"], |
|
|
|
|
stream_ordering=r["stream_ordering"], |
|
|
|
|
membership=Membership.INVITE, |
|
|
|
|
) |
|
|
|
|
for r in self.cursor_to_dict(txn) |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
@ -264,19 +260,21 @@ class RoomMemberWorkerStore(EventsWorkerStore): |
|
|
|
|
of the most recent join for that user and room. |
|
|
|
|
""" |
|
|
|
|
rooms = yield self.get_rooms_for_user_where_membership_is( |
|
|
|
|
user_id, membership_list=[Membership.JOIN], |
|
|
|
|
user_id, membership_list=[Membership.JOIN] |
|
|
|
|
) |
|
|
|
|
defer.returnValue( |
|
|
|
|
frozenset( |
|
|
|
|
GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering) |
|
|
|
|
for r in rooms |
|
|
|
|
) |
|
|
|
|
) |
|
|
|
|
defer.returnValue(frozenset( |
|
|
|
|
GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering) |
|
|
|
|
for r in rooms |
|
|
|
|
)) |
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def get_rooms_for_user(self, user_id, on_invalidate=None): |
|
|
|
|
"""Returns a set of room_ids the user is currently joined to |
|
|
|
|
""" |
|
|
|
|
rooms = yield self.get_rooms_for_user_with_stream_ordering( |
|
|
|
|
user_id, on_invalidate=on_invalidate, |
|
|
|
|
user_id, on_invalidate=on_invalidate |
|
|
|
|
) |
|
|
|
|
defer.returnValue(frozenset(r.room_id for r in rooms)) |
|
|
|
|
|
|
|
|
@ -285,13 +283,13 @@ class RoomMemberWorkerStore(EventsWorkerStore): |
|
|
|
|
"""Returns the set of users who share a room with `user_id` |
|
|
|
|
""" |
|
|
|
|
room_ids = yield self.get_rooms_for_user( |
|
|
|
|
user_id, on_invalidate=cache_context.invalidate, |
|
|
|
|
user_id, on_invalidate=cache_context.invalidate |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
user_who_share_room = set() |
|
|
|
|
for room_id in room_ids: |
|
|
|
|
user_ids = yield self.get_users_in_room( |
|
|
|
|
room_id, on_invalidate=cache_context.invalidate, |
|
|
|
|
room_id, on_invalidate=cache_context.invalidate |
|
|
|
|
) |
|
|
|
|
user_who_share_room.update(user_ids) |
|
|
|
|
|
|
|
|
@ -309,9 +307,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): |
|
|
|
|
|
|
|
|
|
current_state_ids = yield context.get_current_state_ids(self) |
|
|
|
|
result = yield self._get_joined_users_from_context( |
|
|
|
|
event.room_id, state_group, current_state_ids, |
|
|
|
|
event=event, |
|
|
|
|
context=context, |
|
|
|
|
event.room_id, state_group, current_state_ids, event=event, context=context |
|
|
|
|
) |
|
|
|
|
defer.returnValue(result) |
|
|
|
|
|
|
|
|
@ -325,13 +321,21 @@ class RoomMemberWorkerStore(EventsWorkerStore): |
|
|
|
|
state_group = object() |
|
|
|
|
|
|
|
|
|
return self._get_joined_users_from_context( |
|
|
|
|
room_id, state_group, state_entry.state, context=state_entry, |
|
|
|
|
room_id, state_group, state_entry.state, context=state_entry |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
@cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True, |
|
|
|
|
max_entries=100000) |
|
|
|
|
def _get_joined_users_from_context(self, room_id, state_group, current_state_ids, |
|
|
|
|
cache_context, event=None, context=None): |
|
|
|
|
@cachedInlineCallbacks( |
|
|
|
|
num_args=2, cache_context=True, iterable=True, max_entries=100000 |
|
|
|
|
) |
|
|
|
|
def _get_joined_users_from_context( |
|
|
|
|
self, |
|
|
|
|
room_id, |
|
|
|
|
state_group, |
|
|
|
|
current_state_ids, |
|
|
|
|
cache_context, |
|
|
|
|
event=None, |
|
|
|
|
context=None, |
|
|
|
|
): |
|
|
|
|
# We don't use `state_group`, it's there so that we can cache based |
|
|
|
|
# on it. However, it's important that it's never None, since two current_states |
|
|
|
|
# with a state_group of None are likely to be different. |
|
|
|
@ -371,9 +375,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): |
|
|
|
|
# the hit ratio counts. After all, we don't populate the cache if we |
|
|
|
|
# miss it here |
|
|
|
|
event_map = self._get_events_from_cache( |
|
|
|
|
member_event_ids, |
|
|
|
|
allow_rejected=False, |
|
|
|
|
update_metrics=False, |
|
|
|
|
member_event_ids, allow_rejected=False, update_metrics=False |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
missing_member_event_ids = [] |
|
|
|
@ -397,21 +399,21 @@ class RoomMemberWorkerStore(EventsWorkerStore): |
|
|
|
|
table="room_memberships", |
|
|
|
|
column="event_id", |
|
|
|
|
iterable=missing_member_event_ids, |
|
|
|
|
retcols=('user_id', 'display_name', 'avatar_url',), |
|
|
|
|
keyvalues={ |
|
|
|
|
"membership": Membership.JOIN, |
|
|
|
|
}, |
|
|
|
|
retcols=('user_id', 'display_name', 'avatar_url'), |
|
|
|
|
keyvalues={"membership": Membership.JOIN}, |
|
|
|
|
batch_size=500, |
|
|
|
|
desc="_get_joined_users_from_context", |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
users_in_room.update({ |
|
|
|
|
to_ascii(row["user_id"]): ProfileInfo( |
|
|
|
|
avatar_url=to_ascii(row["avatar_url"]), |
|
|
|
|
display_name=to_ascii(row["display_name"]), |
|
|
|
|
) |
|
|
|
|
for row in rows |
|
|
|
|
}) |
|
|
|
|
users_in_room.update( |
|
|
|
|
{ |
|
|
|
|
to_ascii(row["user_id"]): ProfileInfo( |
|
|
|
|
avatar_url=to_ascii(row["avatar_url"]), |
|
|
|
|
display_name=to_ascii(row["display_name"]), |
|
|
|
|
) |
|
|
|
|
for row in rows |
|
|
|
|
} |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
if event is not None and event.type == EventTypes.Member: |
|
|
|
|
if event.membership == Membership.JOIN: |
|
|
|
@ -505,7 +507,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): |
|
|
|
|
state_group = object() |
|
|
|
|
|
|
|
|
|
return self._get_joined_hosts( |
|
|
|
|
room_id, state_group, state_entry.state, state_entry=state_entry, |
|
|
|
|
room_id, state_group, state_entry.state, state_entry=state_entry |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
@cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True) |
|
|
|
@ -531,6 +533,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): |
|
|
|
|
"""Returns whether user_id has elected to discard history for room_id. |
|
|
|
|
|
|
|
|
|
Returns False if they have since re-joined.""" |
|
|
|
|
|
|
|
|
|
def f(txn): |
|
|
|
|
sql = ( |
|
|
|
|
"SELECT" |
|
|
|
@ -547,6 +550,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): |
|
|
|
|
txn.execute(sql, (user_id, room_id)) |
|
|
|
|
rows = txn.fetchall() |
|
|
|
|
return rows[0][0] |
|
|
|
|
|
|
|
|
|
count = yield self.runInteraction("did_forget_membership", f) |
|
|
|
|
defer.returnValue(count == 0) |
|
|
|
|
|
|
|
|
@ -575,13 +579,14 @@ class RoomMemberStore(RoomMemberWorkerStore): |
|
|
|
|
"avatar_url": event.content.get("avatar_url", None), |
|
|
|
|
} |
|
|
|
|
for event in events |
|
|
|
|
] |
|
|
|
|
], |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
for event in events: |
|
|
|
|
txn.call_after( |
|
|
|
|
self._membership_stream_cache.entity_has_changed, |
|
|
|
|
event.state_key, event.internal_metadata.stream_ordering |
|
|
|
|
event.state_key, |
|
|
|
|
event.internal_metadata.stream_ordering, |
|
|
|
|
) |
|
|
|
|
txn.call_after( |
|
|
|
|
self.get_invited_rooms_for_user.invalidate, (event.state_key,) |
|
|
|
@ -607,7 +612,7 @@ class RoomMemberStore(RoomMemberWorkerStore): |
|
|
|
|
"inviter": event.sender, |
|
|
|
|
"room_id": event.room_id, |
|
|
|
|
"stream_id": event.internal_metadata.stream_ordering, |
|
|
|
|
} |
|
|
|
|
}, |
|
|
|
|
) |
|
|
|
|
else: |
|
|
|
|
sql = ( |
|
|
|
@ -616,12 +621,15 @@ class RoomMemberStore(RoomMemberWorkerStore): |
|
|
|
|
" AND replaced_by is NULL" |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
txn.execute(sql, ( |
|
|
|
|
event.internal_metadata.stream_ordering, |
|
|
|
|
event.event_id, |
|
|
|
|
event.room_id, |
|
|
|
|
event.state_key, |
|
|
|
|
)) |
|
|
|
|
txn.execute( |
|
|
|
|
sql, |
|
|
|
|
( |
|
|
|
|
event.internal_metadata.stream_ordering, |
|
|
|
|
event.event_id, |
|
|
|
|
event.room_id, |
|
|
|
|
event.state_key, |
|
|
|
|
), |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def locally_reject_invite(self, user_id, room_id): |
|
|
|
@ -632,18 +640,14 @@ class RoomMemberStore(RoomMemberWorkerStore): |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
def f(txn, stream_ordering): |
|
|
|
|
txn.execute(sql, ( |
|
|
|
|
stream_ordering, |
|
|
|
|
True, |
|
|
|
|
room_id, |
|
|
|
|
user_id, |
|
|
|
|
)) |
|
|
|
|
txn.execute(sql, (stream_ordering, True, room_id, user_id)) |
|
|
|
|
|
|
|
|
|
with self._stream_id_gen.get_next() as stream_ordering: |
|
|
|
|
yield self.runInteraction("locally_reject_invite", f, stream_ordering) |
|
|
|
|
|
|
|
|
|
def forget(self, user_id, room_id): |
|
|
|
|
"""Indicate that user_id wishes to discard history for room_id.""" |
|
|
|
|
|
|
|
|
|
def f(txn): |
|
|
|
|
sql = ( |
|
|
|
|
"UPDATE" |
|
|
|
@ -657,9 +661,8 @@ class RoomMemberStore(RoomMemberWorkerStore): |
|
|
|
|
) |
|
|
|
|
txn.execute(sql, (user_id, room_id)) |
|
|
|
|
|
|
|
|
|
self._invalidate_cache_and_stream( |
|
|
|
|
txn, self.did_forget, (user_id, room_id,), |
|
|
|
|
) |
|
|
|
|
self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id)) |
|
|
|
|
|
|
|
|
|
return self.runInteraction("forget_membership", f) |
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
|
@ -674,7 +677,7 @@ class RoomMemberStore(RoomMemberWorkerStore): |
|
|
|
|
INSERT_CLUMP_SIZE = 1000 |
|
|
|
|
|
|
|
|
|
def add_membership_profile_txn(txn): |
|
|
|
|
sql = (""" |
|
|
|
|
sql = """ |
|
|
|
|
SELECT stream_ordering, event_id, events.room_id, event_json.json |
|
|
|
|
FROM events |
|
|
|
|
INNER JOIN event_json USING (event_id) |
|
|
|
@ -683,7 +686,7 @@ class RoomMemberStore(RoomMemberWorkerStore): |
|
|
|
|
AND type = 'm.room.member' |
|
|
|
|
ORDER BY stream_ordering DESC |
|
|
|
|
LIMIT ? |
|
|
|
|
""") |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) |
|
|
|
|
|
|
|
|
@ -707,16 +710,14 @@ class RoomMemberStore(RoomMemberWorkerStore): |
|
|
|
|
avatar_url = content.get("avatar_url", None) |
|
|
|
|
|
|
|
|
|
if display_name or avatar_url: |
|
|
|
|
to_update.append(( |
|
|
|
|
display_name, avatar_url, event_id, room_id |
|
|
|
|
)) |
|
|
|
|
to_update.append((display_name, avatar_url, event_id, room_id)) |
|
|
|
|
|
|
|
|
|
to_update_sql = (""" |
|
|
|
|
to_update_sql = """ |
|
|
|
|
UPDATE room_memberships SET display_name = ?, avatar_url = ? |
|
|
|
|
WHERE event_id = ? AND room_id = ? |
|
|
|
|
""") |
|
|
|
|
""" |
|
|
|
|
for index in range(0, len(to_update), INSERT_CLUMP_SIZE): |
|
|
|
|
clump = to_update[index:index + INSERT_CLUMP_SIZE] |
|
|
|
|
clump = to_update[index : index + INSERT_CLUMP_SIZE] |
|
|
|
|
txn.executemany(to_update_sql, clump) |
|
|
|
|
|
|
|
|
|
progress = { |
|
|
|
@ -789,7 +790,7 @@ class _JoinedHostsCache(object): |
|
|
|
|
self.hosts_to_joined_users.pop(host, None) |
|
|
|
|
else: |
|
|
|
|
joined_users = yield self.store.get_joined_users_from_state( |
|
|
|
|
self.room_id, state_entry, |
|
|
|
|
self.room_id, state_entry |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
self.hosts_to_joined_users = {} |
|
|
|
|