From 36e153f9813d1d419dfa86bfe7b328435370ab38 Mon Sep 17 00:00:00 2001
From: Wichert Akkerman <wichert@wiggy.net>
Date: Tue, 15 Oct 2019 20:56:11 +0200
Subject: [PATCH] Allow transports to define their own ping mechanism

---
 client.go              |  4 ++--
 transport.go           |  1 +
 websocket_transport.go | 10 ++++++++++
 xmpp_transport.go      | 12 ++++++++++++
 4 files changed, 25 insertions(+), 2 deletions(-)

diff --git a/client.go b/client.go
index 7f56d0b..9c74da7 100644
--- a/client.go
+++ b/client.go
@@ -276,8 +276,8 @@ func keepalive(transport Transport, quit <-chan struct{}) {
 	for {
 		select {
 		case <-ticker.C:
-			if n, err := fmt.Fprintf(transport, "\n"); err != nil || n != 1 {
-				// When keep alive fails, we force close the transportection. In all cases, the recv will also fail.
+			if err := transport.Ping(); err != nil {
+				// When keepalive fails, we force close the transport. In all cases, the recv will also fail.
 				ticker.Stop()
 				_ = transport.Close()
 				return
diff --git a/transport.go b/transport.go
index 3e0ca0d..6c4b8e0 100644
--- a/transport.go
+++ b/transport.go
@@ -25,6 +25,7 @@ type Transport interface {
 
 	IsSecure() bool
 
+	Ping() error
 	Read(p []byte) (n int, err error)
 	Write(p []byte) (n int, err error)
 	Close() error
diff --git a/websocket_transport.go b/websocket_transport.go
index a526fc4..26ab511 100644
--- a/websocket_transport.go
+++ b/websocket_transport.go
@@ -9,6 +9,8 @@ import (
 	"nhooyr.io/websocket"
 )
 
+const pingTimeout = time.Duration(5) * time.Second
+
 type WebsocketTransport struct {
 	Config  TransportConfiguration
 	wsConn  *websocket.Conn
@@ -46,6 +48,14 @@ func (t WebsocketTransport) IsSecure() bool {
 	return strings.HasPrefix(t.Config.Address, "wss:")
 }
 
+func (t WebsocketTransport) Ping() error {
+	ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
+	defer cancel()
+	// Note that we do not use wsConn.Ping(), because not all websocket servers
+	// (ejabberd for example) implement ping frames
+	return t.wsConn.Write(ctx, websocket.MessageText, []byte(" "))
+}
+
 func (t WebsocketTransport) Read(p []byte) (n int, err error) {
 	return t.netConn.Read(p)
 }
diff --git a/xmpp_transport.go b/xmpp_transport.go
index 2530b82..614a76d 100644
--- a/xmpp_transport.go
+++ b/xmpp_transport.go
@@ -2,6 +2,7 @@ package xmpp
 
 import (
 	"crypto/tls"
+	"errors"
 	"net"
 	"time"
 )
@@ -59,6 +60,17 @@ func (t *XMPPTransport) StartTLS(domain string) error {
 	return nil
 }
 
+func (t XMPPTransport) Ping() error {
+	n, err := t.conn.Write([]byte("\n"))
+	if err != nil {
+		return err
+	}
+	if n != 1 {
+		return errors.New("Could not write ping")
+	}
+	return nil
+}
+
 func (t XMPPTransport) Read(p []byte) (n int, err error) {
 	return t.conn.Read(p)
 }