From 7cc9509eca0d754b763253dd3c25cec688b47639 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 18 Dec 2020 12:13:03 +0000 Subject: [PATCH] Extract OIDCProviderConfig object Collect all the config options which related to an OIDC provider into a single object. --- synapse/config/oidc_config.py | 165 ++++++++++++++++++++++--------- synapse/handlers/oidc_handler.py | 37 +++---- 2 files changed, 140 insertions(+), 62 deletions(-) diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index 4e3055282..9f36e6384 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2020 Quentin Gliech +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional, Type + +import attr + from synapse.python_dependencies import DependencyException, check_requirements +from synapse.types import Collection, JsonDict from synapse.util.module_loader import load_module from ._base import Config, ConfigError @@ -25,65 +31,29 @@ class OIDCConfig(Config): section = "oidc" def read_config(self, config, **kwargs): - self.oidc_enabled = False + self.oidc_provider = None # type: Optional[OidcProviderConfig] oidc_config = config.get("oidc_config") + if oidc_config and oidc_config.get("enabled", False): + self.oidc_provider = _parse_oidc_config_dict(oidc_config) - if not oidc_config or not oidc_config.get("enabled", False): + if not self.oidc_provider: return try: check_requirements("oidc") except DependencyException as e: - raise ConfigError(e.message) + raise ConfigError(e.message) from e public_baseurl = self.public_baseurl if public_baseurl is None: raise ConfigError("oidc_config requires a public_baseurl to be set") self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback" - self.oidc_enabled = True - self.oidc_discover = oidc_config.get("discover", True) - self.oidc_issuer = oidc_config["issuer"] - self.oidc_client_id = oidc_config["client_id"] - self.oidc_client_secret = oidc_config["client_secret"] - self.oidc_client_auth_method = oidc_config.get( - "client_auth_method", "client_secret_basic" - ) - self.oidc_scopes = oidc_config.get("scopes", ["openid"]) - self.oidc_authorization_endpoint = oidc_config.get("authorization_endpoint") - self.oidc_token_endpoint = oidc_config.get("token_endpoint") - self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint") - self.oidc_jwks_uri = oidc_config.get("jwks_uri") - self.oidc_skip_verification = oidc_config.get("skip_verification", False) - self.oidc_user_profile_method = oidc_config.get("user_profile_method", "auto") - self.oidc_allow_existing_users = oidc_config.get("allow_existing_users", False) - - ump_config = oidc_config.get("user_mapping_provider", {}) - ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER) - ump_config.setdefault("config", {}) - - ( - self.oidc_user_mapping_provider_class, - self.oidc_user_mapping_provider_config, - ) = load_module(ump_config, ("oidc_config", "user_mapping_provider")) - - # Ensure loaded user mapping module has defined all necessary methods - required_methods = [ - "get_remote_user_id", - "map_user_attributes", - ] - missing_methods = [ - method - for method in required_methods - if not hasattr(self.oidc_user_mapping_provider_class, method) - ] - if missing_methods: - raise ConfigError( - "Class specified by oidc_config." - "user_mapping_provider.module is missing required " - "methods: %s" % (", ".join(missing_methods),) - ) + @property + def oidc_enabled(self) -> bool: + # OIDC is enabled if we have a provider + return bool(self.oidc_provider) def generate_config_section(self, config_dir_path, server_name, **kwargs): return """\ @@ -224,3 +194,108 @@ class OIDCConfig(Config): """.format( mapping_provider=DEFAULT_USER_MAPPING_PROVIDER ) + + +def _parse_oidc_config_dict(oidc_config: JsonDict) -> "OidcProviderConfig": + """Take the configuration dict and parse it into an OidcProviderConfig + + Raises: + ConfigError if the configuration is malformed. + """ + ump_config = oidc_config.get("user_mapping_provider", {}) + ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER) + ump_config.setdefault("config", {}) + + (user_mapping_provider_class, user_mapping_provider_config,) = load_module( + ump_config, ("oidc_config", "user_mapping_provider") + ) + + # Ensure loaded user mapping module has defined all necessary methods + required_methods = [ + "get_remote_user_id", + "map_user_attributes", + ] + missing_methods = [ + method + for method in required_methods + if not hasattr(user_mapping_provider_class, method) + ] + if missing_methods: + raise ConfigError( + "Class specified by oidc_config." + "user_mapping_provider.module is missing required " + "methods: %s" % (", ".join(missing_methods),) + ) + + return OidcProviderConfig( + discover=oidc_config.get("discover", True), + issuer=oidc_config["issuer"], + client_id=oidc_config["client_id"], + client_secret=oidc_config["client_secret"], + client_auth_method=oidc_config.get("client_auth_method", "client_secret_basic"), + scopes=oidc_config.get("scopes", ["openid"]), + authorization_endpoint=oidc_config.get("authorization_endpoint"), + token_endpoint=oidc_config.get("token_endpoint"), + userinfo_endpoint=oidc_config.get("userinfo_endpoint"), + jwks_uri=oidc_config.get("jwks_uri"), + skip_verification=oidc_config.get("skip_verification", False), + user_profile_method=oidc_config.get("user_profile_method", "auto"), + allow_existing_users=oidc_config.get("allow_existing_users", False), + user_mapping_provider_class=user_mapping_provider_class, + user_mapping_provider_config=user_mapping_provider_config, + ) + + +@attr.s +class OidcProviderConfig: + # whether the OIDC discovery mechanism is used to discover endpoints + discover = attr.ib(type=bool) + + # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to + # discover the provider's endpoints. + issuer = attr.ib(type=str) + + # oauth2 client id to use + client_id = attr.ib(type=str) + + # oauth2 client secret to use + client_secret = attr.ib(type=str) + + # auth method to use when exchanging the token. + # Valid values are 'client_secret_basic', 'client_secret_post' and + # 'none'. + client_auth_method = attr.ib(type=str) + + # list of scopes to request + scopes = attr.ib(type=Collection[str]) + + # the oauth2 authorization endpoint. Required if discovery is disabled. + authorization_endpoint = attr.ib(type=Optional[str]) + + # the oauth2 token endpoint. Required if discovery is disabled. + token_endpoint = attr.ib(type=Optional[str]) + + # the OIDC userinfo endpoint. Required if discovery is disabled and the + # "openid" scope is not requested. + userinfo_endpoint = attr.ib(type=Optional[str]) + + # URI where to fetch the JWKS. Required if discovery is disabled and the + # "openid" scope is used. + jwks_uri = attr.ib(type=Optional[str]) + + # Whether to skip metadata verification + skip_verification = attr.ib(type=bool) + + # Whether to fetch the user profile from the userinfo endpoint. Valid + # values are: "auto" or "userinfo_endpoint". + user_profile_method = attr.ib(type=str) + + # whether to allow a user logging in via OIDC to match a pre-existing account + # instead of failing + allow_existing_users = attr.ib(type=bool) + + # the class of the user mapping provider + user_mapping_provider_class = attr.ib(type=Type) + + # the config of the user mapping provider + user_mapping_provider_config = attr.ib() diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 88097639e..84754e5c9 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -94,27 +94,30 @@ class OidcHandler: self._token_generator = OidcSessionTokenGenerator(hs) self._callback_url = hs.config.oidc_callback_url # type: str - self._scopes = hs.config.oidc_scopes # type: List[str] - self._user_profile_method = hs.config.oidc_user_profile_method # type: str + + provider = hs.config.oidc.oidc_provider + # we should not have been instantiated if there is no configured provider. + assert provider is not None + + self._scopes = provider.scopes + self._user_profile_method = provider.user_profile_method self._client_auth = ClientAuth( - hs.config.oidc_client_id, - hs.config.oidc_client_secret, - hs.config.oidc_client_auth_method, + provider.client_id, provider.client_secret, provider.client_auth_method, ) # type: ClientAuth - self._client_auth_method = hs.config.oidc_client_auth_method # type: str + self._client_auth_method = provider.client_auth_method self._provider_metadata = OpenIDProviderMetadata( - issuer=hs.config.oidc_issuer, - authorization_endpoint=hs.config.oidc_authorization_endpoint, - token_endpoint=hs.config.oidc_token_endpoint, - userinfo_endpoint=hs.config.oidc_userinfo_endpoint, - jwks_uri=hs.config.oidc_jwks_uri, + issuer=provider.issuer, + authorization_endpoint=provider.authorization_endpoint, + token_endpoint=provider.token_endpoint, + userinfo_endpoint=provider.userinfo_endpoint, + jwks_uri=provider.jwks_uri, ) # type: OpenIDProviderMetadata - self._provider_needs_discovery = hs.config.oidc_discover # type: bool - self._user_mapping_provider = hs.config.oidc_user_mapping_provider_class( - hs.config.oidc_user_mapping_provider_config - ) # type: OidcMappingProvider - self._skip_verification = hs.config.oidc_skip_verification # type: bool - self._allow_existing_users = hs.config.oidc_allow_existing_users # type: bool + self._provider_needs_discovery = provider.discover + self._user_mapping_provider = provider.user_mapping_provider_class( + provider.user_mapping_provider_config + ) + self._skip_verification = provider.skip_verification + self._allow_existing_users = provider.allow_existing_users self._http_client = hs.get_proxied_http_client() self._server_name = hs.config.server_name # type: str