diff --git a/_examples/delegation/delegation.go b/_examples/delegation/delegation.go
index 61e3716..81642c6 100644
--- a/_examples/delegation/delegation.go
+++ b/_examples/delegation/delegation.go
@@ -13,6 +13,7 @@ func main() {
opts := xmpp.ComponentOptions{
TransportConfiguration: xmpp.TransportConfiguration{
Address: "localhost:9999",
+ Domain: "service.localhost",
},
Domain: "service.localhost",
Secret: "mypass",
diff --git a/_examples/xmpp_component/xmpp_component.go b/_examples/xmpp_component/xmpp_component.go
index fc07f05..e36b287 100644
--- a/_examples/xmpp_component/xmpp_component.go
+++ b/_examples/xmpp_component/xmpp_component.go
@@ -12,6 +12,7 @@ func main() {
opts := xmpp.ComponentOptions{
TransportConfiguration: xmpp.TransportConfiguration{
Address: "localhost:8888",
+ Domain: "service2.localhost",
},
Domain: "service2.localhost",
Secret: "mypass",
diff --git a/_examples/xmpp_echo/xmpp_echo.go b/_examples/xmpp_echo/xmpp_echo.go
index 5654a2b..63e36b6 100644
--- a/_examples/xmpp_echo/xmpp_echo.go
+++ b/_examples/xmpp_echo/xmpp_echo.go
@@ -16,7 +16,8 @@ import (
func main() {
config := xmpp.Config{
TransportConfiguration: xmpp.TransportConfiguration{
- Address: "localhost:5222",
+ // Address: "localhost:5222",
+ Address: "ws://127.0.0.1:5280/xmpp",
},
Jid: "test@localhost",
Credential: xmpp.Password("test"),
diff --git a/client.go b/client.go
index 9c74da7..3d7c8b4 100644
--- a/client.go
+++ b/client.go
@@ -141,8 +141,15 @@ func NewClient(config Config, r *Router) (c *Client, err error) {
c.config.ConnectTimeout = 15 // 15 second as default
}
+ if config.TransportConfiguration.Domain == "" {
+ config.TransportConfiguration.Domain = config.parsedJid.Domain
+ }
c.transport = NewTransport(config.TransportConfiguration)
+ if config.StreamLogger != nil {
+ c.transport.LogTraffic(config.StreamLogger)
+ }
+
return
}
@@ -158,7 +165,7 @@ func (c *Client) Connect() error {
func (c *Client) Resume(state SMState) error {
var err error
- err = c.transport.Connect()
+ streamId, err := c.transport.Connect()
if err != nil {
return err
}
@@ -168,6 +175,7 @@ func (c *Client) Resume(state SMState) error {
if c.Session, err = NewSession(c.transport, c.config, state); err != nil {
return err
}
+ c.Session.StreamId = streamId
c.updateState(StateSessionEstablished)
// Start the keepalive go routine
@@ -181,13 +189,12 @@ func (c *Client) Resume(state SMState) error {
//fmt.Fprintf(client.conn, "%s%s", "chat", "Online")
// TODO: Do we always want to send initial presence automatically ?
// Do we need an option to avoid that or do we rely on client to send the presence itself ?
- fmt.Fprintf(c.Session.streamLogger, "")
+ fmt.Fprintf(c.transport, "")
return err
}
func (c *Client) Disconnect() {
- _ = c.SendRaw("")
// TODO: Add a way to wait for stream close acknowledgement from the server for clean disconnect
if c.transport != nil {
_ = c.transport.Close()
@@ -210,7 +217,7 @@ func (c *Client) Send(packet stanza.Packet) error {
return errors.New("cannot marshal packet " + err.Error())
}
- return c.sendWithWriter(c.Session.streamLogger, data)
+ return c.sendWithWriter(c.transport, data)
}
// SendRaw sends an XMPP stanza as a string to the server.
@@ -223,7 +230,7 @@ func (c *Client) SendRaw(packet string) error {
return errors.New("client is not connected")
}
- return c.sendWithWriter(c.Session.streamLogger, []byte(packet))
+ return c.sendWithWriter(c.transport, []byte(packet))
}
func (c *Client) sendWithWriter(writer io.Writer, packet []byte) error {
@@ -238,7 +245,7 @@ func (c *Client) sendWithWriter(writer io.Writer, packet []byte) error {
// Loop: Receive data from server
func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}) (err error) {
for {
- val, err := stanza.NextPacket(c.Session.decoder)
+ val, err := stanza.NextPacket(c.transport.GetDecoder())
if err != nil {
close(keepaliveQuit)
c.disconnected(state)
diff --git a/component.go b/component.go
index 3301c96..137bc05 100644
--- a/component.go
+++ b/component.go
@@ -67,33 +67,25 @@ func (c *Component) Connect() error {
}
func (c *Component) Resume(sm SMState) error {
var err error
+ var streamId string
+ if c.ComponentOptions.TransportConfiguration.Domain == "" {
+ c.ComponentOptions.TransportConfiguration.Domain = c.ComponentOptions.Domain
+ }
c.transport = NewTransport(c.ComponentOptions.TransportConfiguration)
- if err = c.transport.Connect(); err != nil {
+
+ if streamId, err = c.transport.Connect(); err != nil {
+ c.updateState(StateStreamError)
return err
}
c.updateState(StateConnected)
- // 1. Send stream open tag
- if _, err := fmt.Fprintf(c.transport, componentStreamOpen, c.Domain, stanza.NSComponent, stanza.NSStream); err != nil {
- c.updateState(StateStreamError)
- return NewConnError(errors.New("cannot send stream open "+err.Error()), false)
- }
- c.decoder = xml.NewDecoder(c.transport)
-
- // 2. Initialize xml decoder and extract streamID from reply
- streamId, err := stanza.InitStream(c.decoder)
- if err != nil {
- c.updateState(StateStreamError)
- return NewConnError(errors.New("cannot init decoder "+err.Error()), false)
- }
-
- // 3. Authentication
+ // Authentication
if _, err := fmt.Fprintf(c.transport, "%s", c.handshake(streamId)); err != nil {
c.updateState(StateStreamError)
return NewConnError(errors.New("cannot send handshake "+err.Error()), false)
}
- // 4. Check server response for authentication
+ // Check server response for authentication
val, err := stanza.NextPacket(c.decoder)
if err != nil {
c.updateState(StateDisconnected)
@@ -116,7 +108,6 @@ func (c *Component) Resume(sm SMState) error {
}
func (c *Component) Disconnect() {
- _ = c.SendRaw("")
// TODO: Add a way to wait for stream close acknowledgement from the server for clean disconnect
if c.transport != nil {
_ = c.transport.Close()
diff --git a/config.go b/config.go
index 178feae..6e836da 100644
--- a/config.go
+++ b/config.go
@@ -1,7 +1,6 @@
package xmpp
import (
- "io"
"os"
)
@@ -18,6 +17,5 @@ type Config struct {
ConnectTimeout int // Client timeout in seconds. Default to 15
// Insecure can be set to true to allow to open a session without TLS. If TLS
// is supported on the server, we will still try to use it.
- Insecure bool
- CharsetReader func(charset string, input io.Reader) (io.Reader, error) // passed to xml decoder
+ Insecure bool
}
diff --git a/session.go b/session.go
index ccf1993..692203e 100644
--- a/session.go
+++ b/session.go
@@ -1,16 +1,12 @@
package xmpp
import (
- "encoding/xml"
"errors"
"fmt"
- "io"
"gosrc.io/xmpp/stanza"
)
-const xmppStreamOpen = ""
-
type Session struct {
// Session info
BindJid string // Jabber ID as provided by XMPP server
@@ -21,8 +17,7 @@ type Session struct {
lastPacketId int
// read / write
- streamLogger io.ReadWriter
- decoder *xml.Decoder
+ transport Transport
// error management
err error
@@ -30,10 +25,11 @@ type Session struct {
func NewSession(transport Transport, o Config, state SMState) (*Session, error) {
s := new(Session)
+ s.transport = transport
s.SMState = state
- s.init(transport, o)
+ s.init(o)
- s.startTlsIfSupported(transport, o.parsedJid.Domain, o)
+ s.startTlsIfSupported(o)
if s.err != nil {
return nil, NewConnError(s.err, true)
@@ -45,12 +41,12 @@ func NewSession(transport Transport, o Config, state SMState) (*Session, error)
}
if s.TlsEnabled {
- s.reset(transport, o)
+ s.reset(o)
}
// auth
s.auth(o)
- s.reset(transport, o)
+ s.reset(o)
// attempt resumption
if s.resume(o) {
@@ -72,51 +68,31 @@ func (s *Session) PacketId() string {
return fmt.Sprintf("%x", s.lastPacketId)
}
-func (s *Session) init(transport Transport, o Config) {
- s.setStreamLogger(transport, o)
+func (s *Session) init(o Config) {
s.Features = s.open(o.parsedJid.Domain)
}
-func (s *Session) reset(transport Transport, o Config) {
+func (s *Session) reset(o Config) {
if s.err != nil {
return
}
-
- s.setStreamLogger(transport, o)
s.Features = s.open(o.parsedJid.Domain)
}
-func (s *Session) setStreamLogger(transport Transport, o Config) {
- s.streamLogger = newStreamLogger(transport, o.StreamLogger)
- s.decoder = xml.NewDecoder(s.streamLogger)
- s.decoder.CharsetReader = o.CharsetReader
-}
-
func (s *Session) open(domain string) (f stanza.StreamFeatures) {
- // Send stream open tag
- if _, s.err = fmt.Fprintf(s.streamLogger, xmppStreamOpen, domain, stanza.NSClient, stanza.NSStream); s.err != nil {
- return
- }
-
- // Set xml decoder and extract streamID from reply
- s.StreamId, s.err = stanza.InitStream(s.decoder) // TODO refactor / rename
- if s.err != nil {
- return
- }
-
// extract stream features
- if s.err = s.decoder.Decode(&f); s.err != nil {
+ if s.err = s.transport.GetDecoder().Decode(&f); s.err != nil {
s.err = errors.New("stream open decode features: " + s.err.Error())
}
return
}
-func (s *Session) startTlsIfSupported(transport Transport, domain string, o Config) {
+func (s *Session) startTlsIfSupported(o Config) {
if s.err != nil {
return
}
- if !transport.DoesStartTLS() {
+ if !s.transport.DoesStartTLS() {
if !o.Insecure {
s.err = errors.New("Transport does not support starttls")
}
@@ -124,15 +100,15 @@ func (s *Session) startTlsIfSupported(transport Transport, domain string, o Conf
}
if _, ok := s.Features.DoesStartTLS(); ok {
- fmt.Fprintf(s.streamLogger, "")
+ fmt.Fprintf(s.transport, "")
var k stanza.TLSProceed
- if s.err = s.decoder.DecodeElement(&k, nil); s.err != nil {
+ if s.err = s.transport.GetDecoder().DecodeElement(&k, nil); s.err != nil {
s.err = errors.New("expecting starttls proceed: " + s.err.Error())
return
}
- s.err = transport.StartTLS(domain)
+ s.err = s.transport.StartTLS()
if s.err == nil {
s.TlsEnabled = true
@@ -151,7 +127,7 @@ func (s *Session) auth(o Config) {
return
}
- s.err = authSASL(s.streamLogger, s.decoder, s.Features, o.parsedJid.Node, o.Credential)
+ s.err = authSASL(s.transport, s.transport.GetDecoder(), s.Features, o.parsedJid.Node, o.Credential)
}
// Attempt to resume session using stream management
@@ -163,11 +139,11 @@ func (s *Session) resume(o Config) bool {
return false
}
- fmt.Fprintf(s.streamLogger, "",
+ fmt.Fprintf(s.transport, "",
stanza.NSStreamManagement, s.SMState.Inbound, s.SMState.Id)
var packet stanza.Packet
- packet, s.err = stanza.NextPacket(s.decoder)
+ packet, s.err = stanza.NextPacket(s.transport.GetDecoder())
if s.err == nil {
switch p := packet.(type) {
case stanza.SMResumed:
@@ -194,14 +170,14 @@ func (s *Session) bind(o Config) {
// Send IQ message asking to bind to the local user name.
var resource = o.parsedJid.Resource
if resource != "" {
- fmt.Fprintf(s.streamLogger, "%s",
+ fmt.Fprintf(s.transport, "%s",
s.PacketId(), stanza.NSBind, resource)
} else {
- fmt.Fprintf(s.streamLogger, "", s.PacketId(), stanza.NSBind)
+ fmt.Fprintf(s.transport, "", s.PacketId(), stanza.NSBind)
}
var iq stanza.IQ
- if s.err = s.decoder.Decode(&iq); s.err != nil {
+ if s.err = s.transport.GetDecoder().Decode(&iq); s.err != nil {
s.err = errors.New("error decoding iq bind result: " + s.err.Error())
return
}
@@ -226,8 +202,8 @@ func (s *Session) rfc3921Session(o Config) {
var iq stanza.IQ
// We only negotiate session binding if it is mandatory, we skip it when optional.
if !s.Features.Session.IsOptional() {
- fmt.Fprintf(s.streamLogger, "", s.PacketId(), stanza.NSSession)
- if s.err = s.decoder.Decode(&iq); s.err != nil {
+ fmt.Fprintf(s.transport, "", s.PacketId(), stanza.NSSession)
+ if s.err = s.transport.GetDecoder().Decode(&iq); s.err != nil {
s.err = errors.New("expecting iq result after session open: " + s.err.Error())
return
}
@@ -243,10 +219,10 @@ func (s *Session) EnableStreamManagement(o Config) {
return
}
- fmt.Fprintf(s.streamLogger, "", stanza.NSStreamManagement)
+ fmt.Fprintf(s.transport, "", stanza.NSStreamManagement)
var packet stanza.Packet
- packet, s.err = stanza.NextPacket(s.decoder)
+ packet, s.err = stanza.NextPacket(s.transport.GetDecoder())
if s.err == nil {
switch p := packet.(type) {
case stanza.SMEnabled:
diff --git a/stanza/open.go b/stanza/open.go
new file mode 100644
index 0000000..32ece17
--- /dev/null
+++ b/stanza/open.go
@@ -0,0 +1,13 @@
+package stanza
+
+import "encoding/xml"
+
+// Open Packet
+// Reference: WebSocket connections must start with this element
+// https://tools.ietf.org/html/rfc7395#section-3.4
+type WebsocketOpen struct {
+ XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-framing open"`
+ From string `xml:"from,attr"`
+ Id string `xml:"id,attr"`
+ Version string `xml:"version,attr"`
+}
diff --git a/stanza/stream.go b/stanza/stream.go
index 11cd96b..290abfe 100644
--- a/stanza/stream.go
+++ b/stanza/stream.go
@@ -1,167 +1,14 @@
package stanza
-import (
- "encoding/xml"
-)
+import "encoding/xml"
-// ============================================================================
-// StreamFeatures Packet
-// Reference: The active stream features are published on
-// https://xmpp.org/registrar/stream-features.html
-// Note: That page misses draft and experimental XEP (i.e CSI, etc)
-
-type StreamFeatures struct {
- XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"`
- // Server capabilities hash
- Caps Caps
- // Stream features
- StartTLS tlsStartTLS
- Mechanisms saslMechanisms
- Bind Bind
- StreamManagement streamManagement
- // Obsolete
- Session StreamSession
- // ProcessOne Stream Features
- P1Push p1Push
- P1Rebind p1Rebind
- p1Ack p1Ack
- Any []xml.Name `xml:",any"`
-}
-
-func (StreamFeatures) Name() string {
- return "stream:features"
-}
-
-type streamFeatureDecoder struct{}
-
-var streamFeatures streamFeatureDecoder
-
-func (streamFeatureDecoder) decode(p *xml.Decoder, se xml.StartElement) (StreamFeatures, error) {
- var packet StreamFeatures
- err := p.DecodeElement(&packet, &se)
- return packet, err
-}
-
-// Capabilities
-// Reference: https://xmpp.org/extensions/xep-0115.html#stream
-// "A server MAY include its entity capabilities in a stream feature element so that connecting clients
-// and peer servers do not need to send service discovery requests each time they connect."
-// This is not a stream feature but a way to let client cache server disco info.
-type Caps struct {
- XMLName xml.Name `xml:"http://jabber.org/protocol/caps c"`
- Hash string `xml:"hash,attr"`
- Node string `xml:"node,attr"`
- Ver string `xml:"ver,attr"`
- Ext string `xml:"ext,attr,omitempty"`
-}
-
-// ============================================================================
-// Supported Stream Features
-
-// StartTLS feature
-// Reference: RFC 6120 - https://tools.ietf.org/html/rfc6120#section-5.4
-type tlsStartTLS struct {
- XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls starttls"`
- Required bool
-}
-
-// UnmarshalXML implements custom parsing startTLS required flag
-func (stls *tlsStartTLS) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
- stls.XMLName = start.Name
-
- // Check subelements to extract required field as boolean
- for {
- t, err := d.Token()
- if err != nil {
- return err
- }
-
- switch tt := t.(type) {
-
- case xml.StartElement:
- elt := new(Node)
-
- err = d.DecodeElement(elt, &tt)
- if err != nil {
- return err
- }
-
- if elt.XMLName.Local == "required" {
- stls.Required = true
- }
-
- case xml.EndElement:
- if tt == start.End() {
- return nil
- }
- }
- }
-}
-
-func (sf *StreamFeatures) DoesStartTLS() (feature tlsStartTLS, isSupported bool) {
- if sf.StartTLS.XMLName.Space+" "+sf.StartTLS.XMLName.Local == nsTLS+" starttls" {
- return sf.StartTLS, true
- }
- return feature, false
-}
-
-// Mechanisms
-// Reference: RFC 6120 - https://tools.ietf.org/html/rfc6120#section-6.4.1
-type saslMechanisms struct {
- XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl mechanisms"`
- Mechanism []string `xml:"mechanism"`
-}
-
-// StreamManagement
-// Reference: XEP-0198 - https://xmpp.org/extensions/xep-0198.html#feature
-type streamManagement struct {
- XMLName xml.Name `xml:"urn:xmpp:sm:3 sm"`
-}
-
-func (sf *StreamFeatures) DoesStreamManagement() (isSupported bool) {
- if sf.StreamManagement.XMLName.Space+" "+sf.StreamManagement.XMLName.Local == "urn:xmpp:sm:3 sm" {
- return true
- }
- return false
-}
-
-// P1 extensions
-// Reference: https://docs.ejabberd.im/developer/mobile/core-features/
-
-// p1:push support
-type p1Push struct {
- XMLName xml.Name `xml:"p1:push push"`
-}
-
-// p1:rebind suppor
-type p1Rebind struct {
- XMLName xml.Name `xml:"p1:rebind rebind"`
-}
-
-// p1:ack support
-type p1Ack struct {
- XMLName xml.Name `xml:"p1:ack ack"`
-}
-
-// ============================================================================
-// StreamError Packet
-
-type StreamError struct {
- XMLName xml.Name `xml:"http://etherx.jabber.org/streams error"`
- Error xml.Name `xml:",any"`
- Text string `xml:"urn:ietf:params:xml:ns:xmpp-streams text"`
-}
-
-func (StreamError) Name() string {
- return "stream:error"
-}
-
-type streamErrorDecoder struct{}
-
-var streamError streamErrorDecoder
-
-func (streamErrorDecoder) decode(p *xml.Decoder, se xml.StartElement) (StreamError, error) {
- var packet StreamError
- err := p.DecodeElement(&packet, &se)
- return packet, err
+// Start of stream
+// Reference: XMPP Core stream open
+// https://tools.ietf.org/html/rfc6120#section-4.2
+type Stream struct {
+ XMLName xml.Name `xml:"http://etherx.jabber.org/streams stream"`
+ From string `xml:"from,attr"`
+ To string `xml:"to,attr"`
+ Id string `xml:"id,attr"`
+ Version string `xml:"version,attr"`
}
diff --git a/stanza/stream_features.go b/stanza/stream_features.go
new file mode 100644
index 0000000..11cd96b
--- /dev/null
+++ b/stanza/stream_features.go
@@ -0,0 +1,167 @@
+package stanza
+
+import (
+ "encoding/xml"
+)
+
+// ============================================================================
+// StreamFeatures Packet
+// Reference: The active stream features are published on
+// https://xmpp.org/registrar/stream-features.html
+// Note: That page misses draft and experimental XEP (i.e CSI, etc)
+
+type StreamFeatures struct {
+ XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"`
+ // Server capabilities hash
+ Caps Caps
+ // Stream features
+ StartTLS tlsStartTLS
+ Mechanisms saslMechanisms
+ Bind Bind
+ StreamManagement streamManagement
+ // Obsolete
+ Session StreamSession
+ // ProcessOne Stream Features
+ P1Push p1Push
+ P1Rebind p1Rebind
+ p1Ack p1Ack
+ Any []xml.Name `xml:",any"`
+}
+
+func (StreamFeatures) Name() string {
+ return "stream:features"
+}
+
+type streamFeatureDecoder struct{}
+
+var streamFeatures streamFeatureDecoder
+
+func (streamFeatureDecoder) decode(p *xml.Decoder, se xml.StartElement) (StreamFeatures, error) {
+ var packet StreamFeatures
+ err := p.DecodeElement(&packet, &se)
+ return packet, err
+}
+
+// Capabilities
+// Reference: https://xmpp.org/extensions/xep-0115.html#stream
+// "A server MAY include its entity capabilities in a stream feature element so that connecting clients
+// and peer servers do not need to send service discovery requests each time they connect."
+// This is not a stream feature but a way to let client cache server disco info.
+type Caps struct {
+ XMLName xml.Name `xml:"http://jabber.org/protocol/caps c"`
+ Hash string `xml:"hash,attr"`
+ Node string `xml:"node,attr"`
+ Ver string `xml:"ver,attr"`
+ Ext string `xml:"ext,attr,omitempty"`
+}
+
+// ============================================================================
+// Supported Stream Features
+
+// StartTLS feature
+// Reference: RFC 6120 - https://tools.ietf.org/html/rfc6120#section-5.4
+type tlsStartTLS struct {
+ XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls starttls"`
+ Required bool
+}
+
+// UnmarshalXML implements custom parsing startTLS required flag
+func (stls *tlsStartTLS) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
+ stls.XMLName = start.Name
+
+ // Check subelements to extract required field as boolean
+ for {
+ t, err := d.Token()
+ if err != nil {
+ return err
+ }
+
+ switch tt := t.(type) {
+
+ case xml.StartElement:
+ elt := new(Node)
+
+ err = d.DecodeElement(elt, &tt)
+ if err != nil {
+ return err
+ }
+
+ if elt.XMLName.Local == "required" {
+ stls.Required = true
+ }
+
+ case xml.EndElement:
+ if tt == start.End() {
+ return nil
+ }
+ }
+ }
+}
+
+func (sf *StreamFeatures) DoesStartTLS() (feature tlsStartTLS, isSupported bool) {
+ if sf.StartTLS.XMLName.Space+" "+sf.StartTLS.XMLName.Local == nsTLS+" starttls" {
+ return sf.StartTLS, true
+ }
+ return feature, false
+}
+
+// Mechanisms
+// Reference: RFC 6120 - https://tools.ietf.org/html/rfc6120#section-6.4.1
+type saslMechanisms struct {
+ XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl mechanisms"`
+ Mechanism []string `xml:"mechanism"`
+}
+
+// StreamManagement
+// Reference: XEP-0198 - https://xmpp.org/extensions/xep-0198.html#feature
+type streamManagement struct {
+ XMLName xml.Name `xml:"urn:xmpp:sm:3 sm"`
+}
+
+func (sf *StreamFeatures) DoesStreamManagement() (isSupported bool) {
+ if sf.StreamManagement.XMLName.Space+" "+sf.StreamManagement.XMLName.Local == "urn:xmpp:sm:3 sm" {
+ return true
+ }
+ return false
+}
+
+// P1 extensions
+// Reference: https://docs.ejabberd.im/developer/mobile/core-features/
+
+// p1:push support
+type p1Push struct {
+ XMLName xml.Name `xml:"p1:push push"`
+}
+
+// p1:rebind suppor
+type p1Rebind struct {
+ XMLName xml.Name `xml:"p1:rebind rebind"`
+}
+
+// p1:ack support
+type p1Ack struct {
+ XMLName xml.Name `xml:"p1:ack ack"`
+}
+
+// ============================================================================
+// StreamError Packet
+
+type StreamError struct {
+ XMLName xml.Name `xml:"http://etherx.jabber.org/streams error"`
+ Error xml.Name `xml:",any"`
+ Text string `xml:"urn:ietf:params:xml:ns:xmpp-streams text"`
+}
+
+func (StreamError) Name() string {
+ return "stream:error"
+}
+
+type streamErrorDecoder struct{}
+
+var streamError streamErrorDecoder
+
+func (streamErrorDecoder) decode(p *xml.Decoder, se xml.StartElement) (StreamError, error) {
+ var packet StreamError
+ err := p.DecodeElement(&packet, &se)
+ return packet, err
+}
diff --git a/stanza/stream_test.go b/stanza/stream_features_test.go
similarity index 100%
rename from stanza/stream_test.go
rename to stanza/stream_features_test.go
diff --git a/stream_logger.go b/stream_logger.go
index 8bbecdf..154dc87 100644
--- a/stream_logger.go
+++ b/stream_logger.go
@@ -2,17 +2,16 @@ package xmpp
import (
"io"
- "os"
)
// Mediated Read / Write on socket
// Used if logFile from Config is not nil
type streamLogger struct {
socket io.ReadWriter // Actual connection
- logFile *os.File
+ logFile io.Writer
}
-func newStreamLogger(conn io.ReadWriter, logFile *os.File) io.ReadWriter {
+func newStreamLogger(conn io.ReadWriter, logFile io.Writer) io.ReadWriter {
if logFile == nil {
return conn
} else {
diff --git a/transport.go b/transport.go
index 6c4b8e0..daf66a1 100644
--- a/transport.go
+++ b/transport.go
@@ -2,7 +2,9 @@ package xmpp
import (
"crypto/tls"
+ "encoding/xml"
"errors"
+ "io"
"strings"
)
@@ -12,17 +14,22 @@ type TransportConfiguration struct {
// Address is the XMPP Host and port to connect to. Host is of
// the form 'serverhost:port' i.e "localhost:8888"
Address string
+ Domain string
ConnectTimeout int // Client timeout in seconds. Default to 15
// tls.Config must not be modified after having been passed to NewClient. Any
// changes made after connecting are ignored.
- TLSConfig *tls.Config
+ TLSConfig *tls.Config
+ CharsetReader func(charset string, input io.Reader) (io.Reader, error) // passed to xml decoder
}
type Transport interface {
- Connect() error
+ Connect() (string, error)
DoesStartTLS() bool
- StartTLS(domain string) error
+ StartTLS() error
+ LogTraffic(logFile io.Writer)
+
+ GetDecoder() *xml.Decoder
IsSecure() bool
Ping() error
diff --git a/websocket_transport.go b/websocket_transport.go
index 690bc1d..391c012 100644
--- a/websocket_transport.go
+++ b/websocket_transport.go
@@ -2,11 +2,15 @@ package xmpp
import (
"context"
+ "encoding/xml"
"errors"
+ "fmt"
+ "io"
"net"
"strings"
"time"
+ "gosrc.io/xmpp/stanza"
"nhooyr.io/websocket"
)
@@ -16,35 +20,60 @@ var ServerDoesNotSupportXmppOverWebsocket = errors.New("The websocket server doe
type WebsocketTransport struct {
Config TransportConfiguration
+ decoder *xml.Decoder
wsConn *websocket.Conn
netConn net.Conn
- ctx context.Context
+ logFile io.Writer
}
-func (t *WebsocketTransport) Connect() error {
- t.ctx = context.Background()
+func (t *WebsocketTransport) Connect() (string, error) {
+ ctx := context.Background()
if t.Config.ConnectTimeout > 0 {
- ctx, cancel := context.WithTimeout(t.ctx, time.Duration(t.Config.ConnectTimeout)*time.Second)
- t.ctx = ctx
+ var cancel context.CancelFunc
+ ctx, cancel = context.WithTimeout(ctx, time.Duration(t.Config.ConnectTimeout)*time.Second)
defer cancel()
}
- wsConn, response, err := websocket.Dial(t.ctx, t.Config.Address, &websocket.DialOptions{
+ wsConn, response, err := websocket.Dial(ctx, t.Config.Address, &websocket.DialOptions{
Subprotocols: []string{"xmpp"},
})
if err != nil {
- return NewConnError(err, true)
+ return "", NewConnError(err, true)
}
if response.Header.Get("Sec-WebSocket-Protocol") != "xmpp" {
- return ServerDoesNotSupportXmppOverWebsocket
+ _ = wsConn.Close(websocket.StatusBadGateway, "Could not negotiate XMPP subprotocol")
+ return "", NewConnError(ServerDoesNotSupportXmppOverWebsocket, true)
}
+
t.wsConn = wsConn
- t.netConn = websocket.NetConn(t.ctx, t.wsConn, websocket.MessageText)
- return nil
+ t.netConn = websocket.NetConn(ctx, t.wsConn, websocket.MessageText)
+
+ handshake := fmt.Sprintf("", t.Config.Domain)
+ if _, err = t.Write([]byte(handshake)); err != nil {
+ _ = wsConn.Close(websocket.StatusBadGateway, "XMPP handshake error")
+ return "", NewConnError(err, false)
+ }
+
+ handshakeResponse := make([]byte, 2048)
+ if _, err = t.Read(handshakeResponse); err != nil {
+ _ = wsConn.Close(websocket.StatusBadGateway, "XMPP handshake error")
+ return "", NewConnError(err, false)
+ }
+
+ var openResponse = stanza.WebsocketOpen{}
+ if err = xml.Unmarshal(handshakeResponse, &openResponse); err != nil {
+ _ = wsConn.Close(websocket.StatusBadGateway, "XMPP handshake error")
+ return "", NewConnError(err, false)
+ }
+
+ t.decoder = xml.NewDecoder(t)
+ t.decoder.CharsetReader = t.Config.CharsetReader
+
+ return openResponse.Id, nil
}
-func (t WebsocketTransport) StartTLS(domain string) error {
+func (t WebsocketTransport) StartTLS() error {
return TLSNotSupported
}
@@ -52,6 +81,10 @@ func (t WebsocketTransport) DoesStartTLS() bool {
return false
}
+func (t WebsocketTransport) GetDecoder() *xml.Decoder {
+ return t.decoder
+}
+
func (t WebsocketTransport) IsSecure() bool {
return strings.HasPrefix(t.Config.Address, "wss:")
}
@@ -59,19 +92,29 @@ func (t WebsocketTransport) IsSecure() bool {
func (t WebsocketTransport) Ping() error {
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
defer cancel()
- // Note that we do not use wsConn.Ping(), because not all websocket servers
- // (ejabberd for example) implement ping frames
- return t.wsConn.Write(ctx, websocket.MessageText, []byte(" "))
+ return t.wsConn.Ping(ctx)
}
-func (t WebsocketTransport) Read(p []byte) (n int, err error) {
- return t.netConn.Read(p)
+func (t *WebsocketTransport) Read(p []byte) (n int, err error) {
+ n, err = t.netConn.Read(p)
+ if t.logFile != nil && n > 0 {
+ _, _ = fmt.Fprintf(t.logFile, "RECV:\n%s\n\n", p)
+ }
+ return
}
func (t WebsocketTransport) Write(p []byte) (n int, err error) {
+ if t.logFile != nil {
+ _, _ = fmt.Fprintf(t.logFile, "SEND:\n%s\n\n", p)
+ }
return t.netConn.Write(p)
}
func (t WebsocketTransport) Close() error {
+ t.Write([]byte(""))
return t.netConn.Close()
}
+
+func (t *WebsocketTransport) LogTraffic(logFile io.Writer) {
+ t.logFile = logFile
+}
diff --git a/xmpp_transport.go b/xmpp_transport.go
index 614a76d..64f2b9a 100644
--- a/xmpp_transport.go
+++ b/xmpp_transport.go
@@ -2,39 +2,65 @@ package xmpp
import (
"crypto/tls"
+ "encoding/xml"
"errors"
+ "fmt"
+ "io"
"net"
"time"
+
+ "gosrc.io/xmpp/stanza"
)
// XMPPTransport implements the XMPP native TCP transport
type XMPPTransport struct {
- Config TransportConfiguration
- TLSConfig *tls.Config
- // TCP level connection / can be replaced by a TLS session after starttls
- conn net.Conn
- isSecure bool
+ Config TransportConfiguration
+ TLSConfig *tls.Config
+ decoder *xml.Decoder
+ conn net.Conn
+ readWriter io.ReadWriter
+ isSecure bool
}
-func (t *XMPPTransport) Connect() error {
+const xmppStreamOpen = ""
+
+func (t *XMPPTransport) Connect() (string, error) {
var err error
t.conn, err = net.DialTimeout("tcp", t.Config.Address, time.Duration(t.Config.ConnectTimeout)*time.Second)
if err != nil {
- return NewConnError(err, true)
+ return "", NewConnError(err, true)
}
- return nil
+
+ if _, err = fmt.Fprintf(t.conn, xmppStreamOpen, t.Config.Domain, stanza.NSClient, stanza.NSStream); err != nil {
+ t.conn.Close()
+ return "", NewConnError(err, true)
+ }
+
+ t.decoder = xml.NewDecoder(t.readWriter)
+ t.decoder.CharsetReader = t.Config.CharsetReader
+ sessionId, err := stanza.InitStream(t.decoder)
+ if err != nil {
+ t.conn.Close()
+ return "", NewConnError(err, false)
+ }
+ t.readWriter = t.conn
+ return sessionId, nil
}
func (t XMPPTransport) DoesStartTLS() bool {
return true
}
+func (t XMPPTransport) GetDecoder() *xml.Decoder {
+ return t.decoder
+}
+
func (t XMPPTransport) IsSecure() bool {
return t.isSecure
}
-func (t *XMPPTransport) StartTLS(domain string) error {
+func (t *XMPPTransport) StartTLS() error {
if t.Config.TLSConfig == nil {
t.TLSConfig = &tls.Config{}
} else {
@@ -42,7 +68,7 @@ func (t *XMPPTransport) StartTLS(domain string) error {
}
if t.TLSConfig.ServerName == "" {
- t.TLSConfig.ServerName = domain
+ t.TLSConfig.ServerName = t.Config.Domain
}
tlsConn := tls.Client(t.conn, t.TLSConfig)
// We convert existing connection to TLS
@@ -51,7 +77,7 @@ func (t *XMPPTransport) StartTLS(domain string) error {
}
if !t.TLSConfig.InsecureSkipVerify {
- if err := tlsConn.VerifyHostname(domain); err != nil {
+ if err := tlsConn.VerifyHostname(t.Config.Domain); err != nil {
return err
}
}
@@ -72,13 +98,18 @@ func (t XMPPTransport) Ping() error {
}
func (t XMPPTransport) Read(p []byte) (n int, err error) {
- return t.conn.Read(p)
+ return t.readWriter.Read(p)
}
func (t XMPPTransport) Write(p []byte) (n int, err error) {
- return t.conn.Write(p)
+ return t.readWriter.Write(p)
}
func (t XMPPTransport) Close() error {
+ _, _ = t.readWriter.Write([]byte(""))
return t.conn.Close()
}
+
+func (t *XMPPTransport) LogTraffic(logFile io.Writer) {
+ t.readWriter = &streamLogger{t.conn, logFile}
+}