|
|
|
@ -13,6 +13,7 @@ |
|
|
|
|
# See the License for the specific language governing permissions and |
|
|
|
|
# limitations under the License. |
|
|
|
|
import json |
|
|
|
|
import os |
|
|
|
|
from urllib.parse import parse_qs, urlparse |
|
|
|
|
|
|
|
|
|
from mock import ANY, Mock, patch |
|
|
|
@ -50,7 +51,18 @@ WELL_KNOWN = ISSUER + ".well-known/openid-configuration" |
|
|
|
|
JWKS_URI = ISSUER + ".well-known/jwks.json" |
|
|
|
|
|
|
|
|
|
# config for common cases |
|
|
|
|
COMMON_CONFIG = { |
|
|
|
|
DEFAULT_CONFIG = { |
|
|
|
|
"enabled": True, |
|
|
|
|
"client_id": CLIENT_ID, |
|
|
|
|
"client_secret": CLIENT_SECRET, |
|
|
|
|
"issuer": ISSUER, |
|
|
|
|
"scopes": SCOPES, |
|
|
|
|
"user_mapping_provider": {"module": __name__ + ".TestMappingProvider"}, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
# extends the default config with explicit OAuth2 endpoints instead of using discovery |
|
|
|
|
EXPLICIT_ENDPOINT_CONFIG = { |
|
|
|
|
**DEFAULT_CONFIG, |
|
|
|
|
"discover": False, |
|
|
|
|
"authorization_endpoint": AUTHORIZATION_ENDPOINT, |
|
|
|
|
"token_endpoint": TOKEN_ENDPOINT, |
|
|
|
@ -107,6 +119,32 @@ async def get_json(url): |
|
|
|
|
return {"keys": []} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _key_file_path() -> str: |
|
|
|
|
"""path to a file containing the private half of a test key""" |
|
|
|
|
|
|
|
|
|
# this key was generated with: |
|
|
|
|
# openssl ecparam -name prime256v1 -genkey -noout | |
|
|
|
|
# openssl pkcs8 -topk8 -nocrypt -out oidc_test_key.p8 |
|
|
|
|
# |
|
|
|
|
# we use PKCS8 rather than SEC-1 (which is what openssl ecparam spits out), because |
|
|
|
|
# that's what Apple use, and we want to be sure that we work with Apple's keys. |
|
|
|
|
# |
|
|
|
|
# (For the record: both PKCS8 and SEC-1 specify (different) ways of representing |
|
|
|
|
# keys using ASN.1. Both are then typically formatted using PEM, which says: use the |
|
|
|
|
# base64-encoded DER encoding of ASN.1, with headers and footers. But we don't |
|
|
|
|
# really need to care about any of that.) |
|
|
|
|
return os.path.join(os.path.dirname(__file__), "oidc_test_key.p8") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _public_key_file_path() -> str: |
|
|
|
|
"""path to a file containing the public half of a test key""" |
|
|
|
|
# this was generated with: |
|
|
|
|
# openssl ec -in oidc_test_key.p8 -pubout -out oidc_test_key.pub.pem |
|
|
|
|
# |
|
|
|
|
# See above about where oidc_test_key.p8 came from |
|
|
|
|
return os.path.join(os.path.dirname(__file__), "oidc_test_key.pub.pem") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OidcHandlerTestCase(HomeserverTestCase): |
|
|
|
|
if not HAS_OIDC: |
|
|
|
|
skip = "requires OIDC" |
|
|
|
@ -114,20 +152,6 @@ class OidcHandlerTestCase(HomeserverTestCase): |
|
|
|
|
def default_config(self): |
|
|
|
|
config = super().default_config() |
|
|
|
|
config["public_baseurl"] = BASE_URL |
|
|
|
|
oidc_config = { |
|
|
|
|
"enabled": True, |
|
|
|
|
"client_id": CLIENT_ID, |
|
|
|
|
"client_secret": CLIENT_SECRET, |
|
|
|
|
"issuer": ISSUER, |
|
|
|
|
"scopes": SCOPES, |
|
|
|
|
"user_mapping_provider": {"module": __name__ + ".TestMappingProvider"}, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
# Update this config with what's in the default config so that |
|
|
|
|
# override_config works as expected. |
|
|
|
|
oidc_config.update(config.get("oidc_config", {})) |
|
|
|
|
config["oidc_config"] = oidc_config |
|
|
|
|
|
|
|
|
|
return config |
|
|
|
|
|
|
|
|
|
def make_homeserver(self, reactor, clock): |
|
|
|
@ -170,13 +194,14 @@ class OidcHandlerTestCase(HomeserverTestCase): |
|
|
|
|
self.render_error.reset_mock() |
|
|
|
|
return args |
|
|
|
|
|
|
|
|
|
@override_config({"oidc_config": DEFAULT_CONFIG}) |
|
|
|
|
def test_config(self): |
|
|
|
|
"""Basic config correctly sets up the callback URL and client auth correctly.""" |
|
|
|
|
self.assertEqual(self.provider._callback_url, CALLBACK_URL) |
|
|
|
|
self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID) |
|
|
|
|
self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET) |
|
|
|
|
|
|
|
|
|
@override_config({"oidc_config": {"discover": True}}) |
|
|
|
|
@override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}}) |
|
|
|
|
def test_discovery(self): |
|
|
|
|
"""The handler should discover the endpoints from OIDC discovery document.""" |
|
|
|
|
# This would throw if some metadata were invalid |
|
|
|
@ -195,13 +220,13 @@ class OidcHandlerTestCase(HomeserverTestCase): |
|
|
|
|
self.get_success(self.provider.load_metadata()) |
|
|
|
|
self.http_client.get_json.assert_not_called() |
|
|
|
|
|
|
|
|
|
@override_config({"oidc_config": COMMON_CONFIG}) |
|
|
|
|
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG}) |
|
|
|
|
def test_no_discovery(self): |
|
|
|
|
"""When discovery is disabled, it should not try to load from discovery document.""" |
|
|
|
|
self.get_success(self.provider.load_metadata()) |
|
|
|
|
self.http_client.get_json.assert_not_called() |
|
|
|
|
|
|
|
|
|
@override_config({"oidc_config": COMMON_CONFIG}) |
|
|
|
|
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG}) |
|
|
|
|
def test_load_jwks(self): |
|
|
|
|
"""JWKS loading is done once (then cached) if used.""" |
|
|
|
|
jwks = self.get_success(self.provider.load_jwks()) |
|
|
|
@ -236,6 +261,7 @@ class OidcHandlerTestCase(HomeserverTestCase): |
|
|
|
|
self.http_client.get_json.assert_not_called() |
|
|
|
|
self.assertEqual(jwks, {"keys": []}) |
|
|
|
|
|
|
|
|
|
@override_config({"oidc_config": DEFAULT_CONFIG}) |
|
|
|
|
def test_validate_config(self): |
|
|
|
|
"""Provider metadatas are extensively validated.""" |
|
|
|
|
h = self.provider |
|
|
|
@ -318,13 +344,14 @@ class OidcHandlerTestCase(HomeserverTestCase): |
|
|
|
|
# Shouldn't raise with a valid userinfo, even without jwks |
|
|
|
|
force_load_metadata() |
|
|
|
|
|
|
|
|
|
@override_config({"oidc_config": {"skip_verification": True}}) |
|
|
|
|
@override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}}) |
|
|
|
|
def test_skip_verification(self): |
|
|
|
|
"""Provider metadata validation can be disabled by config.""" |
|
|
|
|
with self.metadata_edit({"issuer": "http://insecure"}): |
|
|
|
|
# This should not throw |
|
|
|
|
get_awaitable_result(self.provider.load_metadata()) |
|
|
|
|
|
|
|
|
|
@override_config({"oidc_config": DEFAULT_CONFIG}) |
|
|
|
|
def test_redirect_request(self): |
|
|
|
|
"""The redirect request has the right arguments & generates a valid session cookie.""" |
|
|
|
|
req = Mock(spec=["cookies"]) |
|
|
|
@ -368,6 +395,7 @@ class OidcHandlerTestCase(HomeserverTestCase): |
|
|
|
|
self.assertEqual(params["nonce"], [nonce]) |
|
|
|
|
self.assertEqual(redirect, "http://client/redirect") |
|
|
|
|
|
|
|
|
|
@override_config({"oidc_config": DEFAULT_CONFIG}) |
|
|
|
|
def test_callback_error(self): |
|
|
|
|
"""Errors from the provider returned in the callback are displayed.""" |
|
|
|
|
request = Mock(args={}) |
|
|
|
@ -379,6 +407,7 @@ class OidcHandlerTestCase(HomeserverTestCase): |
|
|
|
|
self.get_success(self.handler.handle_oidc_callback(request)) |
|
|
|
|
self.assertRenderedError("invalid_client", "some description") |
|
|
|
|
|
|
|
|
|
@override_config({"oidc_config": DEFAULT_CONFIG}) |
|
|
|
|
def test_callback(self): |
|
|
|
|
"""Code callback works and display errors if something went wrong. |
|
|
|
|
|
|
|
|
@ -480,6 +509,7 @@ class OidcHandlerTestCase(HomeserverTestCase): |
|
|
|
|
self.get_success(self.handler.handle_oidc_callback(request)) |
|
|
|
|
self.assertRenderedError("invalid_request") |
|
|
|
|
|
|
|
|
|
@override_config({"oidc_config": DEFAULT_CONFIG}) |
|
|
|
|
def test_callback_session(self): |
|
|
|
|
"""The callback verifies the session presence and validity""" |
|
|
|
|
request = Mock(spec=["args", "getCookie", "cookies"]) |
|
|
|
@ -522,7 +552,9 @@ class OidcHandlerTestCase(HomeserverTestCase): |
|
|
|
|
self.get_success(self.handler.handle_oidc_callback(request)) |
|
|
|
|
self.assertRenderedError("invalid_request") |
|
|
|
|
|
|
|
|
|
@override_config({"oidc_config": {"client_auth_method": "client_secret_post"}}) |
|
|
|
|
@override_config( |
|
|
|
|
{"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}} |
|
|
|
|
) |
|
|
|
|
def test_exchange_code(self): |
|
|
|
|
"""Code exchange behaves correctly and handles various error scenarios.""" |
|
|
|
|
token = {"type": "bearer"} |
|
|
|
@ -607,9 +639,105 @@ class OidcHandlerTestCase(HomeserverTestCase): |
|
|
|
|
@override_config( |
|
|
|
|
{ |
|
|
|
|
"oidc_config": { |
|
|
|
|
"enabled": True, |
|
|
|
|
"client_id": CLIENT_ID, |
|
|
|
|
"issuer": ISSUER, |
|
|
|
|
"client_auth_method": "client_secret_post", |
|
|
|
|
"client_secret_jwt_key": { |
|
|
|
|
"key_file": _key_file_path(), |
|
|
|
|
"jwt_header": {"alg": "ES256", "kid": "ABC789"}, |
|
|
|
|
"jwt_payload": {"iss": "DEFGHI"}, |
|
|
|
|
}, |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
) |
|
|
|
|
def test_exchange_code_jwt_key(self): |
|
|
|
|
"""Test that code exchange works with a JWK client secret.""" |
|
|
|
|
from authlib.jose import jwt |
|
|
|
|
|
|
|
|
|
token = {"type": "bearer"} |
|
|
|
|
self.http_client.request = simple_async_mock( |
|
|
|
|
return_value=FakeResponse( |
|
|
|
|
code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8") |
|
|
|
|
) |
|
|
|
|
) |
|
|
|
|
code = "code" |
|
|
|
|
|
|
|
|
|
# advance the clock a bit before we start, so we aren't working with zero |
|
|
|
|
# timestamps. |
|
|
|
|
self.reactor.advance(1000) |
|
|
|
|
start_time = self.reactor.seconds() |
|
|
|
|
ret = self.get_success(self.provider._exchange_code(code)) |
|
|
|
|
|
|
|
|
|
self.assertEqual(ret, token) |
|
|
|
|
|
|
|
|
|
# the request should have hit the token endpoint |
|
|
|
|
kwargs = self.http_client.request.call_args[1] |
|
|
|
|
self.assertEqual(kwargs["method"], "POST") |
|
|
|
|
self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT) |
|
|
|
|
|
|
|
|
|
# the client secret provided to the should be a jwt which can be checked with |
|
|
|
|
# the public key |
|
|
|
|
args = parse_qs(kwargs["data"].decode("utf-8")) |
|
|
|
|
secret = args["client_secret"][0] |
|
|
|
|
with open(_public_key_file_path()) as f: |
|
|
|
|
key = f.read() |
|
|
|
|
claims = jwt.decode(secret, key) |
|
|
|
|
self.assertEqual(claims.header["kid"], "ABC789") |
|
|
|
|
self.assertEqual(claims["aud"], ISSUER) |
|
|
|
|
self.assertEqual(claims["iss"], "DEFGHI") |
|
|
|
|
self.assertEqual(claims["sub"], CLIENT_ID) |
|
|
|
|
self.assertEqual(claims["iat"], start_time) |
|
|
|
|
self.assertGreater(claims["exp"], start_time) |
|
|
|
|
|
|
|
|
|
# check the rest of the POSTed data |
|
|
|
|
self.assertEqual(args["grant_type"], ["authorization_code"]) |
|
|
|
|
self.assertEqual(args["code"], [code]) |
|
|
|
|
self.assertEqual(args["client_id"], [CLIENT_ID]) |
|
|
|
|
self.assertEqual(args["redirect_uri"], [CALLBACK_URL]) |
|
|
|
|
|
|
|
|
|
@override_config( |
|
|
|
|
{ |
|
|
|
|
"oidc_config": { |
|
|
|
|
"enabled": True, |
|
|
|
|
"client_id": CLIENT_ID, |
|
|
|
|
"issuer": ISSUER, |
|
|
|
|
"client_auth_method": "none", |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
) |
|
|
|
|
def test_exchange_code_no_auth(self): |
|
|
|
|
"""Test that code exchange works with no client secret.""" |
|
|
|
|
token = {"type": "bearer"} |
|
|
|
|
self.http_client.request = simple_async_mock( |
|
|
|
|
return_value=FakeResponse( |
|
|
|
|
code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8") |
|
|
|
|
) |
|
|
|
|
) |
|
|
|
|
code = "code" |
|
|
|
|
ret = self.get_success(self.provider._exchange_code(code)) |
|
|
|
|
|
|
|
|
|
self.assertEqual(ret, token) |
|
|
|
|
|
|
|
|
|
# the request should have hit the token endpoint |
|
|
|
|
kwargs = self.http_client.request.call_args[1] |
|
|
|
|
self.assertEqual(kwargs["method"], "POST") |
|
|
|
|
self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT) |
|
|
|
|
|
|
|
|
|
# check the POSTed data |
|
|
|
|
args = parse_qs(kwargs["data"].decode("utf-8")) |
|
|
|
|
self.assertEqual(args["grant_type"], ["authorization_code"]) |
|
|
|
|
self.assertEqual(args["code"], [code]) |
|
|
|
|
self.assertEqual(args["client_id"], [CLIENT_ID]) |
|
|
|
|
self.assertEqual(args["redirect_uri"], [CALLBACK_URL]) |
|
|
|
|
|
|
|
|
|
@override_config( |
|
|
|
|
{ |
|
|
|
|
"oidc_config": { |
|
|
|
|
**DEFAULT_CONFIG, |
|
|
|
|
"user_mapping_provider": { |
|
|
|
|
"module": __name__ + ".TestMappingProviderExtra" |
|
|
|
|
} |
|
|
|
|
}, |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
) |
|
|
|
@ -652,6 +780,7 @@ class OidcHandlerTestCase(HomeserverTestCase): |
|
|
|
|
new_user=True, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
@override_config({"oidc_config": DEFAULT_CONFIG}) |
|
|
|
|
def test_map_userinfo_to_user(self): |
|
|
|
|
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly.""" |
|
|
|
|
auth_handler = self.hs.get_auth_handler() |
|
|
|
@ -692,7 +821,7 @@ class OidcHandlerTestCase(HomeserverTestCase): |
|
|
|
|
"Mapping provider does not support de-duplicating Matrix IDs", |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
@override_config({"oidc_config": {"allow_existing_users": True}}) |
|
|
|
|
@override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}}) |
|
|
|
|
def test_map_userinfo_to_existing_user(self): |
|
|
|
|
"""Existing users can log in with OpenID Connect when allow_existing_users is True.""" |
|
|
|
|
store = self.hs.get_datastore() |
|
|
|
@ -772,6 +901,7 @@ class OidcHandlerTestCase(HomeserverTestCase): |
|
|
|
|
"@TEST_USER_2:test", "oidc", ANY, ANY, None, new_user=False |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
@override_config({"oidc_config": DEFAULT_CONFIG}) |
|
|
|
|
def test_map_userinfo_to_invalid_localpart(self): |
|
|
|
|
"""If the mapping provider generates an invalid localpart it should be rejected.""" |
|
|
|
|
self.get_success( |
|
|
|
@ -782,9 +912,10 @@ class OidcHandlerTestCase(HomeserverTestCase): |
|
|
|
|
@override_config( |
|
|
|
|
{ |
|
|
|
|
"oidc_config": { |
|
|
|
|
**DEFAULT_CONFIG, |
|
|
|
|
"user_mapping_provider": { |
|
|
|
|
"module": __name__ + ".TestMappingProviderFailures" |
|
|
|
|
} |
|
|
|
|
}, |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
) |
|
|
|
@ -829,6 +960,7 @@ class OidcHandlerTestCase(HomeserverTestCase): |
|
|
|
|
"mapping_error", "Unable to generate a Matrix ID from the SSO response" |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
@override_config({"oidc_config": DEFAULT_CONFIG}) |
|
|
|
|
def test_empty_localpart(self): |
|
|
|
|
"""Attempts to map onto an empty localpart should be rejected.""" |
|
|
|
|
userinfo = { |
|
|
|
@ -841,9 +973,10 @@ class OidcHandlerTestCase(HomeserverTestCase): |
|
|
|
|
@override_config( |
|
|
|
|
{ |
|
|
|
|
"oidc_config": { |
|
|
|
|
**DEFAULT_CONFIG, |
|
|
|
|
"user_mapping_provider": { |
|
|
|
|
"config": {"localpart_template": "{{ user.username }}"} |
|
|
|
|
} |
|
|
|
|
}, |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
) |
|
|
|
|