|
|
|
|
@ -425,17 +425,14 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): |
|
|
|
|
account timestamp as milliseconds since the epoch. None if the account |
|
|
|
|
has not been renewed using the current token yet. |
|
|
|
|
""" |
|
|
|
|
ret_dict = await self.db_pool.simple_select_one( |
|
|
|
|
table="account_validity", |
|
|
|
|
keyvalues={"renewal_token": renewal_token}, |
|
|
|
|
retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"], |
|
|
|
|
desc="get_user_from_renewal_token", |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
return ( |
|
|
|
|
ret_dict["user_id"], |
|
|
|
|
ret_dict["expiration_ts_ms"], |
|
|
|
|
ret_dict["token_used_ts_ms"], |
|
|
|
|
return cast( |
|
|
|
|
Tuple[str, int, Optional[int]], |
|
|
|
|
await self.db_pool.simple_select_one( |
|
|
|
|
table="account_validity", |
|
|
|
|
keyvalues={"renewal_token": renewal_token}, |
|
|
|
|
retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"], |
|
|
|
|
desc="get_user_from_renewal_token", |
|
|
|
|
), |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
async def get_renewal_token_for_user(self, user_id: str) -> str: |
|
|
|
|
@ -989,16 +986,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): |
|
|
|
|
Returns: |
|
|
|
|
user id, or None if no user id/threepid mapping exists |
|
|
|
|
""" |
|
|
|
|
ret = self.db_pool.simple_select_one_txn( |
|
|
|
|
return self.db_pool.simple_select_one_onecol_txn( |
|
|
|
|
txn, |
|
|
|
|
"user_threepids", |
|
|
|
|
{"medium": medium, "address": address}, |
|
|
|
|
["user_id"], |
|
|
|
|
"user_id", |
|
|
|
|
True, |
|
|
|
|
) |
|
|
|
|
if ret: |
|
|
|
|
return ret["user_id"] |
|
|
|
|
return None |
|
|
|
|
|
|
|
|
|
async def user_add_threepid( |
|
|
|
|
self, |
|
|
|
|
@ -1435,16 +1429,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): |
|
|
|
|
if res is None: |
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
uses_allowed, pending, completed, expiry_time = res |
|
|
|
|
|
|
|
|
|
# Check if the token has expired |
|
|
|
|
now = self._clock.time_msec() |
|
|
|
|
if res["expiry_time"] and res["expiry_time"] < now: |
|
|
|
|
if expiry_time and expiry_time < now: |
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
# Check if the token has been used up |
|
|
|
|
if ( |
|
|
|
|
res["uses_allowed"] |
|
|
|
|
and res["pending"] + res["completed"] >= res["uses_allowed"] |
|
|
|
|
): |
|
|
|
|
if uses_allowed and pending + completed >= uses_allowed: |
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
# Otherwise, the token is valid |
|
|
|
|
@ -1490,8 +1483,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): |
|
|
|
|
# Override type because the return type is only optional if |
|
|
|
|
# allow_none is True, and we don't want mypy throwing errors |
|
|
|
|
# about None not being indexable. |
|
|
|
|
res = cast( |
|
|
|
|
Dict[str, Any], |
|
|
|
|
pending, completed = cast( |
|
|
|
|
Tuple[int, int], |
|
|
|
|
self.db_pool.simple_select_one_txn( |
|
|
|
|
txn, |
|
|
|
|
"registration_tokens", |
|
|
|
|
@ -1506,8 +1499,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): |
|
|
|
|
"registration_tokens", |
|
|
|
|
keyvalues={"token": token}, |
|
|
|
|
updatevalues={ |
|
|
|
|
"completed": res["completed"] + 1, |
|
|
|
|
"pending": res["pending"] - 1, |
|
|
|
|
"completed": completed + 1, |
|
|
|
|
"pending": pending - 1, |
|
|
|
|
}, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
@ -1585,13 +1578,22 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): |
|
|
|
|
Returns: |
|
|
|
|
A dict, or None if token doesn't exist. |
|
|
|
|
""" |
|
|
|
|
return await self.db_pool.simple_select_one( |
|
|
|
|
row = await self.db_pool.simple_select_one( |
|
|
|
|
"registration_tokens", |
|
|
|
|
keyvalues={"token": token}, |
|
|
|
|
retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"], |
|
|
|
|
allow_none=True, |
|
|
|
|
desc="get_one_registration_token", |
|
|
|
|
) |
|
|
|
|
if row is None: |
|
|
|
|
return None |
|
|
|
|
return { |
|
|
|
|
"token": row[0], |
|
|
|
|
"uses_allowed": row[1], |
|
|
|
|
"pending": row[2], |
|
|
|
|
"completed": row[3], |
|
|
|
|
"expiry_time": row[4], |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
async def generate_registration_token( |
|
|
|
|
self, length: int, chars: str |
|
|
|
|
@ -1714,7 +1716,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): |
|
|
|
|
return None |
|
|
|
|
|
|
|
|
|
# Get all info about the token so it can be sent in the response |
|
|
|
|
return self.db_pool.simple_select_one_txn( |
|
|
|
|
result = self.db_pool.simple_select_one_txn( |
|
|
|
|
txn, |
|
|
|
|
"registration_tokens", |
|
|
|
|
keyvalues={"token": token}, |
|
|
|
|
@ -1728,6 +1730,17 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): |
|
|
|
|
allow_none=True, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
if result is None: |
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
return { |
|
|
|
|
"token": result[0], |
|
|
|
|
"uses_allowed": result[1], |
|
|
|
|
"pending": result[2], |
|
|
|
|
"completed": result[3], |
|
|
|
|
"expiry_time": result[4], |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
return await self.db_pool.runInteraction( |
|
|
|
|
"update_registration_token", _update_registration_token_txn |
|
|
|
|
) |
|
|
|
|
@ -1939,11 +1952,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): |
|
|
|
|
keyvalues={"token": token}, |
|
|
|
|
updatevalues={"used_ts": ts}, |
|
|
|
|
) |
|
|
|
|
user_id = values["user_id"] |
|
|
|
|
expiry_ts = values["expiry_ts"] |
|
|
|
|
used_ts = values["used_ts"] |
|
|
|
|
auth_provider_id = values["auth_provider_id"] |
|
|
|
|
auth_provider_session_id = values["auth_provider_session_id"] |
|
|
|
|
( |
|
|
|
|
user_id, |
|
|
|
|
expiry_ts, |
|
|
|
|
used_ts, |
|
|
|
|
auth_provider_id, |
|
|
|
|
auth_provider_session_id, |
|
|
|
|
) = values |
|
|
|
|
|
|
|
|
|
# Token was already used |
|
|
|
|
if used_ts is not None: |
|
|
|
|
@ -2756,12 +2771,11 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): |
|
|
|
|
# reason, the next check is on the client secret, which is NOT NULL, |
|
|
|
|
# so we don't have to worry about the client secret matching by |
|
|
|
|
# accident. |
|
|
|
|
row = {"client_secret": None, "validated_at": None} |
|
|
|
|
row = None, None |
|
|
|
|
else: |
|
|
|
|
raise ThreepidValidationError("Unknown session_id") |
|
|
|
|
|
|
|
|
|
retrieved_client_secret = row["client_secret"] |
|
|
|
|
validated_at = row["validated_at"] |
|
|
|
|
retrieved_client_secret, validated_at = row |
|
|
|
|
|
|
|
|
|
row = self.db_pool.simple_select_one_txn( |
|
|
|
|
txn, |
|
|
|
|
@ -2775,8 +2789,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): |
|
|
|
|
raise ThreepidValidationError( |
|
|
|
|
"Validation token not found or has expired" |
|
|
|
|
) |
|
|
|
|
expires = row["expires"] |
|
|
|
|
next_link = row["next_link"] |
|
|
|
|
expires, next_link = row |
|
|
|
|
|
|
|
|
|
if retrieved_client_secret != client_secret: |
|
|
|
|
raise ThreepidValidationError( |
|
|
|
|
|