|
|
|
@ -21,7 +21,6 @@ from mock import Mock, patch |
|
|
|
|
import attr |
|
|
|
|
import pymacaroons |
|
|
|
|
|
|
|
|
|
from twisted.internet import defer |
|
|
|
|
from twisted.python.failure import Failure |
|
|
|
|
from twisted.web._newclient import ResponseDone |
|
|
|
|
|
|
|
|
@ -87,6 +86,13 @@ class TestMappingProvider(OidcMappingProvider): |
|
|
|
|
async def map_user_attributes(self, userinfo, token): |
|
|
|
|
return {"localpart": userinfo["username"], "display_name": None} |
|
|
|
|
|
|
|
|
|
# Do not include get_extra_attributes to test backwards compatibility paths. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestMappingProviderExtra(TestMappingProvider): |
|
|
|
|
async def get_extra_attributes(self, userinfo, token): |
|
|
|
|
return {"phone": userinfo["phone"]} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def simple_async_mock(return_value=None, raises=None): |
|
|
|
|
# AsyncMock is not available in python3.5, this mimics part of its behaviour |
|
|
|
@ -126,7 +132,7 @@ class OidcHandlerTestCase(HomeserverTestCase): |
|
|
|
|
|
|
|
|
|
config = self.default_config() |
|
|
|
|
config["public_baseurl"] = BASE_URL |
|
|
|
|
oidc_config = config.get("oidc_config", {}) |
|
|
|
|
oidc_config = {} |
|
|
|
|
oidc_config["enabled"] = True |
|
|
|
|
oidc_config["client_id"] = CLIENT_ID |
|
|
|
|
oidc_config["client_secret"] = CLIENT_SECRET |
|
|
|
@ -135,6 +141,10 @@ class OidcHandlerTestCase(HomeserverTestCase): |
|
|
|
|
oidc_config["user_mapping_provider"] = { |
|
|
|
|
"module": __name__ + ".TestMappingProvider", |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
# Update this config with what's in the default config so that |
|
|
|
|
# override_config works as expected. |
|
|
|
|
oidc_config.update(config.get("oidc_config", {})) |
|
|
|
|
config["oidc_config"] = oidc_config |
|
|
|
|
|
|
|
|
|
hs = self.setup_test_homeserver( |
|
|
|
@ -165,11 +175,10 @@ class OidcHandlerTestCase(HomeserverTestCase): |
|
|
|
|
self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET) |
|
|
|
|
|
|
|
|
|
@override_config({"oidc_config": {"discover": True}}) |
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def test_discovery(self): |
|
|
|
|
"""The handler should discover the endpoints from OIDC discovery document.""" |
|
|
|
|
# This would throw if some metadata were invalid |
|
|
|
|
metadata = yield defer.ensureDeferred(self.handler.load_metadata()) |
|
|
|
|
metadata = self.get_success(self.handler.load_metadata()) |
|
|
|
|
self.http_client.get_json.assert_called_once_with(WELL_KNOWN) |
|
|
|
|
|
|
|
|
|
self.assertEqual(metadata.issuer, ISSUER) |
|
|
|
@ -181,43 +190,40 @@ class OidcHandlerTestCase(HomeserverTestCase): |
|
|
|
|
|
|
|
|
|
# subsequent calls should be cached |
|
|
|
|
self.http_client.reset_mock() |
|
|
|
|
yield defer.ensureDeferred(self.handler.load_metadata()) |
|
|
|
|
self.get_success(self.handler.load_metadata()) |
|
|
|
|
self.http_client.get_json.assert_not_called() |
|
|
|
|
|
|
|
|
|
@override_config({"oidc_config": COMMON_CONFIG}) |
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def test_no_discovery(self): |
|
|
|
|
"""When discovery is disabled, it should not try to load from discovery document.""" |
|
|
|
|
yield defer.ensureDeferred(self.handler.load_metadata()) |
|
|
|
|
self.get_success(self.handler.load_metadata()) |
|
|
|
|
self.http_client.get_json.assert_not_called() |
|
|
|
|
|
|
|
|
|
@override_config({"oidc_config": COMMON_CONFIG}) |
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def test_load_jwks(self): |
|
|
|
|
"""JWKS loading is done once (then cached) if used.""" |
|
|
|
|
jwks = yield defer.ensureDeferred(self.handler.load_jwks()) |
|
|
|
|
jwks = self.get_success(self.handler.load_jwks()) |
|
|
|
|
self.http_client.get_json.assert_called_once_with(JWKS_URI) |
|
|
|
|
self.assertEqual(jwks, {"keys": []}) |
|
|
|
|
|
|
|
|
|
# subsequent calls should be cached… |
|
|
|
|
self.http_client.reset_mock() |
|
|
|
|
yield defer.ensureDeferred(self.handler.load_jwks()) |
|
|
|
|
self.get_success(self.handler.load_jwks()) |
|
|
|
|
self.http_client.get_json.assert_not_called() |
|
|
|
|
|
|
|
|
|
# …unless forced |
|
|
|
|
self.http_client.reset_mock() |
|
|
|
|
yield defer.ensureDeferred(self.handler.load_jwks(force=True)) |
|
|
|
|
self.get_success(self.handler.load_jwks(force=True)) |
|
|
|
|
self.http_client.get_json.assert_called_once_with(JWKS_URI) |
|
|
|
|
|
|
|
|
|
# Throw if the JWKS uri is missing |
|
|
|
|
with self.metadata_edit({"jwks_uri": None}): |
|
|
|
|
with self.assertRaises(RuntimeError): |
|
|
|
|
yield defer.ensureDeferred(self.handler.load_jwks(force=True)) |
|
|
|
|
self.get_failure(self.handler.load_jwks(force=True), RuntimeError) |
|
|
|
|
|
|
|
|
|
# Return empty key set if JWKS are not used |
|
|
|
|
self.handler._scopes = [] # not asking the openid scope |
|
|
|
|
self.http_client.get_json.reset_mock() |
|
|
|
|
jwks = yield defer.ensureDeferred(self.handler.load_jwks(force=True)) |
|
|
|
|
jwks = self.get_success(self.handler.load_jwks(force=True)) |
|
|
|
|
self.http_client.get_json.assert_not_called() |
|
|
|
|
self.assertEqual(jwks, {"keys": []}) |
|
|
|
|
|
|
|
|
@ -299,11 +305,10 @@ class OidcHandlerTestCase(HomeserverTestCase): |
|
|
|
|
# This should not throw |
|
|
|
|
self.handler._validate_metadata() |
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def test_redirect_request(self): |
|
|
|
|
"""The redirect request has the right arguments & generates a valid session cookie.""" |
|
|
|
|
req = Mock(spec=["addCookie"]) |
|
|
|
|
url = yield defer.ensureDeferred( |
|
|
|
|
url = self.get_success( |
|
|
|
|
self.handler.handle_redirect_request(req, b"http://client/redirect") |
|
|
|
|
) |
|
|
|
|
url = urlparse(url) |
|
|
|
@ -343,20 +348,18 @@ class OidcHandlerTestCase(HomeserverTestCase): |
|
|
|
|
self.assertEqual(params["nonce"], [nonce]) |
|
|
|
|
self.assertEqual(redirect, "http://client/redirect") |
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def test_callback_error(self): |
|
|
|
|
"""Errors from the provider returned in the callback are displayed.""" |
|
|
|
|
self.handler._render_error = Mock() |
|
|
|
|
request = Mock(args={}) |
|
|
|
|
request.args[b"error"] = [b"invalid_client"] |
|
|
|
|
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) |
|
|
|
|
self.get_success(self.handler.handle_oidc_callback(request)) |
|
|
|
|
self.assertRenderedError("invalid_client", "") |
|
|
|
|
|
|
|
|
|
request.args[b"error_description"] = [b"some description"] |
|
|
|
|
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) |
|
|
|
|
self.get_success(self.handler.handle_oidc_callback(request)) |
|
|
|
|
self.assertRenderedError("invalid_client", "some description") |
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def test_callback(self): |
|
|
|
|
"""Code callback works and display errors if something went wrong. |
|
|
|
|
|
|
|
|
@ -377,7 +380,7 @@ class OidcHandlerTestCase(HomeserverTestCase): |
|
|
|
|
"sub": "foo", |
|
|
|
|
"preferred_username": "bar", |
|
|
|
|
} |
|
|
|
|
user_id = UserID("foo", "domain.org") |
|
|
|
|
user_id = "@foo:domain.org" |
|
|
|
|
self.handler._render_error = Mock(return_value=None) |
|
|
|
|
self.handler._exchange_code = simple_async_mock(return_value=token) |
|
|
|
|
self.handler._parse_id_token = simple_async_mock(return_value=userinfo) |
|
|
|
@ -394,13 +397,12 @@ class OidcHandlerTestCase(HomeserverTestCase): |
|
|
|
|
client_redirect_url = "http://client/redirect" |
|
|
|
|
user_agent = "Browser" |
|
|
|
|
ip_address = "10.0.0.1" |
|
|
|
|
session = self.handler._generate_oidc_session_token( |
|
|
|
|
request.getCookie.return_value = self.handler._generate_oidc_session_token( |
|
|
|
|
state=state, |
|
|
|
|
nonce=nonce, |
|
|
|
|
client_redirect_url=client_redirect_url, |
|
|
|
|
ui_auth_session_id=None, |
|
|
|
|
) |
|
|
|
|
request.getCookie.return_value = session |
|
|
|
|
|
|
|
|
|
request.args = {} |
|
|
|
|
request.args[b"code"] = [code.encode("utf-8")] |
|
|
|
@ -410,10 +412,10 @@ class OidcHandlerTestCase(HomeserverTestCase): |
|
|
|
|
request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")] |
|
|
|
|
request.getClientIP.return_value = ip_address |
|
|
|
|
|
|
|
|
|
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) |
|
|
|
|
self.get_success(self.handler.handle_oidc_callback(request)) |
|
|
|
|
|
|
|
|
|
self.handler._auth_handler.complete_sso_login.assert_called_once_with( |
|
|
|
|
user_id, request, client_redirect_url, |
|
|
|
|
user_id, request, client_redirect_url, {}, |
|
|
|
|
) |
|
|
|
|
self.handler._exchange_code.assert_called_once_with(code) |
|
|
|
|
self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce) |
|
|
|
@ -427,13 +429,13 @@ class OidcHandlerTestCase(HomeserverTestCase): |
|
|
|
|
self.handler._map_userinfo_to_user = simple_async_mock( |
|
|
|
|
raises=MappingException() |
|
|
|
|
) |
|
|
|
|
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) |
|
|
|
|
self.get_success(self.handler.handle_oidc_callback(request)) |
|
|
|
|
self.assertRenderedError("mapping_error") |
|
|
|
|
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) |
|
|
|
|
|
|
|
|
|
# Handle ID token errors |
|
|
|
|
self.handler._parse_id_token = simple_async_mock(raises=Exception()) |
|
|
|
|
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) |
|
|
|
|
self.get_success(self.handler.handle_oidc_callback(request)) |
|
|
|
|
self.assertRenderedError("invalid_token") |
|
|
|
|
|
|
|
|
|
self.handler._auth_handler.complete_sso_login.reset_mock() |
|
|
|
@ -444,10 +446,10 @@ class OidcHandlerTestCase(HomeserverTestCase): |
|
|
|
|
|
|
|
|
|
# With userinfo fetching |
|
|
|
|
self.handler._scopes = [] # do not ask the "openid" scope |
|
|
|
|
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) |
|
|
|
|
self.get_success(self.handler.handle_oidc_callback(request)) |
|
|
|
|
|
|
|
|
|
self.handler._auth_handler.complete_sso_login.assert_called_once_with( |
|
|
|
|
user_id, request, client_redirect_url, |
|
|
|
|
user_id, request, client_redirect_url, {}, |
|
|
|
|
) |
|
|
|
|
self.handler._exchange_code.assert_called_once_with(code) |
|
|
|
|
self.handler._parse_id_token.assert_not_called() |
|
|
|
@ -459,17 +461,16 @@ class OidcHandlerTestCase(HomeserverTestCase): |
|
|
|
|
|
|
|
|
|
# Handle userinfo fetching error |
|
|
|
|
self.handler._fetch_userinfo = simple_async_mock(raises=Exception()) |
|
|
|
|
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) |
|
|
|
|
self.get_success(self.handler.handle_oidc_callback(request)) |
|
|
|
|
self.assertRenderedError("fetch_error") |
|
|
|
|
|
|
|
|
|
# Handle code exchange failure |
|
|
|
|
self.handler._exchange_code = simple_async_mock( |
|
|
|
|
raises=OidcError("invalid_request") |
|
|
|
|
) |
|
|
|
|
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) |
|
|
|
|
self.get_success(self.handler.handle_oidc_callback(request)) |
|
|
|
|
self.assertRenderedError("invalid_request") |
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def test_callback_session(self): |
|
|
|
|
"""The callback verifies the session presence and validity""" |
|
|
|
|
self.handler._render_error = Mock(return_value=None) |
|
|
|
@ -478,20 +479,20 @@ class OidcHandlerTestCase(HomeserverTestCase): |
|
|
|
|
# Missing cookie |
|
|
|
|
request.args = {} |
|
|
|
|
request.getCookie.return_value = None |
|
|
|
|
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) |
|
|
|
|
self.get_success(self.handler.handle_oidc_callback(request)) |
|
|
|
|
self.assertRenderedError("missing_session", "No session cookie found") |
|
|
|
|
|
|
|
|
|
# Missing session parameter |
|
|
|
|
request.args = {} |
|
|
|
|
request.getCookie.return_value = "session" |
|
|
|
|
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) |
|
|
|
|
self.get_success(self.handler.handle_oidc_callback(request)) |
|
|
|
|
self.assertRenderedError("invalid_request", "State parameter is missing") |
|
|
|
|
|
|
|
|
|
# Invalid cookie |
|
|
|
|
request.args = {} |
|
|
|
|
request.args[b"state"] = [b"state"] |
|
|
|
|
request.getCookie.return_value = "session" |
|
|
|
|
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) |
|
|
|
|
self.get_success(self.handler.handle_oidc_callback(request)) |
|
|
|
|
self.assertRenderedError("invalid_session") |
|
|
|
|
|
|
|
|
|
# Mismatching session |
|
|
|
@ -504,18 +505,17 @@ class OidcHandlerTestCase(HomeserverTestCase): |
|
|
|
|
request.args = {} |
|
|
|
|
request.args[b"state"] = [b"mismatching state"] |
|
|
|
|
request.getCookie.return_value = session |
|
|
|
|
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) |
|
|
|
|
self.get_success(self.handler.handle_oidc_callback(request)) |
|
|
|
|
self.assertRenderedError("mismatching_session") |
|
|
|
|
|
|
|
|
|
# Valid session |
|
|
|
|
request.args = {} |
|
|
|
|
request.args[b"state"] = [b"state"] |
|
|
|
|
request.getCookie.return_value = session |
|
|
|
|
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) |
|
|
|
|
self.get_success(self.handler.handle_oidc_callback(request)) |
|
|
|
|
self.assertRenderedError("invalid_request") |
|
|
|
|
|
|
|
|
|
@override_config({"oidc_config": {"client_auth_method": "client_secret_post"}}) |
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def test_exchange_code(self): |
|
|
|
|
"""Code exchange behaves correctly and handles various error scenarios.""" |
|
|
|
|
token = {"type": "bearer"} |
|
|
|
@ -524,7 +524,7 @@ class OidcHandlerTestCase(HomeserverTestCase): |
|
|
|
|
return_value=FakeResponse(code=200, phrase=b"OK", body=token_json) |
|
|
|
|
) |
|
|
|
|
code = "code" |
|
|
|
|
ret = yield defer.ensureDeferred(self.handler._exchange_code(code)) |
|
|
|
|
ret = self.get_success(self.handler._exchange_code(code)) |
|
|
|
|
kwargs = self.http_client.request.call_args[1] |
|
|
|
|
|
|
|
|
|
self.assertEqual(ret, token) |
|
|
|
@ -546,10 +546,9 @@ class OidcHandlerTestCase(HomeserverTestCase): |
|
|
|
|
body=b'{"error": "foo", "error_description": "bar"}', |
|
|
|
|
) |
|
|
|
|
) |
|
|
|
|
with self.assertRaises(OidcError) as exc: |
|
|
|
|
yield defer.ensureDeferred(self.handler._exchange_code(code)) |
|
|
|
|
self.assertEqual(exc.exception.error, "foo") |
|
|
|
|
self.assertEqual(exc.exception.error_description, "bar") |
|
|
|
|
exc = self.get_failure(self.handler._exchange_code(code), OidcError) |
|
|
|
|
self.assertEqual(exc.value.error, "foo") |
|
|
|
|
self.assertEqual(exc.value.error_description, "bar") |
|
|
|
|
|
|
|
|
|
# Internal server error with no JSON body |
|
|
|
|
self.http_client.request = simple_async_mock( |
|
|
|
@ -557,9 +556,8 @@ class OidcHandlerTestCase(HomeserverTestCase): |
|
|
|
|
code=500, phrase=b"Internal Server Error", body=b"Not JSON", |
|
|
|
|
) |
|
|
|
|
) |
|
|
|
|
with self.assertRaises(OidcError) as exc: |
|
|
|
|
yield defer.ensureDeferred(self.handler._exchange_code(code)) |
|
|
|
|
self.assertEqual(exc.exception.error, "server_error") |
|
|
|
|
exc = self.get_failure(self.handler._exchange_code(code), OidcError) |
|
|
|
|
self.assertEqual(exc.value.error, "server_error") |
|
|
|
|
|
|
|
|
|
# Internal server error with JSON body |
|
|
|
|
self.http_client.request = simple_async_mock( |
|
|
|
@ -569,17 +567,16 @@ class OidcHandlerTestCase(HomeserverTestCase): |
|
|
|
|
body=b'{"error": "internal_server_error"}', |
|
|
|
|
) |
|
|
|
|
) |
|
|
|
|
with self.assertRaises(OidcError) as exc: |
|
|
|
|
yield defer.ensureDeferred(self.handler._exchange_code(code)) |
|
|
|
|
self.assertEqual(exc.exception.error, "internal_server_error") |
|
|
|
|
|
|
|
|
|
exc = self.get_failure(self.handler._exchange_code(code), OidcError) |
|
|
|
|
self.assertEqual(exc.value.error, "internal_server_error") |
|
|
|
|
|
|
|
|
|
# 4xx error without "error" field |
|
|
|
|
self.http_client.request = simple_async_mock( |
|
|
|
|
return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",) |
|
|
|
|
) |
|
|
|
|
with self.assertRaises(OidcError) as exc: |
|
|
|
|
yield defer.ensureDeferred(self.handler._exchange_code(code)) |
|
|
|
|
self.assertEqual(exc.exception.error, "server_error") |
|
|
|
|
exc = self.get_failure(self.handler._exchange_code(code), OidcError) |
|
|
|
|
self.assertEqual(exc.value.error, "server_error") |
|
|
|
|
|
|
|
|
|
# 2xx error with "error" field |
|
|
|
|
self.http_client.request = simple_async_mock( |
|
|
|
@ -587,9 +584,62 @@ class OidcHandlerTestCase(HomeserverTestCase): |
|
|
|
|
code=200, phrase=b"OK", body=b'{"error": "some_error"}', |
|
|
|
|
) |
|
|
|
|
) |
|
|
|
|
with self.assertRaises(OidcError) as exc: |
|
|
|
|
yield defer.ensureDeferred(self.handler._exchange_code(code)) |
|
|
|
|
self.assertEqual(exc.exception.error, "some_error") |
|
|
|
|
exc = self.get_failure(self.handler._exchange_code(code), OidcError) |
|
|
|
|
self.assertEqual(exc.value.error, "some_error") |
|
|
|
|
|
|
|
|
|
@override_config( |
|
|
|
|
{ |
|
|
|
|
"oidc_config": { |
|
|
|
|
"user_mapping_provider": { |
|
|
|
|
"module": __name__ + ".TestMappingProviderExtra" |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
) |
|
|
|
|
def test_extra_attributes(self): |
|
|
|
|
""" |
|
|
|
|
Login while using a mapping provider that implements get_extra_attributes. |
|
|
|
|
""" |
|
|
|
|
token = { |
|
|
|
|
"type": "bearer", |
|
|
|
|
"id_token": "id_token", |
|
|
|
|
"access_token": "access_token", |
|
|
|
|
} |
|
|
|
|
userinfo = { |
|
|
|
|
"sub": "foo", |
|
|
|
|
"phone": "1234567", |
|
|
|
|
} |
|
|
|
|
user_id = "@foo:domain.org" |
|
|
|
|
self.handler._exchange_code = simple_async_mock(return_value=token) |
|
|
|
|
self.handler._parse_id_token = simple_async_mock(return_value=userinfo) |
|
|
|
|
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) |
|
|
|
|
self.handler._auth_handler.complete_sso_login = simple_async_mock() |
|
|
|
|
request = Mock( |
|
|
|
|
spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"] |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
state = "state" |
|
|
|
|
client_redirect_url = "http://client/redirect" |
|
|
|
|
request.getCookie.return_value = self.handler._generate_oidc_session_token( |
|
|
|
|
state=state, |
|
|
|
|
nonce="nonce", |
|
|
|
|
client_redirect_url=client_redirect_url, |
|
|
|
|
ui_auth_session_id=None, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
request.args = {} |
|
|
|
|
request.args[b"code"] = [b"code"] |
|
|
|
|
request.args[b"state"] = [state.encode("utf-8")] |
|
|
|
|
|
|
|
|
|
request.requestHeaders = Mock(spec=["getRawHeaders"]) |
|
|
|
|
request.requestHeaders.getRawHeaders.return_value = [b"Browser"] |
|
|
|
|
request.getClientIP.return_value = "10.0.0.1" |
|
|
|
|
|
|
|
|
|
self.get_success(self.handler.handle_oidc_callback(request)) |
|
|
|
|
|
|
|
|
|
self.handler._auth_handler.complete_sso_login.assert_called_once_with( |
|
|
|
|
user_id, request, client_redirect_url, {"phone": "1234567"}, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
def test_map_userinfo_to_user(self): |
|
|
|
|
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly.""" |
|
|
|
|