-- Prosody IM
-- Copyright (C) 2012 Florian Zeitz
-- Copyright (C) 2014 Daurnimator
--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
local softreq = require " util.dependencies " . softreq ;
local random_bytes = require " util.random " . bytes ;
local bit = assert ( softreq " bit " or softreq " bit32 " ,
" No bit module found. See https://prosody.im/doc/depends#bitop " ) ;
local band = bit.band ;
local bor = bit.bor ;
local bxor = bit.bxor ;
local lshift = bit.lshift ;
local rshift = bit.rshift ;
local t_concat = table.concat ;
local s_byte = string.byte ;
local s_char = string.char ;
local s_sub = string.sub ;
local s_pack = string.pack ;
local s_unpack = string.unpack ;
if not s_pack and softreq " struct " then
s_pack = softreq " struct " . pack ;
s_unpack = softreq " struct " . unpack ;
end
local function read_uint16be ( str , pos )
local l1 , l2 = s_byte ( str , pos , pos + 1 ) ;
return l1 * 256 + l2 ;
end
-- FIXME: this may lose precision
local function read_uint64be ( str , pos )
local l1 , l2 , l3 , l4 , l5 , l6 , l7 , l8 = s_byte ( str , pos , pos + 7 ) ;
local h = lshift ( l1 , 24 ) + lshift ( l2 , 16 ) + lshift ( l3 , 8 ) + l4 ;
local l = lshift ( l5 , 24 ) + lshift ( l6 , 16 ) + lshift ( l7 , 8 ) + l8 ;
return h * 2 ^ 32 + l ;
end
local function pack_uint16be ( x )
return s_char ( rshift ( x , 8 ) , band ( x , 0xFF ) ) ;
end
local function get_byte ( x , n )
return band ( rshift ( x , n ) , 0xFF ) ;
end
local function pack_uint64be ( x )
local h = band ( x / 2 ^ 32 , 2 ^ 32 - 1 ) ;
return s_char ( get_byte ( h , 24 ) , get_byte ( h , 16 ) , get_byte ( h , 8 ) , band ( h , 0xFF ) ,
get_byte ( x , 24 ) , get_byte ( x , 16 ) , get_byte ( x , 8 ) , band ( x , 0xFF ) ) ;
end
if s_pack then
function pack_uint16be ( x )
return s_pack ( " >I2 " , x ) ;
end
function pack_uint64be ( x )
return s_pack ( " >I8 " , x ) ;
end
end
if s_unpack then
function read_uint16be ( str , pos )
return s_unpack ( " >I2 " , str , pos ) ;
end
function read_uint64be ( str , pos )
return s_unpack ( " >I8 " , str , pos ) ;
end
end
local function parse_frame_header ( frame )
if # frame < 2 then return ; end
local byte1 , byte2 = s_byte ( frame , 1 , 2 ) ;
local result = {
FIN = band ( byte1 , 0x80 ) > 0 ;
RSV1 = band ( byte1 , 0x40 ) > 0 ;
RSV2 = band ( byte1 , 0x20 ) > 0 ;
RSV3 = band ( byte1 , 0x10 ) > 0 ;
opcode = band ( byte1 , 0x0F ) ;
MASK = band ( byte2 , 0x80 ) > 0 ;
length = band ( byte2 , 0x7F ) ;
} ;
local length_bytes = 0 ;
if result.length == 126 then
length_bytes = 2 ;
elseif result.length == 127 then
length_bytes = 8 ;
end
local header_length = 2 + length_bytes + ( result.MASK and 4 or 0 ) ;
if # frame < header_length then return ; end
if length_bytes == 2 then
result.length = read_uint16be ( frame , 3 ) ;
elseif length_bytes == 8 then
result.length = read_uint64be ( frame , 3 ) ;
end
if result.MASK then
result.key = { s_byte ( frame , length_bytes + 3 , length_bytes + 6 ) } ;
end
return result , header_length ;
end
-- XORs the string `str` with the array of bytes `key`
-- TODO: optimize
local function apply_mask ( str , key , from , to )
from = from or 1
if from < 0 then from = # str + from + 1 end -- negative indicies
to = to or # str
if to < 0 then to = # str + to + 1 end -- negative indicies
local key_len = # key
local counter = 0 ;
local data = { } ;
for i = from , to do
local key_index = counter % key_len + 1 ;
counter = counter + 1 ;
data [ counter ] = s_char ( bxor ( key [ key_index ] , s_byte ( str , i ) ) ) ;
end
return t_concat ( data ) ;
end
local function parse_frame_body ( frame , header , pos )
if header.MASK then
return apply_mask ( frame , header.key , pos , pos + header.length - 1 ) ;
else
return frame : sub ( pos , pos + header.length - 1 ) ;
end
end
local function parse_frame ( frame )
local result , pos = parse_frame_header ( frame ) ;
if result == nil or # frame < ( pos + result.length ) then return ; end
result.data = parse_frame_body ( frame , result , pos + 1 ) ;
return result , pos + result.length ;
end
local function build_frame ( desc )
local data = desc.data or " " ;
assert ( desc.opcode and desc.opcode >= 0 and desc.opcode <= 0xF , " Invalid WebSocket opcode " ) ;
if desc.opcode >= 0x8 then
-- RFC 6455 5.5
assert ( # data <= 125 , " WebSocket control frames MUST have a payload length of 125 bytes or less. " ) ;
end
local b1 = bor ( desc.opcode ,
desc.FIN and 0x80 or 0 ,
desc.RSV1 and 0x40 or 0 ,
desc.RSV2 and 0x20 or 0 ,
desc.RSV3 and 0x10 or 0 ) ;
local b2 = # data ;
local length_extra ;
if b2 <= 125 then -- 7-bit length
length_extra = " " ;
elseif b2 <= 0xFFFF then -- 2-byte length
b2 = 126 ;
length_extra = pack_uint16be ( # data ) ;
else -- 8-byte length
b2 = 127 ;
length_extra = pack_uint64be ( # data ) ;
end
local key = " "
if desc.MASK then
local key_a = desc.key
if key_a then
key = s_char ( unpack ( key_a , 1 , 4 ) ) ;
else
key = random_bytes ( 4 ) ;
key_a = { key : byte ( 1 , 4 ) } ;
end
b2 = bor ( b2 , 0x80 ) ;
data = apply_mask ( data , key_a ) ;
end
return s_char ( b1 , b2 ) .. length_extra .. key .. data
end
local function parse_close ( data )
local code , message
if # data >= 2 then
code = read_uint16be ( data , 1 ) ;
if # data > 2 then
message = s_sub ( data , 3 ) ;
end
end
return code , message
end
local function build_close ( code , message , mask )
local data = pack_uint16be ( code ) ;
if message then
assert ( # message <= 123 , " Close reason must be <=123 bytes " ) ;
data = data .. message ;
end
return build_frame ( {
opcode = 0x8 ;
FIN = true ;
MASK = mask ;
data = data ;
} ) ;
end
return {
parse_header = parse_frame_header ;
parse_body = parse_frame_body ;
parse = parse_frame ;
build = build_frame ;
parse_close = parse_close ;
build_close = build_close ;
} ;