From 433b1528ade79e10c76befad6a5566675832e49a Mon Sep 17 00:00:00 2001 From: nikky Date: Fri, 7 Aug 2020 22:52:53 +0200 Subject: [PATCH] improve websocket handling replace Ring with channel add pings and pongHandler close websockets properly --- bridge/api/api.go | 108 ++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 90 insertions(+), 18 deletions(-) diff --git a/bridge/api/api.go b/bridge/api/api.go index 62336881..bd0b63c6 100644 --- a/bridge/api/api.go +++ b/bridge/api/api.go @@ -1,7 +1,9 @@ package api import ( + "bytes" "encoding/json" + "io" "net/http" "sync" "time" @@ -11,11 +13,21 @@ import ( "github.com/gorilla/websocket" "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" - ring "github.com/zfjagann/golang-ring" +) + +const ( + // Time allowed to write a message to the peer. + writeWait = 10 * time.Second + + // Time allowed to read the next pong message from the peer. + pongWait = 10 * time.Second // TODO: 60 seconds + + // Send pings to peer with this period. Must be less than pongWait. + pingPeriod = (pongWait * 9) / 10 ) type API struct { - Messages ring.Ring + send chan config.Message sync.RWMutex *bridge.Config } @@ -33,10 +45,7 @@ func New(cfg *bridge.Config) bridge.Bridger { e := echo.New() e.HideBanner = true e.HidePort = true - b.Messages = ring.Ring{} - if b.GetInt("Buffer") != 0 { - b.Messages.SetCapacity(b.GetInt("Buffer")) - } + b.send = make(chan config.Message, b.GetInt("Buffer")) if b.GetString("Token") != "" { e.Use(middleware.KeyAuth(func(key string, c echo.Context) (bool, error) { return key == b.GetString("Token"), nil @@ -83,7 +92,7 @@ func (b *API) Send(msg config.Message) (string, error) { if msg.Event == config.EventMsgDelete { return "", nil } - b.Messages.Enqueue(&msg) + b.send <- msg return "", nil } @@ -110,8 +119,21 @@ func (b *API) handlePostMessage(c echo.Context) error { func (b *API) handleMessages(c echo.Context) error { b.Lock() defer b.Unlock() - c.JSONPretty(http.StatusOK, b.Messages.Values(), " ") - b.Messages = ring.Ring{} + // collect all messages until the channel has no more messages in the buffer + var messages []config.Message + loop: for { + select { + case msg := <- b.send: + messages = append(messages, msg) + default: + break loop + } + } + // TODO: get all messages from send channel + c.JSONPretty(http.StatusOK, messages, " ") + // TODO: clear send channel ? + //b.send = make(chan config.Message, b.GetInt("Buffer")) + //b.Messages = ring.Ring{} return nil } @@ -131,8 +153,9 @@ func (b *API) handleStream(c echo.Context) error { } c.Response().Flush() for { - msg := b.Messages.Dequeue() - if msg != nil { + select { + // block until channel has message + case msg := <- b.send: if err := json.NewEncoder(c.Response()).Encode(msg); err != nil { return err } @@ -154,30 +177,79 @@ func (b *API) handleWebsocketMessage(message config.Message) { } func (b *API) writePump(conn *websocket.Conn) { + ticker := time.NewTicker(pingPeriod) + defer func() { + b.Log.Debug("closing websocket") + ticker.Stop() + conn.Close() + }() + for { - msg := b.Messages.Dequeue() - if msg != nil { + select { + case msg := <-b.send: + conn.SetWriteDeadline(time.Now().Add(writeWait)) err := conn.WriteJSON(msg) if err != nil { - break + b.Log.Errorf("error: %v", err) + return + } + case <-ticker.C: + b.Log.Debug("sending ping") + conn.SetWriteDeadline(time.Now().Add(writeWait)) + if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { + b.Log.Errorf("error: %v", err) + return } } } } func (b *API) readPump(conn *websocket.Conn) { - for { + defer func() { + b.Log.Debug("closing websocket") + conn.Close() + }() + + _ = conn.SetReadDeadline(time.Now().Add(pongWait)) + conn.SetPongHandler( + func(string) error { + b.Log.Debug("received pong") + conn.SetReadDeadline(time.Now().Add(pongWait)) + return nil + }, + ) + + for { message := config.Message{} - err := conn.ReadJSON(&message) + //err := conn.ReadJSON(&message) + //if err != nil { + // b.Log.Errorf("error: %v", err) + // return + //} + _, messageBytes, err := conn.ReadMessage() if err != nil { - break + b.Log.Errorf("error: %v", err) + return + } + err = json.NewDecoder(bytes.NewReader(messageBytes)).Decode(&message) + if err != nil { + if err == io.EOF { + // One value is expected in the message. + err = io.ErrUnexpectedEOF + } + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + b.Log.Errorf("Websocket closed unexpectedly: %v", err) + } + return } b.handleWebsocketMessage(message) } } func (b *API) handleWebsocket(c echo.Context) error { - conn, err := websocket.Upgrade(c.Response().Writer, c.Request(), nil, 1024, 1024) + u := websocket.Upgrader{ReadBufferSize: 1024, WriteBufferSize: 1024} + conn, err := u.Upgrade(c.Response().Writer, c.Request(), nil) + //websocket.Upgrade(c.Response().Writer, c.Request(), nil, 1024, 1024) if err != nil { return err }