diff --git a/session.go b/session.go index ab7191e..04678b7 100644 --- a/session.go +++ b/session.go @@ -73,9 +73,10 @@ func (s *Session) init(o Config) { } func (s *Session) reset(o Config) { - if s.err != nil { + if s.StreamId, s.err = s.transport.StartStream(); s.err != nil { return } + s.Features = s.open(o.parsedJid.Domain) } @@ -108,7 +109,7 @@ func (s *Session) startTlsIfSupported(o Config) { return } - s.StreamId, s.err = s.transport.StartTLS() + s.err = s.transport.StartTLS() if s.err == nil { s.TlsEnabled = true diff --git a/stanza/ns.go b/stanza/ns.go index e955a35..f1be9b1 100644 --- a/stanza/ns.go +++ b/stanza/ns.go @@ -6,6 +6,7 @@ const ( NSSASL = "urn:ietf:params:xml:ns:xmpp-sasl" NSBind = "urn:ietf:params:xml:ns:xmpp-bind" NSSession = "urn:ietf:params:xml:ns:xmpp-session" + NSFraming = "urn:ietf:params:xml:ns:xmpp-framing" NSClient = "jabber:client" NSComponent = "jabber:component:accept" ) diff --git a/stanza/parser.go b/stanza/parser.go index cdd8b70..75f78e7 100644 --- a/stanza/parser.go +++ b/stanza/parser.go @@ -24,8 +24,10 @@ func InitStream(p *xml.Decoder) (sessionID string, err error) { switch elem := t.(type) { case xml.StartElement: - if elem.Name.Space != NSStream || elem.Name.Local != "stream" { - err = errors.New("xmpp: expected but got <" + elem.Name.Local + "> in " + elem.Name.Space) + isStreamOpen := elem.Name.Space == NSStream && elem.Name.Local == "stream" + isFrameOpen := elem.Name.Space == NSFraming && elem.Name.Local == "open" + if !isStreamOpen && !isFrameOpen { + err = errors.New("xmpp: expected or but got <" + elem.Name.Local + "> in " + elem.Name.Space) return sessionID, err } diff --git a/transport.go b/transport.go index 2e44381..22006c9 100644 --- a/transport.go +++ b/transport.go @@ -27,10 +27,11 @@ type TransportConfiguration struct { type Transport interface { Connect() (string, error) DoesStartTLS() bool - StartTLS() (string, error) + StartTLS() error LogTraffic(logFile io.Writer) + StartStream() (string, error) GetDecoder() *xml.Decoder IsSecure() bool diff --git a/websocket_transport.go b/websocket_transport.go index ad422c1..bf43611 100644 --- a/websocket_transport.go +++ b/websocket_transport.go @@ -57,29 +57,24 @@ func (t *WebsocketTransport) Connect() (string, error) { t.wsConn = wsConn t.startReader() - handshake := fmt.Sprintf(``, t.Config.Domain) - if _, err = t.Write([]byte(handshake)); err != nil { - t.cleanup(websocket.StatusBadGateway) - return "", NewConnError(err, false) - } - - handshakeResponse := make([]byte, 2048) - if _, err = t.Read(handshakeResponse); err != nil { - t.cleanup(websocket.StatusBadGateway) - - return "", NewConnError(err, false) - } - - var openResponse = stanza.WebsocketOpen{} - if err = xml.Unmarshal(handshakeResponse, &openResponse); err != nil { - t.cleanup(websocket.StatusBadGateway) - return "", NewConnError(err, false) - } - t.decoder = xml.NewDecoder(t) t.decoder.CharsetReader = t.Config.CharsetReader - return openResponse.Id, nil + return t.StartStream() +} + +func (t WebsocketTransport) StartStream() (string, error) { + if _, err := fmt.Fprintf(t, ``, t.Config.Domain); err != nil { + t.cleanup(websocket.StatusBadGateway) + return "", NewConnError(err, true) + } + + sessionID, err := stanza.InitStream(t.GetDecoder()) + if err != nil { + t.Close() + return "", NewConnError(err, false) + } + return sessionID, nil } // startReader runs a go function that keeps reading from the websocket. This @@ -108,14 +103,18 @@ func (t WebsocketTransport) startReader() { }() } -func (t WebsocketTransport) StartTLS() (string, error) { - return "", ErrTLSNotSupported +func (t WebsocketTransport) StartTLS() error { + return ErrTLSNotSupported } func (t WebsocketTransport) DoesStartTLS() bool { return false } +func (t WebsocketTransport) GetDomain() string { + return t.Config.Domain +} + func (t WebsocketTransport) GetDecoder() *xml.Decoder { return t.decoder } @@ -152,20 +151,21 @@ func (t WebsocketTransport) Write(p []byte) (int, error) { func (t WebsocketTransport) Close() error { t.Write([]byte("")) - return t.wsConn.Close(websocket.StatusGoingAway, "Done") + return t.cleanup(websocket.StatusGoingAway) } func (t *WebsocketTransport) LogTraffic(logFile io.Writer) { t.logFile = logFile } -func (t *WebsocketTransport) cleanup(code websocket.StatusCode) { +func (t *WebsocketTransport) cleanup(code websocket.StatusCode) error { + var err error if t.queue != nil { close(t.queue) t.queue = nil } if t.wsConn != nil { - t.wsConn.Close(websocket.StatusGoingAway, "Done") + err = t.wsConn.Close(websocket.StatusGoingAway, "Done") t.wsConn = nil } if t.closeFunc != nil { @@ -173,4 +173,5 @@ func (t *WebsocketTransport) cleanup(code websocket.StatusCode) { t.closeFunc = nil t.closeCtx = nil } + return err } diff --git a/xmpp_transport.go b/xmpp_transport.go index edcabdf..a1a11be 100644 --- a/xmpp_transport.go +++ b/xmpp_transport.go @@ -37,20 +37,20 @@ func (t *XMPPTransport) Connect() (string, error) { } t.readWriter = newStreamLogger(t.conn, t.logFile) - return t.startStream() + t.decoder = xml.NewDecoder(t.readWriter) + t.decoder.CharsetReader = t.Config.CharsetReader + return t.StartStream() } -func (t *XMPPTransport) startStream() (string, error) { - if _, err := fmt.Fprintf(t.readWriter, t.openStatement, t.Config.Domain); err != nil { - t.conn.Close() +func (t XMPPTransport) StartStream() (string, error) { + if _, err := fmt.Fprintf(t, t.openStatement, t.Config.Domain); err != nil { + t.Close() return "", NewConnError(err, true) } - t.decoder = xml.NewDecoder(t.readWriter) - t.decoder.CharsetReader = t.Config.CharsetReader - sessionID, err := stanza.InitStream(t.decoder) + sessionID, err := stanza.InitStream(t.GetDecoder()) if err != nil { - t.conn.Close() + t.Close() return "", NewConnError(err, false) } return sessionID, nil @@ -60,6 +60,10 @@ func (t XMPPTransport) DoesStartTLS() bool { return true } +func (t XMPPTransport) GetDomain() string { + return t.Config.Domain +} + func (t XMPPTransport) GetDecoder() *xml.Decoder { return t.decoder } @@ -68,7 +72,7 @@ func (t XMPPTransport) IsSecure() bool { return t.isSecure } -func (t *XMPPTransport) StartTLS() (string, error) { +func (t *XMPPTransport) StartTLS() error { if t.Config.TLSConfig == nil { t.TLSConfig = &tls.Config{} } else { @@ -81,7 +85,7 @@ func (t *XMPPTransport) StartTLS() (string, error) { tlsConn := tls.Client(t.conn, t.TLSConfig) // We convert existing connection to TLS if err := tlsConn.Handshake(); err != nil { - return "", err + return err } t.conn = tlsConn @@ -91,13 +95,12 @@ func (t *XMPPTransport) StartTLS() (string, error) { if !t.TLSConfig.InsecureSkipVerify { if err := tlsConn.VerifyHostname(t.Config.Domain); err != nil { - return "", err + return err } } t.isSecure = true - - return t.startStream() + return nil } func (t XMPPTransport) Ping() error {