Changed "Disconnect" to wait for the closing stream tag. (#141)

Updated example with a README.md and fixed some logs.
This commit is contained in:
remicorniere 2019-12-26 13:47:02 +00:00 committed by Jérôme Sautret
parent e62b7fa0c7
commit 94aceac802
13 changed files with 252 additions and 52 deletions

View File

@ -0,0 +1,51 @@
# Chat TUI example
This is a simple chat example, with a TUI.
It shows the library usage and a few of its capabilities.
## How to run
### Build
You can build the client using :
```
go build -o example_client
```
and then run with (on unix for example):
```
./example_client
```
or you can simply build + run in one command while at the example directory root, like this:
```
go run xmpp_chat_client.go interface.go
```
### Configuration
The example needs a configuration file to run. A sample file is provided.
By default, the example will look for a file named "config" in the current directory.
To provide a different configuration file, pass the following argument to the example :
```
go run xmpp_chat_client.go interface.go -c /path/to/config
```
where /path/to/config is the path to the directory containing the configuration file. The configuration file must be named
"config" and be using the yaml format.
Required fields are :
```yaml
Server :
- full_address: "localhost:5222"
Client : # This is you
- jid: "testuser2@localhost"
- pass: "pass123" #Password in a config file yay
# Contacts list, ";" separated
Contacts : "testuser1@localhost;testuser3@localhost"
# Should we log stanzas ?
LogStanzas:
- logger_on: "true"
- logfile_path: "./logs" # Path to directory, not file.
```
## How to use
Shortcuts :
- ctrl+space : switch between input window and menu window.
- While in input window :
- enter : sends a message if in message mode (see menu options)
- ctrl+e : sends a raw stanza when in raw mode (see menu options)
- ctrl+c : quit

View File

@ -1,9 +1,7 @@
# Default config for the client # Sample config for the client
Server : Server :
- full_address: "localhost:5222" - full_address: "localhost:5222"
- port: 5222
Client : Client :
- name: "testuser2"
- jid: "testuser2@localhost" - jid: "testuser2@localhost"
- pass: "pass123" #Password in a config file yay - pass: "pass123" #Password in a config file yay

View File

@ -17,6 +17,13 @@ const (
menuWindow = "mw" // Where the menu is shown menuWindow = "mw" // Where the menu is shown
disconnectMsg = "msg" disconnectMsg = "msg"
// Windows titles
chatLogWindowTitle = "Chat log"
menuWindowTitle = "Menu"
chatInputWindowTitle = "Write a message :"
rawInputWindowTitle = "Write or paste a raw stanza. Press \"Ctrl+E\" to send :"
contactsListWindowTitle = "Contacts"
// Menu options // Menu options
disconnect = "Disconnect" disconnect = "Disconnect"
askServerForRoster = "Ask server for roster" askServerForRoster = "Ask server for roster"
@ -60,7 +67,7 @@ func layout(g *gocui.Gui) error {
if !gocui.IsUnknownView(err) { if !gocui.IsUnknownView(err) {
return err return err
} }
v.Title = "Chat log" v.Title = chatLogWindowTitle
v.Wrap = true v.Wrap = true
v.Autoscroll = true v.Autoscroll = true
} }
@ -69,7 +76,7 @@ func layout(g *gocui.Gui) error {
if !gocui.IsUnknownView(err) { if !gocui.IsUnknownView(err) {
return err return err
} }
v.Title = "Contacts" v.Title = contactsListWindowTitle
v.Wrap = true v.Wrap = true
// If we set this to true, the contacts list will "fit" in the window but if the number // If we set this to true, the contacts list will "fit" in the window but if the number
// of contacts exceeds the maximum height, some contacts will be hidden... // of contacts exceeds the maximum height, some contacts will be hidden...
@ -82,7 +89,7 @@ func layout(g *gocui.Gui) error {
if !gocui.IsUnknownView(err) { if !gocui.IsUnknownView(err) {
return err return err
} }
v.Title = "Menu" v.Title = menuWindowTitle
v.Wrap = true v.Wrap = true
v.Autoscroll = true v.Autoscroll = true
fmt.Fprint(v, strings.Join(menuOptions, "\n")) fmt.Fprint(v, strings.Join(menuOptions, "\n"))
@ -95,7 +102,7 @@ func layout(g *gocui.Gui) error {
if !gocui.IsUnknownView(err) { if !gocui.IsUnknownView(err) {
return err return err
} }
v.Title = "Write or paste a raw stanza. Press \"Ctrl+E\" to send :" v.Title = rawInputWindowTitle
v.Editable = true v.Editable = true
v.Wrap = true v.Wrap = true
} }
@ -104,7 +111,7 @@ func layout(g *gocui.Gui) error {
if !gocui.IsUnknownView(err) { if !gocui.IsUnknownView(err) {
return err return err
} }
v.Title = "Write a message :" v.Title = chatInputWindowTitle
v.Editable = true v.Editable = true
v.Wrap = true v.Wrap = true

View File

@ -63,7 +63,7 @@ func main() {
// ============================================================ // ============================================================
// Parse the flag with the config directory path as argument // Parse the flag with the config directory path as argument
flag.String("c", defaultConfigFilePath, "Provide a path to the directory that contains the configuration"+ flag.String("c", defaultConfigFilePath, "Provide a path to the directory that contains the configuration"+
" file you want to use. Config file should be named \"config\" and be of YAML format..") " file you want to use. Config file should be named \"config\" and be in YAML format..")
pflag.CommandLine.AddGoFlagSet(flag.CommandLine) pflag.CommandLine.AddGoFlagSet(flag.CommandLine)
pflag.Parse() pflag.Parse()
@ -139,7 +139,8 @@ func startClient(g *gocui.Gui, config *config) {
handlerWithGui := func(_ xmpp.Sender, p stanza.Packet) { handlerWithGui := func(_ xmpp.Sender, p stanza.Packet) {
msg, ok := p.(stanza.Message) msg, ok := p.(stanza.Message)
if logger != nil { if logger != nil {
logger.Println(msg) m, _ := xml.Marshal(msg)
logger.Println(string(m))
} }
v, err := g.View(chatLogWindow) v, err := g.View(chatLogWindow)
@ -209,7 +210,7 @@ func startMessaging(client xmpp.Sender, config *config, g *gocui.Gui) {
} }
return return
case text = <-textChan: case text = <-textChan:
reply := stanza.Message{Attrs: stanza.Attrs{To: correspondent, From: config.Client[clientJid], Type: stanza.MessageTypeChat}, Body: text} reply := stanza.Message{Attrs: stanza.Attrs{To: correspondent, Type: stanza.MessageTypeChat}, Body: text}
if logger != nil { if logger != nil {
raw, _ := xml.Marshal(reply) raw, _ := xml.Marshal(reply)
logger.Println(string(raw)) logger.Println(string(raw))
@ -284,6 +285,8 @@ func errorHandler(err error) {
// If user tries to send a message to someone not registered with the server, the server will return an error. // If user tries to send a message to someone not registered with the server, the server will return an error.
func updateRosterFromConfig(g *gocui.Gui, config *config) { func updateRosterFromConfig(g *gocui.Gui, config *config) {
viewState.contacts = append(strings.Split(config.Contacts, configContactSep), backFromContacts) viewState.contacts = append(strings.Split(config.Contacts, configContactSep), backFromContacts)
// Put a "go back" button at the end of the list
viewState.contacts = append(viewState.contacts, backFromContacts)
} }
// Updates the menu panel of the view with the current user's roster, by asking the server. // Updates the menu panel of the view with the current user's roster, by asking the server.
@ -318,6 +321,7 @@ func askForRoster(client xmpp.Sender, g *gocui.Gui, config *config) {
for _, item := range rosterItems.Items { for _, item := range rosterItems.Items {
viewState.contacts = append(viewState.contacts, item.Jid) viewState.contacts = append(viewState.contacts, item.Jid)
} }
// Put a "go back" button at the end of the list
viewState.contacts = append(viewState.contacts, backFromContacts) viewState.contacts = append(viewState.contacts, backFromContacts)
fmt.Fprintln(chlw, infoFormat+"Contacts list updated !") fmt.Fprintln(chlw, infoFormat+"Contacts list updated !")
return return

View File

@ -154,7 +154,8 @@ func NewClient(config Config, r *Router, errorHandler func(error)) (c *Client, e
if config.TransportConfiguration.Domain == "" { if config.TransportConfiguration.Domain == "" {
config.TransportConfiguration.Domain = config.parsedJid.Domain config.TransportConfiguration.Domain = config.parsedJid.Domain
} }
c.transport = NewClientTransport(config.TransportConfiguration) c.config.TransportConfiguration.ConnectTimeout = c.config.ConnectTimeout
c.transport = NewClientTransport(c.config.TransportConfiguration)
if config.StreamLogger != nil { if config.StreamLogger != nil {
c.transport.LogTraffic(config.StreamLogger) c.transport.LogTraffic(config.StreamLogger)
@ -183,7 +184,24 @@ func (c *Client) Resume(state SMState) error {
// Client is ok, we now open XMPP session // Client is ok, we now open XMPP session
if c.Session, err = NewSession(c.transport, c.config, state); err != nil { if c.Session, err = NewSession(c.transport, c.config, state); err != nil {
c.transport.Close() // Try to get the stream close tag from the server.
go func() {
for {
val, err := stanza.NextPacket(c.transport.GetDecoder())
if err != nil {
c.ErrorHandler(err)
c.disconnected(state)
return
}
switch val.(type) {
case stanza.StreamClosePacket:
// TCP messages should arrive in order, so we can expect to get nothing more after this occurs
c.transport.ReceivedStreamClose()
return
}
}
}()
c.Disconnect()
return err return err
} }
c.Session.StreamId = streamId c.Session.StreamId = streamId
@ -205,15 +223,12 @@ func (c *Client) Resume(state SMState) error {
return err return err
} }
func (c *Client) Disconnect() { func (c *Client) Disconnect() error {
// TODO : Wait for server response for clean disconnect
presence := stanza.NewPresence(stanza.Attrs{From: c.config.Jid})
presence.Type = stanza.PresenceTypeUnavailable
c.Send(presence)
c.SendRaw(stanza.StreamClose)
if c.transport != nil { if c.transport != nil {
_ = c.transport.Close() return c.transport.Close()
} }
// No transport so no connection.
return nil
} }
func (c *Client) SetHandler(handler EventHandler) { func (c *Client) SetHandler(handler EventHandler) {
@ -294,7 +309,8 @@ func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}) {
close(keepaliveQuit) close(keepaliveQuit)
c.streamError(packet.Error.Local, packet.Text) c.streamError(packet.Error.Local, packet.Text)
c.ErrorHandler(errors.New("stream error: " + packet.Error.Local)) c.ErrorHandler(errors.New("stream error: " + packet.Error.Local))
return // We don't return here, because we want to wait for the stream close tag from the server, or timeout.
c.Disconnect()
// Process Stream management nonzas // Process Stream management nonzas
case stanza.SMRequest: case stanza.SMRequest:
answer := stanza.SMAnswer{XMLName: xml.Name{ answer := stanza.SMAnswer{XMLName: xml.Name{
@ -306,6 +322,10 @@ func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}) {
c.ErrorHandler(err) c.ErrorHandler(err)
return return
} }
case stanza.StreamClosePacket:
// TCP messages should arrive in order, so we can expect to get nothing more after this occurs
c.transport.ReceivedStreamClose()
return
default: default:
state.Inbound++ state.Inbound++
} }

View File

@ -67,7 +67,10 @@ func TestClient_Connect(t *testing.T) {
func TestClient_NoInsecure(t *testing.T) { func TestClient_NoInsecure(t *testing.T) {
// Setup Mock server // Setup Mock server
mock := ServerMock{} mock := ServerMock{}
mock.Start(t, testXMPPAddress, handlerAbortTLS) mock.Start(t, testXMPPAddress, func(t *testing.T, sc *ServerConn) {
handlerAbortTLS(t, sc)
closeConn(t, sc)
})
// Test / Check result // Test / Check result
config := Config{ config := Config{
@ -97,7 +100,10 @@ func TestClient_NoInsecure(t *testing.T) {
func TestClient_FeaturesTracking(t *testing.T) { func TestClient_FeaturesTracking(t *testing.T) {
// Setup Mock server // Setup Mock server
mock := ServerMock{} mock := ServerMock{}
mock.Start(t, testXMPPAddress, handlerAbortTLS) mock.Start(t, testXMPPAddress, func(t *testing.T, sc *ServerConn) {
handlerAbortTLS(t, sc)
closeConn(t, sc)
})
// Test / Check result // Test / Check result
config := Config{ config := Config{
@ -247,6 +253,7 @@ func TestClient_SendRaw(t *testing.T) {
handlerClientConnectSuccess(t, sc) handlerClientConnectSuccess(t, sc)
discardPresence(t, sc) discardPresence(t, sc)
respondToIQ(t, sc) respondToIQ(t, sc)
closeConn(t, sc)
done <- struct{}{} done <- struct{}{}
} }
type testCase struct { type testCase struct {
@ -290,6 +297,7 @@ func TestClient_SendRaw(t *testing.T) {
select { select {
// We don't use the default "long" timeout here because waiting it out means passing the test. // We don't use the default "long" timeout here because waiting it out means passing the test.
case <-time.After(100 * time.Millisecond): case <-time.After(100 * time.Millisecond):
c.Disconnect()
case err = <-errChan: case err = <-errChan:
if err == nil && tcase.shouldErr { if err == nil && tcase.shouldErr {
t.Errorf("Failed to get closing stream err") t.Errorf("Failed to get closing stream err")
@ -297,7 +305,6 @@ func TestClient_SendRaw(t *testing.T) {
t.Errorf("This test is not supposed to err !") t.Errorf("This test is not supposed to err !")
} }
} }
c.transport.Close()
select { select {
case <-done: case <-done:
m.Stop() m.Stop()
@ -309,7 +316,10 @@ func TestClient_SendRaw(t *testing.T) {
} }
func TestClient_Disconnect(t *testing.T) { func TestClient_Disconnect(t *testing.T) {
c, m := mockClientConnection(t, handlerClientConnectSuccess, testClientBasePort) c, m := mockClientConnection(t, func(t *testing.T, sc *ServerConn) {
handlerClientConnectSuccess(t, sc)
closeConn(t, sc)
}, testClientBasePort)
err := c.transport.Ping() err := c.transport.Ping()
if err != nil { if err != nil {
t.Errorf("Could not ping but not disconnected yet") t.Errorf("Could not ping but not disconnected yet")
@ -326,7 +336,10 @@ func TestClient_DisconnectStreamManager(t *testing.T) {
// Init mock server // Init mock server
// Setup Mock server // Setup Mock server
mock := ServerMock{} mock := ServerMock{}
mock.Start(t, testXMPPAddress, handlerAbortTLS) mock.Start(t, testXMPPAddress, func(t *testing.T, sc *ServerConn) {
handlerAbortTLS(t, sc)
closeConn(t, sc)
})
// Test / Check result // Test / Check result
config := Config{ config := Config{
@ -375,6 +388,23 @@ func handlerClientConnectSuccess(t *testing.T, sc *ServerConn) {
bind(t, sc) bind(t, sc)
} }
// closeConn closes the connection on request from the client
func closeConn(t *testing.T, sc *ServerConn) {
for {
cls, err := stanza.NextPacket(sc.decoder)
if err != nil {
t.Errorf("cannot read from socket: %s", err)
return
}
switch cls.(type) {
case stanza.StreamClosePacket:
fmt.Fprintf(sc.connection, stanza.StreamClose)
return
}
}
}
// We expect client will abort on TLS // We expect client will abort on TLS
func handlerAbortTLS(t *testing.T, sc *ServerConn) { func handlerAbortTLS(t *testing.T, sc *ServerConn) {
checkClientOpenStream(t, sc) checkClientOpenStream(t, sc)

View File

@ -113,11 +113,13 @@ func (c *Component) Resume(sm SMState) error {
} }
} }
func (c *Component) Disconnect() { func (c *Component) Disconnect() error {
// TODO: Add a way to wait for stream close acknowledgement from the server for clean disconnect // TODO: Add a way to wait for stream close acknowledgement from the server for clean disconnect
if c.transport != nil { if c.transport != nil {
_ = c.transport.Close() return c.transport.Close()
} }
// No transport so no connection.
return nil
} }
func (c *Component) SetHandler(handler EventHandler) { func (c *Component) SetHandler(handler EventHandler) {
@ -126,7 +128,6 @@ func (c *Component) SetHandler(handler EventHandler) {
// Receiver Go routine receiver // Receiver Go routine receiver
func (c *Component) recv() { func (c *Component) recv() {
for { for {
val, err := stanza.NextPacket(c.transport.GetDecoder()) val, err := stanza.NextPacket(c.transport.GetDecoder())
if err != nil { if err != nil {
@ -140,6 +141,11 @@ func (c *Component) recv() {
c.router.route(c, val) c.router.route(c, val)
c.streamError(p.Error.Local, p.Text) c.streamError(p.Error.Local, p.Text)
c.ErrorHandler(errors.New("stream error: " + p.Error.Local)) c.ErrorHandler(errors.New("stream error: " + p.Error.Local))
// We don't return here, because we want to wait for the stream close tag from the server, or timeout.
c.Disconnect()
case stanza.StreamClosePacket:
// TCP messages should arrive in order, so we can expect to get nothing more after this occurs
c.transport.ReceivedStreamClose()
return return
} }
c.router.route(c, val) c.router.route(c, val)

View File

@ -50,11 +50,20 @@ func InitStream(p *xml.Decoder) (sessionID string, err error) {
// TODO make auth and bind use NextPacket instead of directly NextStart // TODO make auth and bind use NextPacket instead of directly NextStart
func NextPacket(p *xml.Decoder) (Packet, error) { func NextPacket(p *xml.Decoder) (Packet, error) {
// Read start element to find out how we want to parse the XMPP packet // Read start element to find out how we want to parse the XMPP packet
se, err := NextStart(p) t, err := NextXmppToken(p)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if ee, ok := t.(xml.EndElement); ok {
return decodeStream(p, ee)
}
// If not an end element, then must be a start
se, ok := t.(xml.StartElement)
if !ok {
return nil, errors.New("unknown token ")
}
// Decode one of the top level XMPP namespace // Decode one of the top level XMPP namespace
switch se.Name.Space { switch se.Name.Space {
case NSStream: case NSStream:
@ -73,7 +82,29 @@ func NextPacket(p *xml.Decoder) (Packet, error) {
} }
} }
// Scan XML token stream to find next StartElement. // NextXmppToken scans XML token stream to find next StartElement or stream EndElement.
// We need the EndElement scan, because we must register stream close tags
func NextXmppToken(p *xml.Decoder) (xml.Token, error) {
for {
t, err := p.Token()
if err == io.EOF {
return xml.StartElement{}, errors.New("connection closed")
}
if err != nil {
return xml.StartElement{}, fmt.Errorf("NextStart %s", err)
}
switch t := t.(type) {
case xml.StartElement:
return t, nil
case xml.EndElement:
if t.Name.Space == NSStream && t.Name.Local == "stream" {
return t, nil
}
}
}
}
// NextStart scans XML token stream to find next StartElement.
func NextStart(p *xml.Decoder) (xml.StartElement, error) { func NextStart(p *xml.Decoder) (xml.StartElement, error) {
for { for {
t, err := p.Token() t, err := p.Token()
@ -97,7 +128,8 @@ TODO: From all the decoder, we can return a pointer to the actual concrete type,
*/ */
// decodeStream will fully decode a stream packet // decodeStream will fully decode a stream packet
func decodeStream(p *xml.Decoder, se xml.StartElement) (Packet, error) { func decodeStream(p *xml.Decoder, t xml.Token) (Packet, error) {
if se, ok := t.(xml.StartElement); ok {
switch se.Name.Local { switch se.Name.Local {
case "error": case "error":
return streamError.decode(p, se) return streamError.decode(p, se)
@ -107,6 +139,18 @@ func decodeStream(p *xml.Decoder, se xml.StartElement) (Packet, error) {
return nil, errors.New("unexpected XMPP packet " + return nil, errors.New("unexpected XMPP packet " +
se.Name.Space + " <" + se.Name.Local + "/>") se.Name.Space + " <" + se.Name.Local + "/>")
} }
}
if ee, ok := t.(xml.EndElement); ok {
if ee.Name.Local == "stream" {
return streamClose.decode(ee), nil
}
return nil, errors.New("unexpected XMPP packet " +
ee.Name.Space + " <" + ee.Name.Local + "/>")
}
// Should not happen
return nil, errors.New("unexpected XML token ")
} }
// decodeSASL decodes a packet related to SASL authentication. // decodeSASL decodes a packet related to SASL authentication.

View File

@ -165,3 +165,21 @@ func (streamErrorDecoder) decode(p *xml.Decoder, se xml.StartElement) (StreamErr
err := p.DecodeElement(&packet, &se) err := p.DecodeElement(&packet, &se)
return packet, err return packet, err
} }
// ============================================================================
// StreamClose "Packet"
// This is just a closing tag and hold no information
type StreamClosePacket struct{}
func (StreamClosePacket) Name() string {
return "stream:stream"
}
type streamCloseDecoder struct{}
var streamClose streamCloseDecoder
func (streamCloseDecoder) decode(_ xml.EndElement) StreamClosePacket {
return StreamClosePacket{}
}

View File

@ -29,7 +29,7 @@ type StreamClient interface {
Send(packet stanza.Packet) error Send(packet stanza.Packet) error
SendIQ(ctx context.Context, iq stanza.IQ) (chan stanza.IQ, error) SendIQ(ctx context.Context, iq stanza.IQ) (chan stanza.IQ, error)
SendRaw(packet string) error SendRaw(packet string) error
Disconnect() Disconnect() error
SetHandler(handler EventHandler) SetHandler(handler EventHandler)
} }

View File

@ -40,6 +40,9 @@ type Transport interface {
Read(p []byte) (n int, err error) Read(p []byte) (n int, err error)
Write(p []byte) (n int, err error) Write(p []byte) (n int, err error)
Close() error Close() error
// ReceivedStreamClose signals to the transport that a </stream:stream> has been received and that the tcp connection
// should be closed.
ReceivedStreamClose()
} }
// NewClientTransport creates a new Transport instance for clients. // NewClientTransport creates a new Transport instance for clients.

View File

@ -18,7 +18,7 @@ const maxPacketSize = 32768
const pingTimeout = time.Duration(5) * time.Second const pingTimeout = time.Duration(5) * time.Second
var ServerDoesNotSupportXmppOverWebsocket = errors.New("The websocket server does not support the xmpp subprotocol") var ServerDoesNotSupportXmppOverWebsocket = errors.New("the websocket server does not support the xmpp subprotocol")
// The decoder is expected to be initialized after connecting to a server. // The decoder is expected to be initialized after connecting to a server.
type WebsocketTransport struct { type WebsocketTransport struct {
@ -47,6 +47,7 @@ func (t *WebsocketTransport) Connect() (string, error) {
wsConn, response, err := websocket.Dial(ctx, t.Config.Address, &websocket.DialOptions{ wsConn, response, err := websocket.Dial(ctx, t.Config.Address, &websocket.DialOptions{
Subprotocols: []string{"xmpp"}, Subprotocols: []string{"xmpp"},
}) })
if err != nil { if err != nil {
return "", NewConnError(err, true) return "", NewConnError(err, true)
} }
@ -177,3 +178,8 @@ func (t *WebsocketTransport) cleanup(code websocket.StatusCode) error {
} }
return err return err
} }
// ReceivedStreamClose is not used for websockets for now
func (t *WebsocketTransport) ReceivedStreamClose() {
return
}

View File

@ -24,6 +24,7 @@ type XMPPTransport struct {
readWriter io.ReadWriter readWriter io.ReadWriter
logFile io.Writer logFile io.Writer
isSecure bool isSecure bool
closeChan chan stanza.StreamClosePacket
} }
var componentStreamOpen = fmt.Sprintf("<?xml version='1.0'?><stream:stream to='%%s' xmlns='%s' xmlns:stream='%s'>", stanza.NSComponent, stanza.NSStream) var componentStreamOpen = fmt.Sprintf("<?xml version='1.0'?><stream:stream to='%%s' xmlns='%s' xmlns:stream='%s'>", stanza.NSComponent, stanza.NSStream)
@ -38,13 +39,14 @@ func (t *XMPPTransport) Connect() (string, error) {
return "", NewConnError(err, true) return "", NewConnError(err, true)
} }
t.closeChan = make(chan stanza.StreamClosePacket)
t.readWriter = newStreamLogger(t.conn, t.logFile) t.readWriter = newStreamLogger(t.conn, t.logFile)
t.decoder = xml.NewDecoder(bufio.NewReaderSize(t.readWriter, maxPacketSize)) t.decoder = xml.NewDecoder(bufio.NewReaderSize(t.readWriter, maxPacketSize))
t.decoder.CharsetReader = t.Config.CharsetReader t.decoder.CharsetReader = t.Config.CharsetReader
return t.StartStream() return t.StartStream()
} }
func (t XMPPTransport) StartStream() (string, error) { func (t *XMPPTransport) StartStream() (string, error) {
if _, err := fmt.Fprintf(t, t.openStatement, t.Config.Domain); err != nil { if _, err := fmt.Fprintf(t, t.openStatement, t.Config.Domain); err != nil {
t.Close() t.Close()
return "", NewConnError(err, true) return "", NewConnError(err, true)
@ -58,19 +60,19 @@ func (t XMPPTransport) StartStream() (string, error) {
return sessionID, nil return sessionID, nil
} }
func (t XMPPTransport) DoesStartTLS() bool { func (t *XMPPTransport) DoesStartTLS() bool {
return true return true
} }
func (t XMPPTransport) GetDomain() string { func (t *XMPPTransport) GetDomain() string {
return t.Config.Domain return t.Config.Domain
} }
func (t XMPPTransport) GetDecoder() *xml.Decoder { func (t *XMPPTransport) GetDecoder() *xml.Decoder {
return t.decoder return t.decoder
} }
func (t XMPPTransport) IsSecure() bool { func (t *XMPPTransport) IsSecure() bool {
return t.isSecure return t.isSecure
} }
@ -105,7 +107,7 @@ func (t *XMPPTransport) StartTLS() error {
return nil return nil
} }
func (t XMPPTransport) Ping() error { func (t *XMPPTransport) Ping() error {
n, err := t.conn.Write([]byte("\n")) n, err := t.conn.Write([]byte("\n"))
if err != nil { if err != nil {
return err return err
@ -116,24 +118,31 @@ func (t XMPPTransport) Ping() error {
return nil return nil
} }
func (t XMPPTransport) Read(p []byte) (n int, err error) { func (t *XMPPTransport) Read(p []byte) (n int, err error) {
if t.readWriter == nil { if t.readWriter == nil {
return 0, errors.New("cannot read: not connected, no readwriter") return 0, errors.New("cannot read: not connected, no readwriter")
} }
return t.readWriter.Read(p) return t.readWriter.Read(p)
} }
func (t XMPPTransport) Write(p []byte) (n int, err error) { func (t *XMPPTransport) Write(p []byte) (n int, err error) {
if t.readWriter == nil { if t.readWriter == nil {
return 0, errors.New("cannot write: not connected, no readwriter") return 0, errors.New("cannot write: not connected, no readwriter")
} }
return t.readWriter.Write(p) return t.readWriter.Write(p)
} }
func (t XMPPTransport) Close() error { func (t *XMPPTransport) Close() error {
if t.readWriter != nil { if t.readWriter != nil {
_, _ = t.readWriter.Write([]byte("</stream:stream>")) _, _ = t.readWriter.Write([]byte(stanza.StreamClose))
} }
// Try to wait for the stream close tag from the server. After a timeout, disconnect anyway.
select {
case <-t.closeChan:
case <-time.After(time.Duration(t.Config.ConnectTimeout) * time.Second):
}
if t.conn != nil { if t.conn != nil {
return t.conn.Close() return t.conn.Close()
} }
@ -143,3 +152,7 @@ func (t XMPPTransport) Close() error {
func (t *XMPPTransport) LogTraffic(logFile io.Writer) { func (t *XMPPTransport) LogTraffic(logFile io.Writer) {
t.logFile = logFile t.logFile = logFile
} }
func (t *XMPPTransport) ReceivedStreamClose() {
t.closeChan <- stanza.StreamClosePacket{}
}