|
|
|
@ -15,15 +15,19 @@ |
|
|
|
|
|
|
|
|
|
import logging |
|
|
|
|
from collections import namedtuple |
|
|
|
|
from typing import List, Tuple |
|
|
|
|
from typing import TYPE_CHECKING, List, Set, Tuple |
|
|
|
|
|
|
|
|
|
from synapse.api.errors import AuthError, SynapseError |
|
|
|
|
from synapse.logging.context import run_in_background |
|
|
|
|
from synapse.metrics.background_process_metrics import run_as_background_process |
|
|
|
|
from synapse.replication.tcp.streams import TypingStream |
|
|
|
|
from synapse.types import UserID, get_domain_from_id |
|
|
|
|
from synapse.util.caches.stream_change_cache import StreamChangeCache |
|
|
|
|
from synapse.util.metrics import Measure |
|
|
|
|
from synapse.util.wheel_timer import WheelTimer |
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
|
|
from synapse.server import HomeServer |
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -39,48 +43,48 @@ FEDERATION_TIMEOUT = 60 * 1000 |
|
|
|
|
FEDERATION_PING_INTERVAL = 40 * 1000 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TypingHandler(object): |
|
|
|
|
def __init__(self, hs): |
|
|
|
|
class FollowerTypingHandler: |
|
|
|
|
"""A typing handler on a different process than the writer that is updated |
|
|
|
|
via replication. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
def __init__(self, hs: "HomeServer"): |
|
|
|
|
self.store = hs.get_datastore() |
|
|
|
|
self.server_name = hs.config.server_name |
|
|
|
|
self.auth = hs.get_auth() |
|
|
|
|
self.is_mine_id = hs.is_mine_id |
|
|
|
|
self.notifier = hs.get_notifier() |
|
|
|
|
self.state = hs.get_state_handler() |
|
|
|
|
|
|
|
|
|
self.hs = hs |
|
|
|
|
|
|
|
|
|
self.clock = hs.get_clock() |
|
|
|
|
self.wheel_timer = WheelTimer(bucket_size=5000) |
|
|
|
|
self.is_mine_id = hs.is_mine_id |
|
|
|
|
|
|
|
|
|
self.federation = None |
|
|
|
|
if hs.should_send_federation(): |
|
|
|
|
self.federation = hs.get_federation_sender() |
|
|
|
|
|
|
|
|
|
hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu) |
|
|
|
|
if hs.config.worker.writers.typing != hs.get_instance_name(): |
|
|
|
|
hs.get_federation_registry().register_instance_for_edu( |
|
|
|
|
"m.typing", hs.config.worker.writers.typing, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
hs.get_distributor().observe("user_left_room", self.user_left_room) |
|
|
|
|
# map room IDs to serial numbers |
|
|
|
|
self._room_serials = {} |
|
|
|
|
# map room IDs to sets of users currently typing |
|
|
|
|
self._room_typing = {} |
|
|
|
|
|
|
|
|
|
self._member_typing_until = {} # clock time we expect to stop |
|
|
|
|
self._member_last_federation_poke = {} |
|
|
|
|
|
|
|
|
|
self.wheel_timer = WheelTimer(bucket_size=5000) |
|
|
|
|
self._latest_room_serial = 0 |
|
|
|
|
self._reset() |
|
|
|
|
|
|
|
|
|
# caches which room_ids changed at which serials |
|
|
|
|
self._typing_stream_change_cache = StreamChangeCache( |
|
|
|
|
"TypingStreamChangeCache", self._latest_room_serial |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
self.clock.looping_call(self._handle_timeouts, 5000) |
|
|
|
|
|
|
|
|
|
def _reset(self): |
|
|
|
|
""" |
|
|
|
|
Reset the typing handler's data caches. |
|
|
|
|
"""Reset the typing handler's data caches. |
|
|
|
|
""" |
|
|
|
|
# map room IDs to serial numbers |
|
|
|
|
self._room_serials = {} |
|
|
|
|
# map room IDs to sets of users currently typing |
|
|
|
|
self._room_typing = {} |
|
|
|
|
|
|
|
|
|
self._member_last_federation_poke = {} |
|
|
|
|
self.wheel_timer = WheelTimer(bucket_size=5000) |
|
|
|
|
|
|
|
|
|
def _handle_timeouts(self): |
|
|
|
|
logger.debug("Checking for typing timeouts") |
|
|
|
|
|
|
|
|
@ -89,22 +93,21 @@ class TypingHandler(object): |
|
|
|
|
members = set(self.wheel_timer.fetch(now)) |
|
|
|
|
|
|
|
|
|
for member in members: |
|
|
|
|
self._handle_timeout_for_member(now, member) |
|
|
|
|
|
|
|
|
|
def _handle_timeout_for_member(self, now: int, member: RoomMember): |
|
|
|
|
if not self.is_typing(member): |
|
|
|
|
# Nothing to do if they're no longer typing |
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
until = self._member_typing_until.get(member, None) |
|
|
|
|
if not until or until <= now: |
|
|
|
|
logger.info("Timing out typing for: %s", member.user_id) |
|
|
|
|
self._stopped_typing(member) |
|
|
|
|
continue |
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
# Check if we need to resend a keep alive over federation for this |
|
|
|
|
# user. |
|
|
|
|
if self.hs.is_mine_id(member.user_id): |
|
|
|
|
if self.federation and self.is_mine_id(member.user_id): |
|
|
|
|
last_fed_poke = self._member_last_federation_poke.get(member, None) |
|
|
|
|
if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now: |
|
|
|
|
run_in_background(self._push_remote, member=member, typing=True) |
|
|
|
|
run_as_background_process( |
|
|
|
|
"typing._push_remote", self._push_remote, member=member, typing=True |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
# Add a paranoia timer to ensure that we always have a timer for |
|
|
|
|
# each person typing. |
|
|
|
@ -113,6 +116,117 @@ class TypingHandler(object): |
|
|
|
|
def is_typing(self, member): |
|
|
|
|
return member.user_id in self._room_typing.get(member.room_id, []) |
|
|
|
|
|
|
|
|
|
async def _push_remote(self, member, typing): |
|
|
|
|
if not self.federation: |
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
users = await self.store.get_users_in_room(member.room_id) |
|
|
|
|
self._member_last_federation_poke[member] = self.clock.time_msec() |
|
|
|
|
|
|
|
|
|
now = self.clock.time_msec() |
|
|
|
|
self.wheel_timer.insert( |
|
|
|
|
now=now, obj=member, then=now + FEDERATION_PING_INTERVAL |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
for domain in {get_domain_from_id(u) for u in users}: |
|
|
|
|
if domain != self.server_name: |
|
|
|
|
logger.debug("sending typing update to %s", domain) |
|
|
|
|
self.federation.build_and_send_edu( |
|
|
|
|
destination=domain, |
|
|
|
|
edu_type="m.typing", |
|
|
|
|
content={ |
|
|
|
|
"room_id": member.room_id, |
|
|
|
|
"user_id": member.user_id, |
|
|
|
|
"typing": typing, |
|
|
|
|
}, |
|
|
|
|
key=member, |
|
|
|
|
) |
|
|
|
|
except Exception: |
|
|
|
|
logger.exception("Error pushing typing notif to remotes") |
|
|
|
|
|
|
|
|
|
def process_replication_rows( |
|
|
|
|
self, token: int, rows: List[TypingStream.TypingStreamRow] |
|
|
|
|
): |
|
|
|
|
"""Should be called whenever we receive updates for typing stream. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
if self._latest_room_serial > token: |
|
|
|
|
# The master has gone backwards. To prevent inconsistent data, just |
|
|
|
|
# clear everything. |
|
|
|
|
self._reset() |
|
|
|
|
|
|
|
|
|
# Set the latest serial token to whatever the server gave us. |
|
|
|
|
self._latest_room_serial = token |
|
|
|
|
|
|
|
|
|
for row in rows: |
|
|
|
|
self._room_serials[row.room_id] = token |
|
|
|
|
|
|
|
|
|
prev_typing = set(self._room_typing.get(row.room_id, [])) |
|
|
|
|
now_typing = set(row.user_ids) |
|
|
|
|
self._room_typing[row.room_id] = row.user_ids |
|
|
|
|
|
|
|
|
|
run_as_background_process( |
|
|
|
|
"_handle_change_in_typing", |
|
|
|
|
self._handle_change_in_typing, |
|
|
|
|
row.room_id, |
|
|
|
|
prev_typing, |
|
|
|
|
now_typing, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
async def _handle_change_in_typing( |
|
|
|
|
self, room_id: str, prev_typing: Set[str], now_typing: Set[str] |
|
|
|
|
): |
|
|
|
|
"""Process a change in typing of a room from replication, sending EDUs |
|
|
|
|
for any local users. |
|
|
|
|
""" |
|
|
|
|
for user_id in now_typing - prev_typing: |
|
|
|
|
if self.is_mine_id(user_id): |
|
|
|
|
await self._push_remote(RoomMember(room_id, user_id), True) |
|
|
|
|
|
|
|
|
|
for user_id in prev_typing - now_typing: |
|
|
|
|
if self.is_mine_id(user_id): |
|
|
|
|
await self._push_remote(RoomMember(room_id, user_id), False) |
|
|
|
|
|
|
|
|
|
def get_current_token(self): |
|
|
|
|
return self._latest_room_serial |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TypingWriterHandler(FollowerTypingHandler): |
|
|
|
|
def __init__(self, hs): |
|
|
|
|
super().__init__(hs) |
|
|
|
|
|
|
|
|
|
assert hs.config.worker.writers.typing == hs.get_instance_name() |
|
|
|
|
|
|
|
|
|
self.auth = hs.get_auth() |
|
|
|
|
self.notifier = hs.get_notifier() |
|
|
|
|
|
|
|
|
|
self.hs = hs |
|
|
|
|
|
|
|
|
|
hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu) |
|
|
|
|
|
|
|
|
|
hs.get_distributor().observe("user_left_room", self.user_left_room) |
|
|
|
|
|
|
|
|
|
self._member_typing_until = {} # clock time we expect to stop |
|
|
|
|
|
|
|
|
|
# caches which room_ids changed at which serials |
|
|
|
|
self._typing_stream_change_cache = StreamChangeCache( |
|
|
|
|
"TypingStreamChangeCache", self._latest_room_serial |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
def _handle_timeout_for_member(self, now: int, member: RoomMember): |
|
|
|
|
super()._handle_timeout_for_member(now, member) |
|
|
|
|
|
|
|
|
|
if not self.is_typing(member): |
|
|
|
|
# Nothing to do if they're no longer typing |
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
until = self._member_typing_until.get(member, None) |
|
|
|
|
if not until or until <= now: |
|
|
|
|
logger.info("Timing out typing for: %s", member.user_id) |
|
|
|
|
self._stopped_typing(member) |
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
async def started_typing(self, target_user, auth_user, room_id, timeout): |
|
|
|
|
target_user_id = target_user.to_string() |
|
|
|
|
auth_user_id = auth_user.to_string() |
|
|
|
@ -179,35 +293,11 @@ class TypingHandler(object): |
|
|
|
|
def _push_update(self, member, typing): |
|
|
|
|
if self.hs.is_mine_id(member.user_id): |
|
|
|
|
# Only send updates for changes to our own users. |
|
|
|
|
run_in_background(self._push_remote, member, typing) |
|
|
|
|
|
|
|
|
|
self._push_update_local(member=member, typing=typing) |
|
|
|
|
|
|
|
|
|
async def _push_remote(self, member, typing): |
|
|
|
|
try: |
|
|
|
|
users = await self.store.get_users_in_room(member.room_id) |
|
|
|
|
self._member_last_federation_poke[member] = self.clock.time_msec() |
|
|
|
|
|
|
|
|
|
now = self.clock.time_msec() |
|
|
|
|
self.wheel_timer.insert( |
|
|
|
|
now=now, obj=member, then=now + FEDERATION_PING_INTERVAL |
|
|
|
|
run_as_background_process( |
|
|
|
|
"typing._push_remote", self._push_remote, member, typing |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
for domain in {get_domain_from_id(u) for u in users}: |
|
|
|
|
if domain != self.server_name: |
|
|
|
|
logger.debug("sending typing update to %s", domain) |
|
|
|
|
self.federation.build_and_send_edu( |
|
|
|
|
destination=domain, |
|
|
|
|
edu_type="m.typing", |
|
|
|
|
content={ |
|
|
|
|
"room_id": member.room_id, |
|
|
|
|
"user_id": member.user_id, |
|
|
|
|
"typing": typing, |
|
|
|
|
}, |
|
|
|
|
key=member, |
|
|
|
|
) |
|
|
|
|
except Exception: |
|
|
|
|
logger.exception("Error pushing typing notif to remotes") |
|
|
|
|
self._push_update_local(member=member, typing=typing) |
|
|
|
|
|
|
|
|
|
async def _recv_edu(self, origin, content): |
|
|
|
|
room_id = content["room_id"] |
|
|
|
@ -304,8 +394,11 @@ class TypingHandler(object): |
|
|
|
|
|
|
|
|
|
return rows, current_id, limited |
|
|
|
|
|
|
|
|
|
def get_current_token(self): |
|
|
|
|
return self._latest_room_serial |
|
|
|
|
def process_replication_rows( |
|
|
|
|
self, token: int, rows: List[TypingStream.TypingStreamRow] |
|
|
|
|
): |
|
|
|
|
# The writing process should never get updates from replication. |
|
|
|
|
raise Exception("Typing writer instance got typing info over replication") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TypingNotificationEventSource(object): |
|
|
|
|