-- Prosody IM
-- Copyright (C) 2016-2018 Kim Alvefur
--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
local t_insert = table.insert ;
local t_concat = table.concat ;
local setmetatable = setmetatable ;
local tostring = tostring ;
local pcall = pcall ;
local type = type ;
local next = next ;
local pairs = pairs ;
local log = require " util.logger " . init ( " server_epoll " ) ;
local socket = require " socket " ;
local luasec = require " ssl " ;
local gettime = require " util.time " . now ;
local indexedbheap = require " util.indexedbheap " ;
local createtable = require " util.table " . create ;
local inet = require " util.net " ;
local inet_pton = inet.pton ;
local _SOCKETINVALID = socket._SOCKETINVALID or - 1 ;
local poller = require " util.poll "
local EEXIST = poller.EEXIST ;
local ENOENT = poller.ENOENT ;
local poll = assert ( poller.new ( ) ) ;
local _ENV = nil ;
-- luacheck: std none
local default_config = { __index = {
-- If a connection is silent for this long, close it unless onreadtimeout says not to
read_timeout = 14 * 60 ;
-- How long to wait for a socket to become writable after queuing data to send
send_timeout = 60 ;
-- Some number possibly influencing how many pending connections can be accepted
tcp_backlog = 128 ;
-- If accepting a new incoming connection fails, wait this long before trying again
accept_retry_interval = 10 ;
-- If there is still more data to read from LuaSocktes buffer, wait this long and read again
read_retry_delay = 1e-06 ;
-- Size of chunks to read from sockets
read_size = 8192 ;
-- Timeout used during between steps in TLS handshakes
ssl_handshake_timeout = 60 ;
-- Maximum and minimum amount of time to sleep waiting for events (adjusted for pending timers)
max_wait = 86400 ;
min_wait = 1e-06 ;
} } ;
local cfg = default_config.__index ;
local fds = createtable ( 10 , 0 ) ; -- FD -> conn
-- Timer and scheduling --
local timers = indexedbheap.create ( ) ;
local function noop ( ) end
local function closetimer ( t )
t [ 1 ] = 0 ;
t [ 2 ] = noop ;
timers : remove ( t.id ) ;
end
local function reschedule ( t , time )
t [ 1 ] = time ;
timers : reprioritize ( t.id , time ) ;
end
-- Add absolute timer
local function at ( time , f )
local timer = { time , f , close = closetimer , reschedule = reschedule , id = nil } ;
timer.id = timers : insert ( timer , time ) ;
return timer ;
end
-- Add relative timer
local function addtimer ( timeout , f )
return at ( gettime ( ) + timeout , f ) ;
end
-- Run callbacks of expired timers
-- Return time until next timeout
local function runtimers ( next_delay , min_wait )
-- Any timers at all?
local now = gettime ( ) ;
local peek = timers : peek ( ) ;
while peek do
if peek > now then
next_delay = peek - now ;
break ;
end
local _ , timer , id = timers : pop ( ) ;
local ok , ret = pcall ( timer [ 2 ] , now ) ;
if ok and type ( ret ) == " number " then
local next_time = now + ret ;
timer [ 1 ] = next_time ;
timers : insert ( timer , next_time ) ;
end
peek = timers : peek ( ) ;
end
if peek == nil then
return next_delay ;
end
if next_delay < min_wait then
return min_wait ;
end
return next_delay ;
end
-- Socket handler interface
local interface = { } ;
local interface_mt = { __index = interface } ;
function interface_mt : __tostring ( )
if self.sockname and self.peername then
return ( " FD %d (%s, %d, %s, %d) " ) : format ( self : getfd ( ) , self.peername , self.peerport , self.sockname , self.sockport ) ;
elseif self.sockname or self.peername then
return ( " FD %d (%s, %d) " ) : format ( self : getfd ( ) , self.sockname or self.peername , self.sockport or self.peerport ) ;
end
return ( " FD %d " ) : format ( self : getfd ( ) ) ;
end
-- Replace the listener and tell the old one
function interface : setlistener ( listeners , data )
self : on ( " detach " ) ;
self.listeners = listeners ;
self : on ( " attach " , data ) ;
end
-- Call a listener callback
function interface : on ( what , ... )
if not self.listeners then
log ( " error " , " %s has no listeners " , self ) ;
return ;
end
local listener = self.listeners [ " on " .. what ] ;
if not listener then
-- log("debug", "Missing listener 'on%s'", what); -- uncomment for development and debugging
return ;
end
local ok , err = pcall ( listener , self , ... ) ;
if not ok then
log ( " error " , " Error calling on%s: %s " , what , err ) ;
end
return err ;
end
-- Return the file descriptor number
function interface : getfd ( )
if self.conn then
return self.conn : getfd ( ) ;
end
return _SOCKETINVALID ;
end
function interface : server ( )
return self._server or self ;
end
-- Get IP address
function interface : ip ( )
return self.peername or self.sockname ;
end
-- Get a port number, doesn't matter which
function interface : port ( )
return self.sockport or self.peerport ;
end
-- Get local port number
function interface : clientport ( )
return self.sockport ;
end
-- Get remote port
function interface : serverport ( )
if self.sockport then
return self.sockport ;
elseif self._server then
self._server : port ( ) ;
end
end
-- Return underlying socket
function interface : socket ( )
return self.conn ;
end
function interface : set_mode ( new_mode )
self.read_size = new_mode ;
end
function interface : setoption ( k , v )
-- LuaSec doesn't expose setoption :(
if self.conn . setoption then
self.conn : setoption ( k , v ) ;
end
end
-- Timeout for detecting dead or idle sockets
function interface : setreadtimeout ( t )
if t == false then
if self._readtimeout then
self._readtimeout : close ( ) ;
self._readtimeout = nil ;
end
return
end
t = t or cfg.read_timeout ;
if self._readtimeout then
self._readtimeout : reschedule ( gettime ( ) + t ) ;
else
self._readtimeout = addtimer ( t , function ( )
if self : on ( " readtimeout " ) then
return cfg.read_timeout ;
else
self : on ( " disconnect " , " read timeout " ) ;
self : destroy ( ) ;
end
end ) ;
end
end
-- Timeout for detecting dead sockets
function interface : setwritetimeout ( t )
if t == false then
if self._writetimeout then
self._writetimeout : close ( ) ;
self._writetimeout = nil ;
end
return
end
t = t or cfg.send_timeout ;
if self._writetimeout then
self._writetimeout : reschedule ( gettime ( ) + t ) ;
else
self._writetimeout = addtimer ( t , function ( )
self : on ( " disconnect " , " write timeout " ) ;
self : destroy ( ) ;
end ) ;
end
end
function interface : add ( r , w )
local fd = self : getfd ( ) ;
if fd < 0 then
return nil , " invalid fd " ;
end
if r == nil then r = self._wantread ; end
if w == nil then w = self._wantwrite ; end
local ok , err , errno = poll : add ( fd , r , w ) ;
if not ok then
if errno == EEXIST then
log ( " debug " , " %s already registered! " , self ) ;
return self : set ( r , w ) ; -- So try to change its flags
end
log ( " error " , " Could not register %s: %s(%d) " , self , err , errno ) ;
return ok , err ;
end
self._wantread , self._wantwrite = r , w ;
fds [ fd ] = self ;
log ( " debug " , " Watching %s " , self ) ;
return true ;
end
function interface : set ( r , w )
local fd = self : getfd ( ) ;
if fd < 0 then
return nil , " invalid fd " ;
end
if r == nil then r = self._wantread ; end
if w == nil then w = self._wantwrite ; end
local ok , err , errno = poll : set ( fd , r , w ) ;
if not ok then
log ( " error " , " Could not update poller state %s: %s(%d) " , self , err , errno ) ;
return ok , err ;
end
self._wantread , self._wantwrite = r , w ;
return true ;
end
function interface : del ( )
local fd = self : getfd ( ) ;
if fd < 0 then
return nil , " invalid fd " ;
end
if fds [ fd ] ~= self then
return nil , " unregistered fd " ;
end
local ok , err , errno = poll : del ( fd ) ;
if not ok and errno ~= ENOENT then
log ( " error " , " Could not unregister %s: %s(%d) " , self , err , errno ) ;
return ok , err ;
end
self._wantread , self._wantwrite = nil , nil ;
fds [ fd ] = nil ;
log ( " debug " , " Unwatched %s " , self ) ;
return true ;
end
function interface : setflags ( r , w )
if not ( self._wantread or self._wantwrite ) then
if not ( r or w ) then
return true ; -- no change
end
return self : add ( r , w ) ;
end
if not ( r or w ) then
return self : del ( ) ;
end
return self : set ( r , w ) ;
end
-- Called when socket is readable
function interface : onreadable ( )
local data , err , partial = self.conn : receive ( self.read_size or cfg.read_size ) ;
if data then
self : onconnect ( ) ;
self : on ( " incoming " , data ) ;
else
if err == " wantread " then
self : set ( true , nil ) ;
err = " timeout " ;
elseif err == " wantwrite " then
self : set ( nil , true ) ;
err = " timeout " ;
end
if partial and partial ~= " " then
self : onconnect ( ) ;
self : on ( " incoming " , partial , err ) ;
end
if err ~= " timeout " then
self : on ( " disconnect " , err ) ;
self : destroy ( )
return ;
end
end
if not self.conn then return ; end
if self._wantread and self.conn : dirty ( ) then
self : setreadtimeout ( false ) ;
self : pausefor ( cfg.read_retry_delay ) ;
else
self : setreadtimeout ( ) ;
end
end
-- Called when socket is writable
function interface : onwritable ( )
self : onconnect ( ) ;
if not self.conn then return ; end -- could have been closed in onconnect
local buffer = self.writebuffer ;
local data = t_concat ( buffer ) ;
local ok , err , partial = self.conn : send ( data ) ;
if ok then
self : set ( nil , false ) ;
for i = # buffer , 1 , - 1 do
buffer [ i ] = nil ;
end
self : setwritetimeout ( false ) ;
self : ondrain ( ) ; -- Be aware of writes in ondrain
return ;
elseif partial then
buffer [ 1 ] = data : sub ( partial + 1 ) ;
for i = # buffer , 2 , - 1 do
buffer [ i ] = nil ;
end
self : setwritetimeout ( ) ;
end
if err == " wantwrite " or err == " timeout " then
self : set ( nil , true ) ;
elseif err == " wantread " then
self : set ( true , nil ) ;
elseif err ~= " timeout " then
self : on ( " disconnect " , err ) ;
self : destroy ( ) ;
end
end
-- The write buffer has been successfully emptied
function interface : ondrain ( )
return self : on ( " drain " ) ;
end
-- Add data to write buffer and set flag for wanting to write
function interface : write ( data )
local buffer = self.writebuffer ;
if buffer then
t_insert ( buffer , data ) ;
else
self.writebuffer = { data } ;
end
self : setwritetimeout ( ) ;
self : set ( nil , true ) ;
return # data ;
end
interface.send = interface.write ;
-- Close, possibly after writing is done
function interface : close ( )
if self.writebuffer and self.writebuffer [ 1 ] then
self : set ( false , true ) ; -- Flush final buffer contents
self.write , self.send = noop , noop ; -- No more writing
log ( " debug " , " Close %s after writing " , self ) ;
self.ondrain = interface.close ;
else
log ( " debug " , " Close %s now " , self ) ;
self.write , self.send = noop , noop ;
self.close = noop ;
self : on ( " disconnect " ) ;
self : destroy ( ) ;
end
end
function interface : destroy ( )
self : del ( ) ;
self : setwritetimeout ( false ) ;
self : setreadtimeout ( false ) ;
self.onreadable = noop ;
self.onwritable = noop ;
self.destroy = noop ;
self.close = noop ;
self.on = noop ;
self.conn : close ( ) ;
self.conn = nil ;
end
function interface : ssl ( )
return self._tls ;
end
function interface : starttls ( tls_ctx )
if tls_ctx then self.tls_ctx = tls_ctx ; end
self.starttls = false ;
if self.writebuffer and self.writebuffer [ 1 ] then
log ( " debug " , " Start TLS on %s after write " , self ) ;
self.ondrain = interface.starttls ;
self : set ( nil , true ) ; -- make sure wantwrite is set
else
if self.ondrain == interface.starttls then
self.ondrain = nil ;
end
self.onwritable = interface.tlshandskake ;
self.onreadable = interface.tlshandskake ;
self : set ( true , true ) ;
log ( " debug " , " Prepare to start TLS on %s " , self ) ;
end
end
function interface : tlshandskake ( )
self : setwritetimeout ( false ) ;
self : setreadtimeout ( false ) ;
if not self._tls then
self._tls = true ;
log ( " debug " , " Start TLS on %s now " , self ) ;
self : del ( ) ;
local ok , conn , err = pcall ( luasec.wrap , self.conn , self.tls_ctx ) ;
if not ok then
conn , err = ok , conn ;
log ( " error " , " Failed to initialize TLS: %s " , err ) ;
end
if not conn then
self : on ( " disconnect " , err ) ;
self : destroy ( ) ;
return conn , err ;
end
conn : settimeout ( 0 ) ;
self.conn = conn ;
self : on ( " starttls " ) ;
self.ondrain = nil ;
self.onwritable = interface.tlshandskake ;
self.onreadable = interface.tlshandskake ;
return self : init ( ) ;
end
local ok , err = self.conn : dohandshake ( ) ;
if ok then
log ( " debug " , " TLS handshake on %s complete " , self ) ;
self.onwritable = nil ;
self.onreadable = nil ;
self : on ( " status " , " ssl-handshake-complete " ) ;
self : setwritetimeout ( ) ;
self : set ( true , true ) ;
elseif err == " wantread " then
log ( " debug " , " TLS handshake on %s to wait until readable " , self ) ;
self : set ( true , false ) ;
self : setreadtimeout ( cfg.ssl_handshake_timeout ) ;
elseif err == " wantwrite " then
log ( " debug " , " TLS handshake on %s to wait until writable " , self ) ;
self : set ( false , true ) ;
self : setwritetimeout ( cfg.ssl_handshake_timeout ) ;
else
log ( " debug " , " TLS handshake error on %s: %s " , self , err ) ;
self : on ( " disconnect " , err ) ;
self : destroy ( ) ;
end
end
local function wrapsocket ( client , server , read_size , listeners , tls_ctx ) -- luasocket object -> interface object
client : settimeout ( 0 ) ;
local conn = setmetatable ( {
conn = client ;
_server = server ;
created = gettime ( ) ;
listeners = listeners ;
read_size = read_size or ( server and server.read_size ) ;
writebuffer = { } ;
tls_ctx = tls_ctx or ( server and server.tls_ctx ) ;
tls_direct = server and server.tls_direct ;
} , interface_mt ) ;
conn : updatenames ( ) ;
return conn ;
end
function interface : updatenames ( )
local conn = self.conn ;
local ok , peername , peerport = pcall ( conn.getpeername , conn ) ;
if ok then
self.peername , self.peerport = peername , peerport ;
end
local ok , sockname , sockport = pcall ( conn.getsockname , conn ) ;
if ok then
self.sockname , self.sockport = sockname , sockport ;
end
end
-- A server interface has new incoming connections waiting
-- This replaces the onreadable callback
function interface : onacceptable ( )
local conn , err = self.conn : accept ( ) ;
if not conn then
log ( " debug " , " Error accepting new client: %s, server will be paused for %ds " , err , cfg.accept_retry_interval ) ;
self : pausefor ( cfg.accept_retry_interval ) ;
return ;
end
local client = wrapsocket ( conn , self , nil , self.listeners ) ;
log ( " debug " , " New connection %s " , tostring ( client ) ) ;
client : init ( ) ;
if self.tls_direct then
client : starttls ( self.tls_ctx ) ;
end
end
-- Initialization
function interface : init ( )
self : setwritetimeout ( ) ;
return self : add ( true , true ) ;
end
function interface : pause ( )
return self : set ( false ) ;
end
function interface : resume ( )
return self : set ( true ) ;
end
-- Pause connection for some time
function interface : pausefor ( t )
if self._pausefor then
self._pausefor : close ( ) ;
end
if t == false then return ; end
self : set ( false ) ;
self._pausefor = addtimer ( t , function ( )
self._pausefor = nil ;
self : set ( true ) ;
if self.conn and self.conn : dirty ( ) then
self : onreadable ( ) ;
end
end ) ;
end
-- Connected!
function interface : onconnect ( )
if self.conn and not self.peername and self.conn . getpeername then
self.peername , self.peerport = self.conn : getpeername ( ) ;
end
self.onconnect = noop ;
self : on ( " connect " ) ;
end
local function addserver ( addr , port , listeners , read_size , tls_ctx )
local conn , err = socket.bind ( addr , port , cfg.tcp_backlog ) ;
if not conn then return conn , err ; end
conn : settimeout ( 0 ) ;
local server = setmetatable ( {
conn = conn ;
created = gettime ( ) ;
listeners = listeners ;
read_size = read_size ;
onreadable = interface.onacceptable ;
tls_ctx = tls_ctx ;
tls_direct = tls_ctx and true or false ;
sockname = addr ;
sockport = port ;
} , interface_mt ) ;
server : add ( true , false ) ;
return server ;
end
-- COMPAT
local function wrapclient ( conn , addr , port , listeners , read_size , tls_ctx )
local client = wrapsocket ( conn , nil , read_size , listeners , tls_ctx ) ;
if not client.peername then
client.peername , client.peerport = addr , port ;
end
local ok , err = client : init ( ) ;
if not ok then return ok , err ; end
if tls_ctx then
client : starttls ( tls_ctx ) ;
end
return client ;
end
-- New outgoing TCP connection
local function addclient ( addr , port , listeners , read_size , tls_ctx , typ )
local create ;
if not typ then
local n = inet_pton ( addr ) ;
if not n then return nil , " invalid-ip " ; end
if # n == 16 then
typ = " tcp6 " ;
else
typ = " tcp4 " ;
end
end
if typ then
create = socket [ typ ] ;
end
if type ( create ) ~= " function " then
return nil , " invalid socket type " ;
end
local conn , err = create ( ) ;
local ok , err = conn : settimeout ( 0 ) ;
if not ok then return ok , err ; end
local ok , err = conn : setpeername ( addr , port ) ;
if not ok and err ~= " timeout " then return ok , err ; end
local client = wrapsocket ( conn , nil , read_size , listeners , tls_ctx )
local ok , err = client : init ( ) ;
if not ok then return ok , err ; end
if tls_ctx then
client : starttls ( tls_ctx ) ;
end
return client , conn ;
end
local function watchfd ( fd , onreadable , onwritable )
local conn = setmetatable ( {
conn = fd ;
onreadable = onreadable ;
onwritable = onwritable ;
close = function ( self )
self : del ( ) ;
end
} , interface_mt ) ;
if type ( fd ) == " number " then
conn.getfd = function ( )
return fd ;
end ;
-- Otherwise it'll need to be something LuaSocket-compatible
end
conn : add ( onreadable , onwritable ) ;
return conn ;
end ;
-- Dump all data from one connection into another
local function link ( from , to )
from.listeners = setmetatable ( {
onincoming = function ( _ , data )
from : pause ( ) ;
to : write ( data ) ;
end ,
} , { __index = from.listeners } ) ;
to.listeners = setmetatable ( {
ondrain = function ( )
from : resume ( ) ;
end ,
} , { __index = to.listeners } ) ;
from : set ( true , nil ) ;
to : set ( nil , true ) ;
end
-- COMPAT
-- net.adns calls this but then replaces :send so this can be a noop
function interface : set_send ( new_send ) -- luacheck: ignore 212
end
-- Close all connections and servers
local function closeall ( )
for fd , conn in pairs ( fds ) do -- luacheck: ignore 213/fd
conn : close ( ) ;
end
end
local quitting = nil ;
-- Signal main loop about shutdown via above upvalue
local function setquitting ( quit )
if quit then
quitting = " quitting " ;
closeall ( ) ;
else
quitting = nil ;
end
end
-- Main loop
local function loop ( once )
repeat
local t = runtimers ( cfg.max_wait , cfg.min_wait ) ;
local fd , r , w = poll : wait ( t ) ;
if fd then
local conn = fds [ fd ] ;
if conn then
if r then
conn : onreadable ( ) ;
end
if w then
conn : onwritable ( ) ;
end
else
log ( " debug " , " Removing unknown fd %d " , fd ) ;
poll : del ( fd ) ;
end
elseif r ~= " timeout " and r ~= " signal " then
log ( " debug " , " epoll_wait error: %s[%d] " , r , w ) ;
end
until once or ( quitting and next ( fds ) == nil ) ;
return quitting ;
end
return {
get_backend = function ( ) return " epoll " ; end ;
addserver = addserver ;
addclient = addclient ;
add_task = addtimer ;
at = at ;
loop = loop ;
closeall = closeall ;
setquitting = setquitting ;
wrapclient = wrapclient ;
watchfd = watchfd ;
link = link ;
set_config = function ( newconfig )
cfg = setmetatable ( newconfig , default_config ) ;
end ;
-- libevent emulation
event = { EV_READ = " r " , EV_WRITE = " w " , EV_READWRITE = " rw " , EV_LEAVE = - 1 } ;
addevent = function ( fd , mode , callback )
local function onevent ( self )
local ret = self : callback ( ) ;
if ret == - 1 then
self : set ( false , false ) ;
elseif ret then
self : set ( mode == " r " or mode == " rw " , mode == " w " or mode == " rw " ) ;
end
end
local conn = setmetatable ( {
getfd = function ( ) return fd ; end ;
callback = callback ;
onreadable = onevent ;
onwritable = onevent ;
close = function ( self )
self : del ( ) ;
fds [ fd ] = nil ;
end ;
} , interface_mt ) ;
local ok , err = conn : add ( mode == " r " or mode == " rw " , mode == " w " or mode == " rw " ) ;
if not ok then return ok , err ; end
return conn ;
end ;
} ;