@ -14,7 +14,9 @@
import logging
import re
from typing import TYPE_CHECKING , Awaitable , Callable , Dict , List , Optional
from typing import TYPE_CHECKING , Any , Awaitable , Callable , Dict , List , Optional
from typing_extensions import TypedDict
from synapse . api . errors import Codes , LoginError , SynapseError
from synapse . api . ratelimiting import Ratelimiter
@ -25,6 +27,8 @@ from synapse.http import get_request_uri
from synapse . http . server import HttpServer , finish_request
from synapse . http . servlet import (
RestServlet ,
assert_params_in_dict ,
parse_boolean ,
parse_bytes_from_args ,
parse_json_object_from_request ,
parse_string ,
@ -40,6 +44,21 @@ if TYPE_CHECKING:
logger = logging . getLogger ( __name__ )
LoginResponse = TypedDict (
" LoginResponse " ,
{
" user_id " : str ,
" access_token " : str ,
" home_server " : str ,
" expires_in_ms " : Optional [ int ] ,
" refresh_token " : Optional [ str ] ,
" device_id " : str ,
" well_known " : Optional [ Dict [ str , Any ] ] ,
} ,
total = False ,
)
class LoginRestServlet ( RestServlet ) :
PATTERNS = client_patterns ( " /login$ " , v1 = True )
CAS_TYPE = " m.login.cas "
@ -48,6 +67,7 @@ class LoginRestServlet(RestServlet):
JWT_TYPE = " org.matrix.login.jwt "
JWT_TYPE_DEPRECATED = " m.login.jwt "
APPSERVICE_TYPE = " uk.half-shot.msc2778.login.application_service "
REFRESH_TOKEN_PARAM = " org.matrix.msc2918.refresh_token "
def __init__ ( self , hs : " HomeServer " ) :
super ( ) . __init__ ( )
@ -65,9 +85,12 @@ class LoginRestServlet(RestServlet):
self . cas_enabled = hs . config . cas_enabled
self . oidc_enabled = hs . config . oidc_enabled
self . _msc2858_enabled = hs . config . experimental . msc2858_enabled
self . _msc2918_enabled = hs . config . access_token_lifetime is not None
self . auth = hs . get_auth ( )
self . clock = hs . get_clock ( )
self . auth_handler = self . hs . get_auth_handler ( )
self . registration_handler = hs . get_registration_handler ( )
self . _sso_handler = hs . get_sso_handler ( )
@ -138,6 +161,15 @@ class LoginRestServlet(RestServlet):
async def on_POST ( self , request : SynapseRequest ) :
login_submission = parse_json_object_from_request ( request )
if self . _msc2918_enabled :
# Check if this login should also issue a refresh token, as per
# MSC2918
should_issue_refresh_token = parse_boolean (
request , name = LoginRestServlet . REFRESH_TOKEN_PARAM , default = False
)
else :
should_issue_refresh_token = False
try :
if login_submission [ " type " ] == LoginRestServlet . APPSERVICE_TYPE :
appservice = self . auth . get_appservice_by_req ( request )
@ -147,19 +179,32 @@ class LoginRestServlet(RestServlet):
None , request . getClientIP ( )
)
result = await self . _do_appservice_login ( login_submission , appservice )
result = await self . _do_appservice_login (
login_submission ,
appservice ,
should_issue_refresh_token = should_issue_refresh_token ,
)
elif self . jwt_enabled and (
login_submission [ " type " ] == LoginRestServlet . JWT_TYPE
or login_submission [ " type " ] == LoginRestServlet . JWT_TYPE_DEPRECATED
) :
await self . _address_ratelimiter . ratelimit ( None , request . getClientIP ( ) )
result = await self . _do_jwt_login ( login_submission )
result = await self . _do_jwt_login (
login_submission ,
should_issue_refresh_token = should_issue_refresh_token ,
)
elif login_submission [ " type " ] == LoginRestServlet . TOKEN_TYPE :
await self . _address_ratelimiter . ratelimit ( None , request . getClientIP ( ) )
result = await self . _do_token_login ( login_submission )
result = await self . _do_token_login (
login_submission ,
should_issue_refresh_token = should_issue_refresh_token ,
)
else :
await self . _address_ratelimiter . ratelimit ( None , request . getClientIP ( ) )
result = await self . _do_other_login ( login_submission )
result = await self . _do_other_login (
login_submission ,
should_issue_refresh_token = should_issue_refresh_token ,
)
except KeyError :
raise SynapseError ( 400 , " Missing JSON keys. " )
@ -169,7 +214,10 @@ class LoginRestServlet(RestServlet):
return 200 , result
async def _do_appservice_login (
self , login_submission : JsonDict , appservice : ApplicationService
self ,
login_submission : JsonDict ,
appservice : ApplicationService ,
should_issue_refresh_token : bool = False ,
) :
identifier = login_submission . get ( " identifier " )
logger . info ( " Got appservice login request with identifier: %r " , identifier )
@ -198,14 +246,21 @@ class LoginRestServlet(RestServlet):
raise LoginError ( 403 , " Invalid access_token " , errcode = Codes . FORBIDDEN )
return await self . _complete_login (
qualified_user_id , login_submission , ratelimit = appservice . is_rate_limited ( )
qualified_user_id ,
login_submission ,
ratelimit = appservice . is_rate_limited ( ) ,
should_issue_refresh_token = should_issue_refresh_token ,
)
async def _do_other_login ( self , login_submission : JsonDict ) - > Dict [ str , str ] :
async def _do_other_login (
self , login_submission : JsonDict , should_issue_refresh_token : bool = False
) - > LoginResponse :
""" Handle non-token/saml/jwt logins
Args :
login_submission :
should_issue_refresh_token : True if this login should issue
a refresh token alongside the access token .
Returns :
HTTP response
@ -224,7 +279,10 @@ class LoginRestServlet(RestServlet):
login_submission , ratelimit = True
)
result = await self . _complete_login (
canonical_user_id , login_submission , callback
canonical_user_id ,
login_submission ,
callback ,
should_issue_refresh_token = should_issue_refresh_token ,
)
return result
@ -232,11 +290,12 @@ class LoginRestServlet(RestServlet):
self ,
user_id : str ,
login_submission : JsonDict ,
callback : Optional [ Callable [ [ Dict [ str , str ] ] , Awaitable [ None ] ] ] = None ,
callback : Optional [ Callable [ [ LoginResponse ] , Awaitable [ None ] ] ] = None ,
create_non_existent_users : bool = False ,
ratelimit : bool = True ,
auth_provider_id : Optional [ str ] = None ,
) - > Dict [ str , str ] :
should_issue_refresh_token : bool = False ,
) - > LoginResponse :
""" Called when we ' ve successfully authed the user and now need to
actually login them in ( e . g . create devices ) . This gets called on
all successful logins .
@ -253,6 +312,8 @@ class LoginRestServlet(RestServlet):
ratelimit : Whether to ratelimit the login request .
auth_provider_id : The SSO IdP the user used , if any ( just used for the
prometheus metrics ) .
should_issue_refresh_token : True if this login should issue
a refresh token alongside the access token .
Returns :
result : Dictionary of account information after successful login .
@ -274,28 +335,48 @@ class LoginRestServlet(RestServlet):
device_id = login_submission . get ( " device_id " )
initial_display_name = login_submission . get ( " initial_device_display_name " )
device_id , access_token = await self . registration_handler . register_device (
user_id , device_id , initial_display_name , auth_provider_id = auth_provider_id
(
device_id ,
access_token ,
valid_until_ms ,
refresh_token ,
) = await self . registration_handler . register_device (
user_id ,
device_id ,
initial_display_name ,
auth_provider_id = auth_provider_id ,
should_issue_refresh_token = should_issue_refresh_token ,
)
result = {
" user_id " : user_id ,
" access_token " : access_token ,
" home_server " : self . hs . hostname ,
" device_id " : device_id ,
}
result = LoginResponse (
user_id = user_id ,
access_token = access_token ,
home_server = self . hs . hostname ,
device_id = device_id ,
)
if valid_until_ms is not None :
expires_in_ms = valid_until_ms - self . clock . time_msec ( )
result [ " expires_in_ms " ] = expires_in_ms
if refresh_token is not None :
result [ " refresh_token " ] = refresh_token
if callback is not None :
await callback ( result )
return result
async def _do_token_login ( self , login_submission : JsonDict ) - > Dict [ str , str ] :
async def _do_token_login (
self , login_submission : JsonDict , should_issue_refresh_token : bool = False
) - > LoginResponse :
"""
Handle the final stage of SSO login .
Args :
login_submission : The JSON request body .
login_submission : The JSON request body .
should_issue_refresh_token : True if this login should issue
a refresh token alongside the access token .
Returns :
The body of the JSON response .
@ -309,9 +390,12 @@ class LoginRestServlet(RestServlet):
login_submission ,
self . auth_handler . _sso_login_callback ,
auth_provider_id = res . auth_provider_id ,
should_issue_refresh_token = should_issue_refresh_token ,
)
async def _do_jwt_login ( self , login_submission : JsonDict ) - > Dict [ str , str ] :
async def _do_jwt_login (
self , login_submission : JsonDict , should_issue_refresh_token : bool = False
) - > LoginResponse :
token = login_submission . get ( " token " , None )
if token is None :
raise LoginError (
@ -342,7 +426,10 @@ class LoginRestServlet(RestServlet):
user_id = UserID ( user , self . hs . hostname ) . to_string ( )
result = await self . _complete_login (
user_id , login_submission , create_non_existent_users = True
user_id ,
login_submission ,
create_non_existent_users = True ,
should_issue_refresh_token = should_issue_refresh_token ,
)
return result
@ -371,6 +458,42 @@ def _get_auth_flow_dict_for_idp(
return e
class RefreshTokenServlet ( RestServlet ) :
PATTERNS = client_patterns (
" /org.matrix.msc2918.refresh_token/refresh$ " , releases = ( ) , unstable = True
)
def __init__ ( self , hs : " HomeServer " ) :
self . _auth_handler = hs . get_auth_handler ( )
self . _clock = hs . get_clock ( )
self . access_token_lifetime = hs . config . access_token_lifetime
async def on_POST (
self ,
request : SynapseRequest ,
) :
refresh_submission = parse_json_object_from_request ( request )
assert_params_in_dict ( refresh_submission , [ " refresh_token " ] )
token = refresh_submission [ " refresh_token " ]
if not isinstance ( token , str ) :
raise SynapseError ( 400 , " Invalid param: refresh_token " , Codes . INVALID_PARAM )
valid_until_ms = self . _clock . time_msec ( ) + self . access_token_lifetime
access_token , refresh_token = await self . _auth_handler . refresh_token (
token , valid_until_ms
)
expires_in_ms = valid_until_ms - self . _clock . time_msec ( )
return (
200 ,
{
" access_token " : access_token ,
" refresh_token " : refresh_token ,
" expires_in_ms " : expires_in_ms ,
} ,
)
class SsoRedirectServlet ( RestServlet ) :
PATTERNS = list ( client_patterns ( " /login/(cas|sso)/redirect$ " , v1 = True ) ) + [
re . compile (
@ -477,6 +600,8 @@ class CasTicketServlet(RestServlet):
def register_servlets ( hs , http_server ) :
LoginRestServlet ( hs ) . register ( http_server )
if hs . config . access_token_lifetime is not None :
RefreshTokenServlet ( hs ) . register ( http_server )
SsoRedirectServlet ( hs ) . register ( http_server )
if hs . config . cas_enabled :
CasTicketServlet ( hs ) . register ( http_server )