@ -151,6 +151,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
hs = self . setup_test_homeserver ( proxied_http_client = self . http_client )
self . handler = hs . get_oidc_handler ( )
self . provider = self . handler . _provider
sso_handler = hs . get_sso_handler ( )
# Mock the render error method.
self . render_error = Mock ( return_value = None )
@ -162,9 +163,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
return hs
def metadata_edit ( self , values ) :
return patch . dict ( self . handl er. _provider_metadata , values )
return patch . dict ( self . provid er. _provider_metadata , values )
def assertRenderedError ( self , error , error_description = None ) :
self . render_error . assert_called_once ( )
args = self . render_error . call_args [ 0 ]
self . assertEqual ( args [ 1 ] , error )
if error_description is not None :
@ -175,15 +177,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
def test_config ( self ) :
""" Basic config correctly sets up the callback URL and client auth correctly. """
self . assertEqual ( self . handl er. _callback_url , CALLBACK_URL )
self . assertEqual ( self . handl er. _client_auth . client_id , CLIENT_ID )
self . assertEqual ( self . handl er. _client_auth . client_secret , CLIENT_SECRET )
self . assertEqual ( self . provid er. _callback_url , CALLBACK_URL )
self . assertEqual ( self . provid er. _client_auth . client_id , CLIENT_ID )
self . assertEqual ( self . provid er. _client_auth . client_secret , CLIENT_SECRET )
@override_config ( { " oidc_config " : { " discover " : True } } )
def test_discovery ( self ) :
""" The handler should discover the endpoints from OIDC discovery document. """
# This would throw if some metadata were invalid
metadata = self . get_success ( self . handl er. load_metadata ( ) )
metadata = self . get_success ( self . provid er. load_metadata ( ) )
self . http_client . get_json . assert_called_once_with ( WELL_KNOWN )
self . assertEqual ( metadata . issuer , ISSUER )
@ -195,47 +197,47 @@ class OidcHandlerTestCase(HomeserverTestCase):
# subsequent calls should be cached
self . http_client . reset_mock ( )
self . get_success ( self . handl er. load_metadata ( ) )
self . get_success ( self . provid er. load_metadata ( ) )
self . http_client . get_json . assert_not_called ( )
@override_config ( { " oidc_config " : COMMON_CONFIG } )
def test_no_discovery ( self ) :
""" When discovery is disabled, it should not try to load from discovery document. """
self . get_success ( self . handl er. load_metadata ( ) )
self . get_success ( self . provid er. load_metadata ( ) )
self . http_client . get_json . assert_not_called ( )
@override_config ( { " oidc_config " : COMMON_CONFIG } )
def test_load_jwks ( self ) :
""" JWKS loading is done once (then cached) if used. """
jwks = self . get_success ( self . handl er. load_jwks ( ) )
jwks = self . get_success ( self . provid er. load_jwks ( ) )
self . http_client . get_json . assert_called_once_with ( JWKS_URI )
self . assertEqual ( jwks , { " keys " : [ ] } )
# subsequent calls should be cached…
self . http_client . reset_mock ( )
self . get_success ( self . handl er. load_jwks ( ) )
self . get_success ( self . provid er. load_jwks ( ) )
self . http_client . get_json . assert_not_called ( )
# …unless forced
self . http_client . reset_mock ( )
self . get_success ( self . handl er. load_jwks ( force = True ) )
self . get_success ( self . provid er. load_jwks ( force = True ) )
self . http_client . get_json . assert_called_once_with ( JWKS_URI )
# Throw if the JWKS uri is missing
with self . metadata_edit ( { " jwks_uri " : None } ) :
self . get_failure ( self . handl er. load_jwks ( force = True ) , RuntimeError )
self . get_failure ( self . provid er. load_jwks ( force = True ) , RuntimeError )
# Return empty key set if JWKS are not used
self . handl er. _scopes = [ ] # not asking the openid scope
self . provid er. _scopes = [ ] # not asking the openid scope
self . http_client . get_json . reset_mock ( )
jwks = self . get_success ( self . handl er. load_jwks ( force = True ) )
jwks = self . get_success ( self . provid er. load_jwks ( force = True ) )
self . http_client . get_json . assert_not_called ( )
self . assertEqual ( jwks , { " keys " : [ ] } )
@override_config ( { " oidc_config " : COMMON_CONFIG } )
def test_validate_config ( self ) :
""" Provider metadatas are extensively validated. """
h = self . handl er
h = self . provid er
# Default test config does not throw
h . _validate_metadata ( )
@ -314,13 +316,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
""" Provider metadata validation can be disabled by config. """
with self . metadata_edit ( { " issuer " : " http://insecure " } ) :
# This should not throw
self . handl er. _validate_metadata ( )
self . provid er. _validate_metadata ( )
def test_redirect_request ( self ) :
""" The redirect request has the right arguments & generates a valid session cookie. """
req = Mock ( spec = [ " addCookie " ] )
url = self . get_success (
self . handl er. handle_redirect_request ( req , b " http://client/redirect " )
self . provid er. handle_redirect_request ( req , b " http://client/redirect " )
)
url = urlparse ( url )
auth_endpoint = urlparse ( AUTHORIZATION_ENDPOINT )
@ -388,7 +390,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
# ensure that we are correctly testing the fallback when "get_extra_attributes"
# is not implemented.
mapping_provider = self . handl er. _user_mapping_provider
mapping_provider = self . provid er. _user_mapping_provider
with self . assertRaises ( AttributeError ) :
_ = mapping_provider . get_extra_attributes
@ -403,9 +405,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
" username " : username ,
}
expected_user_id = " @ %s : %s " % ( username , self . hs . hostname )
self . handl er. _exchange_code = simple_async_mock ( return_value = token )
self . handl er. _parse_id_token = simple_async_mock ( return_value = userinfo )
self . handl er. _fetch_userinfo = simple_async_mock ( return_value = userinfo )
self . provid er. _exchange_code = simple_async_mock ( return_value = token )
self . provid er. _parse_id_token = simple_async_mock ( return_value = userinfo )
self . provid er. _fetch_userinfo = simple_async_mock ( return_value = userinfo )
auth_handler = self . hs . get_auth_handler ( )
auth_handler . complete_sso_login = simple_async_mock ( )
@ -425,14 +427,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
auth_handler . complete_sso_login . assert_called_once_with (
expected_user_id , request , client_redirect_url , None ,
)
self . handl er. _exchange_code . assert_called_once_with ( code )
self . handl er. _parse_id_token . assert_called_once_with ( token , nonce = nonce )
self . handl er. _fetch_userinfo . assert_not_called ( )
self . provid er. _exchange_code . assert_called_once_with ( code )
self . provid er. _parse_id_token . assert_called_once_with ( token , nonce = nonce )
self . provid er. _fetch_userinfo . assert_not_called ( )
self . render_error . assert_not_called ( )
# Handle mapping errors
with patch . object (
self . handl er,
self . provid er,
" _remote_id_from_userinfo " ,
new = Mock ( side_effect = MappingException ( ) ) ,
) :
@ -440,36 +442,36 @@ class OidcHandlerTestCase(HomeserverTestCase):
self . assertRenderedError ( " mapping_error " )
# Handle ID token errors
self . handl er. _parse_id_token = simple_async_mock ( raises = Exception ( ) )
self . provid er. _parse_id_token = simple_async_mock ( raises = Exception ( ) )
self . get_success ( self . handler . handle_oidc_callback ( request ) )
self . assertRenderedError ( " invalid_token " )
auth_handler . complete_sso_login . reset_mock ( )
self . handl er. _exchange_code . reset_mock ( )
self . handl er. _parse_id_token . reset_mock ( )
self . handl er. _fetch_userinfo . reset_mock ( )
self . provid er. _exchange_code . reset_mock ( )
self . provid er. _parse_id_token . reset_mock ( )
self . provid er. _fetch_userinfo . reset_mock ( )
# With userinfo fetching
self . handl er. _scopes = [ ] # do not ask the "openid" scope
self . provid er. _scopes = [ ] # do not ask the "openid" scope
self . get_success ( self . handler . handle_oidc_callback ( request ) )
auth_handler . complete_sso_login . assert_called_once_with (
expected_user_id , request , client_redirect_url , None ,
)
self . handl er. _exchange_code . assert_called_once_with ( code )
self . handl er. _parse_id_token . assert_not_called ( )
self . handl er. _fetch_userinfo . assert_called_once_with ( token )
self . provid er. _exchange_code . assert_called_once_with ( code )
self . provid er. _parse_id_token . assert_not_called ( )
self . provid er. _fetch_userinfo . assert_called_once_with ( token )
self . render_error . assert_not_called ( )
# Handle userinfo fetching error
self . handl er. _fetch_userinfo = simple_async_mock ( raises = Exception ( ) )
self . provid er. _fetch_userinfo = simple_async_mock ( raises = Exception ( ) )
self . get_success ( self . handler . handle_oidc_callback ( request ) )
self . assertRenderedError ( " fetch_error " )
# Handle code exchange failure
from synapse . handlers . oidc_handler import OidcError
self . handl er. _exchange_code = simple_async_mock (
self . provid er. _exchange_code = simple_async_mock (
raises = OidcError ( " invalid_request " )
)
self . get_success ( self . handler . handle_oidc_callback ( request ) )
@ -524,7 +526,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
return_value = FakeResponse ( code = 200 , phrase = b " OK " , body = token_json )
)
code = " code "
ret = self . get_success ( self . handl er. _exchange_code ( code ) )
ret = self . get_success ( self . provid er. _exchange_code ( code ) )
kwargs = self . http_client . request . call_args [ 1 ]
self . assertEqual ( ret , token )
@ -548,7 +550,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
from synapse . handlers . oidc_handler import OidcError
exc = self . get_failure ( self . handl er. _exchange_code ( code ) , OidcError )
exc = self . get_failure ( self . provid er. _exchange_code ( code ) , OidcError )
self . assertEqual ( exc . value . error , " foo " )
self . assertEqual ( exc . value . error_description , " bar " )
@ -558,7 +560,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
code = 500 , phrase = b " Internal Server Error " , body = b " Not JSON " ,
)
)
exc = self . get_failure ( self . handl er. _exchange_code ( code ) , OidcError )
exc = self . get_failure ( self . provid er. _exchange_code ( code ) , OidcError )
self . assertEqual ( exc . value . error , " server_error " )
# Internal server error with JSON body
@ -570,14 +572,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
)
exc = self . get_failure ( self . handl er. _exchange_code ( code ) , OidcError )
exc = self . get_failure ( self . provid er. _exchange_code ( code ) , OidcError )
self . assertEqual ( exc . value . error , " internal_server_error " )
# 4xx error without "error" field
self . http_client . request = simple_async_mock (
return_value = FakeResponse ( code = 400 , phrase = b " Bad request " , body = b " {} " , )
)
exc = self . get_failure ( self . handl er. _exchange_code ( code ) , OidcError )
exc = self . get_failure ( self . provid er. _exchange_code ( code ) , OidcError )
self . assertEqual ( exc . value . error , " server_error " )
# 2xx error with "error" field
@ -586,7 +588,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
code = 200 , phrase = b " OK " , body = b ' { " error " : " some_error " } ' ,
)
)
exc = self . get_failure ( self . handl er. _exchange_code ( code ) , OidcError )
exc = self . get_failure ( self . provid er. _exchange_code ( code ) , OidcError )
self . assertEqual ( exc . value . error , " some_error " )
@override_config (
@ -612,8 +614,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
" username " : " foo " ,
" phone " : " 1234567 " ,
}
self . handl er. _exchange_code = simple_async_mock ( return_value = token )
self . handl er. _parse_id_token = simple_async_mock ( return_value = userinfo )
self . provid er. _exchange_code = simple_async_mock ( return_value = token )
self . provid er. _parse_id_token = simple_async_mock ( return_value = userinfo )
auth_handler = self . hs . get_auth_handler ( )
auth_handler . complete_sso_login = simple_async_mock ( )
@ -979,9 +981,10 @@ async def _make_callback_with_userinfo(
from synapse . handlers . oidc_handler import OidcSessionData
handler = hs . get_oidc_handler ( )
handler . _exchange_code = simple_async_mock ( return_value = { } )
handler . _parse_id_token = simple_async_mock ( return_value = userinfo )
handler . _fetch_userinfo = simple_async_mock ( return_value = userinfo )
provider = handler . _provider
provider . _exchange_code = simple_async_mock ( return_value = { } )
provider . _parse_id_token = simple_async_mock ( return_value = userinfo )
provider . _fetch_userinfo = simple_async_mock ( return_value = userinfo )
state = " state "
session = handler . _token_generator . generate_oidc_session_token (