Add startup background job for account validity

If account validity is enabled in the server's configuration, this job will run at startup as a background job and will stick an expiration date to any registered account missing one.
pull/14/head
Brendan Abolivier 6 years ago
parent 8f9ce1a8a2
commit ad5b4074e1
No known key found for this signature in database
GPG Key ID: 1E015C145F1916CD
  1. 62
      synapse/storage/_base.py
  2. 16
      synapse/storage/registration.py
  3. 55
      tests/rest/client/v2_alpha/test_register.py

@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017-2018 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -227,6 +229,8 @@ class SQLBaseStore(object):
# A set of tables that are not safe to use native upserts in.
self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
self._account_validity = self.hs.config.account_validity
# We add the user_directory_search table to the blacklist on SQLite
# because the existing search table does not have an index, making it
# unsafe to use native upserts.
@ -243,6 +247,14 @@ class SQLBaseStore(object):
self._check_safe_to_upsert,
)
if self._account_validity.enabled:
self._clock.call_later(
0.0,
run_as_background_process,
"account_validity_set_expiration_dates",
self._set_expiration_date_when_missing,
)
@defer.inlineCallbacks
def _check_safe_to_upsert(self):
"""
@ -275,6 +287,56 @@ class SQLBaseStore(object):
self._check_safe_to_upsert,
)
@defer.inlineCallbacks
def _set_expiration_date_when_missing(self):
"""
Retrieves the list of registered users that don't have an expiration date, and
adds an expiration date for each of them.
"""
def select_users_with_no_expiration_date_txn(txn):
"""Retrieves the list of registered users with no expiration date from the
database.
"""
sql = (
"SELECT users.name FROM users"
" LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
" WHERE account_validity.user_id is NULL;"
)
txn.execute(sql, [])
return self.cursor_to_dict(txn)
res = yield self.runInteraction(
"get_users_with_no_expiration_date",
select_users_with_no_expiration_date_txn,
)
if res:
for user in res:
self.runInteraction(
"set_expiration_date_for_user_background",
self.set_expiration_date_for_user_txn,
user["name"],
)
def set_expiration_date_for_user_txn(self, txn, user_id):
"""Sets an expiration date to the account with the given user ID.
Args:
user_id (str): User ID to set an expiration date for.
"""
now_ms = self._clock.time_msec()
expiration_ts = now_ms + self._account_validity.period
self._simple_insert_txn(
txn,
"account_validity",
values={
"user_id": user_id,
"expiration_ts_ms": expiration_ts,
"email_sent": False,
},
)
def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec()

@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017-2018 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -725,17 +727,7 @@ class RegistrationStore(
raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE)
if self._account_validity.enabled:
now_ms = self.clock.time_msec()
expiration_ts = now_ms + self._account_validity.period
self._simple_insert_txn(
txn,
"account_validity",
values={
"user_id": user_id,
"expiration_ts_ms": expiration_ts,
"email_sent": False,
}
)
self.set_expiration_date_for_user_txn(txn, user_id)
if token:
# it's possible for this to get a conflict, but only for a single user

@ -1,3 +1,20 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017-2018 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import json
import os
@ -409,3 +426,41 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
self.assertEqual(len(self.email_attempts), 1)
class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
]
def make_homeserver(self, reactor, clock):
self.validity_period = 10
config = self.default_config()
config["enable_registration"] = True
config["account_validity"] = {
"enabled": False,
}
self.hs = self.setup_test_homeserver(config=config)
self.hs.config.account_validity.period = self.validity_period
self.store = self.hs.get_datastore()
return self.hs
def test_background_job(self):
"""
Tests whether the account validity startup background job does the right thing,
which is sticking an expiration date to every account that doesn't already have
one.
"""
user_id = self.register_user("kermit", "user")
now_ms = self.hs.clock.time_msec()
self.get_success(self.store._set_expiration_date_when_missing())
res = self.get_success(self.store.get_expiration_ts_for_user(user_id))
self.assertEqual(res, now_ms + self.validity_period)

Loading…
Cancel
Save