local t_insert , t_concat = table.insert , table.concat ;
local parser_new = require " prosody.net.http.parser " . new ;
local events = require " prosody.util.events " . new ( ) ;
local addserver = require " prosody.net.server " . addserver ;
local logger = require " prosody.util.logger " ;
local log = logger.init ( " http.server " ) ;
local os_date = os.date ;
local pairs = pairs ;
local s_upper = string.upper ;
local setmetatable = setmetatable ;
local cache = require " prosody.util.cache " ;
local codes = require " prosody.net.http.codes " ;
local promise = require " prosody.util.promise " ;
local errors = require " prosody.util.error " ;
local blocksize = 2 ^ 16 ;
local async = require " prosody.util.async " ;
local id = require " prosody.util.id " ;
local _M = { } ;
local sessions = { } ;
local incomplete = { } ;
local listener = { } ;
local hosts = { } ;
local default_host ;
local options = { } ;
local function is_wildcard_event ( event )
return event : sub ( - 2 , - 1 ) == " /* " ;
end
local function is_wildcard_match ( wildcard_event , event )
return wildcard_event : sub ( 1 , - 2 ) == event : sub ( 1 , # wildcard_event - 1 ) ;
end
local _handlers = events._handlers ;
local recent_wildcard_events = cache.new ( 10000 , function ( key , value ) -- luacheck: ignore 212/value
rawset ( _handlers , key , nil ) ;
end ) ;
local event_map = events._event_map ;
setmetatable ( events._handlers , {
-- Called when firing an event that doesn't exist (but may match a wildcard handler)
__index = function ( handlers , curr_event )
if is_wildcard_event ( curr_event ) then return ; end -- Wildcard events cannot be fired
-- Find all handlers that could match this event, sort them
-- and then put the array into handlers[curr_event] (and return it)
local matching_handlers_set = { } ;
local handlers_array = { } ;
for event , handlers_set in pairs ( event_map ) do
if event == curr_event or
is_wildcard_event ( event ) and is_wildcard_match ( event , curr_event ) then
for handler , priority in pairs ( handlers_set ) do
matching_handlers_set [ handler ] = { ( select ( 2 , event : gsub ( " / " , " %1 " ) ) ) , is_wildcard_event ( event ) and 0 or 1 , priority } ;
table.insert ( handlers_array , handler ) ;
end
end
end
if # handlers_array > 0 then
table.sort ( handlers_array , function ( b , a )
local a_score , b_score = matching_handlers_set [ a ] , matching_handlers_set [ b ] ;
for i = 1 , # a_score do
if a_score [ i ] ~= b_score [ i ] then -- If equal, compare next score value
return a_score [ i ] < b_score [ i ] ;
end
end
return false ;
end ) ;
else
handlers_array = false ;
end
rawset ( handlers , curr_event , handlers_array ) ;
if not event_map [ curr_event ] then -- Only wildcard handlers match, if any
recent_wildcard_events : set ( curr_event , true ) ;
end
return handlers_array ;
end ;
__newindex = function ( handlers , curr_event , handlers_array )
if handlers_array == nil
and is_wildcard_event ( curr_event ) then
-- Invalidate the indexes of all matching events
for event in pairs ( handlers ) do
if is_wildcard_match ( curr_event , event ) then
handlers [ event ] = nil ;
end
end
end
rawset ( handlers , curr_event , handlers_array ) ;
end ;
} ) ;
local handle_request ;
events.add_handler ( " http-error " , function ( error )
return " Error processing request: " .. codes [ error.code ] .. " . Check your error log for more information. " ;
end , - 1 ) ;
local runner_callbacks = { } ;
function runner_callbacks : ready ( )
self.data . conn : resume ( ) ;
end
function runner_callbacks : waiting ( )
self.data . conn : pause ( ) ;
end
function runner_callbacks : error ( err )
log ( " error " , " Traceback[httpserver]: %s " , err ) ;
local response = { headers = { content_type = " text/plain " } ; body = " " } ;
response.body = events.fire_event ( " http-error " , { code = 500 ; private_message = err ; response = response } ) ;
self.data . conn : write ( " HTTP/1.0 500 Internal Server Error \r \n \z
X - Content - Type - Options : nosniff \ r \ n \ z
Content - Type : " .. response.headers.content_type .. " \ r \ n \ r \ n " );
self.data . conn : write ( response.body ) ;
self.data . conn : close ( ) ;
end
local function noop ( ) end
function listener . onconnect ( conn )
local session = { conn = conn } ;
local secure = conn : ssl ( ) and true or nil ;
local ip = conn : ip ( ) ;
session.thread = async.runner ( function ( request )
local wait , done ;
if request.partial == true then
-- Have the header for a request, we want to receive the rest
-- when we've decided where the data should go.
wait , done = noop , noop ;
else -- Got the entire request
-- Hold off on receiving more incoming requests until this one has been handled.
wait , done = async.waiter ( ) ;
end
handle_request ( conn , request , done ) ; wait ( ) ;
end , runner_callbacks , session ) ;
local function success_cb ( request )
--log("debug", "success_cb: %s", request.path);
request.id = id.short ( ) ;
request.log = logger.init ( " http. " .. request.method .. " - " .. request.id ) ;
request.ip = ip ;
request.secure = secure ;
session.thread : run ( request ) ;
end
local function error_cb ( err )
log ( " debug " , " error_cb: %s " , err or " <nil> " ) ;
-- FIXME don't close immediately, wait until we process current stuff
-- FIXME if err, send off a bad-request response
conn : close ( ) ;
end
local function options_cb ( )
return options ;
end
session.parser = parser_new ( success_cb , error_cb , " server " , options_cb ) ;
sessions [ conn ] = session ;
end
function listener . ondisconnect ( conn )
local open_response = conn._http_open_response ;
if open_response and open_response.on_destroy then
open_response.finished = true ;
open_response : on_destroy ( ) ;
end
incomplete [ conn ] = nil ;
sessions [ conn ] = nil ;
end
function listener . ondetach ( conn )
sessions [ conn ] = nil ;
incomplete [ conn ] = nil ;
end
function listener . onincoming ( conn , data )
sessions [ conn ] . parser : feed ( data ) ;
end
function listener . ondrain ( conn )
local response = incomplete [ conn ] ;
if response and response._send_more then
response._send_more ( ) ;
end
end
local headerfix = setmetatable ( { } , {
__index = function ( t , k )
local v = " \r \n " .. k : gsub ( " _ " , " - " ) : gsub ( " %f[%w]. " , s_upper ) .. " : " ;
t [ k ] = v ;
return v ;
end
} ) ;
local function handle_result ( request , response , result )
if result == nil then
result = 404 ;
end
if result == true then
return ;
end
local body ;
local result_type = type ( result ) ;
if result_type == " number " then
response.status_code = result ;
if result >= 400 then
body = events.fire_event ( " http-error " , { request = request , response = response , code = result } ) ;
end
elseif result_type == " string " then
body = result ;
elseif errors.is_error ( result ) then
response.status_code = result.code or 500 ;
body = events.fire_event ( " http-error " , { request = request , response = response , code = result.code or 500 , error = result } ) ;
elseif promise.is_promise ( result ) then
result : next ( function ( ret )
handle_result ( request , response , ret ) ;
end , function ( err )
response.status_code = 500 ;
handle_result ( request , response , err or 500 ) ;
end ) ;
return true ;
elseif result_type == " table " then
for k , v in pairs ( result ) do
if k ~= " headers " then
response [ k ] = v ;
else
for header_name , header_value in pairs ( v ) do
response.headers [ header_name ] = header_value ;
end
end
end
end
return response : send ( body ) ;
end
function _M . hijack_response ( response , listener ) -- luacheck: ignore
error ( " TODO " ) ;
end
function handle_request ( conn , request , finish_cb )
--log("debug", "handler: %s", request.path);
local headers = { } ;
for k , v in pairs ( request.headers ) do headers [ k : gsub ( " - " , " _ " ) ] = v ; end
request.headers = headers ;
request.conn = conn ;
request.log ( " debug " , " %s %s HTTP/%s " , request.method , request.path , request.httpversion ) ;
local date_header = os_date ( ' !%a, %d %b %Y %H:%M:%S GMT ' ) ; -- FIXME use
local conn_header = request.headers . connection ;
conn_header = conn_header and " , " .. conn_header : gsub ( " [ \t ] " , " " ) : lower ( ) .. " , " or " "
local httpversion = request.httpversion
local persistent = conn_header : find ( " ,keep-alive, " , 1 , true )
or ( httpversion == " 1.1 " and not conn_header : find ( " ,close, " , 1 , true ) ) ;
local response_conn_header ;
if persistent then
response_conn_header = " Keep-Alive " ;
else
response_conn_header = httpversion == " 1.1 " and " close " or nil
end
local is_head_request = request.method == " HEAD " ;
local response = {
id = request.id ;
log = request.log ;
request = request ;
is_head_request = is_head_request ;
status_code = 200 ;
headers = { date = date_header ; connection = response_conn_header ; x_request_id = request.id } ;
persistent = persistent ;
conn = conn ;
send = _M.send_response ;
write_headers = _M.write_headers ;
send_file = _M.send_file ;
done = _M.finish_response ;
finish_cb = finish_cb ;
} ;
conn._http_open_response = response ;
local host = request.headers . host ;
if host then host = host : gsub ( " :%d+$ " , " " ) ; end
-- Some sanity checking
local err_code , err ;
if not request.path then
err_code , err = 400 , " Invalid path " ;
end
if err then
response.status_code = err_code ;
response : send ( events.fire_event ( " http-error " , { code = err_code , message = err , response = response } ) ) ;
return ;
end
local global_event = request.method .. " " .. request.path : match ( " [^?]* " ) ;
local payload = { request = request , response = response } ;
local result = events.fire_event ( global_event , payload ) ;
if result == nil and is_head_request then
local global_head_event = " GET " .. request.path : match ( " [^?]* " ) ;
result = events.fire_event ( global_head_event , payload ) ;
end
if result == nil then
if not hosts [ host ] then
if hosts [ default_host ] then
host = default_host ;
elseif host then
err_code , err = 404 , " Unknown host: " .. host ;
else
err_code , err = 400 , " Missing or invalid 'Host' header " ;
end
end
if err then
response.status_code = err_code ;
response : send ( events.fire_event ( " http-error " , { code = err_code , message = err , response = response } ) ) ;
return ;
end
local host_event = request.method .. " " .. host .. request.path : match ( " [^?]* " ) ;
result = events.fire_event ( host_event , payload ) ;
if result == nil and is_head_request then
local host_head_event = " GET " .. host .. request.path : match ( " [^?]* " ) ;
result = events.fire_event ( host_head_event , payload ) ;
end
end
return handle_result ( request , response , result ) ;
end
local function prepare_header ( response )
local status_line = " HTTP/ " .. response.request . httpversion .. " " .. ( response.status or codes [ response.status_code ] ) ;
response.log ( " debug " , " %s " , status_line ) ;
local headers = response.headers ;
local output = { status_line } ;
for k , v in pairs ( headers ) do
t_insert ( output , headerfix [ k ] .. v ) ;
end
t_insert ( output , " \r \n \r \n " ) ;
return output ;
end
_M.prepare_header = prepare_header ;
function _M . write_headers ( response )
if response.finished then return ; end
local output = prepare_header ( response ) ;
response.conn : write ( t_concat ( output ) ) ;
end
function _M . send_head_response ( response )
if response.finished then return ; end
_M.write_headers ( response ) ;
response : done ( ) ;
end
function _M . send_response ( response , body )
if response.finished then return ; end
body = body or response.body or " " ;
-- Per RFC 7230, informational (1xx) and 204 (no content) should have no c-l header
if response.status_code > 199 and response.status_code ~= 204 then
response.headers . content_length = ( " %d " ) : format ( # body ) ;
end
if response.is_head_request then
return _M.send_head_response ( response )
end
local output = prepare_header ( response ) ;
t_insert ( output , body ) ;
response.conn : write ( t_concat ( output ) ) ;
response : done ( ) ;
end
function _M . send_file ( response , f )
if response.is_head_request then
if f.close then f : close ( ) ; end
return _M.send_head_response ( response ) ;
end
if response.finished then return ; end
local chunked = not response.headers . content_length ;
if chunked then response.headers . transfer_encoding = " chunked " ; end
incomplete [ response.conn ] = response ;
response._send_more = function ( )
if response.finished then
incomplete [ response.conn ] = nil ;
return ;
end
local chunk = f : read ( blocksize ) ;
if chunk then
if chunked then
chunk = ( " %x \r \n %s \r \n " ) : format ( # chunk , chunk ) ;
end
-- io.write("."); io.flush();
response.conn : write ( chunk ) ;
else
incomplete [ response.conn ] = nil ;
if f.close then f : close ( ) ; end
if chunked then
response.conn : write ( " 0 \r \n \r \n " ) ;
end
-- io.write("\n");
return response : done ( ) ;
end
end
_M.write_headers ( response ) ;
return true ;
end
function _M . finish_response ( response )
if response.finished then return ; end
response.finished = true ;
response.conn . _http_open_response = nil ;
if response.on_destroy then
response : on_destroy ( ) ;
response.on_destroy = nil ;
end
response : finish_cb ( ) ;
if not response.persistent then
response.conn : close ( ) ;
end
end
function _M . add_handler ( event , handler , priority )
events.add_handler ( event , handler , priority ) ;
end
function _M . remove_handler ( event , handler )
events.remove_handler ( event , handler ) ;
end
function _M . listen_on ( port , interface , ssl )
return addserver ( interface or " * " , port , listener , " *a " , ssl ) ;
end
function _M . add_host ( host )
hosts [ host ] = true ;
end
function _M . remove_host ( host )
hosts [ host ] = nil ;
end
function _M . set_default_host ( host )
default_host = host ;
end
function _M . fire_event ( event , ... )
return events.fire_event ( event , ... ) ;
end
function _M . set_option ( name , value )
options [ name ] = value ;
end
function _M . get_request_from_conn ( conn )
local response = conn and conn._http_open_response ;
return response and response.request or nil ;
end
_M.listener = listener ;
_M.codes = codes ;
_M._events = events ;
return _M ;