@ -17,6 +17,7 @@ import time
import urllib . parse
from html . parser import HTMLParser
from typing import Any , Dict , Iterable , List , Optional , Tuple , Union
from urllib . parse import parse_qs , urlencode , urlparse
from mock import Mock
@ -30,13 +31,14 @@ from synapse.rest.client.v1 import login, logout
from synapse . rest . client . v2_alpha import devices , register
from synapse . rest . client . v2_alpha . account import WhoamiRestServlet
from synapse . rest . synapse . client . pick_idp import PickIdpResource
from synapse . rest . synapse . client . pick_username import pick_username_resource
from synapse . types import create_requester
from tests import unittest
from tests . handlers . test_oidc import HAS_OIDC
from tests . handlers . test_saml import has_saml2
from tests . rest . client . v1 . utils import TEST_OIDC_AUTH_ENDPOINT , TEST_OIDC_CONFIG
from tests . unittest import override_config , skip_unless
from tests . unittest import HomeserverTestCase , override_config , skip_unless
try :
import jwt
@ -1060,3 +1062,104 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
channel = self . make_request ( b " POST " , LOGIN_URL , params )
self . assertEquals ( channel . result [ " code " ] , b " 401 " , channel . result )
@skip_unless ( HAS_OIDC , " requires OIDC " )
class UsernamePickerTestCase ( HomeserverTestCase ) :
""" Tests for the username picker flow of SSO login """
servlets = [ login . register_servlets ]
def default_config ( self ) :
config = super ( ) . default_config ( )
config [ " public_baseurl " ] = BASE_URL
config [ " oidc_config " ] = { }
config [ " oidc_config " ] . update ( TEST_OIDC_CONFIG )
config [ " oidc_config " ] [ " user_mapping_provider " ] = {
" config " : { " display_name_template " : " {{ user.displayname }} " }
}
# whitelist this client URI so we redirect straight to it rather than
# serving a confirmation page
config [ " sso " ] = { " client_whitelist " : [ " https://whitelisted.client " ] }
return config
def create_resource_dict ( self ) - > Dict [ str , Resource ] :
from synapse . rest . oidc import OIDCResource
d = super ( ) . create_resource_dict ( )
d [ " /_synapse/client/pick_username " ] = pick_username_resource ( self . hs )
d [ " /_synapse/oidc " ] = OIDCResource ( self . hs )
return d
def test_username_picker ( self ) :
""" Test the happy path of a username picker flow. """
client_redirect_url = " https://whitelisted.client "
# do the start of the login flow
channel = self . helper . auth_via_oidc (
{ " sub " : " tester " , " displayname " : " Jonny " } , client_redirect_url
)
# that should redirect to the username picker
self . assertEqual ( channel . code , 302 , channel . result )
picker_url = channel . headers . getRawHeaders ( " Location " ) [ 0 ]
self . assertEqual ( picker_url , " /_synapse/client/pick_username " )
# ... with a username_mapping_session cookie
cookies = { } # type: Dict[str,str]
channel . extract_cookies ( cookies )
self . assertIn ( " username_mapping_session " , cookies )
session_id = cookies [ " username_mapping_session " ]
# introspect the sso handler a bit to check that the username mapping session
# looks ok.
username_mapping_sessions = self . hs . get_sso_handler ( ) . _username_mapping_sessions
self . assertIn (
session_id , username_mapping_sessions , " session id not found in map " ,
)
session = username_mapping_sessions [ session_id ]
self . assertEqual ( session . remote_user_id , " tester " )
self . assertEqual ( session . display_name , " Jonny " )
self . assertEqual ( session . client_redirect_url , client_redirect_url )
# the expiry time should be about 15 minutes away
expected_expiry = self . clock . time_msec ( ) + ( 15 * 60 * 1000 )
self . assertApproximates ( session . expiry_time_ms , expected_expiry , tolerance = 1000 )
# Now, submit a username to the username picker, which should serve a redirect
# back to the client
submit_path = picker_url + " /submit "
content = urlencode ( { b " username " : b " bobby " } ) . encode ( " utf8 " )
chan = self . make_request (
" POST " ,
path = submit_path ,
content = content ,
content_is_form = True ,
custom_headers = [
( " Cookie " , " username_mapping_session= " + session_id ) ,
# old versions of twisted don't do form-parsing without a valid
# content-length header.
( " Content-Length " , str ( len ( content ) ) ) ,
] ,
)
self . assertEqual ( chan . code , 302 , chan . result )
location_headers = chan . headers . getRawHeaders ( " Location " )
# ensure that the returned location starts with the requested redirect URL
self . assertEqual (
location_headers [ 0 ] [ : len ( client_redirect_url ) ] , client_redirect_url
)
# fish the login token out of the returned redirect uri
parts = urlparse ( location_headers [ 0 ] )
query = parse_qs ( parts . query )
login_token = query [ " loginToken " ] [ 0 ]
# finally, submit the matrix login token to the login API, which gives us our
# matrix access token, mxid, and device id.
chan = self . make_request (
" POST " , " /login " , content = { " type " : " m.login.token " , " token " : login_token } ,
)
self . assertEqual ( chan . code , 200 , chan . result )
self . assertEqual ( chan . json_body [ " user_id " ] , " @bobby:test " )