@ -260,10 +260,12 @@ type Conn struct {
newCompressionWriter func ( io . WriteCloser , int ) io . WriteCloser
// Read fields
reader io . ReadCloser // the current reader returned to the application
readErr error
br * bufio . Reader
readRemaining int64 // bytes remaining in current frame.
reader io . ReadCloser // the current reader returned to the application
readErr error
br * bufio . Reader
// bytes remaining in current frame.
// set setReadRemaining to safely update this value and prevent overflow
readRemaining int64
readFinal bool // true the current message has more frames.
readLength int64 // Message size.
readLimit int64 // Maximum message size.
@ -320,6 +322,17 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int,
return c
}
// setReadRemaining tracks the number of bytes remaining on the connection. If n
// overflows, an ErrReadLimit is returned.
func ( c * Conn ) setReadRemaining ( n int64 ) error {
if n < 0 {
return ErrReadLimit
}
c . readRemaining = n
return nil
}
// Subprotocol returns the negotiated protocol for the connection.
func ( c * Conn ) Subprotocol ( ) string {
return c . subprotocol
@ -451,7 +464,8 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
return err
}
func ( c * Conn ) prepWrite ( messageType int ) error {
// beginMessage prepares a connection and message writer for a new message.
func ( c * Conn ) beginMessage ( mw * messageWriter , messageType int ) error {
// Close previous writer if not already closed by the application. It's
// probably better to return an error in this situation, but we cannot
// change this without breaking existing applications.
@ -471,6 +485,10 @@ func (c *Conn) prepWrite(messageType int) error {
return err
}
mw . c = c
mw . frameType = messageType
mw . pos = maxFrameHeaderSize
if c . writeBuf == nil {
wpd , ok := c . writePool . Get ( ) . ( writePoolData )
if ok {
@ -491,16 +509,11 @@ func (c *Conn) prepWrite(messageType int) error {
// All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and
// PongMessage) are supported.
func ( c * Conn ) NextWriter ( messageType int ) ( io . WriteCloser , error ) {
if err := c . prepWrite ( messageType ) ; err != nil {
var mw messageWriter
if err := c . beginMessage ( & mw , messageType ) ; err != nil {
return nil , err
}
mw := & messageWriter {
c : c ,
frameType : messageType ,
pos : maxFrameHeaderSize ,
}
c . writer = mw
c . writer = & mw
if c . newCompressionWriter != nil && c . enableWriteCompression && isData ( messageType ) {
w := c . newCompressionWriter ( c . writer , c . compressionLevel )
mw . compress = true
@ -517,10 +530,16 @@ type messageWriter struct {
err error
}
func ( w * messageWriter ) fatal ( err error ) error {
func ( w * messageWriter ) endMessage ( err error ) error {
if w . err != nil {
w . err = err
w . c . writer = nil
return err
}
c := w . c
w . err = err
c . writer = nil
if c . writePool != nil {
c . writePool . Put ( writePoolData { buf : c . writeBuf } )
c . writeBuf = nil
}
return err
}
@ -534,7 +553,7 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
// Check for invalid control frames.
if isControl ( w . frameType ) &&
( ! final || length > maxControlFramePayloadSize ) {
return w . fatal ( errInvalidControlFrame )
return w . endMessage ( errInvalidControlFrame )
}
b0 := byte ( w . frameType )
@ -579,7 +598,7 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
copy ( c . writeBuf [ maxFrameHeaderSize - 4 : ] , key [ : ] )
maskBytes ( key , 0 , c . writeBuf [ maxFrameHeaderSize : w . pos ] )
if len ( extra ) > 0 {
return c . writeFatal ( errors . New ( "websocket: internal error, extra used in client mode" ) )
return w . endMessage ( c . writeFatal ( errors . New ( "websocket: internal error, extra used in client mode" ) ) )
}
}
@ -600,15 +619,11 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
c . isWriting = false
if err != nil {
return w . fatal ( err )
return w . endMessage ( err )
}
if final {
c . writer = nil
if c . writePool != nil {
c . writePool . Put ( writePoolData { buf : c . writeBuf } )
c . writeBuf = nil
}
w . endMessage ( errWriteClosed )
return nil
}
@ -706,11 +721,7 @@ func (w *messageWriter) Close() error {
if w . err != nil {
return w . err
}
if err := w . flushFrame ( true , nil ) ; err != nil {
return err
}
w . err = errWriteClosed
return nil
return w . flushFrame ( true , nil )
}
// WritePreparedMessage writes prepared message into connection.
@ -742,10 +753,10 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error {
if c . isServer && ( c . newCompressionWriter == nil || ! c . enableWriteCompression ) {
// Fast path with no allocations and single frame.
if err := c . prepWrite ( messageType ) ; err != nil {
var mw messageWriter
if err := c . beginMessage ( & mw , messageType ) ; err != nil {
return err
}
mw := messageWriter { c : c , frameType : messageType , pos : maxFrameHeaderSize }
n := copy ( c . writeBuf [ mw . pos : ] , data )
mw . pos += n
data = data [ n : ]
@ -792,7 +803,7 @@ func (c *Conn) advanceFrame() (int, error) {
final := p [ 0 ] & finalBit != 0
frameType := int ( p [ 0 ] & 0xf )
mask := p [ 1 ] & maskBit != 0
c . readRemaining = int64 ( p [ 1 ] & 0x7f )
c . setReadRemaining ( int64 ( p [ 1 ] & 0x7f ) )
c . readDecompress = false
if c . newDecompressionReader != nil && ( p [ 0 ] & rsv1Bit ) != 0 {
@ -826,7 +837,17 @@ func (c *Conn) advanceFrame() (int, error) {
return noFrame , c . handleProtocolError ( "unknown opcode " + strconv . Itoa ( frameType ) )
}
// 3. Read and parse frame length.
// 3. Read and parse frame length as per
// https://tools.ietf.org/html/rfc6455#section-5.2
//
// The length of the "Payload data", in bytes: if 0-125, that is the payload
// length.
// - If 126, the following 2 bytes interpreted as a 16-bit unsigned
// integer are the payload length.
// - If 127, the following 8 bytes interpreted as
// a 64-bit unsigned integer (the most significant bit MUST be 0) are the
// payload length. Multibyte length quantities are expressed in network byte
// order.
switch c . readRemaining {
case 126 :
@ -834,13 +855,19 @@ func (c *Conn) advanceFrame() (int, error) {
if err != nil {
return noFrame , err
}
c . readRemaining = int64 ( binary . BigEndian . Uint16 ( p ) )
if err := c . setReadRemaining ( int64 ( binary . BigEndian . Uint16 ( p ) ) ) ; err != nil {
return noFrame , err
}
case 127 :
p , err := c . read ( 8 )
if err != nil {
return noFrame , err
}
c . readRemaining = int64 ( binary . BigEndian . Uint64 ( p ) )
if err := c . setReadRemaining ( int64 ( binary . BigEndian . Uint64 ( p ) ) ) ; err != nil {
return noFrame , err
}
}
// 4. Handle frame masking.
@ -863,6 +890,12 @@ func (c *Conn) advanceFrame() (int, error) {
if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage {
c . readLength += c . readRemaining
// Don't allow readLength to overflow in the presence of a large readRemaining
// counter.
if c . readLength < 0 {
return noFrame , ErrReadLimit
}
if c . readLimit > 0 && c . readLength > c . readLimit {
c . WriteControl ( CloseMessage , FormatCloseMessage ( CloseMessageTooBig , "" ) , time . Now ( ) . Add ( writeWait ) )
return noFrame , ErrReadLimit
@ -876,7 +909,7 @@ func (c *Conn) advanceFrame() (int, error) {
var payload [ ] byte
if c . readRemaining > 0 {
payload , err = c . read ( int ( c . readRemaining ) )
c . readRemaining = 0
c . setReadRemaining ( 0 )
if err != nil {
return noFrame , err
}
@ -949,6 +982,7 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
c . readErr = hideTempErr ( err )
break
}
if frameType == TextMessage || frameType == BinaryMessage {
c . messageReader = & messageReader { c }
c . reader = c . messageReader
@ -989,7 +1023,9 @@ func (r *messageReader) Read(b []byte) (int, error) {
if c . isServer {
c . readMaskPos = maskBytes ( c . readMaskKey , c . readMaskPos , b [ : n ] )
}
c . readRemaining -= int64 ( n )
rem := c . readRemaining
rem -= int64 ( n )
c . setReadRemaining ( rem )
if c . readRemaining > 0 && c . readErr == io . EOF {
c . readErr = errUnexpectedEOF
}
@ -1041,7 +1077,7 @@ func (c *Conn) SetReadDeadline(t time.Time) error {
return c . conn . SetReadDeadline ( t )
}
// SetReadLimit sets the maximum size for a message read from the peer. If a
// SetReadLimit sets the maximum size in bytes for a message read from the peer. If a
// message exceeds the limit, the connection sends a close message to the peer
// and returns ErrReadLimit to the application.
func ( c * Conn ) SetReadLimit ( limit int64 ) {