diff --git a/CHANGELOG.md b/CHANGELOG.md
index 07598b8..41f58ae 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,21 @@
# Fluux XMPP Changelog
+## v0.5.0
+
+### Changes
+
+- Added support for XEP-0198 (Stream management)
+- Added message queue : when using "SendX" methods on a client, messages are also stored in a queue. When requesting
+acks from the server, sent messages will be discarded, and unsent ones will be sent again. (see https://xmpp.org/extensions/xep-0198.html#acking)
+- Added support for stanza_errors (see https://xmpp.org/rfcs/rfc3920.html#def C.2. Stream error namespace and https://xmpp.org/rfcs/rfc6120.html#schemas-streamerror)
+- Added separate hooks for connection and reconnection on the client. One can now specify different actions to get triggered on client connect
+and reconnect, at client init time.
+- Client state update is now thread safe
+- Changed the Config struct to use pointer semantics
+- Tests
+- Refactoring, including removing some Fprintf statements in favor of Marshal + Write and using structs from the library
+instead of strings
+
## v0.4.0
### Changes
diff --git a/_examples/go.sum b/_examples/go.sum
index 286bc95..467aab6 100644
--- a/_examples/go.sum
+++ b/_examples/go.sum
@@ -99,7 +99,9 @@ github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/9
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/processone/mpg123 v1.0.0 h1:o2WOyGZRM255or1Zc/LtF/jARn51B+9aQl72Qace0GA=
github.com/processone/mpg123 v1.0.0/go.mod h1:X/FeL+h8vD1bYsG9tIWV3M2c4qNTZOficyvPVBP08go=
+github.com/processone/soundcloud v1.0.0 h1:/+i6+Yveb7Y6IFGDSkesYI+HddblzcRTQClazzVHxoE=
github.com/processone/soundcloud v1.0.0/go.mod h1:kDLeWpkRtN3C8kIReQdxoiRi92P9xR6yW6qLOJnNWfY=
github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDft0ttaMvbicHlPoso=
@@ -155,6 +157,7 @@ golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73r
golang.org/x/net v0.0.0-20190110200230-915654e7eabc/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190522155817-f3200d17e092/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
+golang.org/x/net v0.0.0-20190620200207-3b0461eec859 h1:R/3boaszxrf1GEUWTVDzSKVwLmSJpwZ1yqXm8j0v2QI=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
diff --git a/_examples/xmpp_echo/xmpp_echo.go b/_examples/xmpp_echo/xmpp_echo.go
index b6c6766..d7a35c8 100644
--- a/_examples/xmpp_echo/xmpp_echo.go
+++ b/_examples/xmpp_echo/xmpp_echo.go
@@ -28,7 +28,7 @@ func main() {
router := xmpp.NewRouter()
router.HandleFunc("message", handleMessage)
- client, err := xmpp.NewClient(config, router, errorHandler)
+ client, err := xmpp.NewClient(&config, router, errorHandler)
if err != nil {
log.Fatalf("%+v", err)
}
diff --git a/_examples/xmpp_jukebox/xmpp_jukebox.go b/_examples/xmpp_jukebox/xmpp_jukebox.go
index b8075a2..31f000a 100644
--- a/_examples/xmpp_jukebox/xmpp_jukebox.go
+++ b/_examples/xmpp_jukebox/xmpp_jukebox.go
@@ -54,7 +54,7 @@ func main() {
handleIQ(s, p, player)
})
- client, err := xmpp.NewClient(config, router, errorHandler)
+ client, err := xmpp.NewClient(&config, router, errorHandler)
if err != nil {
log.Fatalf("%+v", err)
}
diff --git a/_examples/xmpp_oauth2/xmpp_oauth2.go b/_examples/xmpp_oauth2/xmpp_oauth2.go
index 89b2639..a993049 100644
--- a/_examples/xmpp_oauth2/xmpp_oauth2.go
+++ b/_examples/xmpp_oauth2/xmpp_oauth2.go
@@ -28,7 +28,7 @@ func main() {
router := xmpp.NewRouter()
router.HandleFunc("message", handleMessage)
- client, err := xmpp.NewClient(config, router, errorHandler)
+ client, err := xmpp.NewClient(&config, router, errorHandler)
if err != nil {
log.Fatalf("%+v", err)
}
diff --git a/_examples/xmpp_pubsub_client/xmpp_ps_client.go b/_examples/xmpp_pubsub_client/xmpp_ps_client.go
index b2e9cf6..2308071 100644
--- a/_examples/xmpp_pubsub_client/xmpp_ps_client.go
+++ b/_examples/xmpp_pubsub_client/xmpp_ps_client.go
@@ -38,7 +38,7 @@ func main() {
log.Println("Received a message ! => \n" + string(data))
})
- client, err := xmpp.NewClient(config, router, func(err error) { log.Println(err) })
+ client, err := xmpp.NewClient(&config, router, func(err error) { log.Println(err) })
if err != nil {
log.Fatalf("%+v", err)
}
diff --git a/_examples/xmpp_websocket/xmpp_websocket.go b/_examples/xmpp_websocket/xmpp_websocket.go
index c8c0620..3a0c1ba 100644
--- a/_examples/xmpp_websocket/xmpp_websocket.go
+++ b/_examples/xmpp_websocket/xmpp_websocket.go
@@ -26,7 +26,7 @@ func main() {
router := xmpp.NewRouter()
router.HandleFunc("message", handleMessage)
- client, err := xmpp.NewClient(config, router, errorHandler)
+ client, err := xmpp.NewClient(&config, router, errorHandler)
if err != nil {
log.Fatalf("%+v", err)
}
diff --git a/auth.go b/auth.go
index b8d20b9..902371b 100644
--- a/auth.go
+++ b/auth.go
@@ -60,10 +60,21 @@ func authPlain(socket io.ReadWriter, decoder *xml.Decoder, mech string, user str
raw := "\x00" + user + "\x00" + secret
enc := make([]byte, base64.StdEncoding.EncodedLen(len(raw)))
base64.StdEncoding.Encode(enc, []byte(raw))
- _, err := fmt.Fprintf(socket, "%s", stanza.NSSASL, mech, enc)
+
+ a := stanza.SASLAuth{
+ Mechanism: mech,
+ Value: string(enc),
+ }
+ data, err := xml.Marshal(a)
if err != nil {
return err
}
+ n, err := socket.Write(data)
+ if err != nil {
+ return err
+ } else if n == 0 {
+ return errors.New("failed to write authSASL nonza to socket : wrote 0 bytes")
+ }
// Next message should be either success or failure.
val, err := stanza.NextPacket(decoder)
diff --git a/client.go b/client.go
index bd40c38..446dcbe 100644
--- a/client.go
+++ b/client.go
@@ -6,6 +6,7 @@ import (
"errors"
"io"
"net"
+ "sync"
"time"
"gosrc.io/xmpp/stanza"
@@ -14,15 +15,36 @@ import (
//=============================================================================
// EventManager
-// ConnState represents the current connection state.
+// SyncConnState represents the current connection state.
+type SyncConnState struct {
+ sync.RWMutex
+ // Current state of the client. Please use the dedicated getter and setter for this field as they are thread safe.
+ state ConnState
+}
type ConnState = uint8
+// getState is a thread-safe getter for the current state
+func (scs *SyncConnState) getState() ConnState {
+ var res ConnState
+ scs.RLock()
+ res = scs.state
+ scs.RUnlock()
+ return res
+}
+
+// setState is a thread-safe setter for the current
+func (scs *SyncConnState) setState(cs ConnState) {
+ scs.Lock()
+ scs.state = cs
+ scs.Unlock()
+}
+
// This is a the list of events happening on the connection that the
// client can be notified about.
const (
InitialPresence = ""
StateDisconnected ConnState = iota
- StateConnected
+ StateResuming
StateSessionEstablished
StateStreamError
StatePermanentError
@@ -31,7 +53,7 @@ const (
// Event is a structure use to convey event changes related to client state. This
// is for example used to notify the client when the client get disconnected.
type Event struct {
- State ConnState
+ State SyncConnState
Description string
StreamError string
SMState SMState
@@ -44,7 +66,16 @@ type SMState struct {
Id string
// Inbound stanza count
Inbound uint
- // TODO Store location for IP affinity
+
+ // IP affinity
+ preferredReconAddr string
+
+ // Error
+ StreamErrorGroup stanza.StanzaErrorGroup
+
+ // Track sent stanzas
+ *stanza.UnAckQueue
+
// TODO Store max and timestamp, to check if we should retry resumption or not
}
@@ -53,29 +84,35 @@ type SMState struct {
type EventHandler func(Event) error
type EventManager struct {
- // Store current state
- CurrentState ConnState
+ // Store current state. Please use "getState" and "setState" to access and/or modify this.
+ CurrentState SyncConnState
// Callback used to propagate connection state changes
Handler EventHandler
}
+// updateState changes the CurrentState in the event manager. The state read is threadsafe but there is no guarantee
+// regarding the triggered callback function.
func (em *EventManager) updateState(state ConnState) {
- em.CurrentState = state
+ em.CurrentState.setState(state)
if em.Handler != nil {
em.Handler(Event{State: em.CurrentState})
}
}
+// disconnected changes the CurrentState in the event manager to "disconnected". The state read is threadsafe but there is no guarantee
+// regarding the triggered callback function.
func (em *EventManager) disconnected(state SMState) {
- em.CurrentState = StateDisconnected
+ em.CurrentState.setState(StateDisconnected)
if em.Handler != nil {
em.Handler(Event{State: em.CurrentState, SMState: state})
}
}
+// streamError changes the CurrentState in the event manager to "streamError". The state read is threadsafe but there is no guarantee
+// regarding the triggered callback function.
func (em *EventManager) streamError(error, desc string) {
- em.CurrentState = StateStreamError
+ em.CurrentState.setState(StateStreamError)
if em.Handler != nil {
em.Handler(Event{State: em.CurrentState, StreamError: error, Description: desc})
}
@@ -90,7 +127,7 @@ var ErrCanOnlySendGetOrSetIq = errors.New("SendIQ can only send get and set IQ s
// server.
type Client struct {
// Store user defined options and states
- config Config
+ config *Config
// Session gather data that can be accessed by users of this library
Session *Session
transport Transport
@@ -100,6 +137,12 @@ type Client struct {
EventManager
// Handle errors from client execution
ErrorHandler func(error)
+
+ // Post connection hook. This will be executed on first connection
+ PostConnectHook func() error
+
+ // Post resume hook. This will be executed after the client resumes a lost connection using StreamManagement (XEP-0198)
+ PostResumeHook func() error
}
/*
@@ -107,9 +150,9 @@ Setting up the client / Checking the parameters
*/
// NewClient generates a new XMPP client, based on Config passed as parameters.
-// If host is not specified, the DNS SRV should be used to find the host from the domainpart of the Jid.
+// If host is not specified, the DNS SRV should be used to find the host from the domain part of the Jid.
// Default the port to 5222.
-func NewClient(config Config, r *Router, errorHandler func(error)) (c *Client, err error) {
+func NewClient(config *Config, r *Router, errorHandler func(error)) (c *Client, err error) {
if config.KeepaliveInterval == 0 {
config.KeepaliveInterval = time.Second * 30
}
@@ -169,26 +212,45 @@ func NewClient(config Config, r *Router, errorHandler func(error)) (c *Client, e
return
}
-// Connect triggers actual TCP connection, based on previously defined parameters.
-// Connect simply triggers resumption, with an empty session state.
+// Connect establishes a first time connection to a XMPP server.
+// It calls the PostConnectHook
func (c *Client) Connect() error {
- var state SMState
- return c.Resume(state)
+ err := c.connect()
+ if err != nil {
+ return err
+ }
+ // TODO: Do we always want to send initial presence automatically ?
+ // Do we need an option to avoid that or do we rely on client to send the presence itself ?
+ err = c.sendWithWriter(c.transport, []byte(InitialPresence))
+ // Execute the post first connection hook. Typically this holds "ask for roster" and this type of actions.
+ if c.PostConnectHook != nil {
+ err = c.PostConnectHook()
+ if err != nil {
+ return err
+ }
+ }
+
+ // Start the keepalive go routine
+ keepaliveQuit := make(chan struct{})
+ go keepalive(c.transport, c.config.KeepaliveInterval, keepaliveQuit)
+ // Start the receiver go routine
+ go c.recv(keepaliveQuit)
+ return err
}
-// Resume attempts resuming a Stream Managed session, based on the provided stream management
-// state.
-func (c *Client) Resume(state SMState) error {
+// connect establishes an actual TCP connection, based on previously defined parameters, as well as a XMPP session
+func (c *Client) connect() error {
+ var state SMState
var err error
-
+ // This is the TCP connection
streamId, err := c.transport.Connect()
if err != nil {
return err
}
- c.updateState(StateConnected)
- // Client is ok, we now open XMPP session
- if c.Session, err = NewSession(c.transport, c.config, state); err != nil {
+ // Client is ok, we now open XMPP session with TLS negotiation if possible and session resume or binding
+ // depending on state.
+ if c.Session, err = NewSession(c, state); err != nil {
// Try to get the stream close tag from the server.
go func() {
for {
@@ -212,22 +274,26 @@ func (c *Client) Resume(state SMState) error {
c.Session.StreamId = streamId
c.updateState(StateSessionEstablished)
- // Start the keepalive go routine
- keepaliveQuit := make(chan struct{})
- go keepalive(c.transport, c.config.KeepaliveInterval, keepaliveQuit)
- // Start the receiver go routine
- state = c.Session.SMState
- go c.recv(state, keepaliveQuit)
-
- // We're connected and can now receive and send messages.
- //fmt.Fprintf(client.conn, "%s%s", "chat", "Online")
- // TODO: Do we always want to send initial presence automatically ?
- // Do we need an option to avoid that or do we rely on client to send the presence itself ?
- err = c.sendWithWriter(c.transport, []byte(InitialPresence))
-
return err
}
+// Resume attempts resuming a Stream Managed session, based on the provided stream management
+// state. See XEP-0198
+func (c *Client) Resume() error {
+ c.EventManager.updateState(StateResuming)
+ err := c.connect()
+ if err != nil {
+ return err
+ }
+ // Execute post reconnect hook. This can be different from the first connection hook, and not trigger roster retrival
+ // for example.
+ if c.PostResumeHook != nil {
+ err = c.PostResumeHook()
+ }
+ return err
+}
+
+// Disconnect disconnects the client from the server, sending a stream close nonza and closing the TCP connection.
func (c *Client) Disconnect() error {
if c.transport != nil {
return c.transport.Close()
@@ -252,6 +318,15 @@ func (c *Client) Send(packet stanza.Packet) error {
return errors.New("cannot marshal packet " + err.Error())
}
+ // Store stanza as non-acked as part of stream management
+ // See https://xmpp.org/extensions/xep-0198.html#scenarios
+ if c.config.StreamManagementEnable {
+ if _, ok := packet.(stanza.SMRequest); !ok {
+ toStore := stanza.UnAckedStz{Stz: string(data)}
+ c.Session.SMState.UnAckQueue.Push(&toStore)
+ }
+ }
+
return c.sendWithWriter(c.transport, data)
}
@@ -284,6 +359,12 @@ func (c *Client) SendRaw(packet string) error {
return errors.New("client is not connected")
}
+ // Store stanza as non-acked as part of stream management
+ // See https://xmpp.org/extensions/xep-0198.html#scenarios
+ if c.config.StreamManagementEnable {
+ toStore := stanza.UnAckedStz{Stz: packet}
+ c.Session.SMState.UnAckQueue.Push(&toStore)
+ }
return c.sendWithWriter(c.transport, []byte(packet))
}
@@ -297,13 +378,13 @@ func (c *Client) sendWithWriter(writer io.Writer, packet []byte) error {
// Go routines
// Loop: Receive data from server
-func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}) {
+func (c *Client) recv(keepaliveQuit chan<- struct{}) {
for {
val, err := stanza.NextPacket(c.transport.GetDecoder())
if err != nil {
c.ErrorHandler(err)
close(keepaliveQuit)
- c.disconnected(state)
+ c.disconnected(c.Session.SMState)
return
}
@@ -321,7 +402,7 @@ func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}) {
answer := stanza.SMAnswer{XMLName: xml.Name{
Space: stanza.NSStreamManagement,
Local: "a",
- }, H: state.Inbound}
+ }, H: c.Session.SMState.Inbound}
err = c.Send(answer)
if err != nil {
c.ErrorHandler(err)
@@ -332,7 +413,7 @@ func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}) {
c.transport.ReceivedStreamClose()
return
default:
- state.Inbound++
+ c.Session.SMState.Inbound++
}
// Do normal route processing in a go-routine so we can immediately
// start receiving other stanzas. This also allows route handlers to
diff --git a/client_internal_test.go b/client_internal_test.go
index 6daef09..140eab7 100644
--- a/client_internal_test.go
+++ b/client_internal_test.go
@@ -2,7 +2,16 @@ package xmpp
import (
"bytes"
+ "encoding/xml"
+ "fmt"
+ "gosrc.io/xmpp/stanza"
+ "strconv"
"testing"
+ "time"
+)
+
+const (
+ streamManagementID = "test-stream_management-id"
)
func TestClient_Send(t *testing.T) {
@@ -17,3 +26,583 @@ func TestClient_Send(t *testing.T) {
t.Errorf("Incorrect value sent to buffer: '%s'", buffer.String())
}
}
+
+// Stream management test.
+// Connection is established, then the server sends supported features and so on.
+// After the bind, client attempts a stream management enablement, and server replies in kind.
+func Test_StreamManagement(t *testing.T) {
+ serverDone := make(chan struct{})
+ clientDone := make(chan struct{})
+
+ client, mock := initSrvCliForResumeTests(t, func(t *testing.T, sc *ServerConn) {
+ checkClientOpenStream(t, sc)
+
+ sendStreamFeatures(t, sc) // Send initial features
+ readAuth(t, sc.decoder)
+ sc.connection.Write([]byte(""))
+
+ checkClientOpenStream(t, sc) // Reset stream
+ sendFeaturesStreamManagment(t, sc) // Send post auth features
+ bind(t, sc)
+ enableStreamManagement(t, sc, false, true)
+ serverDone <- struct{}{}
+ }, testClientStreamManagement, true, true)
+ go func() {
+ var state SMState
+ var err error
+ // Client is ok, we now open XMPP session
+ if client.Session, err = NewSession(client, state); err != nil {
+ t.Fatalf("failed to open XMPP session: %s", err)
+ }
+ clientDone <- struct{}{}
+ }()
+
+ waitForEntity(t, clientDone)
+ waitForEntity(t, serverDone)
+ mock.Stop()
+}
+
+// Absence of stream management test.
+// Connection is established, then the server sends supported features and so on.
+// Client has stream management disabled in its config, and should not ask for it. Server is not set up to reply.
+func Test_NoStreamManagement(t *testing.T) {
+ serverDone := make(chan struct{})
+ clientDone := make(chan struct{})
+
+ // Setup Mock server
+ client, mock := initSrvCliForResumeTests(t, func(t *testing.T, sc *ServerConn) {
+ checkClientOpenStream(t, sc)
+
+ sendStreamFeatures(t, sc) // Send initial features
+ readAuth(t, sc.decoder)
+ sc.connection.Write([]byte(""))
+
+ checkClientOpenStream(t, sc) // Reset stream
+ sendFeaturesNoStreamManagment(t, sc) // Send post auth features
+ bind(t, sc)
+ serverDone <- struct{}{}
+ }, testClientStreamManagement, true, false)
+
+ go func() {
+ var state SMState
+
+ // Client is ok, we now open XMPP session
+ var err error
+ if client.Session, err = NewSession(client, state); err != nil {
+ t.Fatalf("failed to open XMPP session: %s", err)
+ }
+ clientDone <- struct{}{}
+ }()
+
+ waitForEntity(t, clientDone)
+ waitForEntity(t, serverDone)
+
+ mock.Stop()
+}
+
+func Test_StreamManagementNotSupported(t *testing.T) {
+ serverDone := make(chan struct{})
+ clientDone := make(chan struct{})
+
+ client, mock := initSrvCliForResumeTests(t, func(t *testing.T, sc *ServerConn) {
+ checkClientOpenStream(t, sc)
+
+ sendStreamFeatures(t, sc) // Send initial features
+ readAuth(t, sc.decoder)
+ sc.connection.Write([]byte(""))
+
+ checkClientOpenStream(t, sc) // Reset stream
+ sendFeaturesNoStreamManagment(t, sc) // Send post auth features
+ bind(t, sc)
+ serverDone <- struct{}{}
+ }, testClientStreamManagement, true, true)
+
+ go func() {
+ var state SMState
+ var err error
+ // Client is ok, we now open XMPP session
+ if client.Session, err = NewSession(client, state); err != nil {
+ t.Fatalf("failed to open XMPP session: %s", err)
+ }
+ clientDone <- struct{}{}
+ }()
+
+ // Wait for client
+ waitForEntity(t, clientDone)
+
+ // Check if client got a positive stream management response from the server
+ if client.Session.Features.DoesStreamManagement() {
+ t.Fatalf("server does not provide stream management")
+ }
+
+ // Wait for server
+ waitForEntity(t, serverDone)
+ mock.Stop()
+}
+
+func Test_StreamManagementNoResume(t *testing.T) {
+ serverDone := make(chan struct{})
+ clientDone := make(chan struct{})
+
+ client, mock := initSrvCliForResumeTests(t, func(t *testing.T, sc *ServerConn) {
+ checkClientOpenStream(t, sc)
+
+ sendStreamFeatures(t, sc) // Send initial features
+ readAuth(t, sc.decoder)
+ sc.connection.Write([]byte(""))
+
+ checkClientOpenStream(t, sc) // Reset stream
+ sendFeaturesStreamManagment(t, sc) // Send post auth features
+ bind(t, sc)
+ enableStreamManagement(t, sc, false, false)
+ serverDone <- struct{}{}
+ }, testClientStreamManagement, true, true)
+
+ go func() {
+ var state SMState
+ var err error
+ // Client is ok, we now open XMPP session
+ if client.Session, err = NewSession(client, state); err != nil {
+ t.Fatalf("failed to open XMPP session: %s", err)
+ }
+ clientDone <- struct{}{}
+ }()
+ waitForEntity(t, clientDone)
+ if IsStreamResumable(client) {
+ t.Fatalf("server does not support resumption but client says stream is resumable")
+ }
+ waitForEntity(t, serverDone)
+ mock.Stop()
+}
+
+func Test_StreamManagementResume(t *testing.T) {
+ serverDone := make(chan struct{})
+ clientDone := make(chan struct{})
+ // Setup Mock server
+ mock := ServerMock{}
+ mock.Start(t, testXMPPAddress, func(t *testing.T, sc *ServerConn) {
+ checkClientOpenStream(t, sc)
+
+ sendStreamFeatures(t, sc) // Send initial features
+ readAuth(t, sc.decoder)
+ sc.connection.Write([]byte(""))
+
+ checkClientOpenStream(t, sc) // Reset stream
+ sendFeaturesStreamManagment(t, sc) // Send post auth features
+ bind(t, sc)
+ enableStreamManagement(t, sc, false, true)
+ discardPresence(t, sc)
+ serverDone <- struct{}{}
+ })
+
+ // Test / Check result
+ config := Config{
+ TransportConfiguration: TransportConfiguration{
+ Address: testXMPPAddress,
+ },
+ Jid: "test@localhost",
+ Credential: Password("test"),
+ Insecure: true,
+ StreamManagementEnable: true,
+ streamManagementResume: true} // Enable stream management
+
+ var client *Client
+ router := NewRouter()
+ client, err := NewClient(&config, router, clientDefaultErrorHandler)
+ if err != nil {
+ t.Errorf("connect create XMPP client: %s", err)
+ }
+
+ // =================================================================
+ // Connect client, then disconnect it so we can resume the session
+ go func() {
+ err = client.Connect()
+ if err != nil {
+ t.Fatalf("could not connect client to mock server: %s", err)
+ }
+ clientDone <- struct{}{}
+ }()
+
+ waitForEntity(t, clientDone)
+
+ // ===========================================================================================
+ // Check that the client correctly went into "disconnected" state, after being disconnected
+ statusCorrectChan := make(chan struct{})
+ kill := make(chan struct{})
+
+ transp, ok := client.transport.(*XMPPTransport)
+ if !ok {
+ t.Fatalf("problem with client transport ")
+ }
+
+ transp.conn.Close()
+
+ waitForEntity(t, serverDone)
+ mock.Stop()
+
+ go checkClientResumeStatus(client, statusCorrectChan, kill)
+ select {
+ case <-statusCorrectChan:
+ // Test passed
+ case <-time.After(5 * time.Second):
+ kill <- struct{}{}
+ t.Fatalf("Client is not in disconnected state while it should be. Timed out")
+ }
+
+ // Check if the client can have its connection resumed using its state but also its configuration
+ if !IsStreamResumable(client) {
+ t.Fatalf("should support resumption")
+ }
+
+ // Reboot server. We need to make a new one because (at least for now) the mock server can only have one handler
+ // and they should be different between a first connection and a stream resume since exchanged messages
+ // are different (See XEP-0198)
+ mock2 := ServerMock{}
+ mock2.Start(t, testXMPPAddress, func(t *testing.T, sc *ServerConn) {
+ // Reconnect
+ checkClientOpenStream(t, sc)
+
+ sendStreamFeatures(t, sc) // Send initial features
+ readAuth(t, sc.decoder)
+ sc.connection.Write([]byte(""))
+
+ checkClientOpenStream(t, sc) // Reset stream
+ sendFeaturesStreamManagment(t, sc) // Send post auth features
+ resumeStream(t, sc)
+ serverDone <- struct{}{}
+ })
+
+ // Reconnect
+ go func() {
+ err = client.Resume()
+ if err != nil {
+ t.Fatalf("could not connect client to mock server: %s", err)
+ }
+ clientDone <- struct{}{}
+ }()
+
+ waitForEntity(t, clientDone)
+ waitForEntity(t, serverDone)
+
+ mock2.Stop()
+}
+
+func Test_StreamManagementFail(t *testing.T) {
+ serverDone := make(chan struct{})
+ clientDone := make(chan struct{})
+ // Setup Mock server
+ mock := ServerMock{}
+ mock.Start(t, testXMPPAddress, func(t *testing.T, sc *ServerConn) {
+ checkClientOpenStream(t, sc)
+
+ sendStreamFeatures(t, sc) // Send initial features
+ readAuth(t, sc.decoder)
+ sc.connection.Write([]byte(""))
+
+ checkClientOpenStream(t, sc) // Reset stream
+ sendFeaturesStreamManagment(t, sc) // Send post auth features
+ bind(t, sc)
+ enableStreamManagement(t, sc, true, true)
+ serverDone <- struct{}{}
+ })
+
+ // Test / Check result
+ config := Config{
+ TransportConfiguration: TransportConfiguration{
+ Address: testXMPPAddress,
+ },
+ Jid: "test@localhost",
+ Credential: Password("test"),
+ Insecure: true,
+ StreamManagementEnable: true,
+ streamManagementResume: true} // Enable stream management
+
+ var client *Client
+ router := NewRouter()
+ client, err := NewClient(&config, router, clientDefaultErrorHandler)
+ if err != nil {
+ t.Errorf("connect create XMPP client: %s", err)
+ }
+
+ var state SMState
+ go func() {
+ _, err = client.transport.Connect()
+ if err != nil {
+ return
+ }
+
+ // Client is ok, we now open XMPP session
+ if client.Session, err = NewSession(client, state); err == nil {
+ t.Fatalf("test is supposed to err")
+ }
+ if client.Session.SMState.StreamErrorGroup == nil {
+ t.Fatalf("error was not stored correctly in session state")
+ }
+ clientDone <- struct{}{}
+ }()
+
+ waitForEntity(t, serverDone)
+ waitForEntity(t, clientDone)
+
+ mock.Stop()
+}
+
+func Test_SendStanzaQueueWithSM(t *testing.T) {
+ serverDone := make(chan struct{})
+ clientDone := make(chan struct{})
+ // Setup Mock server
+ mock := ServerMock{}
+ mock.Start(t, testXMPPAddress, func(t *testing.T, sc *ServerConn) {
+ checkClientOpenStream(t, sc)
+
+ sendStreamFeatures(t, sc) // Send initial features
+ readAuth(t, sc.decoder)
+ sc.connection.Write([]byte(""))
+
+ checkClientOpenStream(t, sc) // Reset stream
+ sendFeaturesStreamManagment(t, sc) // Send post auth features
+ bind(t, sc)
+ enableStreamManagement(t, sc, false, true)
+
+ // Ignore the initial presence sent to the server by the client so we can move on to the next packet.
+ discardPresence(t, sc)
+
+ // Used here to silently discard the IQ sent by the client, in order to later trigger a resend
+ skipPacket(t, sc)
+ // Respond to the client ACK request with a number of processed stanzas of 0. This should trigger a resend
+ // of previously ignored stanza to the server, which this handler element will be expecting.
+ respondWithAck(t, sc, 0)
+ serverDone <- struct{}{}
+ })
+
+ // Test / Check result
+ config := Config{
+ TransportConfiguration: TransportConfiguration{
+ Address: testXMPPAddress,
+ },
+ Jid: "test@localhost",
+ Credential: Password("test"),
+ Insecure: true,
+ StreamManagementEnable: true,
+ streamManagementResume: true} // Enable stream management
+
+ var client *Client
+ router := NewRouter()
+ client, err := NewClient(&config, router, clientDefaultErrorHandler)
+ if err != nil {
+ t.Errorf("connect create XMPP client: %s", err)
+ }
+
+ go func() {
+ err = client.Connect()
+
+ client.SendRaw(`
+
+
+`)
+
+ // Last stanza was discarded silently by the server. Let's ask an ack for it. This should trigger resend as the server
+ // will respond with an acknowledged number of stanzas of 0.
+ r := stanza.SMRequest{}
+ client.Send(r)
+ clientDone <- struct{}{}
+ }()
+ waitForEntity(t, serverDone)
+ waitForEntity(t, clientDone)
+
+ mock.Stop()
+}
+
+//========================================================================
+// Helper functions for tests
+
+func skipPacket(t *testing.T, sc *ServerConn) {
+ var p stanza.IQ
+ se, err := stanza.NextStart(sc.decoder)
+
+ if err != nil {
+ t.Fatalf("cannot read packet: %s", err)
+ return
+ }
+ if err := sc.decoder.DecodeElement(&p, &se); err != nil {
+ t.Fatalf("cannot decode packet: %s", err)
+ return
+ }
+}
+
+func respondWithAck(t *testing.T, sc *ServerConn, h int) {
+
+ // Mock server reads the ack request
+ var p stanza.SMRequest
+ se, err := stanza.NextStart(sc.decoder)
+
+ if err != nil {
+ t.Fatalf("cannot read packet: %s", err)
+ return
+ }
+ if err := sc.decoder.DecodeElement(&p, &se); err != nil {
+ t.Fatalf("cannot decode packet: %s", err)
+ return
+ }
+
+ // Mock server sends the ack response
+ a := stanza.SMAnswer{
+ H: uint(h),
+ }
+ data, err := xml.Marshal(a)
+ _, err = sc.connection.Write(data)
+ if err != nil {
+ t.Fatalf("failed to send response ack")
+ }
+
+ // Mock server reads the re-sent stanza that was previously discarded intentionally
+ var p2 stanza.IQ
+ nse, err := stanza.NextStart(sc.decoder)
+
+ if err != nil {
+ t.Fatalf("cannot read packet: %s", err)
+ return
+ }
+ if err := sc.decoder.DecodeElement(&p2, &nse); err != nil {
+ t.Fatalf("cannot decode packet: %s", err)
+ return
+ }
+}
+
+func sendFeaturesStreamManagment(t *testing.T, sc *ServerConn) {
+ // This is a basic server, supporting only 2 features after auth: stream management & session binding
+ features := `
+
+
+`
+ if _, err := fmt.Fprintln(sc.connection, features); err != nil {
+ t.Fatalf("cannot send stream feature: %s", err)
+ }
+}
+
+func sendFeaturesNoStreamManagment(t *testing.T, sc *ServerConn) {
+ // This is a basic server, supporting only 2 features after auth: stream management & session binding
+ features := `
+
+`
+ if _, err := fmt.Fprintln(sc.connection, features); err != nil {
+ t.Fatalf("cannot send stream feature: %s", err)
+ }
+}
+
+// enableStreamManagement is a function for the mock server that can either mock a successful session, or fail depending on
+// the value of the "fail" boolean. True means the session should fail.
+func enableStreamManagement(t *testing.T, sc *ServerConn, fail bool, resume bool) {
+ // Decode element into pointer storage
+ var ed stanza.SMEnable
+ se, err := stanza.NextStart(sc.decoder)
+
+ if err != nil {
+ t.Fatalf("cannot read stream management enable: %s", err)
+ return
+ }
+ if err := sc.decoder.DecodeElement(&ed, &se); err != nil {
+ t.Fatalf("cannot decode stream management enable: %s", err)
+ return
+ }
+
+ if fail {
+ f := stanza.SMFailed{
+ H: nil,
+ StreamErrorGroup: &stanza.UnexpectedRequest{},
+ }
+ data, err := xml.Marshal(f)
+ if err != nil {
+ t.Fatalf("failed to marshall error response: %s", err)
+ }
+ sc.connection.Write(data)
+ } else {
+ e := &stanza.SMEnabled{
+ Resume: strconv.FormatBool(resume),
+ Id: streamManagementID,
+ }
+ data, err := xml.Marshal(e)
+ if err != nil {
+ t.Fatalf("failed to marshall error response: %s", err)
+ }
+ sc.connection.Write(data)
+ }
+}
+
+func resumeStream(t *testing.T, sc *ServerConn) {
+ h := uint(0)
+ response := stanza.SMResumed{
+ PrevId: streamManagementID,
+ H: &h,
+ }
+
+ data, err := xml.Marshal(response)
+ if err != nil {
+ t.Fatalf("failed to marshall stream management enabled response : %s", err)
+ }
+
+ writtenChan := make(chan struct{})
+
+ go func() {
+ sc.connection.Write(data)
+ writtenChan <- struct{}{}
+ }()
+ select {
+ case <-writtenChan:
+ // We're done here
+ return
+ case <-time.After(defaultTimeout):
+ t.Fatalf("failed to write enabled nonza to client")
+ }
+}
+
+func checkClientResumeStatus(client *Client, statusCorrectChan chan struct{}, killChan chan struct{}) {
+ for {
+ if client.CurrentState.getState() == StateDisconnected {
+ statusCorrectChan <- struct{}{}
+ }
+ select {
+ case <-killChan:
+ return
+ case <-time.After(time.Millisecond * 10):
+ // Keep checking status value
+ }
+ }
+}
+
+func initSrvCliForResumeTests(t *testing.T, serverHandler func(*testing.T, *ServerConn), port int, StreamManagementEnable, StreamManagementResume bool) (*Client, *ServerMock) {
+ mock := &ServerMock{}
+ testServerAddress := fmt.Sprintf("%s:%d", testClientDomain, port)
+
+ mock.Start(t, testServerAddress, serverHandler)
+ config := Config{
+ TransportConfiguration: TransportConfiguration{
+ Address: testServerAddress,
+ },
+ Jid: "test@localhost",
+ Credential: Password("test"),
+ Insecure: true,
+ StreamManagementEnable: StreamManagementEnable,
+ streamManagementResume: StreamManagementResume}
+
+ var client *Client
+ var err error
+ router := NewRouter()
+ if client, err = NewClient(&config, router, clientDefaultErrorHandler); err != nil {
+ t.Fatalf("connect create XMPP client: %s", err)
+ }
+
+ if _, err = client.transport.Connect(); err != nil {
+ t.Fatalf("XMPP connection failed: %s", err)
+ }
+
+ return client, mock
+}
+
+func waitForEntity(t *testing.T, entityDone chan struct{}) {
+ select {
+ case <-entityDone:
+ case <-time.After(defaultTimeout):
+ t.Fatalf("test timed out")
+ }
+}
diff --git a/client_test.go b/client_test.go
index 00c956a..4f30c0e 100644
--- a/client_test.go
+++ b/client_test.go
@@ -20,18 +20,20 @@ const (
func TestEventManager(t *testing.T) {
mgr := EventManager{}
- mgr.updateState(StateConnected)
- if mgr.CurrentState != StateConnected {
+ mgr.updateState(StateResuming)
+ if mgr.CurrentState.getState() != StateResuming {
t.Fatal("CurrentState not updated by updateState()")
}
mgr.disconnected(SMState{})
- if mgr.CurrentState != StateDisconnected {
+
+ if mgr.CurrentState.getState() != StateDisconnected {
t.Fatalf("CurrentState not reset by disconnected()")
}
mgr.streamError(ErrTLSNotSupported.Error(), "")
- if mgr.CurrentState != StateStreamError {
+
+ if mgr.CurrentState.getState() != StateStreamError {
t.Fatalf("CurrentState not set by streamError()")
}
}
@@ -53,7 +55,7 @@ func TestClient_Connect(t *testing.T) {
var client *Client
var err error
router := NewRouter()
- if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil {
+ if client, err = NewClient(&config, router, clientDefaultErrorHandler); err != nil {
t.Errorf("connect create XMPP client: %s", err)
}
@@ -84,7 +86,7 @@ func TestClient_NoInsecure(t *testing.T) {
var client *Client
var err error
router := NewRouter()
- if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil {
+ if client, err = NewClient(&config, router, clientDefaultErrorHandler); err != nil {
t.Errorf("cannot create XMPP client: %s", err)
}
@@ -117,7 +119,7 @@ func TestClient_FeaturesTracking(t *testing.T) {
var client *Client
var err error
router := NewRouter()
- if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil {
+ if client, err = NewClient(&config, router, clientDefaultErrorHandler); err != nil {
t.Errorf("cannot create XMPP client: %s", err)
}
@@ -147,7 +149,7 @@ func TestClient_RFC3921Session(t *testing.T) {
var client *Client
var err error
router := NewRouter()
- if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil {
+ if client, err = NewClient(&config, router, clientDefaultErrorHandler); err != nil {
t.Errorf("connect create XMPP client: %s", err)
}
@@ -366,7 +368,7 @@ func TestClient_DisconnectStreamManager(t *testing.T) {
var client *Client
var err error
router := NewRouter()
- if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil {
+ if client, err = NewClient(&config, router, clientDefaultErrorHandler); err != nil {
t.Errorf("cannot create XMPP client: %s", err)
}
@@ -386,6 +388,162 @@ func TestClient_DisconnectStreamManager(t *testing.T) {
mock.Stop()
}
+func Test_ClientPostConnectHook(t *testing.T) {
+ done := make(chan struct{})
+ // Handler for Mock server
+ h := func(t *testing.T, sc *ServerConn) {
+ handlerClientConnectSuccess(t, sc)
+ done <- struct{}{}
+ }
+
+ hookChan := make(chan struct{})
+ mock := &ServerMock{}
+ testServerAddress := fmt.Sprintf("%s:%d", testClientDomain, testClientPostConnectHook)
+
+ mock.Start(t, testServerAddress, h)
+ config := Config{
+ TransportConfiguration: TransportConfiguration{
+ Address: testServerAddress,
+ },
+ Jid: "test@localhost",
+ Credential: Password("test"),
+ Insecure: true}
+
+ var client *Client
+ var err error
+ router := NewRouter()
+ if client, err = NewClient(&config, router, clientDefaultErrorHandler); err != nil {
+ t.Errorf("connect create XMPP client: %s", err)
+ }
+
+ // The post connection client hook should just write to a channel that we will read later.
+ client.PostConnectHook = func() error {
+ go func() {
+ hookChan <- struct{}{}
+ }()
+ return nil
+ }
+ // Handle a possible error
+ errChan := make(chan error)
+ errorHandler := func(err error) {
+ errChan <- err
+ }
+ client.ErrorHandler = errorHandler
+ if err = client.Connect(); err != nil {
+ t.Errorf("XMPP connection failed: %s", err)
+ }
+
+ // Check if the post connection client hook was correctly called
+ select {
+ case err := <-errChan: // If the server sends an error, or there is a connection error
+ t.Fatal(err.Error())
+ case <-time.After(defaultChannelTimeout): // If we timeout
+ t.Fatal("Failed to call post connection client hook")
+ case <-hookChan:
+ // Test succeeded, channel was written to.
+ }
+
+ select {
+ case <-done:
+ mock.Stop()
+ case <-time.After(defaultChannelTimeout):
+ t.Fatal("The mock server failed to finish its job !")
+ }
+}
+
+func Test_ClientPostReconnectHook(t *testing.T) {
+ hookChan := make(chan struct{})
+ // Setup Mock server
+ mock := ServerMock{}
+ mock.Start(t, testXMPPAddress, func(t *testing.T, sc *ServerConn) {
+ checkClientOpenStream(t, sc)
+
+ sendStreamFeatures(t, sc) // Send initial features
+ readAuth(t, sc.decoder)
+ sc.connection.Write([]byte(""))
+
+ checkClientOpenStream(t, sc) // Reset stream
+ sendFeaturesStreamManagment(t, sc) // Send post auth features
+ bind(t, sc)
+ enableStreamManagement(t, sc, false, true)
+ })
+
+ // Test / Check result
+ config := Config{
+ TransportConfiguration: TransportConfiguration{
+ Address: testXMPPAddress,
+ },
+ Jid: "test@localhost",
+ Credential: Password("test"),
+ Insecure: true,
+ StreamManagementEnable: true,
+ streamManagementResume: true} // Enable stream management
+
+ var client *Client
+ router := NewRouter()
+ client, err := NewClient(&config, router, clientDefaultErrorHandler)
+ if err != nil {
+ t.Errorf("connect create XMPP client: %s", err)
+ }
+
+ client.PostResumeHook = func() error {
+ go func() {
+ hookChan <- struct{}{}
+ }()
+ return nil
+ }
+
+ err = client.Connect()
+ if err != nil {
+ t.Fatalf("could not connect client to mock server: %s", err)
+ }
+
+ transp, ok := client.transport.(*XMPPTransport)
+ if !ok {
+ t.Fatalf("problem with client transport ")
+ }
+
+ transp.conn.Close()
+ mock.Stop()
+
+ // Check if the client can have its connection resumed using its state but also its configuration
+ if !IsStreamResumable(client) {
+ t.Fatalf("should support resumption")
+ }
+
+ // Reboot server. We need to make a new one because (at least for now) the mock server can only have one handler
+ // and they should be different between a first connection and a stream resume since exchanged messages
+ // are different (See XEP-0198)
+ mock2 := ServerMock{}
+ mock2.Start(t, testXMPPAddress, func(t *testing.T, sc *ServerConn) {
+ // Reconnect
+ checkClientOpenStream(t, sc)
+
+ sendStreamFeatures(t, sc) // Send initial features
+ readAuth(t, sc.decoder)
+ sc.connection.Write([]byte(""))
+
+ checkClientOpenStream(t, sc) // Reset stream
+ sendFeaturesStreamManagment(t, sc) // Send post auth features
+ resumeStream(t, sc)
+ })
+
+ // Reconnect
+ err = client.Resume()
+ if err != nil {
+ t.Fatalf("could not connect client to mock server: %s", err)
+ }
+
+ select {
+ case <-time.After(defaultChannelTimeout): // If we timeout
+ t.Fatal("Failed to call post connection client hook")
+ case <-hookChan:
+ // Test succeeded, channel was written to.
+ }
+
+ mock2.Stop()
+}
+
//=============================================================================
// Basic XMPP Server Mock Handlers.
@@ -449,7 +607,7 @@ func checkClientOpenStream(t *testing.T, sc *ServerConn) {
var token xml.Token
token, err := sc.decoder.Token()
if err != nil {
- t.Errorf("cannot read next token: %s", err)
+ t.Fatalf("cannot read next token: %s", err)
}
switch elem := token.(type) {
@@ -464,6 +622,7 @@ func checkClientOpenStream(t *testing.T, sc *ServerConn) {
}
return
}
+
}
}
@@ -472,7 +631,6 @@ func mockClientConnection(t *testing.T, serverHandler func(*testing.T, *ServerCo
testServerAddress := fmt.Sprintf("%s:%d", testClientDomain, port)
mock.Start(t, testServerAddress, serverHandler)
-
config := Config{
TransportConfiguration: TransportConfiguration{
Address: testServerAddress,
@@ -484,7 +642,7 @@ func mockClientConnection(t *testing.T, serverHandler func(*testing.T, *ServerCo
var client *Client
var err error
router := NewRouter()
- if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil {
+ if client, err = NewClient(&config, router, clientDefaultErrorHandler); err != nil {
t.Errorf("connect create XMPP client: %s", err)
}
diff --git a/cmd/fluuxmpp/send.go b/cmd/fluuxmpp/send.go
index 7e7ed97..27c1a67 100644
--- a/cmd/fluuxmpp/send.go
+++ b/cmd/fluuxmpp/send.go
@@ -32,7 +32,7 @@ func sendxmpp(cmd *cobra.Command, args []string) {
msgText := args[1]
var err error
- client, err := xmpp.NewClient(xmpp.Config{
+ client, err := xmpp.NewClient(&xmpp.Config{
TransportConfiguration: xmpp.TransportConfiguration{
Address: viper.GetString("addr"),
},
diff --git a/cmd/go.sum b/cmd/go.sum
index c7e00fa..8398605 100644
--- a/cmd/go.sum
+++ b/cmd/go.sum
@@ -65,6 +65,7 @@ github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
github.com/google/pprof v0.0.0-20190908185732-236ed259b199/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
+github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY=
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ=
@@ -92,6 +93,7 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/magiconair/properties v1.8.0 h1:LLgXmsheXeRoUOBOjtwPQCWIYqM/LU1ayDtDePerRcY=
github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ=
+github.com/magiconair/properties v1.8.1 h1:ZC2Vc7/ZFkGmsVC9KvOjumD+G5lXy2RtTKyzRKO2BQ4=
github.com/magiconair/properties v1.8.1/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ=
github.com/mailru/easyjson v0.0.0-20190403194419-1ea4449da983/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
@@ -148,10 +150,12 @@ github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb6
github.com/spf13/pflag v1.0.1/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4=
github.com/spf13/pflag v1.0.3 h1:zPAT6CGy6wXeQ7NtTnaTerfKOsV6V6F8agHXFiazDkg=
github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4=
+github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s=
github.com/spf13/viper v1.4.0 h1:yXHLWeravcrgGyFSyCgdYpXQ9dR9c/WED3pg1RhxqEU=
github.com/spf13/viper v1.4.0/go.mod h1:PTJ7Z/lr49W6bUbkmS1V3by4uWynFiR9p7+dSq/yZzE=
+github.com/spf13/viper v1.6.1 h1:VPZzIkznI1YhVMRi6vNFLHSwhnhReBfgTxIPccpfdZk=
github.com/spf13/viper v1.6.1/go.mod h1:t3iDnF5Jlj76alVNuyFBk5oUMCvsrkbvZK0WQdfDi5k=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
@@ -159,6 +163,7 @@ github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
+github.com/subosito/gotenv v1.2.0 h1:Slr1R9HxAlEKefgq5jn9U+DnETlIUa6HfgEzj0g5d7s=
github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw=
github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
github.com/twitchyliquid64/golang-asm v0.0.0-20190126203739-365674df15fc/go.mod h1:NoCfSFWosfqMqmmD7hApkirIK9ozpHjxRnRxs1l413A=
@@ -207,6 +212,7 @@ golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20190618155005-516e3c20635f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190712062909-fae7ac547cb7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20190927073244-c990c680b611 h1:q9u40nxWT5zRClI/uU9dHCiYGottAg6Nzz4YUQyHxdA=
golang.org/x/sys v0.0.0-20190927073244-c990c680b611/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@@ -231,6 +237,7 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
gopkg.in/gemnasium/logrus-airbrake-hook.v2 v2.1.2/go.mod h1:Xk6kEKp8OKb+X14hQBKWaSkCsqBpgog8nAV2xsGOxlo=
+gopkg.in/ini.v1 v1.51.0 h1:AQvPpx3LzTDM0AjnIRlVFwFFGC+npRopjZxLJj6gdno=
gopkg.in/ini.v1 v1.51.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
@@ -238,9 +245,11 @@ gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bl
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
+gopkg.in/yaml.v2 v2.2.4 h1:/eiJrUcujPVeJ3xlSWaiNi3uSVmDGBK1pDHUHAnao1I=
gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gotest.tools v2.1.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw=
gotest.tools/gotestsum v0.3.5/go.mod h1:Mnf3e5FUzXbkCfynWBGOwLssY7gTQgCHObK9tMpAriY=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
mvdan.cc/sh v2.6.4+incompatible/go.mod h1:IeeQbZq+x2SUGBensq/jge5lLQbS3XT2ktyp3wrt4x8=
+nhooyr.io/websocket v1.6.5 h1:8TzpkldRfefda5JST+CnOH135bzVPz5uzfn/AF+gVKg=
nhooyr.io/websocket v1.6.5/go.mod h1:F259lAzPRAH0htX2y3ehpJe09ih1aSHN7udWki1defY=
diff --git a/component.go b/component.go
index a57b07b..ec0e8df 100644
--- a/component.go
+++ b/component.go
@@ -60,11 +60,10 @@ func NewComponent(opts ComponentOptions, r *Router, errorHandler func(error)) (*
// Connect triggers component connection to XMPP server component port.
// TODO: Failed handshake should be a permanent error
func (c *Component) Connect() error {
- var state SMState
- return c.Resume(state)
+ return c.Resume()
}
-func (c *Component) Resume(sm SMState) error {
+func (c *Component) Resume() error {
var err error
var streamId string
if c.ComponentOptions.TransportConfiguration.Domain == "" {
@@ -73,16 +72,13 @@ func (c *Component) Resume(sm SMState) error {
c.transport, err = NewComponentTransport(c.ComponentOptions.TransportConfiguration)
if err != nil {
c.updateState(StatePermanentError)
-
return NewConnError(err, true)
}
if streamId, err = c.transport.Connect(); err != nil {
c.updateState(StatePermanentError)
-
return NewConnError(err, true)
}
- c.updateState(StateConnected)
// Authentication
if err := c.sendWithWriter(c.transport, []byte(fmt.Sprintf("%s", c.handshake(streamId)))); err != nil {
diff --git a/component_test.go b/component_test.go
index f2b5a2f..59ac08e 100644
--- a/component_test.go
+++ b/component_test.go
@@ -38,6 +38,8 @@ func TestHandshake(t *testing.T) {
// Tests connection process with a handshake exchange
// Tests multiple session IDs. All serverConnections should generate a unique stream ID
func TestGenerateHandshakeId(t *testing.T) {
+ clientDone := make(chan struct{})
+ serverDone := make(chan struct{})
// Using this array with a channel to make a queue of values to test
// These are stream IDs that will be used to test the connection process, mixing them with the "secret" to generate
// some handshake value
@@ -57,11 +59,10 @@ func TestGenerateHandshakeId(t *testing.T) {
// Performs a Component connection with a handshake. It expects to have an ID sent its way through the "uchan"
// channel of this file. Otherwise it will hang for ever.
h := func(t *testing.T, sc *ServerConn) {
-
checkOpenStreamHandshakeID(t, sc, <-uchan)
readHandshakeComponent(t, sc.decoder)
sc.connection.Write([]byte("")) // That's all the server needs to return (see xep-0114)
- return
+ serverDone <- struct{}{}
}
// Init mock server
@@ -92,14 +93,45 @@ func TestGenerateHandshakeId(t *testing.T) {
}
// Try connecting, and storing the resulting streamID in a map.
- m := make(map[string]bool)
- for range uuidsArray {
- streamId, _ := c.transport.Connect()
- m[c.handshake(streamId)] = true
- }
- if len(uuidsArray) != len(m) {
- t.Errorf("Handshake does not produce a unique id. Expected: %d unique ids, got: %d", len(uuidsArray), len(m))
- }
+ go func() {
+ m := make(map[string]bool)
+ for range uuidsArray {
+ idChan := make(chan string)
+ go func() {
+ streamId, err := c.transport.Connect()
+ if err != nil {
+ t.Fatalf("failed to mock component connection to get a handshake: %s", err)
+ }
+ idChan <- streamId
+ }()
+
+ var streamId string
+ select {
+ case streamId = <-idChan:
+ case <-time.After(defaultTimeout):
+ t.Fatalf("test timed out")
+ }
+
+ hs := stanza.Handshake{
+ Value: c.handshake(streamId),
+ }
+ m[hs.Value] = true
+ hsRaw, err := xml.Marshal(hs)
+ if err != nil {
+ t.Fatalf("could not marshal handshake: %s", err)
+ }
+ c.SendRaw(string(hsRaw))
+ waitForEntity(t, serverDone)
+ c.transport.Close()
+ }
+ if len(uuidsArray) != len(m) {
+ t.Errorf("Handshake does not produce a unique id. Expected: %d unique ids, got: %d", len(uuidsArray), len(m))
+ }
+ clientDone <- struct{}{}
+ }()
+
+ waitForEntity(t, clientDone)
+ mock.Stop()
}
// Test that NewStreamManager can accept a Component.
@@ -121,10 +153,11 @@ func TestDecoder(t *testing.T) {
// Tests sending an IQ to the server, and getting the response
func TestSendIq(t *testing.T) {
- done := make(chan struct{})
+ serverDone := make(chan struct{})
+ clientDone := make(chan struct{})
h := func(t *testing.T, sc *ServerConn) {
handlerForComponentIQSend(t, sc)
- done <- struct{}{}
+ serverDone <- struct{}{}
}
//Connecting to a mock server, initialized with given port and handler function
@@ -145,24 +178,23 @@ func TestSendIq(t *testing.T) {
}
c.ErrorHandler = errorHandler
- var res chan stanza.IQ
- res, _ = c.SendIQ(ctx, iqReq)
+ go func() {
+ var res chan stanza.IQ
+ res, _ = c.SendIQ(ctx, iqReq)
- select {
- case <-res:
- case err := <-errChan:
- t.Errorf(err.Error())
- case <-time.After(defaultChannelTimeout):
- t.Errorf("Failed to receive response, to sent IQ, from mock server")
- }
+ select {
+ case <-res:
+ case err := <-errChan:
+ t.Fatalf(err.Error())
+ }
+ clientDone <- struct{}{}
+ }()
+
+ waitForEntity(t, clientDone)
+ waitForEntity(t, serverDone)
- select {
- case <-done:
- m.Stop()
- case <-time.After(defaultChannelTimeout):
- t.Errorf("The mock server failed to finish its job !")
- }
cancel()
+ m.Stop()
}
// Checking that error handling is done properly client side when an invalid IQ is sent and the server responds in kind.
diff --git a/config.go b/config.go
index 178da2e..4609a0a 100644
--- a/config.go
+++ b/config.go
@@ -21,4 +21,15 @@ type Config struct {
// Insecure can be set to true to allow to open a session without TLS. If TLS
// is supported on the server, we will still try to use it.
Insecure bool
+
+ // Activate stream management process during session
+ StreamManagementEnable bool
+ // Enable stream management resume capability
+ streamManagementResume bool
+}
+
+// IsStreamResumable tells if a stream session is resumable by reading the "config" part of a client.
+// It checks if stream management is enabled, and if stream resumption was set and accepted by the server.
+func IsStreamResumable(c *Client) bool {
+ return c.config.StreamManagementEnable && c.config.streamManagementResume
}
diff --git a/router.go b/router.go
index f20af5b..7bba8b9 100644
--- a/router.go
+++ b/router.go
@@ -42,6 +42,17 @@ func NewRouter() *Router {
// route is called by the XMPP client to dispatch stanza received using the set up routes.
// It is also used by test, but is not supposed to be used directly by users of the library.
func (r *Router) route(s Sender, p stanza.Packet) {
+ a, isA := p.(stanza.SMAnswer)
+ if isA {
+ switch tt := s.(type) {
+ case *Client:
+ lastAcked := a.H
+ SendMissingStz(int(lastAcked), s, tt.Session.SMState.UnAckQueue)
+ case *Component:
+ // TODO
+ default:
+ }
+ }
iq, isIq := p.(*stanza.IQ)
if isIq {
r.IQResultRouteLock.RLock()
@@ -70,6 +81,33 @@ func (r *Router) route(s Sender, p stanza.Packet) {
}
}
+// SendMissingStz sends all stanzas that did not reach the server, according to the response to an ack request (see XEP-0198, acks)
+func SendMissingStz(lastSent int, s Sender, uaq *stanza.UnAckQueue) error {
+ uaq.RWMutex.Lock()
+ if len(uaq.Uslice) <= 0 {
+ uaq.RWMutex.Unlock()
+ return nil
+ }
+ last := uaq.Uslice[len(uaq.Uslice)-1]
+ if last.Id > lastSent {
+ // Remove sent stanzas from the queue
+ uaq.PopN(lastSent - last.Id)
+ // Re-send non acknowledged stanzas
+ for _, elt := range uaq.PopN(len(uaq.Uslice)) {
+ eltStz := elt.(*stanza.UnAckedStz)
+ err := s.SendRaw(eltStz.Stz)
+ if err != nil {
+ return err
+ }
+
+ }
+ // Ask for updates on stanzas we just sent to the entity. Not sure I should leave this. Maybe let users call ack again by themselves ?
+ s.Send(stanza.SMRequest{})
+ }
+ uaq.RWMutex.Unlock()
+ return nil
+}
+
func iqNotImplemented(s Sender, iq *stanza.IQ) {
err := stanza.Err{
XMLName: xml.Name{Local: "error"},
diff --git a/session.go b/session.go
index 182e32b..05bdce3 100644
--- a/session.go
+++ b/session.go
@@ -1,10 +1,11 @@
package xmpp
import (
+ "encoding/xml"
"errors"
"fmt"
-
"gosrc.io/xmpp/stanza"
+ "strconv"
)
type Session struct {
@@ -23,44 +24,67 @@ type Session struct {
err error
}
-func NewSession(transport Transport, o Config, state SMState) (*Session, error) {
- s := new(Session)
- s.transport = transport
- s.SMState = state
- s.init(o)
+func NewSession(c *Client, state SMState) (*Session, error) {
+ var s *Session
+ if c.Session == nil {
+ s = new(Session)
+ s.transport = c.transport
+ s.SMState = state
+ s.init()
+ } else {
+ s = c.Session
+ // We keep information about the previously set session, like the session ID, but we read server provided
+ // info again in case it changed between session break and resume, such as features.
+ s.init()
+ }
if s.err != nil {
return nil, NewConnError(s.err, true)
}
- if !transport.IsSecure() {
- s.startTlsIfSupported(o)
+ if !c.transport.IsSecure() {
+ s.startTlsIfSupported(c.config)
}
- if !transport.IsSecure() && !o.Insecure {
+ if !c.transport.IsSecure() && !c.config.Insecure {
err := fmt.Errorf("failed to negotiate TLS session : %s", s.err)
return nil, NewConnError(err, true)
}
if s.TlsEnabled {
- s.reset(o)
+ s.reset()
}
// auth
- s.auth(o)
- s.reset(o)
+ s.auth(c.config)
+ if s.err != nil {
+ return s, s.err
+ }
+ s.reset()
+ if s.err != nil {
+ return s, s.err
+ }
// attempt resumption
- if s.resume(o) {
+ if s.resume(c.config) {
return s, s.err
}
// otherwise, bind resource and 'start' XMPP session
- s.bind(o)
- s.rfc3921Session(o)
+ s.bind(c.config)
+ if s.err != nil {
+ return s, s.err
+ }
+ s.rfc3921Session()
+ if s.err != nil {
+ return s, s.err
+ }
// Enable stream management if supported
- s.EnableStreamManagement(o)
+ s.EnableStreamManagement(c.config)
+ if s.err != nil {
+ return s, s.err
+ }
return s, s.err
}
@@ -70,19 +94,20 @@ func (s *Session) PacketId() string {
return fmt.Sprintf("%x", s.lastPacketId)
}
-func (s *Session) init(o Config) {
- s.Features = s.open(o.parsedJid.Domain)
+// init gathers information on the session such as stream features from the server.
+func (s *Session) init() {
+ s.Features = s.extractStreamFeatures()
}
-func (s *Session) reset(o Config) {
+func (s *Session) reset() {
if s.StreamId, s.err = s.transport.StartStream(); s.err != nil {
return
}
- s.Features = s.open(o.parsedJid.Domain)
+ s.Features = s.extractStreamFeatures()
}
-func (s *Session) open(domain string) (f stanza.StreamFeatures) {
+func (s *Session) extractStreamFeatures() (f stanza.StreamFeatures) {
// extract stream features
if s.err = s.transport.GetDecoder().Decode(&f); s.err != nil {
s.err = errors.New("stream open decode features: " + s.err.Error())
@@ -90,7 +115,7 @@ func (s *Session) open(domain string) (f stanza.StreamFeatures) {
return
}
-func (s *Session) startTlsIfSupported(o Config) {
+func (s *Session) startTlsIfSupported(o *Config) {
if s.err != nil {
return
}
@@ -125,7 +150,7 @@ func (s *Session) startTlsIfSupported(o Config) {
}
}
-func (s *Session) auth(o Config) {
+func (s *Session) auth(o *Config) {
if s.err != nil {
return
}
@@ -134,7 +159,7 @@ func (s *Session) auth(o Config) {
}
// Attempt to resume session using stream management
-func (s *Session) resume(o Config) bool {
+func (s *Session) resume(o *Config) bool {
if !s.Features.DoesStreamManagement() {
return false
}
@@ -142,9 +167,16 @@ func (s *Session) resume(o Config) bool {
return false
}
- fmt.Fprintf(s.transport, "",
- stanza.NSStreamManagement, s.SMState.Inbound, s.SMState.Id)
+ rsm := stanza.SMResume{
+ PrevId: s.SMState.Id,
+ H: &s.SMState.Inbound,
+ }
+ data, err := xml.Marshal(rsm)
+ _, err = s.transport.Write(data)
+ if err != nil {
+ return false
+ }
var packet stanza.Packet
packet, s.err = stanza.NextPacket(s.transport.GetDecoder())
if s.err == nil {
@@ -165,20 +197,48 @@ func (s *Session) resume(o Config) bool {
return false
}
-func (s *Session) bind(o Config) {
+func (s *Session) bind(o *Config) {
if s.err != nil {
return
}
// Send IQ message asking to bind to the local user name.
var resource = o.parsedJid.Resource
- if resource != "" {
- fmt.Fprintf(s.transport, "%s",
- s.PacketId(), stanza.NSBind, resource)
- } else {
- fmt.Fprintf(s.transport, "", s.PacketId(), stanza.NSBind)
+ iqB, err := stanza.NewIQ(stanza.Attrs{
+ Type: stanza.IQTypeSet,
+ Id: s.PacketId(),
+ })
+ if err != nil {
+ s.err = err
+ return
}
+ // Check if we already have a resource name, and include it in the request if so
+ if resource != "" {
+ iqB.Payload = &stanza.Bind{
+ Resource: resource,
+ }
+ } else {
+ iqB.Payload = &stanza.Bind{}
+
+ }
+
+ // Send the bind request IQ
+ data, err := xml.Marshal(iqB)
+ if err != nil {
+ s.err = err
+ return
+ }
+ n, err := s.transport.Write(data)
+ if err != nil {
+ s.err = err
+ return
+ } else if n == 0 {
+ s.err = errors.New("failed to write bind iq stanza to the server : wrote 0 bytes")
+ return
+ }
+
+ // Check the server response
var iq stanza.IQ
if s.err = s.transport.GetDecoder().Decode(&iq); s.err != nil {
s.err = errors.New("error decoding iq bind result: " + s.err.Error())
@@ -197,7 +257,7 @@ func (s *Session) bind(o Config) {
}
// After the bind, if the session is not optional (as per old RFC 3921), we send the session open iq.
-func (s *Session) rfc3921Session(o Config) {
+func (s *Session) rfc3921Session() {
if s.err != nil {
return
}
@@ -205,7 +265,29 @@ func (s *Session) rfc3921Session(o Config) {
var iq stanza.IQ
// We only negotiate session binding if it is mandatory, we skip it when optional.
if !s.Features.Session.IsOptional() {
- fmt.Fprintf(s.transport, "", s.PacketId(), stanza.NSSession)
+ se, err := stanza.NewIQ(stanza.Attrs{
+ Type: stanza.IQTypeSet,
+ Id: s.PacketId(),
+ })
+ if err != nil {
+ s.err = err
+ return
+ }
+ se.Payload = &stanza.StreamSession{}
+ data, err := xml.Marshal(se)
+ if err != nil {
+ s.err = err
+ return
+ }
+ n, err := s.transport.Write(data)
+ if err != nil {
+ s.err = err
+ return
+ } else if n == 0 {
+ s.err = errors.New("there was a problem marshaling the session IQ : wrote 0 bytes to server")
+ return
+ }
+
if s.err = s.transport.GetDecoder().Decode(&iq); s.err != nil {
s.err = errors.New("expecting iq result after session open: " + s.err.Error())
return
@@ -214,28 +296,47 @@ func (s *Session) rfc3921Session(o Config) {
}
// Enable stream management, with session resumption, if supported.
-func (s *Session) EnableStreamManagement(o Config) {
+func (s *Session) EnableStreamManagement(o *Config) {
if s.err != nil {
return
}
- if !s.Features.DoesStreamManagement() {
+ if !s.Features.DoesStreamManagement() || !o.StreamManagementEnable {
+ return
+ }
+ q := stanza.NewUnAckQueue()
+ ebleNonza := stanza.SMEnable{Resume: &o.streamManagementResume}
+ pktStr, err := xml.Marshal(ebleNonza)
+ if err != nil {
+ s.err = err
+ return
+ }
+ _, err = s.transport.Write(pktStr)
+ if err != nil {
+ s.err = err
return
}
-
- fmt.Fprintf(s.transport, "", stanza.NSStreamManagement)
var packet stanza.Packet
packet, s.err = stanza.NextPacket(s.transport.GetDecoder())
if s.err == nil {
switch p := packet.(type) {
case stanza.SMEnabled:
- s.SMState = SMState{Id: p.Id}
+ // Server allows resumption or not using SMEnabled attribute "resume". We must read the server response
+ // and update config accordingly
+ b, err := strconv.ParseBool(p.Resume)
+ if err != nil || !b {
+ o.StreamManagementEnable = false
+ }
+ s.SMState = SMState{Id: p.Id, preferredReconAddr: p.Location}
+ s.SMState.UnAckQueue = q
case stanza.SMFailed:
// TODO: Store error in SMState, for later inspection
+ s.SMState = SMState{StreamErrorGroup: p.StreamErrorGroup}
+ s.SMState.UnAckQueue = q
+ s.err = errors.New("failed to establish session : " + s.SMState.StreamErrorGroup.GroupErrorName())
default:
s.err = errors.New("unexpected reply to SM enable")
}
}
-
return
}
diff --git a/stanza/fifo_queue.go b/stanza/fifo_queue.go
new file mode 100644
index 0000000..ca28810
--- /dev/null
+++ b/stanza/fifo_queue.go
@@ -0,0 +1,34 @@
+package stanza
+
+// FIFO queue for string contents
+// Implementations have no guarantee regarding thread safety !
+type FifoQueue interface {
+ // Pop returns the first inserted element still in queue and deletes it from queue. If queue is empty, returns nil
+ // No guarantee regarding thread safety !
+ Pop() Queueable
+
+ // PopN returns the N first inserted elements still in queue and deletes them from queue. If queue is empty or i<=0, returns nil
+ // If number to pop is greater than queue length, returns all queue elements
+ // No guarantee regarding thread safety !
+ PopN(i int) []Queueable
+
+ // Peek returns a copy of the first inserted element in queue without deleting it. If queue is empty, returns nil
+ // No guarantee regarding thread safety !
+ Peek() Queueable
+
+ // Peek returns a copy of the first inserted element in queue without deleting it. If queue is empty or i<=0, returns nil.
+ // If number to peek is greater than queue length, returns all queue elements
+ // No guarantee regarding thread safety !
+ PeekN() []Queueable
+ // Push adds an element to the queue
+ // No guarantee regarding thread safety !
+ Push(s Queueable) error
+
+ // Empty returns true if queue is empty
+ // No guarantee regarding thread safety !
+ Empty() bool
+}
+
+type Queueable interface {
+ QueueableName() string
+}
diff --git a/stanza/sasl_auth.go b/stanza/sasl_auth.go
index 9dfe557..2fb660e 100644
--- a/stanza/sasl_auth.go
+++ b/stanza/sasl_auth.go
@@ -93,8 +93,8 @@ func (b *Bind) GetSet() *ResultSet {
// This is the draft defining how to handle the transition:
// https://tools.ietf.org/html/draft-cridland-xmpp-session-01
type StreamSession struct {
- XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-session session"`
- Optional bool // If element does exist, it mean we are not required to open session
+ XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-session session"`
+ Optional *struct{} // If element does exist, it mean we are not required to open session
// Result sets
ResultSet *ResultSet `xml:"set,omitempty"`
}
@@ -109,7 +109,7 @@ func (s *StreamSession) GetSet() *ResultSet {
func (s *StreamSession) IsOptional() bool {
if s.XMLName.Local == "session" {
- return s.Optional
+ return s.Optional != nil
}
// If session element is missing, then we should not use session
return true
diff --git a/stanza/sasl_auth_test.go b/stanza/sasl_auth_test.go
index 3e37453..600035c 100644
--- a/stanza/sasl_auth_test.go
+++ b/stanza/sasl_auth_test.go
@@ -9,7 +9,7 @@ import (
// Check that we can detect optional session from advertised stream features
func TestSessionFeatures(t *testing.T) {
- streamFeatures := stanza.StreamFeatures{Session: stanza.StreamSession{Optional: true}}
+ streamFeatures := stanza.StreamFeatures{Session: stanza.StreamSession{Optional: &struct{}{}}}
data, err := xml.Marshal(streamFeatures)
if err != nil {
@@ -32,7 +32,7 @@ func TestSessionIQ(t *testing.T) {
if err != nil {
t.Fatalf("failed to create IQ: %v", err)
}
- iq.Payload = &stanza.StreamSession{XMLName: xml.Name{Local: "session"}, Optional: true}
+ iq.Payload = &stanza.StreamSession{XMLName: xml.Name{Local: "session"}, Optional: &struct{}{}}
data, err := xml.Marshal(iq)
if err != nil {
diff --git a/stanza/stanza_errors.go b/stanza/stanza_errors.go
new file mode 100644
index 0000000..c66ea33
--- /dev/null
+++ b/stanza/stanza_errors.go
@@ -0,0 +1,171 @@
+package stanza
+
+import (
+ "encoding/xml"
+)
+
+type StanzaErrorGroup interface {
+ GroupErrorName() string
+}
+
+type BadFormat struct {
+ XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas bad-format"`
+}
+
+func (e *BadFormat) GroupErrorName() string { return "bad-format" }
+
+type BadNamespacePrefix struct {
+ XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas bad-namespace-prefix"`
+}
+
+func (e *BadNamespacePrefix) GroupErrorName() string { return "bad-namespace-prefix" }
+
+type Conflict struct {
+ XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas conflict"`
+}
+
+func (e *Conflict) GroupErrorName() string { return "conflict" }
+
+type ConnectionTimeout struct {
+ XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas connection-timeout"`
+}
+
+func (e *ConnectionTimeout) GroupErrorName() string { return "connection-timeout" }
+
+type HostGone struct {
+ XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas host-gone"`
+}
+
+func (e *HostGone) GroupErrorName() string { return "host-gone" }
+
+type HostUnknown struct {
+ XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas host-unknown"`
+}
+
+func (e *HostUnknown) GroupErrorName() string { return "host-unknown" }
+
+type ImproperAddressing struct {
+ XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas improper-addressing"`
+}
+
+func (e *ImproperAddressing) GroupErrorName() string { return "improper-addressing" }
+
+type InternalServerError struct {
+ XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas internal-server-error"`
+}
+
+func (e *InternalServerError) GroupErrorName() string { return "internal-server-error" }
+
+type InvalidForm struct {
+ XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas invalid-from"`
+}
+
+func (e *InvalidForm) GroupErrorName() string { return "invalid-from" }
+
+type InvalidId struct {
+ XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas invalid-id"`
+}
+
+func (e *InvalidId) GroupErrorName() string { return "invalid-id" }
+
+type InvalidNamespace struct {
+ XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas invalid-namespace"`
+}
+
+func (e *InvalidNamespace) GroupErrorName() string { return "invalid-namespace" }
+
+type InvalidXML struct {
+ XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas invalid-xml"`
+}
+
+func (e *InvalidXML) GroupErrorName() string { return "invalid-xml" }
+
+type NotAuthorized struct {
+ XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas not-authorized"`
+}
+
+func (e *NotAuthorized) GroupErrorName() string { return "not-authorized" }
+
+type NotWellFormed struct {
+ XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas not-well-formed"`
+}
+
+func (e *NotWellFormed) GroupErrorName() string { return "not-well-formed" }
+
+type PolicyViolation struct {
+ XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas policy-violation"`
+}
+
+func (e *PolicyViolation) GroupErrorName() string { return "policy-violation" }
+
+type RemoteConnectionFailed struct {
+ XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas remote-connection-failed"`
+}
+
+func (e *RemoteConnectionFailed) GroupErrorName() string { return "remote-connection-failed" }
+
+type Reset struct {
+ XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas reset"`
+}
+
+func (e *Reset) GroupErrorName() string { return "reset" }
+
+type ResourceConstraint struct {
+ XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas resource-constraint"`
+}
+
+func (e *ResourceConstraint) GroupErrorName() string { return "resource-constraint" }
+
+type RestrictedXML struct {
+ XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas restricted-xml"`
+}
+
+func (e *RestrictedXML) GroupErrorName() string { return "restricted-xml" }
+
+type SeeOtherHost struct {
+ XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas see-other-host"`
+}
+
+func (e *SeeOtherHost) GroupErrorName() string { return "see-other-host" }
+
+type SystemShutdown struct {
+ XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas system-shutdown"`
+}
+
+func (e *SystemShutdown) GroupErrorName() string { return "system-shutdown" }
+
+type UndefinedCondition struct {
+ XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas undefined-condition"`
+}
+
+func (e *UndefinedCondition) GroupErrorName() string { return "undefined-condition" }
+
+type UnsupportedEncoding struct {
+ XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas unsupported-encoding"`
+}
+
+type UnexpectedRequest struct {
+ XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas unexpected-request"`
+}
+
+func (e *UnexpectedRequest) GroupErrorName() string { return "unexpected-request" }
+
+func (e *UnsupportedEncoding) GroupErrorName() string { return "unsupported-encoding" }
+
+type UnsupportedStanzaType struct {
+ XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas unsupported-stanza-type"`
+}
+
+func (e *UnsupportedStanzaType) GroupErrorName() string { return "unsupported-stanza-type" }
+
+type UnsupportedVersion struct {
+ XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas unsupported-version"`
+}
+
+func (e *UnsupportedVersion) GroupErrorName() string { return "unsupported-version" }
+
+type XMLNotWellFormed struct {
+ XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas xml-not-well-formed"`
+}
+
+func (e *XMLNotWellFormed) GroupErrorName() string { return "xml-not-well-formed" }
diff --git a/stanza/stream_features.go b/stanza/stream_features.go
index d5bed5c..d1b6274 100644
--- a/stanza/stream_features.go
+++ b/stanza/stream_features.go
@@ -118,6 +118,10 @@ type streamManagement struct {
XMLName xml.Name `xml:"urn:xmpp:sm:3 sm"`
}
+func (streamManagement) Name() string {
+ return "streamManagement"
+}
+
func (sf *StreamFeatures) DoesStreamManagement() (isSupported bool) {
if sf.StreamManagement.XMLName.Space+" "+sf.StreamManagement.XMLName.Local == "urn:xmpp:sm:3 sm" {
return true
diff --git a/stanza/stream_management.go b/stanza/stream_management.go
index ddbe9cd..a2a4f0b 100644
--- a/stanza/stream_management.go
+++ b/stanza/stream_management.go
@@ -3,12 +3,19 @@ package stanza
import (
"encoding/xml"
"errors"
+ "sync"
)
const (
NSStreamManagement = "urn:xmpp:sm:3"
)
+type SMEnable struct {
+ XMLName xml.Name `xml:"urn:xmpp:sm:3 enable"`
+ Max *uint `xml:"max,attr,omitempty"`
+ Resume *bool `xml:"resume,attr,omitempty"`
+}
+
// Enabled as defined in Stream Management spec
// Reference: https://xmpp.org/extensions/xep-0198.html#enable
type SMEnabled struct {
@@ -23,6 +30,112 @@ func (SMEnabled) Name() string {
return "Stream Management: enabled"
}
+type UnAckQueue struct {
+ Uslice []*UnAckedStz
+ sync.RWMutex
+}
+type UnAckedStz struct {
+ Id int
+ Stz string
+}
+
+func NewUnAckQueue() *UnAckQueue {
+ return &UnAckQueue{
+ Uslice: make([]*UnAckedStz, 0, 10), // Capacity is 0 to comply with "Push" implementation (so that no reachable element is nil)
+ RWMutex: sync.RWMutex{},
+ }
+}
+
+func (u *UnAckedStz) QueueableName() string {
+ return "Un-acknowledged stanza"
+}
+
+func (uaq *UnAckQueue) PeekN(n int) []Queueable {
+ if uaq == nil {
+ return nil
+ }
+ if n <= 0 {
+ return nil
+ }
+ if len(uaq.Uslice) < n {
+ n = len(uaq.Uslice)
+ }
+
+ if len(uaq.Uslice) == 0 {
+ return nil
+ }
+ var r []Queueable
+ for i := 0; i < n; i++ {
+ r = append(r, uaq.Uslice[i])
+ }
+ return r
+}
+
+// No guarantee regarding thread safety !
+func (uaq *UnAckQueue) Pop() Queueable {
+ if uaq == nil {
+ return nil
+ }
+ r := uaq.Peek()
+ if r != nil {
+ uaq.Uslice = uaq.Uslice[1:]
+ }
+ return r
+}
+
+// No guarantee regarding thread safety !
+func (uaq *UnAckQueue) PopN(n int) []Queueable {
+ if uaq == nil {
+ return nil
+ }
+ r := uaq.PeekN(n)
+ uaq.Uslice = uaq.Uslice[len(r):]
+ return r
+}
+
+func (uaq *UnAckQueue) Peek() Queueable {
+ if uaq == nil {
+ return nil
+ }
+ if len(uaq.Uslice) == 0 {
+ return nil
+ }
+ r := uaq.Uslice[0]
+ return r
+}
+
+func (uaq *UnAckQueue) Push(s Queueable) error {
+ if uaq == nil {
+ return nil
+ }
+ pushIdx := 1
+ if len(uaq.Uslice) != 0 {
+ pushIdx = uaq.Uslice[len(uaq.Uslice)-1].Id + 1
+ }
+
+ sStz, ok := s.(*UnAckedStz)
+ if !ok {
+ return errors.New("element in not compatible with this queue. expected an UnAckedStz")
+ }
+
+ e := UnAckedStz{
+ Id: pushIdx,
+ Stz: sStz.Stz,
+ }
+
+ uaq.Uslice = append(uaq.Uslice, &e)
+
+ return nil
+}
+
+func (uaq *UnAckQueue) Empty() bool {
+ if uaq == nil {
+ return true
+ }
+ r := len(uaq.Uslice)
+ return r == 0
+}
+
// Request as defined in Stream Management spec
// Reference: https://xmpp.org/extensions/xep-0198.html#acking
type SMRequest struct {
@@ -37,7 +150,7 @@ func (SMRequest) Name() string {
// Reference: https://xmpp.org/extensions/xep-0198.html#acking
type SMAnswer struct {
XMLName xml.Name `xml:"urn:xmpp:sm:3 a"`
- H uint `xml:"h,attr,omitempty"`
+ H uint `xml:"h,attr"`
}
func (SMAnswer) Name() string {
@@ -49,24 +162,175 @@ func (SMAnswer) Name() string {
type SMResumed struct {
XMLName xml.Name `xml:"urn:xmpp:sm:3 resumed"`
PrevId string `xml:"previd,attr,omitempty"`
- H uint `xml:"h,attr,omitempty"`
+ H *uint `xml:"h,attr,omitempty"`
}
func (SMResumed) Name() string {
return "Stream Management: resumed"
}
+// Resume as defined in Stream Management spec
+// Reference: https://xmpp.org/extensions/xep-0198.html#acking
+type SMResume struct {
+ XMLName xml.Name `xml:"urn:xmpp:sm:3 resume"`
+ PrevId string `xml:"previd,attr,omitempty"`
+ H *uint `xml:"h,attr,omitempty"`
+}
+
+func (SMResume) Name() string {
+ return "Stream Management: resume"
+}
+
// Failed as defined in Stream Management spec
// Reference: https://xmpp.org/extensions/xep-0198.html#acking
type SMFailed struct {
XMLName xml.Name `xml:"urn:xmpp:sm:3 failed"`
- // TODO: Handle decoding error cause (need custom parsing).
+ H *uint `xml:"h,attr,omitempty"`
+
+ StreamErrorGroup StanzaErrorGroup
}
func (SMFailed) Name() string {
return "Stream Management: failed"
}
+func (smf *SMFailed) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
+ smf.XMLName = start.Name
+
+ // According to https://xmpp.org/rfcs/rfc3920.html#def we should have no attributes aside from the namespace
+ // which we don't use internally
+
+ // decode inner elements
+ for {
+ t, err := d.Token()
+ if err != nil {
+ return err
+ }
+
+ switch tt := t.(type) {
+
+ case xml.StartElement:
+ // Decode sub-elements
+ var err error
+ switch tt.Name.Local {
+ case "bad-format":
+ bf := BadFormat{}
+ err = d.DecodeElement(&bf, &tt)
+ smf.StreamErrorGroup = &bf
+ case "bad-namespace-prefix":
+ bnp := BadNamespacePrefix{}
+ err = d.DecodeElement(&bnp, &tt)
+ smf.StreamErrorGroup = &bnp
+ case "conflict":
+ c := Conflict{}
+ err = d.DecodeElement(&c, &tt)
+ smf.StreamErrorGroup = &c
+ case "connection-timeout":
+ ct := ConnectionTimeout{}
+ err = d.DecodeElement(&ct, &tt)
+ smf.StreamErrorGroup = &ct
+ case "host-gone":
+ hg := HostGone{}
+ err = d.DecodeElement(&hg, &tt)
+ smf.StreamErrorGroup = &hg
+ case "host-unknown":
+ hu := HostUnknown{}
+ err = d.DecodeElement(&hu, &tt)
+ smf.StreamErrorGroup = &hu
+ case "improper-addressing":
+ ia := ImproperAddressing{}
+ err = d.DecodeElement(&ia, &tt)
+ smf.StreamErrorGroup = &ia
+ case "internal-server-error":
+ ise := InternalServerError{}
+ err = d.DecodeElement(&ise, &tt)
+ smf.StreamErrorGroup = &ise
+ case "invalid-from":
+ ifrm := InvalidForm{}
+ err = d.DecodeElement(&ifrm, &tt)
+ smf.StreamErrorGroup = &ifrm
+ case "invalid-id":
+ id := InvalidId{}
+ err = d.DecodeElement(&id, &tt)
+ smf.StreamErrorGroup = &id
+ case "invalid-namespace":
+ ins := InvalidNamespace{}
+ err = d.DecodeElement(&ins, &tt)
+ smf.StreamErrorGroup = &ins
+ case "invalid-xml":
+ ix := InvalidXML{}
+ err = d.DecodeElement(&ix, &tt)
+ smf.StreamErrorGroup = &ix
+ case "not-authorized":
+ na := NotAuthorized{}
+ err = d.DecodeElement(&na, &tt)
+ smf.StreamErrorGroup = &na
+ case "not-well-formed":
+ nwf := NotWellFormed{}
+ err = d.DecodeElement(&nwf, &tt)
+ smf.StreamErrorGroup = &nwf
+ case "policy-violation":
+ pv := PolicyViolation{}
+ err = d.DecodeElement(&pv, &tt)
+ smf.StreamErrorGroup = &pv
+ case "remote-connection-failed":
+ rcf := RemoteConnectionFailed{}
+ err = d.DecodeElement(&rcf, &tt)
+ smf.StreamErrorGroup = &rcf
+ case "resource-constraint":
+ rc := ResourceConstraint{}
+ err = d.DecodeElement(&rc, &tt)
+ smf.StreamErrorGroup = &rc
+ case "restricted-xml":
+ rx := RestrictedXML{}
+ err = d.DecodeElement(&rx, &tt)
+ smf.StreamErrorGroup = &rx
+ case "see-other-host":
+ soh := SeeOtherHost{}
+ err = d.DecodeElement(&soh, &tt)
+ smf.StreamErrorGroup = &soh
+ case "system-shutdown":
+ ss := SystemShutdown{}
+ err = d.DecodeElement(&ss, &tt)
+ smf.StreamErrorGroup = &ss
+ case "undefined-condition":
+ uc := UndefinedCondition{}
+ err = d.DecodeElement(&uc, &tt)
+ smf.StreamErrorGroup = &uc
+ case "unexpected-request":
+ ur := UnexpectedRequest{}
+ err = d.DecodeElement(&ur, &tt)
+ smf.StreamErrorGroup = &ur
+ case "unsupported-encoding":
+ ue := UnsupportedEncoding{}
+ err = d.DecodeElement(&ue, &tt)
+ smf.StreamErrorGroup = &ue
+ case "unsupported-stanza-type":
+ ust := UnsupportedStanzaType{}
+ err = d.DecodeElement(&ust, &tt)
+ smf.StreamErrorGroup = &ust
+ case "unsupported-version":
+ uv := UnsupportedVersion{}
+ err = d.DecodeElement(&uv, &tt)
+ smf.StreamErrorGroup = &uv
+ case "xml-not-well-formed":
+ xnwf := XMLNotWellFormed{}
+ err = d.DecodeElement(&xnwf, &tt)
+ smf.StreamErrorGroup = &xnwf
+ default:
+ return errors.New("error is unknown")
+ }
+ if err != nil {
+ return err
+ }
+ case xml.EndElement:
+ if tt == start.End() {
+ return nil
+ }
+ }
+ }
+}
+
type smDecoder struct{}
var sm smDecoder
@@ -78,9 +342,11 @@ func (s smDecoder) decode(p *xml.Decoder, se xml.StartElement) (Packet, error) {
return s.decodeEnabled(p, se)
case "resumed":
return s.decodeResumed(p, se)
+ case "resume":
+ return s.decodeResume(p, se)
case "r":
return s.decodeRequest(p, se)
- case "h":
+ case "a":
return s.decodeAnswer(p, se)
case "failed":
return s.decodeFailed(p, se)
@@ -102,6 +368,11 @@ func (smDecoder) decodeResumed(p *xml.Decoder, se xml.StartElement) (SMResumed,
return packet, err
}
+func (smDecoder) decodeResume(p *xml.Decoder, se xml.StartElement) (SMResume, error) {
+ var packet SMResume
+ err := p.DecodeElement(&packet, &se)
+ return packet, err
+}
func (smDecoder) decodeRequest(p *xml.Decoder, se xml.StartElement) (SMRequest, error) {
var packet SMRequest
err := p.DecodeElement(&packet, &se)
diff --git a/stanza/stream_management_test.go b/stanza/stream_management_test.go
new file mode 100644
index 0000000..1b3443e
--- /dev/null
+++ b/stanza/stream_management_test.go
@@ -0,0 +1,226 @@
+package stanza_test
+
+import (
+ "gosrc.io/xmpp/stanza"
+ "math/rand"
+ "reflect"
+ "testing"
+ "time"
+)
+
+func TestPopEmptyQueue(t *testing.T) {
+ var uaq stanza.UnAckQueue
+ popped := uaq.Pop()
+ if popped != nil {
+ t.Fatalf("queue is empty but something was popped !")
+ }
+}
+
+func TestPushUnack(t *testing.T) {
+ uaq := initUnAckQueue()
+ toPush := stanza.UnAckedStz{
+ Id: 3,
+ Stz: `
+
+ confucius
+ Qui
+ Kong
+
+`,
+ }
+
+ err := uaq.Push(&toPush)
+ if err != nil {
+ t.Fatalf("could not push element to the queue : %v", err)
+ }
+
+ if len(uaq.Uslice) != 4 {
+ t.Fatalf("push to the non-acked queue failed")
+ }
+ for i := 0; i < 4; i++ {
+ if uaq.Uslice[i].Id != i+1 {
+ t.Fatalf("indexes were not updated correctly. Expected %d got %d", i, uaq.Uslice[i].Id)
+ }
+ }
+
+ // Check that the queue is a fifo : popped element should not be the one we just pushed.
+ popped := uaq.Pop()
+ poppedElt, ok := popped.(*stanza.UnAckedStz)
+ if !ok {
+ t.Fatalf("popped element is not a *stanza.UnAckedStz")
+ }
+
+ if reflect.DeepEqual(*poppedElt, toPush) {
+ t.Fatalf("pushed element is at the top of the fifo queue when it should be at the bottom")
+ }
+
+}
+
+func TestPeekUnack(t *testing.T) {
+ uaq := initUnAckQueue()
+
+ expectedPeek := stanza.UnAckedStz{
+ Id: 1,
+ Stz: `
+
+ Capulet
+
+`,
+ }
+
+ if !reflect.DeepEqual(expectedPeek, *uaq.Uslice[0]) {
+ t.Fatalf("peek failed to return the correct stanza")
+ }
+
+}
+
+func TestPeekNUnack(t *testing.T) {
+ uaq := initUnAckQueue()
+ initLen := len(uaq.Uslice)
+ randPop := rand.Int31n(int32(initLen))
+
+ peeked := uaq.PeekN(int(randPop))
+
+ if len(uaq.Uslice) != initLen {
+ t.Fatalf("queue length changed whith peek n operation : had %d found %d after peek", initLen, len(uaq.Uslice))
+ }
+
+ if len(peeked) != int(randPop) {
+ t.Fatalf("did not peek the correct number of element from queue. Expected %d got %d", randPop, len(peeked))
+ }
+}
+
+func TestPeekNUnackTooLong(t *testing.T) {
+ uaq := initUnAckQueue()
+ initLen := len(uaq.Uslice)
+
+ // Have a random number of elements to peek that's greater than the queue size
+ randPop := rand.Int31n(int32(initLen)) + 1 + int32(initLen)
+
+ peeked := uaq.PeekN(int(randPop))
+
+ if len(uaq.Uslice) != initLen {
+ t.Fatalf("total length changed whith peek n operation : had %d found %d after pop", initLen, len(uaq.Uslice))
+ }
+
+ if len(peeked) != initLen {
+ t.Fatalf("did not peek the correct number of element from queue. Expected %d got %d", initLen, len(peeked))
+ }
+
+}
+
+func TestPopNUnack(t *testing.T) {
+ uaq := initUnAckQueue()
+ initLen := len(uaq.Uslice)
+ randPop := rand.Int31n(int32(initLen))
+
+ popped := uaq.PopN(int(randPop))
+
+ if len(uaq.Uslice)+len(popped) != initLen {
+ t.Fatalf("total length changed whith pop n operation : had %d found %d after pop", initLen, len(uaq.Uslice)+len(popped))
+ }
+
+ for _, elt := range popped {
+ for _, oldElt := range uaq.Uslice {
+ if reflect.DeepEqual(elt, oldElt) {
+ t.Fatalf("pop n operation duplicated some elements")
+ }
+ }
+ }
+}
+
+func TestPopNUnackTooLong(t *testing.T) {
+ uaq := initUnAckQueue()
+ initLen := len(uaq.Uslice)
+
+ // Have a random number of elements to pop that's greater than the queue size
+ randPop := rand.Int31n(int32(initLen)) + 1 + int32(initLen)
+
+ popped := uaq.PopN(int(randPop))
+
+ if len(uaq.Uslice)+len(popped) != initLen {
+ t.Fatalf("total length changed whith pop n operation : had %d found %d after pop", initLen, len(uaq.Uslice)+len(popped))
+ }
+
+ for _, elt := range popped {
+ for _, oldElt := range uaq.Uslice {
+ if reflect.DeepEqual(elt, oldElt) {
+ t.Fatalf("pop n operation duplicated some elements")
+ }
+ }
+ }
+}
+
+func TestPopUnack(t *testing.T) {
+ uaq := initUnAckQueue()
+ initLen := len(uaq.Uslice)
+
+ popped := uaq.Pop()
+
+ if len(uaq.Uslice)+1 != initLen {
+ t.Fatalf("total length changed whith pop operation : had %d found %d after pop", initLen, len(uaq.Uslice)+1)
+ }
+ for _, oldElt := range uaq.Uslice {
+ if reflect.DeepEqual(popped, oldElt) {
+ t.Fatalf("pop n operation duplicated some elements")
+ }
+ }
+
+}
+
+func initUnAckQueue() stanza.UnAckQueue {
+ q := []*stanza.UnAckedStz{
+ {
+ Id: 1,
+ Stz: `
+
+ Capulet
+
+`,
+ },
+ {Id: 2,
+ Stz: `
+
+`},
+ {Id: 3,
+ Stz: `
+
+
+
+ jabber:iq:search
+
+
+ male
+
+
+
+`},
+ }
+
+ return stanza.UnAckQueue{Uslice: q}
+
+}
+
+func init() {
+ rand.Seed(time.Now().UTC().UnixNano())
+}
diff --git a/stream_manager.go b/stream_manager.go
index ebef1fa..da23df1 100644
--- a/stream_manager.go
+++ b/stream_manager.go
@@ -25,7 +25,7 @@ import (
// set callback and trigger reconnection.
type StreamClient interface {
Connect() error
- Resume(state SMState) error
+ Resume() error
Send(packet stanza.Packet) error
SendIQ(ctx context.Context, iq *stanza.IQ) (chan stanza.IQ, error)
SendRaw(packet string) error
@@ -75,9 +75,7 @@ func (sm *StreamManager) Run() error {
}
handler := func(e Event) error {
- switch e.State {
- case StateConnected:
- sm.Metrics.setConnectTime()
+ switch e.State.state {
case StateSessionEstablished:
sm.Metrics.setLoginTime()
case StateDisconnected:
@@ -128,7 +126,7 @@ func (sm *StreamManager) resume(state SMState) error {
// TODO: Make it possible to define logger to log disconnect and reconnection attempts
sm.Metrics = initMetrics()
- if err = sm.client.Resume(state); err != nil {
+ if err = sm.client.Resume(); err != nil {
var actualErr ConnError
if xerrors.As(err, &actualErr) {
if actualErr.Permanent {
@@ -152,11 +150,6 @@ func (sm *StreamManager) resume(state SMState) error {
type Metrics struct {
startTime time.Time
- // ConnectTime returns the duration between client initiation of the TCP/IP
- // connection to the server and actual TCP/IP session establishment.
- // This time includes DNS resolution and can be slightly higher if the DNS
- // resolution result was not in cache.
- ConnectTime time.Duration
// LoginTime returns the between client initiation of the TCP/IP
// connection to the server and the return of the login result.
// This includes ConnectTime, but also XMPP level protocol negotiation
@@ -172,10 +165,6 @@ func initMetrics() *Metrics {
}
}
-func (m *Metrics) setConnectTime() {
- m.ConnectTime = time.Since(m.startTime)
-}
-
func (m *Metrics) setLoginTime() {
m.LoginTime = time.Since(m.startTime)
}
diff --git a/tcp_server_mock.go b/tcp_server_mock.go
index 55740fa..d189e3a 100644
--- a/tcp_server_mock.go
+++ b/tcp_server_mock.go
@@ -36,6 +36,10 @@ const (
testClientRawPort
testClientIqPort
testClientIqFailPort
+ testClientPostConnectHook
+
+ // Client internal tests
+ testClientStreamManagement
)
// ClientHandler is passed by the test client to provide custom behaviour to
diff --git a/xmpp_transport.go b/xmpp_transport.go
index 092b95d..800f1b1 100644
--- a/xmpp_transport.go
+++ b/xmpp_transport.go
@@ -24,7 +24,8 @@ type XMPPTransport struct {
readWriter io.ReadWriter
logFile io.Writer
isSecure bool
- closeChan chan stanza.StreamClosePacket
+ // Used to close TCP connection when a stream close message is received from the server
+ closeChan chan stanza.StreamClosePacket
}
var componentStreamOpen = fmt.Sprintf("", stanza.NSComponent, stanza.NSStream)