diff --git a/websocket_transport.go b/websocket_transport.go
index 391c012..bd8a87b 100644
--- a/websocket_transport.go
+++ b/websocket_transport.go
@@ -6,7 +6,6 @@ import (
"errors"
"fmt"
"io"
- "net"
"strings"
"time"
@@ -14,6 +13,8 @@ import (
"nhooyr.io/websocket"
)
+const maxPacketSize = 32768
+
const pingTimeout = time.Duration(5) * time.Second
var ServerDoesNotSupportXmppOverWebsocket = errors.New("The websocket server does not support the xmpp subprotocol")
@@ -22,17 +23,23 @@ type WebsocketTransport struct {
Config TransportConfiguration
decoder *xml.Decoder
wsConn *websocket.Conn
- netConn net.Conn
+ queue chan []byte
logFile io.Writer
+
+ closeCtx context.Context
+ closeFunc context.CancelFunc
}
func (t *WebsocketTransport) Connect() (string, error) {
- ctx := context.Background()
+ t.queue = make(chan []byte, 256)
+ t.closeCtx, t.closeFunc = context.WithCancel(context.Background())
+ var ctx context.Context
+ ctx = context.Background()
if t.Config.ConnectTimeout > 0 {
- var cancel context.CancelFunc
- ctx, cancel = context.WithTimeout(ctx, time.Duration(t.Config.ConnectTimeout)*time.Second)
- defer cancel()
+ var cancelConnect context.CancelFunc
+ ctx, cancelConnect = context.WithTimeout(t.closeCtx, time.Duration(t.Config.ConnectTimeout)*time.Second)
+ defer cancelConnect()
}
wsConn, response, err := websocket.Dial(ctx, t.Config.Address, &websocket.DialOptions{
@@ -42,28 +49,30 @@ func (t *WebsocketTransport) Connect() (string, error) {
return "", NewConnError(err, true)
}
if response.Header.Get("Sec-WebSocket-Protocol") != "xmpp" {
- _ = wsConn.Close(websocket.StatusBadGateway, "Could not negotiate XMPP subprotocol")
+ t.cleanup(websocket.StatusBadGateway)
return "", NewConnError(ServerDoesNotSupportXmppOverWebsocket, true)
}
+ wsConn.SetReadLimit(maxPacketSize)
t.wsConn = wsConn
- t.netConn = websocket.NetConn(ctx, t.wsConn, websocket.MessageText)
+ t.startReader()
- handshake := fmt.Sprintf("", t.Config.Domain)
+ handshake := fmt.Sprintf(``, t.Config.Domain)
if _, err = t.Write([]byte(handshake)); err != nil {
- _ = wsConn.Close(websocket.StatusBadGateway, "XMPP handshake error")
+ t.cleanup(websocket.StatusBadGateway)
return "", NewConnError(err, false)
}
handshakeResponse := make([]byte, 2048)
if _, err = t.Read(handshakeResponse); err != nil {
- _ = wsConn.Close(websocket.StatusBadGateway, "XMPP handshake error")
+ t.cleanup(websocket.StatusBadGateway)
+
return "", NewConnError(err, false)
}
var openResponse = stanza.WebsocketOpen{}
if err = xml.Unmarshal(handshakeResponse, &openResponse); err != nil {
- _ = wsConn.Close(websocket.StatusBadGateway, "XMPP handshake error")
+ t.cleanup(websocket.StatusBadGateway)
return "", NewConnError(err, false)
}
@@ -73,6 +82,32 @@ func (t *WebsocketTransport) Connect() (string, error) {
return openResponse.Id, nil
}
+// startReader runs a go function that keeps reading from the websocket. This
+// is required to allow Ping() to work: Ping requires a Reader to be running
+// to process incoming control frames.
+func (t WebsocketTransport) startReader() {
+ go func() {
+ buffer := make([]byte, maxPacketSize)
+ for {
+ _, reader, err := t.wsConn.Reader(t.closeCtx)
+ if err != nil {
+ return
+ }
+ n, err := reader.Read(buffer)
+ if err != nil && err != io.EOF {
+ return
+ }
+ if n > 0 {
+ // We need to make a copy, otherwise we will overwrite the slice content
+ // on the next iteration of the for loop.
+ tmp := make([]byte, len(buffer))
+ copy(tmp, buffer)
+ t.queue <- tmp
+ }
+ }
+ }()
+}
+
func (t WebsocketTransport) StartTLS() error {
return TLSNotSupported
}
@@ -90,31 +125,52 @@ func (t WebsocketTransport) IsSecure() bool {
}
func (t WebsocketTransport) Ping() error {
- ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
+ ctx, cancel := context.WithTimeout(t.closeCtx, pingTimeout)
defer cancel()
return t.wsConn.Ping(ctx)
}
-func (t *WebsocketTransport) Read(p []byte) (n int, err error) {
- n, err = t.netConn.Read(p)
- if t.logFile != nil && n > 0 {
- _, _ = fmt.Fprintf(t.logFile, "RECV:\n%s\n\n", p)
+func (t *WebsocketTransport) Read(p []byte) (int, error) {
+ select {
+ case <-t.closeCtx.Done():
+ return 0, t.closeCtx.Err()
+ case data := <-t.queue:
+ if t.logFile != nil && len(data) > 0 {
+ _, _ = fmt.Fprintf(t.logFile, "RECV:\n%s\n\n", data)
+ }
+ copy(p, data)
+ return len(data), nil
}
- return
}
-func (t WebsocketTransport) Write(p []byte) (n int, err error) {
+func (t WebsocketTransport) Write(p []byte) (int, error) {
if t.logFile != nil {
_, _ = fmt.Fprintf(t.logFile, "SEND:\n%s\n\n", p)
}
- return t.netConn.Write(p)
+ return len(p), t.wsConn.Write(t.closeCtx, websocket.MessageText, p)
}
func (t WebsocketTransport) Close() error {
t.Write([]byte(""))
- return t.netConn.Close()
+ return t.wsConn.Close(websocket.StatusGoingAway, "Done")
}
func (t *WebsocketTransport) LogTraffic(logFile io.Writer) {
t.logFile = logFile
}
+
+func (t *WebsocketTransport) cleanup(code websocket.StatusCode) {
+ if t.queue != nil {
+ close(t.queue)
+ t.queue = nil
+ }
+ if t.wsConn != nil {
+ t.wsConn.Close(websocket.StatusGoingAway, "Done")
+ t.wsConn = nil
+ }
+ if t.closeFunc != nil {
+ t.closeFunc()
+ t.closeFunc = nil
+ t.closeCtx = nil
+ }
+}