|
|
|
@ -13,14 +13,26 @@ |
|
|
|
|
# limitations under the License. |
|
|
|
|
|
|
|
|
|
import logging |
|
|
|
|
from typing import Dict, List, Optional, Tuple, Union |
|
|
|
|
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union, cast |
|
|
|
|
|
|
|
|
|
from typing_extensions import TypedDict |
|
|
|
|
|
|
|
|
|
from synapse.metrics.background_process_metrics import wrap_as_background_process |
|
|
|
|
from synapse.storage._base import SQLBaseStore |
|
|
|
|
from synapse.storage.database import DatabasePool, make_tuple_comparison_clause |
|
|
|
|
from synapse.types import UserID |
|
|
|
|
from synapse.storage.database import ( |
|
|
|
|
DatabasePool, |
|
|
|
|
LoggingDatabaseConnection, |
|
|
|
|
LoggingTransaction, |
|
|
|
|
make_tuple_comparison_clause, |
|
|
|
|
) |
|
|
|
|
from synapse.storage.databases.main.monthly_active_users import MonthlyActiveUsersStore |
|
|
|
|
from synapse.storage.types import Connection |
|
|
|
|
from synapse.types import JsonDict, UserID |
|
|
|
|
from synapse.util.caches.lrucache import LruCache |
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
|
|
from synapse.server import HomeServer |
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
# Number of msec of granularity to store the user IP 'last seen' time. Smaller |
|
|
|
@ -29,8 +41,31 @@ logger = logging.getLogger(__name__) |
|
|
|
|
LAST_SEEN_GRANULARITY = 120 * 1000 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DeviceLastConnectionInfo(TypedDict): |
|
|
|
|
"""Metadata for the last connection seen for a user and device combination""" |
|
|
|
|
|
|
|
|
|
# These types must match the columns in the `devices` table |
|
|
|
|
user_id: str |
|
|
|
|
device_id: str |
|
|
|
|
|
|
|
|
|
ip: Optional[str] |
|
|
|
|
user_agent: Optional[str] |
|
|
|
|
last_seen: Optional[int] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LastConnectionInfo(TypedDict): |
|
|
|
|
"""Metadata for the last connection seen for an access token and IP combination""" |
|
|
|
|
|
|
|
|
|
# These types must match the columns in the `user_ips` table |
|
|
|
|
access_token: str |
|
|
|
|
ip: str |
|
|
|
|
|
|
|
|
|
user_agent: str |
|
|
|
|
last_seen: int |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ClientIpBackgroundUpdateStore(SQLBaseStore): |
|
|
|
|
def __init__(self, database: DatabasePool, db_conn, hs): |
|
|
|
|
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): |
|
|
|
|
super().__init__(database, db_conn, hs) |
|
|
|
|
|
|
|
|
|
self.db_pool.updates.register_background_index_update( |
|
|
|
@ -81,8 +116,10 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): |
|
|
|
|
"devices_last_seen", self._devices_last_seen_update |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
async def _remove_user_ip_nonunique(self, progress, batch_size): |
|
|
|
|
def f(conn): |
|
|
|
|
async def _remove_user_ip_nonunique( |
|
|
|
|
self, progress: JsonDict, batch_size: int |
|
|
|
|
) -> int: |
|
|
|
|
def f(conn: LoggingDatabaseConnection) -> None: |
|
|
|
|
txn = conn.cursor() |
|
|
|
|
txn.execute("DROP INDEX IF EXISTS user_ips_user_ip") |
|
|
|
|
txn.close() |
|
|
|
@ -93,14 +130,14 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): |
|
|
|
|
) |
|
|
|
|
return 1 |
|
|
|
|
|
|
|
|
|
async def _analyze_user_ip(self, progress, batch_size): |
|
|
|
|
async def _analyze_user_ip(self, progress: JsonDict, batch_size: int) -> int: |
|
|
|
|
# Background update to analyze user_ips table before we run the |
|
|
|
|
# deduplication background update. The table may not have been analyzed |
|
|
|
|
# for ages due to the table locks. |
|
|
|
|
# |
|
|
|
|
# This will lock out the naive upserts to user_ips while it happens, but |
|
|
|
|
# the analyze should be quick (28GB table takes ~10s) |
|
|
|
|
def user_ips_analyze(txn): |
|
|
|
|
def user_ips_analyze(txn: LoggingTransaction) -> None: |
|
|
|
|
txn.execute("ANALYZE user_ips") |
|
|
|
|
|
|
|
|
|
await self.db_pool.runInteraction("user_ips_analyze", user_ips_analyze) |
|
|
|
@ -109,16 +146,16 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): |
|
|
|
|
|
|
|
|
|
return 1 |
|
|
|
|
|
|
|
|
|
async def _remove_user_ip_dupes(self, progress, batch_size): |
|
|
|
|
async def _remove_user_ip_dupes(self, progress: JsonDict, batch_size: int) -> int: |
|
|
|
|
# This works function works by scanning the user_ips table in batches |
|
|
|
|
# based on `last_seen`. For each row in a batch it searches the rest of |
|
|
|
|
# the table to see if there are any duplicates, if there are then they |
|
|
|
|
# are removed and replaced with a suitable row. |
|
|
|
|
|
|
|
|
|
# Fetch the start of the batch |
|
|
|
|
begin_last_seen = progress.get("last_seen", 0) |
|
|
|
|
begin_last_seen: int = progress.get("last_seen", 0) |
|
|
|
|
|
|
|
|
|
def get_last_seen(txn): |
|
|
|
|
def get_last_seen(txn: LoggingTransaction) -> Optional[int]: |
|
|
|
|
txn.execute( |
|
|
|
|
""" |
|
|
|
|
SELECT last_seen FROM user_ips |
|
|
|
@ -129,7 +166,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): |
|
|
|
|
""", |
|
|
|
|
(begin_last_seen, batch_size), |
|
|
|
|
) |
|
|
|
|
row = txn.fetchone() |
|
|
|
|
row = cast(Optional[Tuple[int]], txn.fetchone()) |
|
|
|
|
if row: |
|
|
|
|
return row[0] |
|
|
|
|
else: |
|
|
|
@ -149,7 +186,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): |
|
|
|
|
end_last_seen, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
def remove(txn): |
|
|
|
|
def remove(txn: LoggingTransaction) -> None: |
|
|
|
|
# This works by looking at all entries in the given time span, and |
|
|
|
|
# then for each (user_id, access_token, ip) tuple in that range |
|
|
|
|
# checking for any duplicates in the rest of the table (via a join). |
|
|
|
@ -161,10 +198,12 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): |
|
|
|
|
|
|
|
|
|
# Define the search space, which requires handling the last batch in |
|
|
|
|
# a different way |
|
|
|
|
args: Tuple[int, ...] |
|
|
|
|
if last: |
|
|
|
|
clause = "? <= last_seen" |
|
|
|
|
args = (begin_last_seen,) |
|
|
|
|
else: |
|
|
|
|
assert end_last_seen is not None |
|
|
|
|
clause = "? <= last_seen AND last_seen < ?" |
|
|
|
|
args = (begin_last_seen, end_last_seen) |
|
|
|
|
|
|
|
|
@ -189,7 +228,9 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): |
|
|
|
|
), |
|
|
|
|
args, |
|
|
|
|
) |
|
|
|
|
res = txn.fetchall() |
|
|
|
|
res = cast( |
|
|
|
|
List[Tuple[str, str, str, Optional[str], str, int, int]], txn.fetchall() |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
# We've got some duplicates |
|
|
|
|
for i in res: |
|
|
|
@ -278,13 +319,15 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): |
|
|
|
|
|
|
|
|
|
return batch_size |
|
|
|
|
|
|
|
|
|
async def _devices_last_seen_update(self, progress, batch_size): |
|
|
|
|
async def _devices_last_seen_update( |
|
|
|
|
self, progress: JsonDict, batch_size: int |
|
|
|
|
) -> int: |
|
|
|
|
"""Background update to insert last seen info into devices table""" |
|
|
|
|
|
|
|
|
|
last_user_id = progress.get("last_user_id", "") |
|
|
|
|
last_device_id = progress.get("last_device_id", "") |
|
|
|
|
last_user_id: str = progress.get("last_user_id", "") |
|
|
|
|
last_device_id: str = progress.get("last_device_id", "") |
|
|
|
|
|
|
|
|
|
def _devices_last_seen_update_txn(txn): |
|
|
|
|
def _devices_last_seen_update_txn(txn: LoggingTransaction) -> int: |
|
|
|
|
# This consists of two queries: |
|
|
|
|
# |
|
|
|
|
# 1. The sub-query searches for the next N devices and joins |
|
|
|
@ -296,6 +339,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): |
|
|
|
|
# we'll just end up updating the same device row multiple |
|
|
|
|
# times, which is fine. |
|
|
|
|
|
|
|
|
|
where_args: List[Union[str, int]] |
|
|
|
|
where_clause, where_args = make_tuple_comparison_clause( |
|
|
|
|
[("user_id", last_user_id), ("device_id", last_device_id)], |
|
|
|
|
) |
|
|
|
@ -319,7 +363,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): |
|
|
|
|
} |
|
|
|
|
txn.execute(sql, where_args + [batch_size]) |
|
|
|
|
|
|
|
|
|
rows = txn.fetchall() |
|
|
|
|
rows = cast(List[Tuple[int, str, str, str, str]], txn.fetchall()) |
|
|
|
|
if not rows: |
|
|
|
|
return 0 |
|
|
|
|
|
|
|
|
@ -350,7 +394,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): |
|
|
|
|
def __init__(self, database: DatabasePool, db_conn, hs): |
|
|
|
|
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): |
|
|
|
|
super().__init__(database, db_conn, hs) |
|
|
|
|
|
|
|
|
|
self.user_ips_max_age = hs.config.server.user_ips_max_age |
|
|
|
@ -359,7 +403,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): |
|
|
|
|
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) |
|
|
|
|
|
|
|
|
|
@wrap_as_background_process("prune_old_user_ips") |
|
|
|
|
async def _prune_old_user_ips(self): |
|
|
|
|
async def _prune_old_user_ips(self) -> None: |
|
|
|
|
"""Removes entries in user IPs older than the configured period.""" |
|
|
|
|
|
|
|
|
|
if self.user_ips_max_age is None: |
|
|
|
@ -394,9 +438,9 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): |
|
|
|
|
) |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
timestamp = self.clock.time_msec() - self.user_ips_max_age |
|
|
|
|
timestamp = self._clock.time_msec() - self.user_ips_max_age |
|
|
|
|
|
|
|
|
|
def _prune_old_user_ips_txn(txn): |
|
|
|
|
def _prune_old_user_ips_txn(txn: LoggingTransaction) -> None: |
|
|
|
|
txn.execute(sql, (timestamp,)) |
|
|
|
|
|
|
|
|
|
await self.db_pool.runInteraction( |
|
|
|
@ -405,7 +449,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): |
|
|
|
|
|
|
|
|
|
async def get_last_client_ip_by_device( |
|
|
|
|
self, user_id: str, device_id: Optional[str] |
|
|
|
|
) -> Dict[Tuple[str, str], dict]: |
|
|
|
|
) -> Dict[Tuple[str, str], DeviceLastConnectionInfo]: |
|
|
|
|
"""For each device_id listed, give the user_ip it was last seen on. |
|
|
|
|
|
|
|
|
|
The result might be slightly out of date as client IPs are inserted in batches. |
|
|
|
@ -423,26 +467,32 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): |
|
|
|
|
if device_id is not None: |
|
|
|
|
keyvalues["device_id"] = device_id |
|
|
|
|
|
|
|
|
|
res = await self.db_pool.simple_select_list( |
|
|
|
|
table="devices", |
|
|
|
|
keyvalues=keyvalues, |
|
|
|
|
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), |
|
|
|
|
res = cast( |
|
|
|
|
List[DeviceLastConnectionInfo], |
|
|
|
|
await self.db_pool.simple_select_list( |
|
|
|
|
table="devices", |
|
|
|
|
keyvalues=keyvalues, |
|
|
|
|
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), |
|
|
|
|
), |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
return {(d["user_id"], d["device_id"]): d for d in res} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ClientIpStore(ClientIpWorkerStore): |
|
|
|
|
def __init__(self, database: DatabasePool, db_conn, hs): |
|
|
|
|
class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore): |
|
|
|
|
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): |
|
|
|
|
|
|
|
|
|
self.client_ip_last_seen = LruCache( |
|
|
|
|
# (user_id, access_token, ip,) -> last_seen |
|
|
|
|
self.client_ip_last_seen = LruCache[Tuple[str, str, str], int]( |
|
|
|
|
cache_name="client_ip_last_seen", max_size=50000 |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
super().__init__(database, db_conn, hs) |
|
|
|
|
|
|
|
|
|
# (user_id, access_token, ip,) -> (user_agent, device_id, last_seen) |
|
|
|
|
self._batch_row_update = {} |
|
|
|
|
self._batch_row_update: Dict[ |
|
|
|
|
Tuple[str, str, str], Tuple[str, Optional[str], int] |
|
|
|
|
] = {} |
|
|
|
|
|
|
|
|
|
self._client_ip_looper = self._clock.looping_call( |
|
|
|
|
self._update_client_ips_batch, 5 * 1000 |
|
|
|
@ -452,8 +502,14 @@ class ClientIpStore(ClientIpWorkerStore): |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
async def insert_client_ip( |
|
|
|
|
self, user_id, access_token, ip, user_agent, device_id, now=None |
|
|
|
|
): |
|
|
|
|
self, |
|
|
|
|
user_id: str, |
|
|
|
|
access_token: str, |
|
|
|
|
ip: str, |
|
|
|
|
user_agent: str, |
|
|
|
|
device_id: Optional[str], |
|
|
|
|
now: Optional[int] = None, |
|
|
|
|
) -> None: |
|
|
|
|
if not now: |
|
|
|
|
now = int(self._clock.time_msec()) |
|
|
|
|
key = (user_id, access_token, ip) |
|
|
|
@ -485,7 +541,11 @@ class ClientIpStore(ClientIpWorkerStore): |
|
|
|
|
"_update_client_ips_batch", self._update_client_ips_batch_txn, to_update |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
def _update_client_ips_batch_txn(self, txn, to_update): |
|
|
|
|
def _update_client_ips_batch_txn( |
|
|
|
|
self, |
|
|
|
|
txn: LoggingTransaction, |
|
|
|
|
to_update: Mapping[Tuple[str, str, str], Tuple[str, Optional[str], int]], |
|
|
|
|
) -> None: |
|
|
|
|
if "user_ips" in self.db_pool._unsafe_to_upsert_tables or ( |
|
|
|
|
not self.database_engine.can_native_upsert |
|
|
|
|
): |
|
|
|
@ -525,7 +585,7 @@ class ClientIpStore(ClientIpWorkerStore): |
|
|
|
|
|
|
|
|
|
async def get_last_client_ip_by_device( |
|
|
|
|
self, user_id: str, device_id: Optional[str] |
|
|
|
|
) -> Dict[Tuple[str, str], dict]: |
|
|
|
|
) -> Dict[Tuple[str, str], DeviceLastConnectionInfo]: |
|
|
|
|
"""For each device_id listed, give the user_ip it was last seen on |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
@ -561,12 +621,12 @@ class ClientIpStore(ClientIpWorkerStore): |
|
|
|
|
|
|
|
|
|
async def get_user_ip_and_agents( |
|
|
|
|
self, user: UserID, since_ts: int = 0 |
|
|
|
|
) -> List[Dict[str, Union[str, int]]]: |
|
|
|
|
) -> List[LastConnectionInfo]: |
|
|
|
|
""" |
|
|
|
|
Fetch IP/User Agent connection since a given timestamp. |
|
|
|
|
""" |
|
|
|
|
user_id = user.to_string() |
|
|
|
|
results = {} |
|
|
|
|
results: Dict[Tuple[str, str], Tuple[str, int]] = {} |
|
|
|
|
|
|
|
|
|
for key in self._batch_row_update: |
|
|
|
|
( |
|
|
|
@ -579,7 +639,7 @@ class ClientIpStore(ClientIpWorkerStore): |
|
|
|
|
if last_seen >= since_ts: |
|
|
|
|
results[(access_token, ip)] = (user_agent, last_seen) |
|
|
|
|
|
|
|
|
|
def get_recent(txn): |
|
|
|
|
def get_recent(txn: LoggingTransaction) -> List[Tuple[str, str, str, int]]: |
|
|
|
|
txn.execute( |
|
|
|
|
""" |
|
|
|
|
SELECT access_token, ip, user_agent, last_seen FROM user_ips |
|
|
|
@ -589,7 +649,7 @@ class ClientIpStore(ClientIpWorkerStore): |
|
|
|
|
""", |
|
|
|
|
(since_ts, user_id), |
|
|
|
|
) |
|
|
|
|
return txn.fetchall() |
|
|
|
|
return cast(List[Tuple[str, str, str, int]], txn.fetchall()) |
|
|
|
|
|
|
|
|
|
rows = await self.db_pool.runInteraction( |
|
|
|
|
desc="get_user_ip_and_agents", func=get_recent |
|
|
|
|