|
|
|
@ -125,6 +125,34 @@ class LoginRestServlet(ClientV1RestServlet): |
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def do_cas_login(self, cas_response_body): |
|
|
|
|
(user, attributes) = self.parse_cas_response(cas_response_body) |
|
|
|
|
user_id = UserID.create(user, self.hs.hostname).to_string() |
|
|
|
|
auth_handler = self.handlers.auth_handler |
|
|
|
|
user_exists = yield auth_handler.does_user_exist(user_id) |
|
|
|
|
if user_exists: |
|
|
|
|
user_id, access_token, refresh_token = ( |
|
|
|
|
yield auth_handler.login_with_cas_user_id(user_id) |
|
|
|
|
) |
|
|
|
|
result = { |
|
|
|
|
"user_id": user_id, # may have changed |
|
|
|
|
"access_token": access_token, |
|
|
|
|
"refresh_token": refresh_token, |
|
|
|
|
"home_server": self.hs.hostname, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
else: |
|
|
|
|
user_id, access_token = ( |
|
|
|
|
yield self.handlers.registration_handler.register(localpart=user) |
|
|
|
|
) |
|
|
|
|
result = { |
|
|
|
|
"user_id": user_id, # may have changed |
|
|
|
|
"access_token": access_token, |
|
|
|
|
"home_server": self.hs.hostname, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
defer.returnValue((200, result)) |
|
|
|
|
|
|
|
|
|
def parse_cas_response(self, cas_response_body): |
|
|
|
|
root = ET.fromstring(cas_response_body) |
|
|
|
|
if not root.tag.endswith("serviceResponse"): |
|
|
|
|
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) |
|
|
|
@ -133,33 +161,17 @@ class LoginRestServlet(ClientV1RestServlet): |
|
|
|
|
for child in root[0]: |
|
|
|
|
if child.tag.endswith("user"): |
|
|
|
|
user = child.text |
|
|
|
|
user_id = UserID.create(user, self.hs.hostname).to_string() |
|
|
|
|
auth_handler = self.handlers.auth_handler |
|
|
|
|
user_exists = yield auth_handler.does_user_exist(user_id) |
|
|
|
|
if user_exists: |
|
|
|
|
user_id, access_token, refresh_token = ( |
|
|
|
|
yield auth_handler.login_with_cas_user_id(user_id) |
|
|
|
|
) |
|
|
|
|
result = { |
|
|
|
|
"user_id": user_id, # may have changed |
|
|
|
|
"access_token": access_token, |
|
|
|
|
"refresh_token": refresh_token, |
|
|
|
|
"home_server": self.hs.hostname, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
else: |
|
|
|
|
user_id, access_token = ( |
|
|
|
|
yield self.handlers.registration_handler.register(localpart=user) |
|
|
|
|
) |
|
|
|
|
result = { |
|
|
|
|
"user_id": user_id, # may have changed |
|
|
|
|
"access_token": access_token, |
|
|
|
|
"home_server": self.hs.hostname, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
defer.returnValue((200, result)) |
|
|
|
|
if child.tag.endswith("attributes"): |
|
|
|
|
attributes = {} |
|
|
|
|
for attribute in child: |
|
|
|
|
if "}" in attribute.tag: |
|
|
|
|
attributes[attribute.tag.split("}")[1]] = attribute.text |
|
|
|
|
else: |
|
|
|
|
attributes[attribute.tag] = attribute.text |
|
|
|
|
if user is None or attributes is None: |
|
|
|
|
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) |
|
|
|
|
|
|
|
|
|
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) |
|
|
|
|
return (user, attributes) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoginFallbackRestServlet(ClientV1RestServlet): |
|
|
|
|