|
|
|
@ -146,7 +146,7 @@ class LoginRestServlet(ClientV1RestServlet): |
|
|
|
|
yield auth_handler.validate_short_term_login_token_and_get_user_id(token) |
|
|
|
|
) |
|
|
|
|
user_id, access_token, refresh_token = ( |
|
|
|
|
yield auth_handler.login_with_user_id(user_id) |
|
|
|
|
yield auth_handler.get_login_tuple_for_user_id(user_id) |
|
|
|
|
) |
|
|
|
|
result = { |
|
|
|
|
"user_id": user_id, # may have changed |
|
|
|
@ -179,7 +179,7 @@ class LoginRestServlet(ClientV1RestServlet): |
|
|
|
|
user_exists = yield auth_handler.does_user_exist(user_id) |
|
|
|
|
if user_exists: |
|
|
|
|
user_id, access_token, refresh_token = ( |
|
|
|
|
yield auth_handler.login_with_user_id(user_id) |
|
|
|
|
yield auth_handler.get_login_tuple_for_user_id(user_id) |
|
|
|
|
) |
|
|
|
|
result = { |
|
|
|
|
"user_id": user_id, # may have changed |
|
|
|
@ -304,7 +304,6 @@ class CasRedirectServlet(ClientV1RestServlet): |
|
|
|
|
}) |
|
|
|
|
request.redirect("%s?%s" % (self.cas_server_url, serviceParam)) |
|
|
|
|
request.finish() |
|
|
|
|
defer.returnValue(None) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CasTicketServlet(ClientV1RestServlet): |
|
|
|
@ -318,21 +317,19 @@ class CasTicketServlet(ClientV1RestServlet): |
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def on_GET(self, request): |
|
|
|
|
clientRedirectUrl = request.args["redirectUrl"][0] |
|
|
|
|
# TODO: get this from the homeserver rather than creating a new one for |
|
|
|
|
# each request |
|
|
|
|
http_client = SimpleHttpClient(self.hs) |
|
|
|
|
client_redirect_url = request.args["redirectUrl"][0] |
|
|
|
|
http_client = self.hs.get_simple_http_client() |
|
|
|
|
uri = self.cas_server_url + "/proxyValidate" |
|
|
|
|
args = { |
|
|
|
|
"ticket": request.args["ticket"], |
|
|
|
|
"service": self.cas_service_url |
|
|
|
|
} |
|
|
|
|
body = yield http_client.get_raw(uri, args) |
|
|
|
|
result = yield self.handle_cas_response(request, body, clientRedirectUrl) |
|
|
|
|
result = yield self.handle_cas_response(request, body, client_redirect_url) |
|
|
|
|
defer.returnValue(result) |
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
def handle_cas_response(self, request, cas_response_body, clientRedirectUrl): |
|
|
|
|
def handle_cas_response(self, request, cas_response_body, client_redirect_url): |
|
|
|
|
user, attributes = self.parse_cas_response(cas_response_body) |
|
|
|
|
|
|
|
|
|
for required_attribute, required_value in self.cas_required_attributes.items(): |
|
|
|
@ -351,15 +348,15 @@ class CasTicketServlet(ClientV1RestServlet): |
|
|
|
|
auth_handler = self.handlers.auth_handler |
|
|
|
|
user_exists = yield auth_handler.does_user_exist(user_id) |
|
|
|
|
if not user_exists: |
|
|
|
|
user_id, ignored = ( |
|
|
|
|
user_id, _ = ( |
|
|
|
|
yield self.handlers.registration_handler.register(localpart=user) |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
login_token = auth_handler.generate_short_term_login_token(user_id) |
|
|
|
|
redirectUrl = self.add_login_token_to_redirect_url(clientRedirectUrl, login_token) |
|
|
|
|
request.redirect(redirectUrl) |
|
|
|
|
redirect_url = self.add_login_token_to_redirect_url(client_redirect_url, |
|
|
|
|
login_token) |
|
|
|
|
request.redirect(redirect_url) |
|
|
|
|
request.finish() |
|
|
|
|
defer.returnValue(None) |
|
|
|
|
|
|
|
|
|
def add_login_token_to_redirect_url(self, url, token): |
|
|
|
|
url_parts = list(urlparse.urlparse(url)) |
|
|
|
|