diff --git a/_examples/go.sum b/_examples/go.sum index 286bc95..467aab6 100644 --- a/_examples/go.sum +++ b/_examples/go.sum @@ -99,7 +99,9 @@ github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/9 github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/processone/mpg123 v1.0.0 h1:o2WOyGZRM255or1Zc/LtF/jARn51B+9aQl72Qace0GA= github.com/processone/mpg123 v1.0.0/go.mod h1:X/FeL+h8vD1bYsG9tIWV3M2c4qNTZOficyvPVBP08go= +github.com/processone/soundcloud v1.0.0 h1:/+i6+Yveb7Y6IFGDSkesYI+HddblzcRTQClazzVHxoE= github.com/processone/soundcloud v1.0.0/go.mod h1:kDLeWpkRtN3C8kIReQdxoiRi92P9xR6yW6qLOJnNWfY= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDft0ttaMvbicHlPoso= @@ -155,6 +157,7 @@ golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73r golang.org/x/net v0.0.0-20190110200230-915654e7eabc/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190522155817-f3200d17e092/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859 h1:R/3boaszxrf1GEUWTVDzSKVwLmSJpwZ1yqXm8j0v2QI= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= diff --git a/_examples/xmpp_echo/xmpp_echo.go b/_examples/xmpp_echo/xmpp_echo.go index b6c6766..d7a35c8 100644 --- a/_examples/xmpp_echo/xmpp_echo.go +++ b/_examples/xmpp_echo/xmpp_echo.go @@ -28,7 +28,7 @@ func main() { router := xmpp.NewRouter() router.HandleFunc("message", handleMessage) - client, err := xmpp.NewClient(config, router, errorHandler) + client, err := xmpp.NewClient(&config, router, errorHandler) if err != nil { log.Fatalf("%+v", err) } diff --git a/_examples/xmpp_jukebox/xmpp_jukebox.go b/_examples/xmpp_jukebox/xmpp_jukebox.go index b8075a2..31f000a 100644 --- a/_examples/xmpp_jukebox/xmpp_jukebox.go +++ b/_examples/xmpp_jukebox/xmpp_jukebox.go @@ -54,7 +54,7 @@ func main() { handleIQ(s, p, player) }) - client, err := xmpp.NewClient(config, router, errorHandler) + client, err := xmpp.NewClient(&config, router, errorHandler) if err != nil { log.Fatalf("%+v", err) } diff --git a/_examples/xmpp_oauth2/xmpp_oauth2.go b/_examples/xmpp_oauth2/xmpp_oauth2.go index 89b2639..a993049 100644 --- a/_examples/xmpp_oauth2/xmpp_oauth2.go +++ b/_examples/xmpp_oauth2/xmpp_oauth2.go @@ -28,7 +28,7 @@ func main() { router := xmpp.NewRouter() router.HandleFunc("message", handleMessage) - client, err := xmpp.NewClient(config, router, errorHandler) + client, err := xmpp.NewClient(&config, router, errorHandler) if err != nil { log.Fatalf("%+v", err) } diff --git a/_examples/xmpp_pubsub_client/xmpp_ps_client.go b/_examples/xmpp_pubsub_client/xmpp_ps_client.go index b2e9cf6..2308071 100644 --- a/_examples/xmpp_pubsub_client/xmpp_ps_client.go +++ b/_examples/xmpp_pubsub_client/xmpp_ps_client.go @@ -38,7 +38,7 @@ func main() { log.Println("Received a message ! => \n" + string(data)) }) - client, err := xmpp.NewClient(config, router, func(err error) { log.Println(err) }) + client, err := xmpp.NewClient(&config, router, func(err error) { log.Println(err) }) if err != nil { log.Fatalf("%+v", err) } diff --git a/_examples/xmpp_websocket/xmpp_websocket.go b/_examples/xmpp_websocket/xmpp_websocket.go index c8c0620..3a0c1ba 100644 --- a/_examples/xmpp_websocket/xmpp_websocket.go +++ b/_examples/xmpp_websocket/xmpp_websocket.go @@ -26,7 +26,7 @@ func main() { router := xmpp.NewRouter() router.HandleFunc("message", handleMessage) - client, err := xmpp.NewClient(config, router, errorHandler) + client, err := xmpp.NewClient(&config, router, errorHandler) if err != nil { log.Fatalf("%+v", err) } diff --git a/auth.go b/auth.go index b8d20b9..902371b 100644 --- a/auth.go +++ b/auth.go @@ -60,10 +60,21 @@ func authPlain(socket io.ReadWriter, decoder *xml.Decoder, mech string, user str raw := "\x00" + user + "\x00" + secret enc := make([]byte, base64.StdEncoding.EncodedLen(len(raw))) base64.StdEncoding.Encode(enc, []byte(raw)) - _, err := fmt.Fprintf(socket, "%s", stanza.NSSASL, mech, enc) + + a := stanza.SASLAuth{ + Mechanism: mech, + Value: string(enc), + } + data, err := xml.Marshal(a) if err != nil { return err } + n, err := socket.Write(data) + if err != nil { + return err + } else if n == 0 { + return errors.New("failed to write authSASL nonza to socket : wrote 0 bytes") + } // Next message should be either success or failure. val, err := stanza.NextPacket(decoder) diff --git a/client.go b/client.go index bd40c38..1be1f4d 100644 --- a/client.go +++ b/client.go @@ -6,6 +6,7 @@ import ( "errors" "io" "net" + "sync" "time" "gosrc.io/xmpp/stanza" @@ -14,15 +15,36 @@ import ( //============================================================================= // EventManager -// ConnState represents the current connection state. +// SyncConnState represents the current connection state. +type SyncConnState struct { + sync.RWMutex + // Current state of the client. Please use the dedicated getter and setter for this field as they are thread safe. + state ConnState +} type ConnState = uint8 +// getState is a thread-safe getter for the current state +func (scs *SyncConnState) getState() ConnState { + var res ConnState + scs.RLock() + res = scs.state + scs.RUnlock() + return res +} + +// setState is a thread-safe setter for the current +func (scs *SyncConnState) setState(cs ConnState) { + scs.Lock() + scs.state = cs + scs.Unlock() +} + // This is a the list of events happening on the connection that the // client can be notified about. const ( InitialPresence = "" StateDisconnected ConnState = iota - StateConnected + StateResuming StateSessionEstablished StateStreamError StatePermanentError @@ -31,7 +53,7 @@ const ( // Event is a structure use to convey event changes related to client state. This // is for example used to notify the client when the client get disconnected. type Event struct { - State ConnState + State SyncConnState Description string StreamError string SMState SMState @@ -44,7 +66,16 @@ type SMState struct { Id string // Inbound stanza count Inbound uint - // TODO Store location for IP affinity + + // IP affinity + preferredReconAddr string + + // Error + StreamErrorGroup stanza.StanzaErrorGroup + + // Track sent stanzas + *stanza.UnAckQueue + // TODO Store max and timestamp, to check if we should retry resumption or not } @@ -53,29 +84,35 @@ type SMState struct { type EventHandler func(Event) error type EventManager struct { - // Store current state - CurrentState ConnState + // Store current state. Please use "getState" and "setState" to access and/or modify this. + CurrentState SyncConnState // Callback used to propagate connection state changes Handler EventHandler } +// updateState changes the CurrentState in the event manager. The state read is threadsafe but there is no guarantee +// regarding the triggered callback function. func (em *EventManager) updateState(state ConnState) { - em.CurrentState = state + em.CurrentState.setState(state) if em.Handler != nil { em.Handler(Event{State: em.CurrentState}) } } +// disconnected changes the CurrentState in the event manager to "disconnected". The state read is threadsafe but there is no guarantee +// regarding the triggered callback function. func (em *EventManager) disconnected(state SMState) { - em.CurrentState = StateDisconnected + em.CurrentState.setState(StateDisconnected) if em.Handler != nil { em.Handler(Event{State: em.CurrentState, SMState: state}) } } +// streamError changes the CurrentState in the event manager to "streamError". The state read is threadsafe but there is no guarantee +// regarding the triggered callback function. func (em *EventManager) streamError(error, desc string) { - em.CurrentState = StateStreamError + em.CurrentState.setState(StateStreamError) if em.Handler != nil { em.Handler(Event{State: em.CurrentState, StreamError: error, Description: desc}) } @@ -90,7 +127,7 @@ var ErrCanOnlySendGetOrSetIq = errors.New("SendIQ can only send get and set IQ s // server. type Client struct { // Store user defined options and states - config Config + config *Config // Session gather data that can be accessed by users of this library Session *Session transport Transport @@ -100,6 +137,12 @@ type Client struct { EventManager // Handle errors from client execution ErrorHandler func(error) + + // Post connection hook. This will be executed on first connection + PostFirstConnHook func() error + + // Post resume hook. This will be executed after the client resumes a lost connection using StreamManagement (XEP-0198) + PostReconnectHook func() error } /* @@ -107,9 +150,9 @@ Setting up the client / Checking the parameters */ // NewClient generates a new XMPP client, based on Config passed as parameters. -// If host is not specified, the DNS SRV should be used to find the host from the domainpart of the Jid. +// If host is not specified, the DNS SRV should be used to find the host from the domain part of the Jid. // Default the port to 5222. -func NewClient(config Config, r *Router, errorHandler func(error)) (c *Client, err error) { +func NewClient(config *Config, r *Router, errorHandler func(error)) (c *Client, err error) { if config.KeepaliveInterval == 0 { config.KeepaliveInterval = time.Second * 30 } @@ -169,26 +212,45 @@ func NewClient(config Config, r *Router, errorHandler func(error)) (c *Client, e return } -// Connect triggers actual TCP connection, based on previously defined parameters. -// Connect simply triggers resumption, with an empty session state. +// Connect establishes a first time connection to a XMPP server. +// It calls the PostFirstConnHook func (c *Client) Connect() error { - var state SMState - return c.Resume(state) + err := c.connect() + if err != nil { + return err + } + // TODO: Do we always want to send initial presence automatically ? + // Do we need an option to avoid that or do we rely on client to send the presence itself ? + err = c.sendWithWriter(c.transport, []byte(InitialPresence)) + // Execute the post first connection hook. Typically this holds "ask for roster" and this type of actions. + if c.PostFirstConnHook != nil { + err = c.PostFirstConnHook() + if err != nil { + return err + } + } + + // Start the keepalive go routine + keepaliveQuit := make(chan struct{}) + go keepalive(c.transport, c.config.KeepaliveInterval, keepaliveQuit) + // Start the receiver go routine + go c.recv(keepaliveQuit) + return err } -// Resume attempts resuming a Stream Managed session, based on the provided stream management -// state. -func (c *Client) Resume(state SMState) error { +// connect establishes an actual TCP connection, based on previously defined parameters, as well as a XMPP session +func (c *Client) connect() error { + var state SMState var err error - + // This is the TCP connection streamId, err := c.transport.Connect() if err != nil { return err } - c.updateState(StateConnected) - // Client is ok, we now open XMPP session - if c.Session, err = NewSession(c.transport, c.config, state); err != nil { + // Client is ok, we now open XMPP session with TLS negotiation if possible and session resume or binding + // depending on state. + if c.Session, err = NewSession(c, state); err != nil { // Try to get the stream close tag from the server. go func() { for { @@ -212,22 +274,26 @@ func (c *Client) Resume(state SMState) error { c.Session.StreamId = streamId c.updateState(StateSessionEstablished) - // Start the keepalive go routine - keepaliveQuit := make(chan struct{}) - go keepalive(c.transport, c.config.KeepaliveInterval, keepaliveQuit) - // Start the receiver go routine - state = c.Session.SMState - go c.recv(state, keepaliveQuit) - - // We're connected and can now receive and send messages. - //fmt.Fprintf(client.conn, "%s%s", "chat", "Online") - // TODO: Do we always want to send initial presence automatically ? - // Do we need an option to avoid that or do we rely on client to send the presence itself ? - err = c.sendWithWriter(c.transport, []byte(InitialPresence)) - return err } +// Resume attempts resuming a Stream Managed session, based on the provided stream management +// state. See XEP-0198 +func (c *Client) Resume() error { + c.EventManager.updateState(StateResuming) + err := c.connect() + if err != nil { + return err + } + // Execute post reconnect hook. This can be different from the first connection hook, and not trigger roster retrival + // for example. + if c.PostReconnectHook != nil { + err = c.PostReconnectHook() + } + return err +} + +// Disconnect disconnects the client from the server, sending a stream close nonza and closing the TCP connection. func (c *Client) Disconnect() error { if c.transport != nil { return c.transport.Close() @@ -252,6 +318,15 @@ func (c *Client) Send(packet stanza.Packet) error { return errors.New("cannot marshal packet " + err.Error()) } + // Store stanza as non-acked as part of stream management + // See https://xmpp.org/extensions/xep-0198.html#scenarios + if c.config.StreamManagementEnable { + if _, ok := packet.(stanza.SMRequest); !ok { + toStore := stanza.UnAckedStz{Stz: string(data)} + c.Session.SMState.UnAckQueue.Push(&toStore) + } + } + return c.sendWithWriter(c.transport, data) } @@ -284,6 +359,12 @@ func (c *Client) SendRaw(packet string) error { return errors.New("client is not connected") } + // Store stanza as non-acked as part of stream management + // See https://xmpp.org/extensions/xep-0198.html#scenarios + if c.config.StreamManagementEnable { + toStore := stanza.UnAckedStz{Stz: packet} + c.Session.SMState.UnAckQueue.Push(&toStore) + } return c.sendWithWriter(c.transport, []byte(packet)) } @@ -297,13 +378,13 @@ func (c *Client) sendWithWriter(writer io.Writer, packet []byte) error { // Go routines // Loop: Receive data from server -func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}) { +func (c *Client) recv(keepaliveQuit chan<- struct{}) { for { val, err := stanza.NextPacket(c.transport.GetDecoder()) if err != nil { c.ErrorHandler(err) close(keepaliveQuit) - c.disconnected(state) + c.disconnected(c.Session.SMState) return } @@ -321,7 +402,7 @@ func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}) { answer := stanza.SMAnswer{XMLName: xml.Name{ Space: stanza.NSStreamManagement, Local: "a", - }, H: state.Inbound} + }, H: c.Session.SMState.Inbound} err = c.Send(answer) if err != nil { c.ErrorHandler(err) @@ -332,7 +413,7 @@ func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}) { c.transport.ReceivedStreamClose() return default: - state.Inbound++ + c.Session.SMState.Inbound++ } // Do normal route processing in a go-routine so we can immediately // start receiving other stanzas. This also allows route handlers to diff --git a/client_internal_test.go b/client_internal_test.go index 6daef09..e517a42 100644 --- a/client_internal_test.go +++ b/client_internal_test.go @@ -2,7 +2,16 @@ package xmpp import ( "bytes" + "encoding/xml" + "fmt" + "gosrc.io/xmpp/stanza" + "strconv" "testing" + "time" +) + +const ( + streamManagementID = "test-stream_management-id" ) func TestClient_Send(t *testing.T) { @@ -17,3 +26,554 @@ func TestClient_Send(t *testing.T) { t.Errorf("Incorrect value sent to buffer: '%s'", buffer.String()) } } + +// Stream management test. +// Connection is established, then the server sends supported features and so on. +// After the bind, client attempts a stream management enablement, and server replies in kind. +func Test_StreamManagement(t *testing.T) { + serverDone := make(chan struct{}) + clientDone := make(chan struct{}) + + client, mock := initSrvCliForResumeTests(t, func(t *testing.T, sc *ServerConn) { + checkClientOpenStream(t, sc) + + sendStreamFeatures(t, sc) // Send initial features + readAuth(t, sc.decoder) + sc.connection.Write([]byte("")) + + checkClientOpenStream(t, sc) // Reset stream + sendFeaturesStreamManagment(t, sc) // Send post auth features + bind(t, sc) + enableStreamManagement(t, sc, false, true) + serverDone <- struct{}{} + }, testClientStreamManagement, true, true) + go func() { + var state SMState + var err error + // Client is ok, we now open XMPP session + if client.Session, err = NewSession(client, state); err != nil { + t.Fatalf("failed to open XMPP session: %s", err) + } + clientDone <- struct{}{} + }() + + waitForEntity(t, clientDone) + waitForEntity(t, serverDone) + mock.Stop() +} + +// Absence of stream management test. +// Connection is established, then the server sends supported features and so on. +// Client has stream management disabled in its config, and should not ask for it. Server is not set up to reply. +func Test_NoStreamManagement(t *testing.T) { + serverDone := make(chan struct{}) + clientDone := make(chan struct{}) + + // Setup Mock server + client, mock := initSrvCliForResumeTests(t, func(t *testing.T, sc *ServerConn) { + checkClientOpenStream(t, sc) + + sendStreamFeatures(t, sc) // Send initial features + readAuth(t, sc.decoder) + sc.connection.Write([]byte("")) + + checkClientOpenStream(t, sc) // Reset stream + sendFeaturesNoStreamManagment(t, sc) // Send post auth features + bind(t, sc) + serverDone <- struct{}{} + }, testClientStreamManagement, true, false) + + go func() { + var state SMState + + // Client is ok, we now open XMPP session + var err error + if client.Session, err = NewSession(client, state); err != nil { + t.Fatalf("failed to open XMPP session: %s", err) + } + clientDone <- struct{}{} + }() + + waitForEntity(t, clientDone) + waitForEntity(t, serverDone) + + mock.Stop() +} + +func Test_StreamManagementNotSupported(t *testing.T) { + serverDone := make(chan struct{}) + clientDone := make(chan struct{}) + + client, mock := initSrvCliForResumeTests(t, func(t *testing.T, sc *ServerConn) { + checkClientOpenStream(t, sc) + + sendStreamFeatures(t, sc) // Send initial features + readAuth(t, sc.decoder) + sc.connection.Write([]byte("")) + + checkClientOpenStream(t, sc) // Reset stream + sendFeaturesNoStreamManagment(t, sc) // Send post auth features + bind(t, sc) + serverDone <- struct{}{} + }, testClientStreamManagement, true, true) + + go func() { + var state SMState + var err error + // Client is ok, we now open XMPP session + if client.Session, err = NewSession(client, state); err != nil { + t.Fatalf("failed to open XMPP session: %s", err) + } + clientDone <- struct{}{} + }() + + // Wait for client + waitForEntity(t, clientDone) + + // Check if client got a positive stream management response from the server + if client.Session.Features.DoesStreamManagement() { + t.Fatalf("server does not provide stream management") + } + + // Wait for server + waitForEntity(t, serverDone) + mock.Stop() +} + +func Test_StreamManagementNoResume(t *testing.T) { + serverDone := make(chan struct{}) + clientDone := make(chan struct{}) + + client, mock := initSrvCliForResumeTests(t, func(t *testing.T, sc *ServerConn) { + checkClientOpenStream(t, sc) + + sendStreamFeatures(t, sc) // Send initial features + readAuth(t, sc.decoder) + sc.connection.Write([]byte("")) + + checkClientOpenStream(t, sc) // Reset stream + sendFeaturesStreamManagment(t, sc) // Send post auth features + bind(t, sc) + enableStreamManagement(t, sc, false, false) + serverDone <- struct{}{} + }, testClientStreamManagement, true, true) + + go func() { + var state SMState + var err error + // Client is ok, we now open XMPP session + if client.Session, err = NewSession(client, state); err != nil { + t.Fatalf("failed to open XMPP session: %s", err) + } + clientDone <- struct{}{} + }() + waitForEntity(t, clientDone) + if IsStreamResumable(client) { + t.Fatalf("server does not support resumption but client says stream is resumable") + } + waitForEntity(t, serverDone) + mock.Stop() +} + +func Test_StreamManagementResume(t *testing.T) { + // Setup Mock server + mock := ServerMock{} + mock.Start(t, testXMPPAddress, func(t *testing.T, sc *ServerConn) { + checkClientOpenStream(t, sc) + + sendStreamFeatures(t, sc) // Send initial features + readAuth(t, sc.decoder) + sc.connection.Write([]byte("")) + + checkClientOpenStream(t, sc) // Reset stream + sendFeaturesStreamManagment(t, sc) // Send post auth features + bind(t, sc) + enableStreamManagement(t, sc, false, true) + discardPresence(t, sc) + }) + + // Test / Check result + config := Config{ + TransportConfiguration: TransportConfiguration{ + Address: testXMPPAddress, + }, + Jid: "test@localhost", + Credential: Password("test"), + Insecure: true, + StreamManagementEnable: true, + streamManagementResume: true} // Enable stream management + + var client *Client + router := NewRouter() + client, err := NewClient(&config, router, clientDefaultErrorHandler) + if err != nil { + t.Errorf("connect create XMPP client: %s", err) + } + + err = client.Connect() + if err != nil { + t.Fatalf("could not connect client to mock server: %s", err) + } + + statusCorrectChan := make(chan struct{}) + kill := make(chan struct{}) + + transp, ok := client.transport.(*XMPPTransport) + if !ok { + t.Fatalf("problem with client transport ") + } + + transp.conn.Close() + mock.Stop() + + // Check if status is correctly updated because of the disconnect + go checkClientResumeStatus(client, statusCorrectChan, kill) + select { + case <-statusCorrectChan: + // Test passed + case <-time.After(5 * time.Second): + kill <- struct{}{} + t.Fatalf("Client is not in disconnected state while it should be. Timed out") + } + + // Check if the client can have its connection resumed using its state but also its configuration + if !IsStreamResumable(client) { + t.Fatalf("should support resumption") + } + + // Reboot server. We need to make a new one because (at least for now) the mock server can only have one handler + // and they should be different between a first connection and a stream resume since exchanged messages + // are different (See XEP-0198) + mock2 := ServerMock{} + mock2.Start(t, testXMPPAddress, func(t *testing.T, sc *ServerConn) { + // Reconnect + checkClientOpenStream(t, sc) + + sendStreamFeatures(t, sc) // Send initial features + readAuth(t, sc.decoder) + sc.connection.Write([]byte("")) + + checkClientOpenStream(t, sc) // Reset stream + sendFeaturesStreamManagment(t, sc) // Send post auth features + resumeStream(t, sc) + }) + + // Reconnect + err = client.Resume() + if err != nil { + t.Fatalf("could not connect client to mock server: %s", err) + } + mock2.Stop() +} + +func Test_StreamManagementFail(t *testing.T) { + // Setup Mock server + mock := ServerMock{} + mock.Start(t, testXMPPAddress, func(t *testing.T, sc *ServerConn) { + checkClientOpenStream(t, sc) + + sendStreamFeatures(t, sc) // Send initial features + readAuth(t, sc.decoder) + sc.connection.Write([]byte("")) + + checkClientOpenStream(t, sc) // Reset stream + sendFeaturesStreamManagment(t, sc) // Send post auth features + bind(t, sc) + enableStreamManagement(t, sc, true, true) + }) + + // Test / Check result + config := Config{ + TransportConfiguration: TransportConfiguration{ + Address: testXMPPAddress, + }, + Jid: "test@localhost", + Credential: Password("test"), + Insecure: true, + StreamManagementEnable: true, + streamManagementResume: true} // Enable stream management + + var client *Client + router := NewRouter() + client, err := NewClient(&config, router, clientDefaultErrorHandler) + if err != nil { + t.Errorf("connect create XMPP client: %s", err) + } + + var state SMState + _, err = client.transport.Connect() + if err != nil { + return + } + + // Client is ok, we now open XMPP session + if client.Session, err = NewSession(client, state); err == nil { + t.Fatalf("test is supposed to err") + } + if client.Session.SMState.StreamErrorGroup == nil { + t.Fatalf("error was not stored correctly in session state") + } + + mock.Stop() +} + +func Test_SendStanzaQueueWithSM(t *testing.T) { + // Setup Mock server + mock := ServerMock{} + serverDone := make(chan struct{}) + mock.Start(t, testXMPPAddress, func(t *testing.T, sc *ServerConn) { + checkClientOpenStream(t, sc) + + sendStreamFeatures(t, sc) // Send initial features + readAuth(t, sc.decoder) + sc.connection.Write([]byte("")) + + checkClientOpenStream(t, sc) // Reset stream + sendFeaturesStreamManagment(t, sc) // Send post auth features + bind(t, sc) + enableStreamManagement(t, sc, false, true) + + // Ignore the initial presence sent to the server by the client so we can move on to the next packet. + discardPresence(t, sc) + + // Used here to silently discard the IQ sent by the client, in order to later trigger a resend + skipPacket(t, sc) + // Respond to the client ACK request with a number of processed stanzas of 0. This should trigger a resend + // of previously ignored stanza to the server, which this handler element will be expecting. + respondWithAck(t, sc, 0, serverDone) + }) + + // Test / Check result + config := Config{ + TransportConfiguration: TransportConfiguration{ + Address: testXMPPAddress, + }, + Jid: "test@localhost", + Credential: Password("test"), + Insecure: true, + StreamManagementEnable: true, + streamManagementResume: true} // Enable stream management + + var client *Client + router := NewRouter() + client, err := NewClient(&config, router, clientDefaultErrorHandler) + if err != nil { + t.Errorf("connect create XMPP client: %s", err) + } + + err = client.Connect() + + client.SendRaw(` + + +`) + + // Last stanza was discarded silently by the server. Let's ask an ack for it. This should trigger resend as the server + // will respond with an acknowledged number of stanzas of 0. + r := stanza.SMRequest{} + client.Send(r) + + select { + case <-time.After(defaultChannelTimeout): + t.Fatalf("server failed to complete the test in time") + case <-serverDone: + // Test completed successfully + } + + mock.Stop() +} + +//======================================================================== +// Helper functions for tests + +func skipPacket(t *testing.T, sc *ServerConn) { + var p stanza.IQ + se, err := stanza.NextStart(sc.decoder) + + if err != nil { + t.Fatalf("cannot read packet: %s", err) + return + } + if err := sc.decoder.DecodeElement(&p, &se); err != nil { + t.Fatalf("cannot decode packet: %s", err) + return + } +} + +func respondWithAck(t *testing.T, sc *ServerConn, h int, serverDone chan struct{}) { + + // Mock server reads the ack request + var p stanza.SMRequest + se, err := stanza.NextStart(sc.decoder) + + if err != nil { + t.Fatalf("cannot read packet: %s", err) + return + } + if err := sc.decoder.DecodeElement(&p, &se); err != nil { + t.Fatalf("cannot decode packet: %s", err) + return + } + + // Mock server sends the ack response + a := stanza.SMAnswer{ + H: uint(h), + } + data, err := xml.Marshal(a) + _, err = sc.connection.Write(data) + if err != nil { + t.Fatalf("failed to send response ack") + } + + // Mock server reads the re-sent stanza that was previously discarded intentionally + var p2 stanza.IQ + nse, err := stanza.NextStart(sc.decoder) + + if err != nil { + t.Fatalf("cannot read packet: %s", err) + return + } + if err := sc.decoder.DecodeElement(&p2, &nse); err != nil { + t.Fatalf("cannot decode packet: %s", err) + return + } + serverDone <- struct{}{} +} + +func sendFeaturesStreamManagment(t *testing.T, sc *ServerConn) { + // This is a basic server, supporting only 2 features after auth: stream management & session binding + features := ` + + +` + if _, err := fmt.Fprintln(sc.connection, features); err != nil { + t.Fatalf("cannot send stream feature: %s", err) + } +} + +func sendFeaturesNoStreamManagment(t *testing.T, sc *ServerConn) { + // This is a basic server, supporting only 2 features after auth: stream management & session binding + features := ` + +` + if _, err := fmt.Fprintln(sc.connection, features); err != nil { + t.Fatalf("cannot send stream feature: %s", err) + } +} + +// enableStreamManagement is a function for the mock server that can either mock a successful session, or fail depending on +// the value of the "fail" boolean. True means the session should fail. +func enableStreamManagement(t *testing.T, sc *ServerConn, fail bool, resume bool) { + // Decode element into pointer storage + var ed stanza.SMEnable + se, err := stanza.NextStart(sc.decoder) + + if err != nil { + t.Fatalf("cannot read stream management enable: %s", err) + return + } + if err := sc.decoder.DecodeElement(&ed, &se); err != nil { + t.Fatalf("cannot decode stream management enable: %s", err) + return + } + + if fail { + f := stanza.SMFailed{ + H: nil, + StreamErrorGroup: &stanza.UnexpectedRequest{}, + } + data, err := xml.Marshal(f) + if err != nil { + t.Fatalf("failed to marshall error response: %s", err) + } + sc.connection.Write(data) + } else { + e := &stanza.SMEnabled{ + Resume: strconv.FormatBool(resume), + Id: streamManagementID, + } + data, err := xml.Marshal(e) + if err != nil { + t.Fatalf("failed to marshall error response: %s", err) + } + sc.connection.Write(data) + } +} + +func resumeStream(t *testing.T, sc *ServerConn) { + h := uint(0) + response := stanza.SMResumed{ + PrevId: streamManagementID, + H: &h, + } + + data, err := xml.Marshal(response) + if err != nil { + t.Fatalf("failed to marshall stream management enabled response : %s", err) + } + + writtenChan := make(chan struct{}) + + go func() { + sc.connection.Write(data) + writtenChan <- struct{}{} + }() + select { + case <-writtenChan: + // We're done here + return + case <-time.After(defaultTimeout): + t.Fatalf("failed to write enabled nonza to client") + } +} + +func checkClientResumeStatus(client *Client, statusCorrectChan chan struct{}, killChan chan struct{}) { + for { + if client.CurrentState.getState() == StateDisconnected { + statusCorrectChan <- struct{}{} + } + select { + case <-killChan: + return + case <-time.After(time.Millisecond * 10): + // Keep checking status value + } + } +} + +func initSrvCliForResumeTests(t *testing.T, serverHandler func(*testing.T, *ServerConn), port int, StreamManagementEnable, StreamManagementResume bool) (*Client, *ServerMock) { + mock := &ServerMock{} + testServerAddress := fmt.Sprintf("%s:%d", testClientDomain, port) + + mock.Start(t, testServerAddress, serverHandler) + config := Config{ + TransportConfiguration: TransportConfiguration{ + Address: testServerAddress, + }, + Jid: "test@localhost", + Credential: Password("test"), + Insecure: true, + StreamManagementEnable: StreamManagementEnable, + streamManagementResume: StreamManagementResume} + + var client *Client + var err error + router := NewRouter() + if client, err = NewClient(&config, router, clientDefaultErrorHandler); err != nil { + t.Fatalf("connect create XMPP client: %s", err) + } + + if _, err = client.transport.Connect(); err != nil { + t.Fatalf("XMPP connection failed: %s", err) + } + + return client, mock +} + +func waitForEntity(t *testing.T, entityDone chan struct{}) { + select { + case <-entityDone: + case <-time.After(defaultTimeout): + t.Fatalf("test timed out") + } +} diff --git a/client_test.go b/client_test.go index 00c956a..3cdd77a 100644 --- a/client_test.go +++ b/client_test.go @@ -20,18 +20,20 @@ const ( func TestEventManager(t *testing.T) { mgr := EventManager{} - mgr.updateState(StateConnected) - if mgr.CurrentState != StateConnected { + mgr.updateState(StateResuming) + if mgr.CurrentState.getState() != StateResuming { t.Fatal("CurrentState not updated by updateState()") } mgr.disconnected(SMState{}) - if mgr.CurrentState != StateDisconnected { + + if mgr.CurrentState.getState() != StateDisconnected { t.Fatalf("CurrentState not reset by disconnected()") } mgr.streamError(ErrTLSNotSupported.Error(), "") - if mgr.CurrentState != StateStreamError { + + if mgr.CurrentState.getState() != StateStreamError { t.Fatalf("CurrentState not set by streamError()") } } @@ -53,7 +55,7 @@ func TestClient_Connect(t *testing.T) { var client *Client var err error router := NewRouter() - if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil { + if client, err = NewClient(&config, router, clientDefaultErrorHandler); err != nil { t.Errorf("connect create XMPP client: %s", err) } @@ -84,7 +86,7 @@ func TestClient_NoInsecure(t *testing.T) { var client *Client var err error router := NewRouter() - if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil { + if client, err = NewClient(&config, router, clientDefaultErrorHandler); err != nil { t.Errorf("cannot create XMPP client: %s", err) } @@ -117,7 +119,7 @@ func TestClient_FeaturesTracking(t *testing.T) { var client *Client var err error router := NewRouter() - if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil { + if client, err = NewClient(&config, router, clientDefaultErrorHandler); err != nil { t.Errorf("cannot create XMPP client: %s", err) } @@ -147,7 +149,7 @@ func TestClient_RFC3921Session(t *testing.T) { var client *Client var err error router := NewRouter() - if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil { + if client, err = NewClient(&config, router, clientDefaultErrorHandler); err != nil { t.Errorf("connect create XMPP client: %s", err) } @@ -366,7 +368,7 @@ func TestClient_DisconnectStreamManager(t *testing.T) { var client *Client var err error router := NewRouter() - if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil { + if client, err = NewClient(&config, router, clientDefaultErrorHandler); err != nil { t.Errorf("cannot create XMPP client: %s", err) } @@ -386,6 +388,162 @@ func TestClient_DisconnectStreamManager(t *testing.T) { mock.Stop() } +func Test_ClientPostConnectHook(t *testing.T) { + done := make(chan struct{}) + // Handler for Mock server + h := func(t *testing.T, sc *ServerConn) { + handlerClientConnectSuccess(t, sc) + done <- struct{}{} + } + + hookChan := make(chan struct{}) + mock := &ServerMock{} + testServerAddress := fmt.Sprintf("%s:%d", testClientDomain, testClientPostConnectHook) + + mock.Start(t, testServerAddress, h) + config := Config{ + TransportConfiguration: TransportConfiguration{ + Address: testServerAddress, + }, + Jid: "test@localhost", + Credential: Password("test"), + Insecure: true} + + var client *Client + var err error + router := NewRouter() + if client, err = NewClient(&config, router, clientDefaultErrorHandler); err != nil { + t.Errorf("connect create XMPP client: %s", err) + } + + // The post connection client hook should just write to a channel that we will read later. + client.PostFirstConnHook = func() error { + go func() { + hookChan <- struct{}{} + }() + return nil + } + // Handle a possible error + errChan := make(chan error) + errorHandler := func(err error) { + errChan <- err + } + client.ErrorHandler = errorHandler + if err = client.Connect(); err != nil { + t.Errorf("XMPP connection failed: %s", err) + } + + // Check if the post connection client hook was correctly called + select { + case err := <-errChan: // If the server sends an error, or there is a connection error + t.Fatal(err.Error()) + case <-time.After(defaultChannelTimeout): // If we timeout + t.Fatal("Failed to call post connection client hook") + case <-hookChan: + // Test succeeded, channel was written to. + } + + select { + case <-done: + mock.Stop() + case <-time.After(defaultChannelTimeout): + t.Fatal("The mock server failed to finish its job !") + } +} + +func Test_ClientPostReconnectHook(t *testing.T) { + hookChan := make(chan struct{}) + // Setup Mock server + mock := ServerMock{} + mock.Start(t, testXMPPAddress, func(t *testing.T, sc *ServerConn) { + checkClientOpenStream(t, sc) + + sendStreamFeatures(t, sc) // Send initial features + readAuth(t, sc.decoder) + sc.connection.Write([]byte("")) + + checkClientOpenStream(t, sc) // Reset stream + sendFeaturesStreamManagment(t, sc) // Send post auth features + bind(t, sc) + enableStreamManagement(t, sc, false, true) + }) + + // Test / Check result + config := Config{ + TransportConfiguration: TransportConfiguration{ + Address: testXMPPAddress, + }, + Jid: "test@localhost", + Credential: Password("test"), + Insecure: true, + StreamManagementEnable: true, + streamManagementResume: true} // Enable stream management + + var client *Client + router := NewRouter() + client, err := NewClient(&config, router, clientDefaultErrorHandler) + if err != nil { + t.Errorf("connect create XMPP client: %s", err) + } + + client.PostReconnectHook = func() error { + go func() { + hookChan <- struct{}{} + }() + return nil + } + + err = client.Connect() + if err != nil { + t.Fatalf("could not connect client to mock server: %s", err) + } + + transp, ok := client.transport.(*XMPPTransport) + if !ok { + t.Fatalf("problem with client transport ") + } + + transp.conn.Close() + mock.Stop() + + // Check if the client can have its connection resumed using its state but also its configuration + if !IsStreamResumable(client) { + t.Fatalf("should support resumption") + } + + // Reboot server. We need to make a new one because (at least for now) the mock server can only have one handler + // and they should be different between a first connection and a stream resume since exchanged messages + // are different (See XEP-0198) + mock2 := ServerMock{} + mock2.Start(t, testXMPPAddress, func(t *testing.T, sc *ServerConn) { + // Reconnect + checkClientOpenStream(t, sc) + + sendStreamFeatures(t, sc) // Send initial features + readAuth(t, sc.decoder) + sc.connection.Write([]byte("")) + + checkClientOpenStream(t, sc) // Reset stream + sendFeaturesStreamManagment(t, sc) // Send post auth features + resumeStream(t, sc) + }) + + // Reconnect + err = client.Resume() + if err != nil { + t.Fatalf("could not connect client to mock server: %s", err) + } + + select { + case <-time.After(defaultChannelTimeout): // If we timeout + t.Fatal("Failed to call post connection client hook") + case <-hookChan: + // Test succeeded, channel was written to. + } + + mock2.Stop() +} + //============================================================================= // Basic XMPP Server Mock Handlers. @@ -449,7 +607,7 @@ func checkClientOpenStream(t *testing.T, sc *ServerConn) { var token xml.Token token, err := sc.decoder.Token() if err != nil { - t.Errorf("cannot read next token: %s", err) + t.Fatalf("cannot read next token: %s", err) } switch elem := token.(type) { @@ -464,6 +622,7 @@ func checkClientOpenStream(t *testing.T, sc *ServerConn) { } return } + } } @@ -472,7 +631,6 @@ func mockClientConnection(t *testing.T, serverHandler func(*testing.T, *ServerCo testServerAddress := fmt.Sprintf("%s:%d", testClientDomain, port) mock.Start(t, testServerAddress, serverHandler) - config := Config{ TransportConfiguration: TransportConfiguration{ Address: testServerAddress, @@ -484,7 +642,7 @@ func mockClientConnection(t *testing.T, serverHandler func(*testing.T, *ServerCo var client *Client var err error router := NewRouter() - if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil { + if client, err = NewClient(&config, router, clientDefaultErrorHandler); err != nil { t.Errorf("connect create XMPP client: %s", err) } diff --git a/cmd/fluuxmpp/send.go b/cmd/fluuxmpp/send.go index 7e7ed97..27c1a67 100644 --- a/cmd/fluuxmpp/send.go +++ b/cmd/fluuxmpp/send.go @@ -32,7 +32,7 @@ func sendxmpp(cmd *cobra.Command, args []string) { msgText := args[1] var err error - client, err := xmpp.NewClient(xmpp.Config{ + client, err := xmpp.NewClient(&xmpp.Config{ TransportConfiguration: xmpp.TransportConfiguration{ Address: viper.GetString("addr"), }, diff --git a/cmd/go.sum b/cmd/go.sum index c7e00fa..8398605 100644 --- a/cmd/go.sum +++ b/cmd/go.sum @@ -65,6 +65,7 @@ github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20190908185732-236ed259b199/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= +github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= @@ -92,6 +93,7 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/magiconair/properties v1.8.0 h1:LLgXmsheXeRoUOBOjtwPQCWIYqM/LU1ayDtDePerRcY= github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= +github.com/magiconair/properties v1.8.1 h1:ZC2Vc7/ZFkGmsVC9KvOjumD+G5lXy2RtTKyzRKO2BQ4= github.com/magiconair/properties v1.8.1/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= github.com/mailru/easyjson v0.0.0-20190403194419-1ea4449da983/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= @@ -148,10 +150,12 @@ github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb6 github.com/spf13/pflag v1.0.1/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= github.com/spf13/pflag v1.0.3 h1:zPAT6CGy6wXeQ7NtTnaTerfKOsV6V6F8agHXFiazDkg= github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s= github.com/spf13/viper v1.4.0 h1:yXHLWeravcrgGyFSyCgdYpXQ9dR9c/WED3pg1RhxqEU= github.com/spf13/viper v1.4.0/go.mod h1:PTJ7Z/lr49W6bUbkmS1V3by4uWynFiR9p7+dSq/yZzE= +github.com/spf13/viper v1.6.1 h1:VPZzIkznI1YhVMRi6vNFLHSwhnhReBfgTxIPccpfdZk= github.com/spf13/viper v1.6.1/go.mod h1:t3iDnF5Jlj76alVNuyFBk5oUMCvsrkbvZK0WQdfDi5k= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -159,6 +163,7 @@ github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1 github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/subosito/gotenv v1.2.0 h1:Slr1R9HxAlEKefgq5jn9U+DnETlIUa6HfgEzj0g5d7s= github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/twitchyliquid64/golang-asm v0.0.0-20190126203739-365674df15fc/go.mod h1:NoCfSFWosfqMqmmD7hApkirIK9ozpHjxRnRxs1l413A= @@ -207,6 +212,7 @@ golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190618155005-516e3c20635f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190712062909-fae7ac547cb7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190927073244-c990c680b611 h1:q9u40nxWT5zRClI/uU9dHCiYGottAg6Nzz4YUQyHxdA= golang.org/x/sys v0.0.0-20190927073244-c990c680b611/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -231,6 +237,7 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33 gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= gopkg.in/gemnasium/logrus-airbrake-hook.v2 v2.1.2/go.mod h1:Xk6kEKp8OKb+X14hQBKWaSkCsqBpgog8nAV2xsGOxlo= +gopkg.in/ini.v1 v1.51.0 h1:AQvPpx3LzTDM0AjnIRlVFwFFGC+npRopjZxLJj6gdno= gopkg.in/ini.v1 v1.51.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= @@ -238,9 +245,11 @@ gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bl gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.4 h1:/eiJrUcujPVeJ3xlSWaiNi3uSVmDGBK1pDHUHAnao1I= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gotest.tools v2.1.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw= gotest.tools/gotestsum v0.3.5/go.mod h1:Mnf3e5FUzXbkCfynWBGOwLssY7gTQgCHObK9tMpAriY= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= mvdan.cc/sh v2.6.4+incompatible/go.mod h1:IeeQbZq+x2SUGBensq/jge5lLQbS3XT2ktyp3wrt4x8= +nhooyr.io/websocket v1.6.5 h1:8TzpkldRfefda5JST+CnOH135bzVPz5uzfn/AF+gVKg= nhooyr.io/websocket v1.6.5/go.mod h1:F259lAzPRAH0htX2y3ehpJe09ih1aSHN7udWki1defY= diff --git a/component.go b/component.go index a57b07b..ec0e8df 100644 --- a/component.go +++ b/component.go @@ -60,11 +60,10 @@ func NewComponent(opts ComponentOptions, r *Router, errorHandler func(error)) (* // Connect triggers component connection to XMPP server component port. // TODO: Failed handshake should be a permanent error func (c *Component) Connect() error { - var state SMState - return c.Resume(state) + return c.Resume() } -func (c *Component) Resume(sm SMState) error { +func (c *Component) Resume() error { var err error var streamId string if c.ComponentOptions.TransportConfiguration.Domain == "" { @@ -73,16 +72,13 @@ func (c *Component) Resume(sm SMState) error { c.transport, err = NewComponentTransport(c.ComponentOptions.TransportConfiguration) if err != nil { c.updateState(StatePermanentError) - return NewConnError(err, true) } if streamId, err = c.transport.Connect(); err != nil { c.updateState(StatePermanentError) - return NewConnError(err, true) } - c.updateState(StateConnected) // Authentication if err := c.sendWithWriter(c.transport, []byte(fmt.Sprintf("%s", c.handshake(streamId)))); err != nil { diff --git a/config.go b/config.go index 178da2e..4609a0a 100644 --- a/config.go +++ b/config.go @@ -21,4 +21,15 @@ type Config struct { // Insecure can be set to true to allow to open a session without TLS. If TLS // is supported on the server, we will still try to use it. Insecure bool + + // Activate stream management process during session + StreamManagementEnable bool + // Enable stream management resume capability + streamManagementResume bool +} + +// IsStreamResumable tells if a stream session is resumable by reading the "config" part of a client. +// It checks if stream management is enabled, and if stream resumption was set and accepted by the server. +func IsStreamResumable(c *Client) bool { + return c.config.StreamManagementEnable && c.config.streamManagementResume } diff --git a/router.go b/router.go index f20af5b..d334d36 100644 --- a/router.go +++ b/router.go @@ -42,6 +42,17 @@ func NewRouter() *Router { // route is called by the XMPP client to dispatch stanza received using the set up routes. // It is also used by test, but is not supposed to be used directly by users of the library. func (r *Router) route(s Sender, p stanza.Packet) { + a, isA := p.(stanza.SMAnswer) + if isA { + switch tt := s.(type) { + case *Client: + lastAcked := a.H + SendMissingStz(int(lastAcked), s, tt.Session.SMState.UnAckQueue) + case *Component: + // TODO + default: + } + } iq, isIq := p.(*stanza.IQ) if isIq { r.IQResultRouteLock.RLock() @@ -70,6 +81,32 @@ func (r *Router) route(s Sender, p stanza.Packet) { } } +func SendMissingStz(lastSent int, s Sender, uaq *stanza.UnAckQueue) error { + uaq.RWMutex.Lock() + if len(uaq.Uslice) <= 0 { + uaq.RWMutex.Unlock() + return nil + } + last := uaq.Uslice[len(uaq.Uslice)-1] + if last.Id > lastSent { + // Remove sent stanzas from the queue + uaq.PopN(lastSent - last.Id) + // Re-send non acknowledged stanzas + for _, elt := range uaq.PopN(len(uaq.Uslice)) { + eltStz := elt.(*stanza.UnAckedStz) + err := s.SendRaw(eltStz.Stz) + if err != nil { + return err + } + + } + // Ask for updates on stanzas we just sent to the entity + s.Send(stanza.SMRequest{}) + } + uaq.RWMutex.Unlock() + return nil +} + func iqNotImplemented(s Sender, iq *stanza.IQ) { err := stanza.Err{ XMLName: xml.Name{Local: "error"}, diff --git a/session.go b/session.go index 182e32b..05bdce3 100644 --- a/session.go +++ b/session.go @@ -1,10 +1,11 @@ package xmpp import ( + "encoding/xml" "errors" "fmt" - "gosrc.io/xmpp/stanza" + "strconv" ) type Session struct { @@ -23,44 +24,67 @@ type Session struct { err error } -func NewSession(transport Transport, o Config, state SMState) (*Session, error) { - s := new(Session) - s.transport = transport - s.SMState = state - s.init(o) +func NewSession(c *Client, state SMState) (*Session, error) { + var s *Session + if c.Session == nil { + s = new(Session) + s.transport = c.transport + s.SMState = state + s.init() + } else { + s = c.Session + // We keep information about the previously set session, like the session ID, but we read server provided + // info again in case it changed between session break and resume, such as features. + s.init() + } if s.err != nil { return nil, NewConnError(s.err, true) } - if !transport.IsSecure() { - s.startTlsIfSupported(o) + if !c.transport.IsSecure() { + s.startTlsIfSupported(c.config) } - if !transport.IsSecure() && !o.Insecure { + if !c.transport.IsSecure() && !c.config.Insecure { err := fmt.Errorf("failed to negotiate TLS session : %s", s.err) return nil, NewConnError(err, true) } if s.TlsEnabled { - s.reset(o) + s.reset() } // auth - s.auth(o) - s.reset(o) + s.auth(c.config) + if s.err != nil { + return s, s.err + } + s.reset() + if s.err != nil { + return s, s.err + } // attempt resumption - if s.resume(o) { + if s.resume(c.config) { return s, s.err } // otherwise, bind resource and 'start' XMPP session - s.bind(o) - s.rfc3921Session(o) + s.bind(c.config) + if s.err != nil { + return s, s.err + } + s.rfc3921Session() + if s.err != nil { + return s, s.err + } // Enable stream management if supported - s.EnableStreamManagement(o) + s.EnableStreamManagement(c.config) + if s.err != nil { + return s, s.err + } return s, s.err } @@ -70,19 +94,20 @@ func (s *Session) PacketId() string { return fmt.Sprintf("%x", s.lastPacketId) } -func (s *Session) init(o Config) { - s.Features = s.open(o.parsedJid.Domain) +// init gathers information on the session such as stream features from the server. +func (s *Session) init() { + s.Features = s.extractStreamFeatures() } -func (s *Session) reset(o Config) { +func (s *Session) reset() { if s.StreamId, s.err = s.transport.StartStream(); s.err != nil { return } - s.Features = s.open(o.parsedJid.Domain) + s.Features = s.extractStreamFeatures() } -func (s *Session) open(domain string) (f stanza.StreamFeatures) { +func (s *Session) extractStreamFeatures() (f stanza.StreamFeatures) { // extract stream features if s.err = s.transport.GetDecoder().Decode(&f); s.err != nil { s.err = errors.New("stream open decode features: " + s.err.Error()) @@ -90,7 +115,7 @@ func (s *Session) open(domain string) (f stanza.StreamFeatures) { return } -func (s *Session) startTlsIfSupported(o Config) { +func (s *Session) startTlsIfSupported(o *Config) { if s.err != nil { return } @@ -125,7 +150,7 @@ func (s *Session) startTlsIfSupported(o Config) { } } -func (s *Session) auth(o Config) { +func (s *Session) auth(o *Config) { if s.err != nil { return } @@ -134,7 +159,7 @@ func (s *Session) auth(o Config) { } // Attempt to resume session using stream management -func (s *Session) resume(o Config) bool { +func (s *Session) resume(o *Config) bool { if !s.Features.DoesStreamManagement() { return false } @@ -142,9 +167,16 @@ func (s *Session) resume(o Config) bool { return false } - fmt.Fprintf(s.transport, "", - stanza.NSStreamManagement, s.SMState.Inbound, s.SMState.Id) + rsm := stanza.SMResume{ + PrevId: s.SMState.Id, + H: &s.SMState.Inbound, + } + data, err := xml.Marshal(rsm) + _, err = s.transport.Write(data) + if err != nil { + return false + } var packet stanza.Packet packet, s.err = stanza.NextPacket(s.transport.GetDecoder()) if s.err == nil { @@ -165,20 +197,48 @@ func (s *Session) resume(o Config) bool { return false } -func (s *Session) bind(o Config) { +func (s *Session) bind(o *Config) { if s.err != nil { return } // Send IQ message asking to bind to the local user name. var resource = o.parsedJid.Resource - if resource != "" { - fmt.Fprintf(s.transport, "%s", - s.PacketId(), stanza.NSBind, resource) - } else { - fmt.Fprintf(s.transport, "", s.PacketId(), stanza.NSBind) + iqB, err := stanza.NewIQ(stanza.Attrs{ + Type: stanza.IQTypeSet, + Id: s.PacketId(), + }) + if err != nil { + s.err = err + return } + // Check if we already have a resource name, and include it in the request if so + if resource != "" { + iqB.Payload = &stanza.Bind{ + Resource: resource, + } + } else { + iqB.Payload = &stanza.Bind{} + + } + + // Send the bind request IQ + data, err := xml.Marshal(iqB) + if err != nil { + s.err = err + return + } + n, err := s.transport.Write(data) + if err != nil { + s.err = err + return + } else if n == 0 { + s.err = errors.New("failed to write bind iq stanza to the server : wrote 0 bytes") + return + } + + // Check the server response var iq stanza.IQ if s.err = s.transport.GetDecoder().Decode(&iq); s.err != nil { s.err = errors.New("error decoding iq bind result: " + s.err.Error()) @@ -197,7 +257,7 @@ func (s *Session) bind(o Config) { } // After the bind, if the session is not optional (as per old RFC 3921), we send the session open iq. -func (s *Session) rfc3921Session(o Config) { +func (s *Session) rfc3921Session() { if s.err != nil { return } @@ -205,7 +265,29 @@ func (s *Session) rfc3921Session(o Config) { var iq stanza.IQ // We only negotiate session binding if it is mandatory, we skip it when optional. if !s.Features.Session.IsOptional() { - fmt.Fprintf(s.transport, "", s.PacketId(), stanza.NSSession) + se, err := stanza.NewIQ(stanza.Attrs{ + Type: stanza.IQTypeSet, + Id: s.PacketId(), + }) + if err != nil { + s.err = err + return + } + se.Payload = &stanza.StreamSession{} + data, err := xml.Marshal(se) + if err != nil { + s.err = err + return + } + n, err := s.transport.Write(data) + if err != nil { + s.err = err + return + } else if n == 0 { + s.err = errors.New("there was a problem marshaling the session IQ : wrote 0 bytes to server") + return + } + if s.err = s.transport.GetDecoder().Decode(&iq); s.err != nil { s.err = errors.New("expecting iq result after session open: " + s.err.Error()) return @@ -214,28 +296,47 @@ func (s *Session) rfc3921Session(o Config) { } // Enable stream management, with session resumption, if supported. -func (s *Session) EnableStreamManagement(o Config) { +func (s *Session) EnableStreamManagement(o *Config) { if s.err != nil { return } - if !s.Features.DoesStreamManagement() { + if !s.Features.DoesStreamManagement() || !o.StreamManagementEnable { + return + } + q := stanza.NewUnAckQueue() + ebleNonza := stanza.SMEnable{Resume: &o.streamManagementResume} + pktStr, err := xml.Marshal(ebleNonza) + if err != nil { + s.err = err + return + } + _, err = s.transport.Write(pktStr) + if err != nil { + s.err = err return } - - fmt.Fprintf(s.transport, "", stanza.NSStreamManagement) var packet stanza.Packet packet, s.err = stanza.NextPacket(s.transport.GetDecoder()) if s.err == nil { switch p := packet.(type) { case stanza.SMEnabled: - s.SMState = SMState{Id: p.Id} + // Server allows resumption or not using SMEnabled attribute "resume". We must read the server response + // and update config accordingly + b, err := strconv.ParseBool(p.Resume) + if err != nil || !b { + o.StreamManagementEnable = false + } + s.SMState = SMState{Id: p.Id, preferredReconAddr: p.Location} + s.SMState.UnAckQueue = q case stanza.SMFailed: // TODO: Store error in SMState, for later inspection + s.SMState = SMState{StreamErrorGroup: p.StreamErrorGroup} + s.SMState.UnAckQueue = q + s.err = errors.New("failed to establish session : " + s.SMState.StreamErrorGroup.GroupErrorName()) default: s.err = errors.New("unexpected reply to SM enable") } } - return } diff --git a/stanza/fifo_queue.go b/stanza/fifo_queue.go new file mode 100644 index 0000000..dcdab02 --- /dev/null +++ b/stanza/fifo_queue.go @@ -0,0 +1,32 @@ +package stanza + +// FIFO queue for string contents +// Implementations have no guarantee regarding thread safety ! +type FifoQueue interface { + // Pop returns the first inserted element still in queue and delete it from queue + // No guarantee regarding thread safety ! + Pop() Queueable + + // PopN returns the N first inserted elements still in queue and delete them from queue + // No guarantee regarding thread safety ! + PopN(i int) []Queueable + + // Peek returns a copy of the first inserted element in queue without deleting it + // No guarantee regarding thread safety ! + Peek() Queueable + + // Peek returns a copy of the first inserted element in queue without deleting it + // No guarantee regarding thread safety ! + PeekN() []Queueable + // Push adds an element to the queue + // No guarantee regarding thread safety ! + Push(s Queueable) error + + // Empty returns true if queue is empty + // No guarantee regarding thread safety ! + Empty() bool +} + +type Queueable interface { + QueueableName() string +} diff --git a/stanza/sasl_auth.go b/stanza/sasl_auth.go index 9dfe557..2fb660e 100644 --- a/stanza/sasl_auth.go +++ b/stanza/sasl_auth.go @@ -93,8 +93,8 @@ func (b *Bind) GetSet() *ResultSet { // This is the draft defining how to handle the transition: // https://tools.ietf.org/html/draft-cridland-xmpp-session-01 type StreamSession struct { - XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-session session"` - Optional bool // If element does exist, it mean we are not required to open session + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-session session"` + Optional *struct{} // If element does exist, it mean we are not required to open session // Result sets ResultSet *ResultSet `xml:"set,omitempty"` } @@ -109,7 +109,7 @@ func (s *StreamSession) GetSet() *ResultSet { func (s *StreamSession) IsOptional() bool { if s.XMLName.Local == "session" { - return s.Optional + return s.Optional != nil } // If session element is missing, then we should not use session return true diff --git a/stanza/sasl_auth_test.go b/stanza/sasl_auth_test.go index 3e37453..600035c 100644 --- a/stanza/sasl_auth_test.go +++ b/stanza/sasl_auth_test.go @@ -9,7 +9,7 @@ import ( // Check that we can detect optional session from advertised stream features func TestSessionFeatures(t *testing.T) { - streamFeatures := stanza.StreamFeatures{Session: stanza.StreamSession{Optional: true}} + streamFeatures := stanza.StreamFeatures{Session: stanza.StreamSession{Optional: &struct{}{}}} data, err := xml.Marshal(streamFeatures) if err != nil { @@ -32,7 +32,7 @@ func TestSessionIQ(t *testing.T) { if err != nil { t.Fatalf("failed to create IQ: %v", err) } - iq.Payload = &stanza.StreamSession{XMLName: xml.Name{Local: "session"}, Optional: true} + iq.Payload = &stanza.StreamSession{XMLName: xml.Name{Local: "session"}, Optional: &struct{}{}} data, err := xml.Marshal(iq) if err != nil { diff --git a/stanza/stanza_errors.go b/stanza/stanza_errors.go new file mode 100644 index 0000000..c66ea33 --- /dev/null +++ b/stanza/stanza_errors.go @@ -0,0 +1,171 @@ +package stanza + +import ( + "encoding/xml" +) + +type StanzaErrorGroup interface { + GroupErrorName() string +} + +type BadFormat struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas bad-format"` +} + +func (e *BadFormat) GroupErrorName() string { return "bad-format" } + +type BadNamespacePrefix struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas bad-namespace-prefix"` +} + +func (e *BadNamespacePrefix) GroupErrorName() string { return "bad-namespace-prefix" } + +type Conflict struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas conflict"` +} + +func (e *Conflict) GroupErrorName() string { return "conflict" } + +type ConnectionTimeout struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas connection-timeout"` +} + +func (e *ConnectionTimeout) GroupErrorName() string { return "connection-timeout" } + +type HostGone struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas host-gone"` +} + +func (e *HostGone) GroupErrorName() string { return "host-gone" } + +type HostUnknown struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas host-unknown"` +} + +func (e *HostUnknown) GroupErrorName() string { return "host-unknown" } + +type ImproperAddressing struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas improper-addressing"` +} + +func (e *ImproperAddressing) GroupErrorName() string { return "improper-addressing" } + +type InternalServerError struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas internal-server-error"` +} + +func (e *InternalServerError) GroupErrorName() string { return "internal-server-error" } + +type InvalidForm struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas invalid-from"` +} + +func (e *InvalidForm) GroupErrorName() string { return "invalid-from" } + +type InvalidId struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas invalid-id"` +} + +func (e *InvalidId) GroupErrorName() string { return "invalid-id" } + +type InvalidNamespace struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas invalid-namespace"` +} + +func (e *InvalidNamespace) GroupErrorName() string { return "invalid-namespace" } + +type InvalidXML struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas invalid-xml"` +} + +func (e *InvalidXML) GroupErrorName() string { return "invalid-xml" } + +type NotAuthorized struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas not-authorized"` +} + +func (e *NotAuthorized) GroupErrorName() string { return "not-authorized" } + +type NotWellFormed struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas not-well-formed"` +} + +func (e *NotWellFormed) GroupErrorName() string { return "not-well-formed" } + +type PolicyViolation struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas policy-violation"` +} + +func (e *PolicyViolation) GroupErrorName() string { return "policy-violation" } + +type RemoteConnectionFailed struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas remote-connection-failed"` +} + +func (e *RemoteConnectionFailed) GroupErrorName() string { return "remote-connection-failed" } + +type Reset struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas reset"` +} + +func (e *Reset) GroupErrorName() string { return "reset" } + +type ResourceConstraint struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas resource-constraint"` +} + +func (e *ResourceConstraint) GroupErrorName() string { return "resource-constraint" } + +type RestrictedXML struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas restricted-xml"` +} + +func (e *RestrictedXML) GroupErrorName() string { return "restricted-xml" } + +type SeeOtherHost struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas see-other-host"` +} + +func (e *SeeOtherHost) GroupErrorName() string { return "see-other-host" } + +type SystemShutdown struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas system-shutdown"` +} + +func (e *SystemShutdown) GroupErrorName() string { return "system-shutdown" } + +type UndefinedCondition struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas undefined-condition"` +} + +func (e *UndefinedCondition) GroupErrorName() string { return "undefined-condition" } + +type UnsupportedEncoding struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas unsupported-encoding"` +} + +type UnexpectedRequest struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas unexpected-request"` +} + +func (e *UnexpectedRequest) GroupErrorName() string { return "unexpected-request" } + +func (e *UnsupportedEncoding) GroupErrorName() string { return "unsupported-encoding" } + +type UnsupportedStanzaType struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas unsupported-stanza-type"` +} + +func (e *UnsupportedStanzaType) GroupErrorName() string { return "unsupported-stanza-type" } + +type UnsupportedVersion struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas unsupported-version"` +} + +func (e *UnsupportedVersion) GroupErrorName() string { return "unsupported-version" } + +type XMLNotWellFormed struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-stanzas xml-not-well-formed"` +} + +func (e *XMLNotWellFormed) GroupErrorName() string { return "xml-not-well-formed" } diff --git a/stanza/stream_features.go b/stanza/stream_features.go index d5bed5c..d1b6274 100644 --- a/stanza/stream_features.go +++ b/stanza/stream_features.go @@ -118,6 +118,10 @@ type streamManagement struct { XMLName xml.Name `xml:"urn:xmpp:sm:3 sm"` } +func (streamManagement) Name() string { + return "streamManagement" +} + func (sf *StreamFeatures) DoesStreamManagement() (isSupported bool) { if sf.StreamManagement.XMLName.Space+" "+sf.StreamManagement.XMLName.Local == "urn:xmpp:sm:3 sm" { return true diff --git a/stanza/stream_management.go b/stanza/stream_management.go index ddbe9cd..f48d9ca 100644 --- a/stanza/stream_management.go +++ b/stanza/stream_management.go @@ -3,12 +3,19 @@ package stanza import ( "encoding/xml" "errors" + "sync" ) const ( NSStreamManagement = "urn:xmpp:sm:3" ) +type SMEnable struct { + XMLName xml.Name `xml:"urn:xmpp:sm:3 enable"` + Max *uint `xml:"max,attr,omitempty"` + Resume *bool `xml:"resume,attr,omitempty"` +} + // Enabled as defined in Stream Management spec // Reference: https://xmpp.org/extensions/xep-0198.html#enable type SMEnabled struct { @@ -23,6 +30,94 @@ func (SMEnabled) Name() string { return "Stream Management: enabled" } +type UnAckQueue struct { + Uslice []*UnAckedStz + sync.RWMutex +} +type UnAckedStz struct { + Id int + Stz string +} + +func NewUnAckQueue() *UnAckQueue { + return &UnAckQueue{ + Uslice: make([]*UnAckedStz, 0, 10), // Capacity is 0 to comply with "Push" implementation (so that no reachable element is nil) + RWMutex: sync.RWMutex{}, + } +} + +func (u *UnAckedStz) QueueableName() string { + return "Un-acknowledged stanza" +} + +func (uaq *UnAckQueue) PeekN(n int) []Queueable { + if n <= 0 { + return []Queueable{} + } + if len(uaq.Uslice) < n { + n = len(uaq.Uslice) + } + + if len(uaq.Uslice) == 0 { + return []Queueable{} + } + var r []Queueable + for i := 0; i < n; i++ { + r = append(r, uaq.Uslice[i]) + } + return r +} + +// No guarantee regarding thread safety ! +func (uaq *UnAckQueue) Pop() Queueable { + r := uaq.Peek() + if r != nil { + uaq.Uslice = uaq.Uslice[1:] + } + return r +} + +// No guarantee regarding thread safety ! +func (uaq *UnAckQueue) PopN(n int) []Queueable { + r := uaq.PeekN(n) + uaq.Uslice = uaq.Uslice[len(r):] + return r +} + +func (uaq *UnAckQueue) Peek() Queueable { + if len(uaq.Uslice) == 0 { + return nil + } + r := uaq.Uslice[0] + return r +} + +func (uaq *UnAckQueue) Push(s Queueable) error { + pushIdx := 1 + if len(uaq.Uslice) != 0 { + pushIdx = uaq.Uslice[len(uaq.Uslice)-1].Id + 1 + } + + sStz, ok := s.(*UnAckedStz) + if !ok { + return errors.New("element in not compatible with this queue. expected an UnAckedStz") + } + + e := UnAckedStz{ + Id: pushIdx, + Stz: sStz.Stz, + } + + uaq.Uslice = append(uaq.Uslice, &e) + + return nil +} + +func (uaq *UnAckQueue) Empty() bool { + r := len(uaq.Uslice) + return r == 0 +} + // Request as defined in Stream Management spec // Reference: https://xmpp.org/extensions/xep-0198.html#acking type SMRequest struct { @@ -37,7 +132,7 @@ func (SMRequest) Name() string { // Reference: https://xmpp.org/extensions/xep-0198.html#acking type SMAnswer struct { XMLName xml.Name `xml:"urn:xmpp:sm:3 a"` - H uint `xml:"h,attr,omitempty"` + H uint `xml:"h,attr"` } func (SMAnswer) Name() string { @@ -49,24 +144,175 @@ func (SMAnswer) Name() string { type SMResumed struct { XMLName xml.Name `xml:"urn:xmpp:sm:3 resumed"` PrevId string `xml:"previd,attr,omitempty"` - H uint `xml:"h,attr,omitempty"` + H *uint `xml:"h,attr,omitempty"` } func (SMResumed) Name() string { return "Stream Management: resumed" } +// Resumed as defined in Stream Management spec +// Reference: https://xmpp.org/extensions/xep-0198.html#acking +type SMResume struct { + XMLName xml.Name `xml:"urn:xmpp:sm:3 resume"` + PrevId string `xml:"previd,attr,omitempty"` + H *uint `xml:"h,attr,omitempty"` +} + +func (SMResume) Name() string { + return "Stream Management: resume" +} + // Failed as defined in Stream Management spec // Reference: https://xmpp.org/extensions/xep-0198.html#acking type SMFailed struct { XMLName xml.Name `xml:"urn:xmpp:sm:3 failed"` - // TODO: Handle decoding error cause (need custom parsing). + H *uint `xml:"h,attr,omitempty"` + + StreamErrorGroup StanzaErrorGroup } func (SMFailed) Name() string { return "Stream Management: failed" } +func (smf *SMFailed) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { + smf.XMLName = start.Name + + // According to https://xmpp.org/rfcs/rfc3920.html#def we should have no attributes aside from the namespace + // which we don't use internally + + // decode inner elements + for { + t, err := d.Token() + if err != nil { + return err + } + + switch tt := t.(type) { + + case xml.StartElement: + // Decode sub-elements + var err error + switch tt.Name.Local { + case "bad-format": + bf := BadFormat{} + err = d.DecodeElement(&bf, &tt) + smf.StreamErrorGroup = &bf + case "bad-namespace-prefix": + bnp := BadNamespacePrefix{} + err = d.DecodeElement(&bnp, &tt) + smf.StreamErrorGroup = &bnp + case "conflict": + c := Conflict{} + err = d.DecodeElement(&c, &tt) + smf.StreamErrorGroup = &c + case "connection-timeout": + ct := ConnectionTimeout{} + err = d.DecodeElement(&ct, &tt) + smf.StreamErrorGroup = &ct + case "host-gone": + hg := HostGone{} + err = d.DecodeElement(&hg, &tt) + smf.StreamErrorGroup = &hg + case "host-unknown": + hu := HostUnknown{} + err = d.DecodeElement(&hu, &tt) + smf.StreamErrorGroup = &hu + case "improper-addressing": + ia := ImproperAddressing{} + err = d.DecodeElement(&ia, &tt) + smf.StreamErrorGroup = &ia + case "internal-server-error": + ise := InternalServerError{} + err = d.DecodeElement(&ise, &tt) + smf.StreamErrorGroup = &ise + case "invalid-from": + ifrm := InvalidForm{} + err = d.DecodeElement(&ifrm, &tt) + smf.StreamErrorGroup = &ifrm + case "invalid-id": + id := InvalidId{} + err = d.DecodeElement(&id, &tt) + smf.StreamErrorGroup = &id + case "invalid-namespace": + ins := InvalidNamespace{} + err = d.DecodeElement(&ins, &tt) + smf.StreamErrorGroup = &ins + case "invalid-xml": + ix := InvalidXML{} + err = d.DecodeElement(&ix, &tt) + smf.StreamErrorGroup = &ix + case "not-authorized": + na := NotAuthorized{} + err = d.DecodeElement(&na, &tt) + smf.StreamErrorGroup = &na + case "not-well-formed": + nwf := NotWellFormed{} + err = d.DecodeElement(&nwf, &tt) + smf.StreamErrorGroup = &nwf + case "policy-violation": + pv := PolicyViolation{} + err = d.DecodeElement(&pv, &tt) + smf.StreamErrorGroup = &pv + case "remote-connection-failed": + rcf := RemoteConnectionFailed{} + err = d.DecodeElement(&rcf, &tt) + smf.StreamErrorGroup = &rcf + case "resource-constraint": + rc := ResourceConstraint{} + err = d.DecodeElement(&rc, &tt) + smf.StreamErrorGroup = &rc + case "restricted-xml": + rx := RestrictedXML{} + err = d.DecodeElement(&rx, &tt) + smf.StreamErrorGroup = &rx + case "see-other-host": + soh := SeeOtherHost{} + err = d.DecodeElement(&soh, &tt) + smf.StreamErrorGroup = &soh + case "system-shutdown": + ss := SystemShutdown{} + err = d.DecodeElement(&ss, &tt) + smf.StreamErrorGroup = &ss + case "undefined-condition": + uc := UndefinedCondition{} + err = d.DecodeElement(&uc, &tt) + smf.StreamErrorGroup = &uc + case "unexpected-request": + ur := UnexpectedRequest{} + err = d.DecodeElement(&ur, &tt) + smf.StreamErrorGroup = &ur + case "unsupported-encoding": + ue := UnsupportedEncoding{} + err = d.DecodeElement(&ue, &tt) + smf.StreamErrorGroup = &ue + case "unsupported-stanza-type": + ust := UnsupportedStanzaType{} + err = d.DecodeElement(&ust, &tt) + smf.StreamErrorGroup = &ust + case "unsupported-version": + uv := UnsupportedVersion{} + err = d.DecodeElement(&uv, &tt) + smf.StreamErrorGroup = &uv + case "xml-not-well-formed": + xnwf := XMLNotWellFormed{} + err = d.DecodeElement(&xnwf, &tt) + smf.StreamErrorGroup = &xnwf + default: + return errors.New("error is unknown") + } + if err != nil { + return err + } + case xml.EndElement: + if tt == start.End() { + return nil + } + } + } +} + type smDecoder struct{} var sm smDecoder @@ -78,9 +324,11 @@ func (s smDecoder) decode(p *xml.Decoder, se xml.StartElement) (Packet, error) { return s.decodeEnabled(p, se) case "resumed": return s.decodeResumed(p, se) + case "resume": + return s.decodeResume(p, se) case "r": return s.decodeRequest(p, se) - case "h": + case "a": return s.decodeAnswer(p, se) case "failed": return s.decodeFailed(p, se) @@ -102,6 +350,11 @@ func (smDecoder) decodeResumed(p *xml.Decoder, se xml.StartElement) (SMResumed, return packet, err } +func (smDecoder) decodeResume(p *xml.Decoder, se xml.StartElement) (SMResume, error) { + var packet SMResume + err := p.DecodeElement(&packet, &se) + return packet, err +} func (smDecoder) decodeRequest(p *xml.Decoder, se xml.StartElement) (SMRequest, error) { var packet SMRequest err := p.DecodeElement(&packet, &se) diff --git a/stanza/stream_management_test.go b/stanza/stream_management_test.go new file mode 100644 index 0000000..8f51ba0 --- /dev/null +++ b/stanza/stream_management_test.go @@ -0,0 +1,187 @@ +package stanza_test + +import ( + "gosrc.io/xmpp/stanza" + "math/rand" + "reflect" + "testing" + "time" +) + +// TODO : tests to add +// - Pop on nil or empty slice +// - PeekN (normal and too long) + +func TestPushUnack(t *testing.T) { + uaq := initUnAckQueue() + toPush := stanza.UnAckedStz{ + Id: 3, + Stz: ` + + confucius + Qui + Kong + +`, + } + + err := uaq.Push(&toPush) + if err != nil { + t.Fatalf("could not push element to the queue : %v", err) + } + + if len(uaq.Uslice) != 4 { + t.Fatalf("push to the non-acked queue failed") + } + for i := 0; i < 4; i++ { + if uaq.Uslice[i].Id != i+1 { + t.Fatalf("indexes were not updated correctly. Expected %d got %d", i, uaq.Uslice[i].Id) + } + } + + // Check that the queue is a fifo : popped element should not be the one we just pushed. + popped := uaq.Pop() + poppedElt, ok := popped.(*stanza.UnAckedStz) + if !ok { + t.Fatalf("popped element is not a *stanza.UnAckedStz") + } + + if reflect.DeepEqual(*poppedElt, toPush) { + t.Fatalf("pushed element is at the top of the fifo queue when it should be at the bottom") + } + +} + +func TestPeekUnack(t *testing.T) { + uaq := initUnAckQueue() + + expectedPeek := stanza.UnAckedStz{ + Id: 1, + Stz: ` + + Capulet + +`, + } + + if !reflect.DeepEqual(expectedPeek, *uaq.Uslice[0]) { + t.Fatalf("peek failed to return the correct stanza") + } + +} + +func TestPopNUnack(t *testing.T) { + uaq := initUnAckQueue() + initLen := len(uaq.Uslice) + randPop := rand.Int31n(int32(initLen)) + + popped := uaq.PopN(int(randPop)) + + if len(uaq.Uslice)+len(popped) != initLen { + t.Fatalf("total length changed whith pop n operation : had %d found %d after pop", initLen, len(uaq.Uslice)+len(popped)) + } + + for _, elt := range popped { + for _, oldElt := range uaq.Uslice { + if reflect.DeepEqual(elt, oldElt) { + t.Fatalf("pop n operation duplicated some elements") + } + } + } +} + +func TestPopNUnackTooLong(t *testing.T) { + uaq := initUnAckQueue() + initLen := len(uaq.Uslice) + + // Have a random number of elements to pop that's greater than the queue size + randPop := rand.Int31n(int32(initLen)) + 1 + int32(initLen) + + popped := uaq.PopN(int(randPop)) + + if len(uaq.Uslice)+len(popped) != initLen { + t.Fatalf("total length changed whith pop n operation : had %d found %d after pop", initLen, len(uaq.Uslice)+len(popped)) + } + + for _, elt := range popped { + for _, oldElt := range uaq.Uslice { + if reflect.DeepEqual(elt, oldElt) { + t.Fatalf("pop n operation duplicated some elements") + } + } + } +} + +func TestPopUnack(t *testing.T) { + uaq := initUnAckQueue() + initLen := len(uaq.Uslice) + + popped := uaq.Pop() + + if len(uaq.Uslice)+1 != initLen { + t.Fatalf("total length changed whith pop operation : had %d found %d after pop", initLen, len(uaq.Uslice)+1) + } + for _, oldElt := range uaq.Uslice { + if reflect.DeepEqual(popped, oldElt) { + t.Fatalf("pop n operation duplicated some elements") + } + } + +} + +func initUnAckQueue() stanza.UnAckQueue { + q := []*stanza.UnAckedStz{ + { + Id: 1, + Stz: ` + + Capulet + +`, + }, + {Id: 2, + Stz: ` + +`}, + {Id: 3, + Stz: ` + + + + jabber:iq:search + + + male + + + +`}, + } + + return stanza.UnAckQueue{Uslice: q} + +} + +func init() { + rand.Seed(time.Now().UTC().UnixNano()) +} diff --git a/stream_manager.go b/stream_manager.go index ebef1fa..da23df1 100644 --- a/stream_manager.go +++ b/stream_manager.go @@ -25,7 +25,7 @@ import ( // set callback and trigger reconnection. type StreamClient interface { Connect() error - Resume(state SMState) error + Resume() error Send(packet stanza.Packet) error SendIQ(ctx context.Context, iq *stanza.IQ) (chan stanza.IQ, error) SendRaw(packet string) error @@ -75,9 +75,7 @@ func (sm *StreamManager) Run() error { } handler := func(e Event) error { - switch e.State { - case StateConnected: - sm.Metrics.setConnectTime() + switch e.State.state { case StateSessionEstablished: sm.Metrics.setLoginTime() case StateDisconnected: @@ -128,7 +126,7 @@ func (sm *StreamManager) resume(state SMState) error { // TODO: Make it possible to define logger to log disconnect and reconnection attempts sm.Metrics = initMetrics() - if err = sm.client.Resume(state); err != nil { + if err = sm.client.Resume(); err != nil { var actualErr ConnError if xerrors.As(err, &actualErr) { if actualErr.Permanent { @@ -152,11 +150,6 @@ func (sm *StreamManager) resume(state SMState) error { type Metrics struct { startTime time.Time - // ConnectTime returns the duration between client initiation of the TCP/IP - // connection to the server and actual TCP/IP session establishment. - // This time includes DNS resolution and can be slightly higher if the DNS - // resolution result was not in cache. - ConnectTime time.Duration // LoginTime returns the between client initiation of the TCP/IP // connection to the server and the return of the login result. // This includes ConnectTime, but also XMPP level protocol negotiation @@ -172,10 +165,6 @@ func initMetrics() *Metrics { } } -func (m *Metrics) setConnectTime() { - m.ConnectTime = time.Since(m.startTime) -} - func (m *Metrics) setLoginTime() { m.LoginTime = time.Since(m.startTime) } diff --git a/tcp_server_mock.go b/tcp_server_mock.go index 55740fa..d189e3a 100644 --- a/tcp_server_mock.go +++ b/tcp_server_mock.go @@ -36,6 +36,10 @@ const ( testClientRawPort testClientIqPort testClientIqFailPort + testClientPostConnectHook + + // Client internal tests + testClientStreamManagement ) // ClientHandler is passed by the test client to provide custom behaviour to diff --git a/xmpp_transport.go b/xmpp_transport.go index 092b95d..800f1b1 100644 --- a/xmpp_transport.go +++ b/xmpp_transport.go @@ -24,7 +24,8 @@ type XMPPTransport struct { readWriter io.ReadWriter logFile io.Writer isSecure bool - closeChan chan stanza.StreamClosePacket + // Used to close TCP connection when a stream close message is received from the server + closeChan chan stanza.StreamClosePacket } var componentStreamOpen = fmt.Sprintf("", stanza.NSComponent, stanza.NSStream)