@ -23,8 +23,12 @@ from twisted.web.client import PartialDownloadError
from synapse . api . errors import Codes , LoginError , SynapseError
from synapse . http . server import finish_request
from synapse . http . servlet import RestServlet , parse_json_object_from_request
from synapse . types import UserID
from synapse . http . servlet import (
RestServlet ,
parse_json_object_from_request ,
parse_string ,
)
from synapse . types import UserID , map_username_to_mxid_localpart
from synapse . util . msisdn import phone_number_to_msisdn
from . base import ClientV1RestServlet , client_path_patterns
@ -358,17 +362,15 @@ class CasTicketServlet(ClientV1RestServlet):
self . cas_server_url = hs . config . cas_server_url
self . cas_service_url = hs . config . cas_service_url
self . cas_required_attributes = hs . config . cas_required_attributes
self . auth_handler = hs . get_auth_handler ( )
self . handlers = hs . get_handlers ( )
self . macaroon_gen = hs . get_macaroon_generator ( )
self . _sso_auth_handler = SSOAuthHandler ( hs )
@defer . inlineCallbacks
def on_GET ( self , request ) :
client_redirect_url = request . args [ b " redirectUrl " ] [ 0 ]
client_redirect_url = parse_string ( request , " redirectUrl " , required = True )
http_client = self . hs . get_simple_http_client ( )
uri = self . cas_server_url + " /proxyValidate "
args = {
" ticket " : request . args [ b " ticket " ] [ 0 ] . decode ( ' ascii ' ) ,
" ticket " : parse_string ( request , " ticket " , required = True ) ,
" service " : self . cas_service_url
}
try :
@ -380,7 +382,6 @@ class CasTicketServlet(ClientV1RestServlet):
result = yield self . handle_cas_response ( request , body , client_redirect_url )
defer . returnValue ( result )
@defer . inlineCallbacks
def handle_cas_response ( self , request , cas_response_body , client_redirect_url ) :
user , attributes = self . parse_cas_response ( cas_response_body )
@ -396,28 +397,9 @@ class CasTicketServlet(ClientV1RestServlet):
if required_value != actual_value :
raise LoginError ( 401 , " Unauthorized " , errcode = Codes . UNAUTHORIZED )
user_id = UserID ( user , self . hs . hostname ) . to_string ( )
auth_handler = self . auth_handler
registered_user_id = yield auth_handler . check_user_exists ( user_id )
if not registered_user_id :
registered_user_id , _ = (
yield self . handlers . registration_handler . register ( localpart = user )
)
login_token = self . macaroon_gen . generate_short_term_login_token (
registered_user_id
return self . _sso_auth_handler . on_successful_auth (
user , request , client_redirect_url ,
)
redirect_url = self . add_login_token_to_redirect_url ( client_redirect_url ,
login_token )
request . redirect ( redirect_url )
finish_request ( request )
def add_login_token_to_redirect_url ( self , url , token ) :
url_parts = list ( urllib . parse . urlparse ( url ) )
query = dict ( urllib . parse . parse_qsl ( url_parts [ 4 ] ) )
query . update ( { " loginToken " : token } )
url_parts [ 4 ] = urllib . parse . urlencode ( query ) . encode ( ' ascii ' )
return urllib . parse . urlunparse ( url_parts )
def parse_cas_response ( self , cas_response_body ) :
user = None
@ -452,6 +434,71 @@ class CasTicketServlet(ClientV1RestServlet):
return user , attributes
class SSOAuthHandler ( object ) :
"""
Utility class for Resources and Servlets which handle the response from a SSO
service
Args :
hs ( synapse . server . HomeServer )
"""
def __init__ ( self , hs ) :
self . _hostname = hs . hostname
self . _auth_handler = hs . get_auth_handler ( )
self . _registration_handler = hs . get_handlers ( ) . registration_handler
self . _macaroon_gen = hs . get_macaroon_generator ( )
@defer . inlineCallbacks
def on_successful_auth (
self , username , request , client_redirect_url ,
) :
""" Called once the user has successfully authenticated with the SSO.
Registers the user if necessary , and then returns a redirect ( with
a login token ) to the client .
Args :
username ( unicode | bytes ) : the remote user id . We ' ll map this onto
something sane for a MXID localpath .
request ( SynapseRequest ) : the incoming request from the browser . We ' ll
respond to it with a redirect .
client_redirect_url ( unicode ) : the redirect_url the client gave us when
it first started the process .
Returns :
Deferred [ none ] : Completes once we have handled the request .
"""
localpart = map_username_to_mxid_localpart ( username )
user_id = UserID ( localpart , self . _hostname ) . to_string ( )
registered_user_id = yield self . _auth_handler . check_user_exists ( user_id )
if not registered_user_id :
registered_user_id , _ = (
yield self . _registration_handler . register (
localpart = localpart ,
generate_token = False ,
)
)
login_token = self . _macaroon_gen . generate_short_term_login_token (
registered_user_id
)
redirect_url = self . _add_login_token_to_redirect_url (
client_redirect_url , login_token
)
request . redirect ( redirect_url )
finish_request ( request )
@staticmethod
def _add_login_token_to_redirect_url ( url , token ) :
url_parts = list ( urllib . parse . urlparse ( url ) )
query = dict ( urllib . parse . parse_qsl ( url_parts [ 4 ] ) )
query . update ( { " loginToken " : token } )
url_parts [ 4 ] = urllib . parse . urlencode ( query )
return urllib . parse . urlunparse ( url_parts )
def register_servlets ( hs , http_server ) :
LoginRestServlet ( hs ) . register ( http_server )
if hs . config . cas_enabled :