diff --git a/nbhttp/engine.go b/nbhttp/engine.go index 2804e5a1..eacbc38d 100644 --- a/nbhttp/engine.go +++ b/nbhttp/engine.go @@ -196,9 +196,11 @@ type Config struct { // ReadBufferPool . ReadBufferPool mempool.Allocator + // Deprecated. // WebsocketCompressor . WebsocketCompressor func(w io.WriteCloser, level int) io.WriteCloser + // Deprecated. // WebsocketDecompressor . WebsocketDecompressor func(r io.Reader) io.ReadCloser diff --git a/nbhttp/websocket/conn.go b/nbhttp/websocket/conn.go index 3021839b..3f04d8cd 100644 --- a/nbhttp/websocket/conn.go +++ b/nbhttp/websocket/conn.go @@ -425,8 +425,8 @@ func (c *Conn) Parse(data []byte) error { if c.compress { var b []byte var rc io.ReadCloser - if c.Engine.WebsocketDecompressor != nil { - rc = c.Engine.WebsocketDecompressor(io.MultiReader(bytes.NewBuffer(message), strings.NewReader(flateReaderTail))) + if c.WebsocketDecompressor != nil { + rc = c.WebsocketDecompressor(c, io.MultiReader(bytes.NewBuffer(message), strings.NewReader(flateReaderTail))) } else { rc = decompressReader(io.MultiReader(bytes.NewBuffer(message), strings.NewReader(flateReaderTail))) } @@ -562,8 +562,8 @@ func (c *Conn) WriteMessage(messageType MessageType, data []byte) error { w.Reset() var cw io.WriteCloser - if c.Engine.WebsocketCompressor != nil { - cw = c.Engine.WebsocketCompressor(w, c.compressionLevel) + if c.WebsocketCompressor != nil { + cw = c.WebsocketCompressor(c, w, c.compressionLevel) } else { cw = compressWriter(w, c.compressionLevel) } diff --git a/nbhttp/websocket/upgrader.go b/nbhttp/websocket/upgrader.go index 456c214a..b991d6c9 100644 --- a/nbhttp/websocket/upgrader.go +++ b/nbhttp/websocket/upgrader.go @@ -58,6 +58,9 @@ type commonFields struct { MessageLengthLimit int BlockingModAsyncCloseDelay time.Duration + WebsocketCompressor func(c *Conn, w io.WriteCloser, level int) io.WriteCloser + WebsocketDecompressor func(c *Conn, r io.Reader) io.ReadCloser + pingMessageHandler func(c *Conn, appData string) pongMessageHandler func(c *Conn, appData string) closeMessageHandler func(c *Conn, code int, text string)