|
|
|
@ -18,7 +18,7 @@ from mock import Mock |
|
|
|
|
from twisted.internet import defer |
|
|
|
|
|
|
|
|
|
from synapse.api.constants import UserTypes |
|
|
|
|
from synapse.api.errors import ResourceLimitError, SynapseError |
|
|
|
|
from synapse.api.errors import Codes, ResourceLimitError, SynapseError |
|
|
|
|
from synapse.handlers.register import RegistrationHandler |
|
|
|
|
from synapse.types import RoomAlias, UserID, create_requester |
|
|
|
|
|
|
|
|
@ -67,7 +67,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): |
|
|
|
|
user_id = frank.to_string() |
|
|
|
|
requester = create_requester(user_id) |
|
|
|
|
result_user_id, result_token = self.get_success( |
|
|
|
|
self.handler.get_or_create_user(requester, frank.localpart, "Frankie") |
|
|
|
|
self.get_or_create_user(requester, frank.localpart, "Frankie") |
|
|
|
|
) |
|
|
|
|
self.assertEquals(result_user_id, user_id) |
|
|
|
|
self.assertTrue(result_token is not None) |
|
|
|
@ -87,7 +87,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): |
|
|
|
|
user_id = frank.to_string() |
|
|
|
|
requester = create_requester(user_id) |
|
|
|
|
result_user_id, result_token = self.get_success( |
|
|
|
|
self.handler.get_or_create_user(requester, local_part, None) |
|
|
|
|
self.get_or_create_user(requester, local_part, None) |
|
|
|
|
) |
|
|
|
|
self.assertEquals(result_user_id, user_id) |
|
|
|
|
self.assertTrue(result_token is not None) |
|
|
|
@ -95,9 +95,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): |
|
|
|
|
def test_mau_limits_when_disabled(self): |
|
|
|
|
self.hs.config.limit_usage_by_mau = False |
|
|
|
|
# Ensure does not throw exception |
|
|
|
|
self.get_success( |
|
|
|
|
self.handler.get_or_create_user(self.requester, "a", "display_name") |
|
|
|
|
) |
|
|
|
|
self.get_success(self.get_or_create_user(self.requester, "a", "display_name")) |
|
|
|
|
|
|
|
|
|
def test_get_or_create_user_mau_not_blocked(self): |
|
|
|
|
self.hs.config.limit_usage_by_mau = True |
|
|
|
@ -105,7 +103,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): |
|
|
|
|
return_value=defer.succeed(self.hs.config.max_mau_value - 1) |
|
|
|
|
) |
|
|
|
|
# Ensure does not throw exception |
|
|
|
|
self.get_success(self.handler.get_or_create_user(self.requester, "c", "User")) |
|
|
|
|
self.get_success(self.get_or_create_user(self.requester, "c", "User")) |
|
|
|
|
|
|
|
|
|
def test_get_or_create_user_mau_blocked(self): |
|
|
|
|
self.hs.config.limit_usage_by_mau = True |
|
|
|
@ -113,7 +111,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): |
|
|
|
|
return_value=defer.succeed(self.lots_of_users) |
|
|
|
|
) |
|
|
|
|
self.get_failure( |
|
|
|
|
self.handler.get_or_create_user(self.requester, "b", "display_name"), |
|
|
|
|
self.get_or_create_user(self.requester, "b", "display_name"), |
|
|
|
|
ResourceLimitError, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
@ -121,7 +119,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): |
|
|
|
|
return_value=defer.succeed(self.hs.config.max_mau_value) |
|
|
|
|
) |
|
|
|
|
self.get_failure( |
|
|
|
|
self.handler.get_or_create_user(self.requester, "b", "display_name"), |
|
|
|
|
self.get_or_create_user(self.requester, "b", "display_name"), |
|
|
|
|
ResourceLimitError, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
@ -232,3 +230,55 @@ class RegistrationTestCase(unittest.HomeserverTestCase): |
|
|
|
|
def test_invalid_user_id_length(self): |
|
|
|
|
invalid_user_id = "x" * 256 |
|
|
|
|
self.get_failure(self.handler.register(localpart=invalid_user_id), SynapseError) |
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def get_or_create_user(self, requester, localpart, displayname, password_hash=None): |
|
|
|
|
"""Creates a new user if the user does not exist, |
|
|
|
|
else revokes all previous access tokens and generates a new one. |
|
|
|
|
|
|
|
|
|
XXX: this used to be in the main codebase, but was only used by this file, |
|
|
|
|
so got moved here. TODO: get rid of it, probably |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
localpart : The local part of the user ID to register. If None, |
|
|
|
|
one will be randomly generated. |
|
|
|
|
Returns: |
|
|
|
|
A tuple of (user_id, access_token). |
|
|
|
|
Raises: |
|
|
|
|
RegistrationError if there was a problem registering. |
|
|
|
|
""" |
|
|
|
|
if localpart is None: |
|
|
|
|
raise SynapseError(400, "Request must include user id") |
|
|
|
|
yield self.hs.get_auth().check_auth_blocking() |
|
|
|
|
need_register = True |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
yield self.handler.check_username(localpart) |
|
|
|
|
except SynapseError as e: |
|
|
|
|
if e.errcode == Codes.USER_IN_USE: |
|
|
|
|
need_register = False |
|
|
|
|
else: |
|
|
|
|
raise |
|
|
|
|
|
|
|
|
|
user = UserID(localpart, self.hs.hostname) |
|
|
|
|
user_id = user.to_string() |
|
|
|
|
token = self.macaroon_generator.generate_access_token(user_id) |
|
|
|
|
|
|
|
|
|
if need_register: |
|
|
|
|
yield self.handler.register_with_store( |
|
|
|
|
user_id=user_id, |
|
|
|
|
token=token, |
|
|
|
|
password_hash=password_hash, |
|
|
|
|
create_profile_with_displayname=user.localpart, |
|
|
|
|
) |
|
|
|
|
else: |
|
|
|
|
yield self.hs.get_auth_handler().delete_access_tokens_for_user(user_id) |
|
|
|
|
yield self.store.add_access_token_to_user(user_id=user_id, token=token) |
|
|
|
|
|
|
|
|
|
if displayname is not None: |
|
|
|
|
# logger.info("setting user display name: %s -> %s", user_id, displayname) |
|
|
|
|
yield self.hs.get_profile_handler().set_displayname( |
|
|
|
|
user, requester, displayname, by_admin=True |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
defer.returnValue((user_id, token)) |
|
|
|
|