@ -16,7 +16,7 @@
""" Contains functions for registering clients. """
import logging
from typing import TYPE_CHECKING , Iterable , List , Optional , Tuple
from typing import TYPE_CHECKING , Dict , Iterable , List , Optional , Tuple
from prometheus_client import Counter
@ -82,6 +82,7 @@ class RegistrationHandler(BaseHandler):
)
else :
self . device_handler = hs . get_device_handler ( )
self . _register_device_client = self . register_device_inner
self . pusher_pool = hs . get_pusherpool ( )
self . session_lifetime = hs . config . session_lifetime
@ -678,17 +679,35 @@ class RegistrationHandler(BaseHandler):
Returns :
Tuple of device ID and access token
"""
res = await self . _register_device_client (
user_id = user_id ,
device_id = device_id ,
initial_display_name = initial_display_name ,
is_guest = is_guest ,
is_appservice_ghost = is_appservice_ghost ,
)
if self . hs . config . worker_app :
r = await self . _register_device_client (
user_id = user_id ,
device_id = device_id ,
initial_display_name = initial_display_name ,
is_guest = is_guest ,
is_appservice_ghost = is_appservice_ghost ,
)
return r [ " device_id " ] , r [ " access_token " ]
login_counter . labels (
guest = is_guest ,
auth_provider = ( auth_provider_id or " " ) ,
) . inc ( )
return res [ " device_id " ] , res [ " access_token " ]
async def register_device_inner (
self ,
user_id : str ,
device_id : Optional [ str ] ,
initial_display_name : Optional [ str ] ,
is_guest : bool = False ,
is_appservice_ghost : bool = False ,
) - > Dict [ str , str ] :
""" Helper for register_device
Does the bits that need doing on the main process . Not for use outside this
class and RegisterDeviceReplicationServlet .
"""
assert not self . hs . config . worker_app
valid_until_ms = None
if self . session_lifetime is not None :
if is_guest :
@ -713,12 +732,7 @@ class RegistrationHandler(BaseHandler):
is_appservice_ghost = is_appservice_ghost ,
)
login_counter . labels (
guest = is_guest ,
auth_provider = ( auth_provider_id or " " ) ,
) . inc ( )
return ( registered_device_id , access_token )
return { " device_id " : registered_device_id , " access_token " : access_token }
async def post_registration_actions (
self , user_id : str , auth_result : dict , access_token : Optional [ str ]