From 7845f62c2207e9fa51f7a0aa7b60b49cf6436696 Mon Sep 17 00:00:00 2001 From: Steven Hammerton Date: Mon, 12 Oct 2015 10:52:43 +0100 Subject: [PATCH] Parse both user and attributes from CAS response --- synapse/rest/client/v1/login.py | 64 +++++++++++++++++++-------------- 1 file changed, 38 insertions(+), 26 deletions(-) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index a99dcaab6..0e12880ab 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -125,6 +125,34 @@ class LoginRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def do_cas_login(self, cas_response_body): + (user, attributes) = self.parse_cas_response(cas_response_body) + user_id = UserID.create(user, self.hs.hostname).to_string() + auth_handler = self.handlers.auth_handler + user_exists = yield auth_handler.does_user_exist(user_id) + if user_exists: + user_id, access_token, refresh_token = ( + yield auth_handler.login_with_cas_user_id(user_id) + ) + result = { + "user_id": user_id, # may have changed + "access_token": access_token, + "refresh_token": refresh_token, + "home_server": self.hs.hostname, + } + + else: + user_id, access_token = ( + yield self.handlers.registration_handler.register(localpart=user) + ) + result = { + "user_id": user_id, # may have changed + "access_token": access_token, + "home_server": self.hs.hostname, + } + + defer.returnValue((200, result)) + + def parse_cas_response(self, cas_response_body): root = ET.fromstring(cas_response_body) if not root.tag.endswith("serviceResponse"): raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) @@ -133,33 +161,17 @@ class LoginRestServlet(ClientV1RestServlet): for child in root[0]: if child.tag.endswith("user"): user = child.text - user_id = UserID.create(user, self.hs.hostname).to_string() - auth_handler = self.handlers.auth_handler - user_exists = yield auth_handler.does_user_exist(user_id) - if user_exists: - user_id, access_token, refresh_token = ( - yield auth_handler.login_with_cas_user_id(user_id) - ) - result = { - "user_id": user_id, # may have changed - "access_token": access_token, - "refresh_token": refresh_token, - "home_server": self.hs.hostname, - } - - else: - user_id, access_token = ( - yield self.handlers.registration_handler.register(localpart=user) - ) - result = { - "user_id": user_id, # may have changed - "access_token": access_token, - "home_server": self.hs.hostname, - } - - defer.returnValue((200, result)) + if child.tag.endswith("attributes"): + attributes = {} + for attribute in child: + if "}" in attribute.tag: + attributes[attribute.tag.split("}")[1]] = attribute.text + else: + attributes[attribute.tag] = attribute.text + if user is None or attributes is None: + raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) - raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) + return (user, attributes) class LoginFallbackRestServlet(ClientV1RestServlet):