From 92329b48e612919a2f2d7aa148ca3a922d7c81b7 Mon Sep 17 00:00:00 2001 From: Wichert Akkerman Date: Fri, 18 Oct 2019 20:29:54 +0200 Subject: [PATCH] Transports need to handle open/close stanzas XMPP and WebSocket transports require different open and close stanzas. To handle this the responsibility handling those and creating the XML decoder is moved to the Transport. --- _examples/delegation/delegation.go | 1 + _examples/xmpp_component/xmpp_component.go | 1 + _examples/xmpp_echo/xmpp_echo.go | 3 +- client.go | 19 +- component.go | 27 +-- config.go | 4 +- session.go | 72 +++----- stanza/open.go | 13 ++ stanza/stream.go | 173 +----------------- stanza/stream_features.go | 167 +++++++++++++++++ ...stream_test.go => stream_features_test.go} | 0 stream_logger.go | 5 +- transport.go | 13 +- websocket_transport.go | 75 ++++++-- xmpp_transport.go | 57 ++++-- 15 files changed, 356 insertions(+), 274 deletions(-) create mode 100644 stanza/open.go create mode 100644 stanza/stream_features.go rename stanza/{stream_test.go => stream_features_test.go} (100%) diff --git a/_examples/delegation/delegation.go b/_examples/delegation/delegation.go index 61e3716..81642c6 100644 --- a/_examples/delegation/delegation.go +++ b/_examples/delegation/delegation.go @@ -13,6 +13,7 @@ func main() { opts := xmpp.ComponentOptions{ TransportConfiguration: xmpp.TransportConfiguration{ Address: "localhost:9999", + Domain: "service.localhost", }, Domain: "service.localhost", Secret: "mypass", diff --git a/_examples/xmpp_component/xmpp_component.go b/_examples/xmpp_component/xmpp_component.go index fc07f05..e36b287 100644 --- a/_examples/xmpp_component/xmpp_component.go +++ b/_examples/xmpp_component/xmpp_component.go @@ -12,6 +12,7 @@ func main() { opts := xmpp.ComponentOptions{ TransportConfiguration: xmpp.TransportConfiguration{ Address: "localhost:8888", + Domain: "service2.localhost", }, Domain: "service2.localhost", Secret: "mypass", diff --git a/_examples/xmpp_echo/xmpp_echo.go b/_examples/xmpp_echo/xmpp_echo.go index 5654a2b..63e36b6 100644 --- a/_examples/xmpp_echo/xmpp_echo.go +++ b/_examples/xmpp_echo/xmpp_echo.go @@ -16,7 +16,8 @@ import ( func main() { config := xmpp.Config{ TransportConfiguration: xmpp.TransportConfiguration{ - Address: "localhost:5222", + // Address: "localhost:5222", + Address: "ws://127.0.0.1:5280/xmpp", }, Jid: "test@localhost", Credential: xmpp.Password("test"), diff --git a/client.go b/client.go index 9c74da7..3d7c8b4 100644 --- a/client.go +++ b/client.go @@ -141,8 +141,15 @@ func NewClient(config Config, r *Router) (c *Client, err error) { c.config.ConnectTimeout = 15 // 15 second as default } + if config.TransportConfiguration.Domain == "" { + config.TransportConfiguration.Domain = config.parsedJid.Domain + } c.transport = NewTransport(config.TransportConfiguration) + if config.StreamLogger != nil { + c.transport.LogTraffic(config.StreamLogger) + } + return } @@ -158,7 +165,7 @@ func (c *Client) Connect() error { func (c *Client) Resume(state SMState) error { var err error - err = c.transport.Connect() + streamId, err := c.transport.Connect() if err != nil { return err } @@ -168,6 +175,7 @@ func (c *Client) Resume(state SMState) error { if c.Session, err = NewSession(c.transport, c.config, state); err != nil { return err } + c.Session.StreamId = streamId c.updateState(StateSessionEstablished) // Start the keepalive go routine @@ -181,13 +189,12 @@ func (c *Client) Resume(state SMState) error { //fmt.Fprintf(client.conn, "%s%s", "chat", "Online") // TODO: Do we always want to send initial presence automatically ? // Do we need an option to avoid that or do we rely on client to send the presence itself ? - fmt.Fprintf(c.Session.streamLogger, "") + fmt.Fprintf(c.transport, "") return err } func (c *Client) Disconnect() { - _ = c.SendRaw("") // TODO: Add a way to wait for stream close acknowledgement from the server for clean disconnect if c.transport != nil { _ = c.transport.Close() @@ -210,7 +217,7 @@ func (c *Client) Send(packet stanza.Packet) error { return errors.New("cannot marshal packet " + err.Error()) } - return c.sendWithWriter(c.Session.streamLogger, data) + return c.sendWithWriter(c.transport, data) } // SendRaw sends an XMPP stanza as a string to the server. @@ -223,7 +230,7 @@ func (c *Client) SendRaw(packet string) error { return errors.New("client is not connected") } - return c.sendWithWriter(c.Session.streamLogger, []byte(packet)) + return c.sendWithWriter(c.transport, []byte(packet)) } func (c *Client) sendWithWriter(writer io.Writer, packet []byte) error { @@ -238,7 +245,7 @@ func (c *Client) sendWithWriter(writer io.Writer, packet []byte) error { // Loop: Receive data from server func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}) (err error) { for { - val, err := stanza.NextPacket(c.Session.decoder) + val, err := stanza.NextPacket(c.transport.GetDecoder()) if err != nil { close(keepaliveQuit) c.disconnected(state) diff --git a/component.go b/component.go index 3301c96..137bc05 100644 --- a/component.go +++ b/component.go @@ -67,33 +67,25 @@ func (c *Component) Connect() error { } func (c *Component) Resume(sm SMState) error { var err error + var streamId string + if c.ComponentOptions.TransportConfiguration.Domain == "" { + c.ComponentOptions.TransportConfiguration.Domain = c.ComponentOptions.Domain + } c.transport = NewTransport(c.ComponentOptions.TransportConfiguration) - if err = c.transport.Connect(); err != nil { + + if streamId, err = c.transport.Connect(); err != nil { + c.updateState(StateStreamError) return err } c.updateState(StateConnected) - // 1. Send stream open tag - if _, err := fmt.Fprintf(c.transport, componentStreamOpen, c.Domain, stanza.NSComponent, stanza.NSStream); err != nil { - c.updateState(StateStreamError) - return NewConnError(errors.New("cannot send stream open "+err.Error()), false) - } - c.decoder = xml.NewDecoder(c.transport) - - // 2. Initialize xml decoder and extract streamID from reply - streamId, err := stanza.InitStream(c.decoder) - if err != nil { - c.updateState(StateStreamError) - return NewConnError(errors.New("cannot init decoder "+err.Error()), false) - } - - // 3. Authentication + // Authentication if _, err := fmt.Fprintf(c.transport, "%s", c.handshake(streamId)); err != nil { c.updateState(StateStreamError) return NewConnError(errors.New("cannot send handshake "+err.Error()), false) } - // 4. Check server response for authentication + // Check server response for authentication val, err := stanza.NextPacket(c.decoder) if err != nil { c.updateState(StateDisconnected) @@ -116,7 +108,6 @@ func (c *Component) Resume(sm SMState) error { } func (c *Component) Disconnect() { - _ = c.SendRaw("") // TODO: Add a way to wait for stream close acknowledgement from the server for clean disconnect if c.transport != nil { _ = c.transport.Close() diff --git a/config.go b/config.go index 178feae..6e836da 100644 --- a/config.go +++ b/config.go @@ -1,7 +1,6 @@ package xmpp import ( - "io" "os" ) @@ -18,6 +17,5 @@ type Config struct { ConnectTimeout int // Client timeout in seconds. Default to 15 // Insecure can be set to true to allow to open a session without TLS. If TLS // is supported on the server, we will still try to use it. - Insecure bool - CharsetReader func(charset string, input io.Reader) (io.Reader, error) // passed to xml decoder + Insecure bool } diff --git a/session.go b/session.go index ccf1993..692203e 100644 --- a/session.go +++ b/session.go @@ -1,16 +1,12 @@ package xmpp import ( - "encoding/xml" "errors" "fmt" - "io" "gosrc.io/xmpp/stanza" ) -const xmppStreamOpen = "" - type Session struct { // Session info BindJid string // Jabber ID as provided by XMPP server @@ -21,8 +17,7 @@ type Session struct { lastPacketId int // read / write - streamLogger io.ReadWriter - decoder *xml.Decoder + transport Transport // error management err error @@ -30,10 +25,11 @@ type Session struct { func NewSession(transport Transport, o Config, state SMState) (*Session, error) { s := new(Session) + s.transport = transport s.SMState = state - s.init(transport, o) + s.init(o) - s.startTlsIfSupported(transport, o.parsedJid.Domain, o) + s.startTlsIfSupported(o) if s.err != nil { return nil, NewConnError(s.err, true) @@ -45,12 +41,12 @@ func NewSession(transport Transport, o Config, state SMState) (*Session, error) } if s.TlsEnabled { - s.reset(transport, o) + s.reset(o) } // auth s.auth(o) - s.reset(transport, o) + s.reset(o) // attempt resumption if s.resume(o) { @@ -72,51 +68,31 @@ func (s *Session) PacketId() string { return fmt.Sprintf("%x", s.lastPacketId) } -func (s *Session) init(transport Transport, o Config) { - s.setStreamLogger(transport, o) +func (s *Session) init(o Config) { s.Features = s.open(o.parsedJid.Domain) } -func (s *Session) reset(transport Transport, o Config) { +func (s *Session) reset(o Config) { if s.err != nil { return } - - s.setStreamLogger(transport, o) s.Features = s.open(o.parsedJid.Domain) } -func (s *Session) setStreamLogger(transport Transport, o Config) { - s.streamLogger = newStreamLogger(transport, o.StreamLogger) - s.decoder = xml.NewDecoder(s.streamLogger) - s.decoder.CharsetReader = o.CharsetReader -} - func (s *Session) open(domain string) (f stanza.StreamFeatures) { - // Send stream open tag - if _, s.err = fmt.Fprintf(s.streamLogger, xmppStreamOpen, domain, stanza.NSClient, stanza.NSStream); s.err != nil { - return - } - - // Set xml decoder and extract streamID from reply - s.StreamId, s.err = stanza.InitStream(s.decoder) // TODO refactor / rename - if s.err != nil { - return - } - // extract stream features - if s.err = s.decoder.Decode(&f); s.err != nil { + if s.err = s.transport.GetDecoder().Decode(&f); s.err != nil { s.err = errors.New("stream open decode features: " + s.err.Error()) } return } -func (s *Session) startTlsIfSupported(transport Transport, domain string, o Config) { +func (s *Session) startTlsIfSupported(o Config) { if s.err != nil { return } - if !transport.DoesStartTLS() { + if !s.transport.DoesStartTLS() { if !o.Insecure { s.err = errors.New("Transport does not support starttls") } @@ -124,15 +100,15 @@ func (s *Session) startTlsIfSupported(transport Transport, domain string, o Conf } if _, ok := s.Features.DoesStartTLS(); ok { - fmt.Fprintf(s.streamLogger, "") + fmt.Fprintf(s.transport, "") var k stanza.TLSProceed - if s.err = s.decoder.DecodeElement(&k, nil); s.err != nil { + if s.err = s.transport.GetDecoder().DecodeElement(&k, nil); s.err != nil { s.err = errors.New("expecting starttls proceed: " + s.err.Error()) return } - s.err = transport.StartTLS(domain) + s.err = s.transport.StartTLS() if s.err == nil { s.TlsEnabled = true @@ -151,7 +127,7 @@ func (s *Session) auth(o Config) { return } - s.err = authSASL(s.streamLogger, s.decoder, s.Features, o.parsedJid.Node, o.Credential) + s.err = authSASL(s.transport, s.transport.GetDecoder(), s.Features, o.parsedJid.Node, o.Credential) } // Attempt to resume session using stream management @@ -163,11 +139,11 @@ func (s *Session) resume(o Config) bool { return false } - fmt.Fprintf(s.streamLogger, "", + fmt.Fprintf(s.transport, "", stanza.NSStreamManagement, s.SMState.Inbound, s.SMState.Id) var packet stanza.Packet - packet, s.err = stanza.NextPacket(s.decoder) + packet, s.err = stanza.NextPacket(s.transport.GetDecoder()) if s.err == nil { switch p := packet.(type) { case stanza.SMResumed: @@ -194,14 +170,14 @@ func (s *Session) bind(o Config) { // Send IQ message asking to bind to the local user name. var resource = o.parsedJid.Resource if resource != "" { - fmt.Fprintf(s.streamLogger, "%s", + fmt.Fprintf(s.transport, "%s", s.PacketId(), stanza.NSBind, resource) } else { - fmt.Fprintf(s.streamLogger, "", s.PacketId(), stanza.NSBind) + fmt.Fprintf(s.transport, "", s.PacketId(), stanza.NSBind) } var iq stanza.IQ - if s.err = s.decoder.Decode(&iq); s.err != nil { + if s.err = s.transport.GetDecoder().Decode(&iq); s.err != nil { s.err = errors.New("error decoding iq bind result: " + s.err.Error()) return } @@ -226,8 +202,8 @@ func (s *Session) rfc3921Session(o Config) { var iq stanza.IQ // We only negotiate session binding if it is mandatory, we skip it when optional. if !s.Features.Session.IsOptional() { - fmt.Fprintf(s.streamLogger, "", s.PacketId(), stanza.NSSession) - if s.err = s.decoder.Decode(&iq); s.err != nil { + fmt.Fprintf(s.transport, "", s.PacketId(), stanza.NSSession) + if s.err = s.transport.GetDecoder().Decode(&iq); s.err != nil { s.err = errors.New("expecting iq result after session open: " + s.err.Error()) return } @@ -243,10 +219,10 @@ func (s *Session) EnableStreamManagement(o Config) { return } - fmt.Fprintf(s.streamLogger, "", stanza.NSStreamManagement) + fmt.Fprintf(s.transport, "", stanza.NSStreamManagement) var packet stanza.Packet - packet, s.err = stanza.NextPacket(s.decoder) + packet, s.err = stanza.NextPacket(s.transport.GetDecoder()) if s.err == nil { switch p := packet.(type) { case stanza.SMEnabled: diff --git a/stanza/open.go b/stanza/open.go new file mode 100644 index 0000000..32ece17 --- /dev/null +++ b/stanza/open.go @@ -0,0 +1,13 @@ +package stanza + +import "encoding/xml" + +// Open Packet +// Reference: WebSocket connections must start with this element +// https://tools.ietf.org/html/rfc7395#section-3.4 +type WebsocketOpen struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-framing open"` + From string `xml:"from,attr"` + Id string `xml:"id,attr"` + Version string `xml:"version,attr"` +} diff --git a/stanza/stream.go b/stanza/stream.go index 11cd96b..290abfe 100644 --- a/stanza/stream.go +++ b/stanza/stream.go @@ -1,167 +1,14 @@ package stanza -import ( - "encoding/xml" -) +import "encoding/xml" -// ============================================================================ -// StreamFeatures Packet -// Reference: The active stream features are published on -// https://xmpp.org/registrar/stream-features.html -// Note: That page misses draft and experimental XEP (i.e CSI, etc) - -type StreamFeatures struct { - XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"` - // Server capabilities hash - Caps Caps - // Stream features - StartTLS tlsStartTLS - Mechanisms saslMechanisms - Bind Bind - StreamManagement streamManagement - // Obsolete - Session StreamSession - // ProcessOne Stream Features - P1Push p1Push - P1Rebind p1Rebind - p1Ack p1Ack - Any []xml.Name `xml:",any"` -} - -func (StreamFeatures) Name() string { - return "stream:features" -} - -type streamFeatureDecoder struct{} - -var streamFeatures streamFeatureDecoder - -func (streamFeatureDecoder) decode(p *xml.Decoder, se xml.StartElement) (StreamFeatures, error) { - var packet StreamFeatures - err := p.DecodeElement(&packet, &se) - return packet, err -} - -// Capabilities -// Reference: https://xmpp.org/extensions/xep-0115.html#stream -// "A server MAY include its entity capabilities in a stream feature element so that connecting clients -// and peer servers do not need to send service discovery requests each time they connect." -// This is not a stream feature but a way to let client cache server disco info. -type Caps struct { - XMLName xml.Name `xml:"http://jabber.org/protocol/caps c"` - Hash string `xml:"hash,attr"` - Node string `xml:"node,attr"` - Ver string `xml:"ver,attr"` - Ext string `xml:"ext,attr,omitempty"` -} - -// ============================================================================ -// Supported Stream Features - -// StartTLS feature -// Reference: RFC 6120 - https://tools.ietf.org/html/rfc6120#section-5.4 -type tlsStartTLS struct { - XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls starttls"` - Required bool -} - -// UnmarshalXML implements custom parsing startTLS required flag -func (stls *tlsStartTLS) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { - stls.XMLName = start.Name - - // Check subelements to extract required field as boolean - for { - t, err := d.Token() - if err != nil { - return err - } - - switch tt := t.(type) { - - case xml.StartElement: - elt := new(Node) - - err = d.DecodeElement(elt, &tt) - if err != nil { - return err - } - - if elt.XMLName.Local == "required" { - stls.Required = true - } - - case xml.EndElement: - if tt == start.End() { - return nil - } - } - } -} - -func (sf *StreamFeatures) DoesStartTLS() (feature tlsStartTLS, isSupported bool) { - if sf.StartTLS.XMLName.Space+" "+sf.StartTLS.XMLName.Local == nsTLS+" starttls" { - return sf.StartTLS, true - } - return feature, false -} - -// Mechanisms -// Reference: RFC 6120 - https://tools.ietf.org/html/rfc6120#section-6.4.1 -type saslMechanisms struct { - XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl mechanisms"` - Mechanism []string `xml:"mechanism"` -} - -// StreamManagement -// Reference: XEP-0198 - https://xmpp.org/extensions/xep-0198.html#feature -type streamManagement struct { - XMLName xml.Name `xml:"urn:xmpp:sm:3 sm"` -} - -func (sf *StreamFeatures) DoesStreamManagement() (isSupported bool) { - if sf.StreamManagement.XMLName.Space+" "+sf.StreamManagement.XMLName.Local == "urn:xmpp:sm:3 sm" { - return true - } - return false -} - -// P1 extensions -// Reference: https://docs.ejabberd.im/developer/mobile/core-features/ - -// p1:push support -type p1Push struct { - XMLName xml.Name `xml:"p1:push push"` -} - -// p1:rebind suppor -type p1Rebind struct { - XMLName xml.Name `xml:"p1:rebind rebind"` -} - -// p1:ack support -type p1Ack struct { - XMLName xml.Name `xml:"p1:ack ack"` -} - -// ============================================================================ -// StreamError Packet - -type StreamError struct { - XMLName xml.Name `xml:"http://etherx.jabber.org/streams error"` - Error xml.Name `xml:",any"` - Text string `xml:"urn:ietf:params:xml:ns:xmpp-streams text"` -} - -func (StreamError) Name() string { - return "stream:error" -} - -type streamErrorDecoder struct{} - -var streamError streamErrorDecoder - -func (streamErrorDecoder) decode(p *xml.Decoder, se xml.StartElement) (StreamError, error) { - var packet StreamError - err := p.DecodeElement(&packet, &se) - return packet, err +// Start of stream +// Reference: XMPP Core stream open +// https://tools.ietf.org/html/rfc6120#section-4.2 +type Stream struct { + XMLName xml.Name `xml:"http://etherx.jabber.org/streams stream"` + From string `xml:"from,attr"` + To string `xml:"to,attr"` + Id string `xml:"id,attr"` + Version string `xml:"version,attr"` } diff --git a/stanza/stream_features.go b/stanza/stream_features.go new file mode 100644 index 0000000..11cd96b --- /dev/null +++ b/stanza/stream_features.go @@ -0,0 +1,167 @@ +package stanza + +import ( + "encoding/xml" +) + +// ============================================================================ +// StreamFeatures Packet +// Reference: The active stream features are published on +// https://xmpp.org/registrar/stream-features.html +// Note: That page misses draft and experimental XEP (i.e CSI, etc) + +type StreamFeatures struct { + XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"` + // Server capabilities hash + Caps Caps + // Stream features + StartTLS tlsStartTLS + Mechanisms saslMechanisms + Bind Bind + StreamManagement streamManagement + // Obsolete + Session StreamSession + // ProcessOne Stream Features + P1Push p1Push + P1Rebind p1Rebind + p1Ack p1Ack + Any []xml.Name `xml:",any"` +} + +func (StreamFeatures) Name() string { + return "stream:features" +} + +type streamFeatureDecoder struct{} + +var streamFeatures streamFeatureDecoder + +func (streamFeatureDecoder) decode(p *xml.Decoder, se xml.StartElement) (StreamFeatures, error) { + var packet StreamFeatures + err := p.DecodeElement(&packet, &se) + return packet, err +} + +// Capabilities +// Reference: https://xmpp.org/extensions/xep-0115.html#stream +// "A server MAY include its entity capabilities in a stream feature element so that connecting clients +// and peer servers do not need to send service discovery requests each time they connect." +// This is not a stream feature but a way to let client cache server disco info. +type Caps struct { + XMLName xml.Name `xml:"http://jabber.org/protocol/caps c"` + Hash string `xml:"hash,attr"` + Node string `xml:"node,attr"` + Ver string `xml:"ver,attr"` + Ext string `xml:"ext,attr,omitempty"` +} + +// ============================================================================ +// Supported Stream Features + +// StartTLS feature +// Reference: RFC 6120 - https://tools.ietf.org/html/rfc6120#section-5.4 +type tlsStartTLS struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls starttls"` + Required bool +} + +// UnmarshalXML implements custom parsing startTLS required flag +func (stls *tlsStartTLS) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { + stls.XMLName = start.Name + + // Check subelements to extract required field as boolean + for { + t, err := d.Token() + if err != nil { + return err + } + + switch tt := t.(type) { + + case xml.StartElement: + elt := new(Node) + + err = d.DecodeElement(elt, &tt) + if err != nil { + return err + } + + if elt.XMLName.Local == "required" { + stls.Required = true + } + + case xml.EndElement: + if tt == start.End() { + return nil + } + } + } +} + +func (sf *StreamFeatures) DoesStartTLS() (feature tlsStartTLS, isSupported bool) { + if sf.StartTLS.XMLName.Space+" "+sf.StartTLS.XMLName.Local == nsTLS+" starttls" { + return sf.StartTLS, true + } + return feature, false +} + +// Mechanisms +// Reference: RFC 6120 - https://tools.ietf.org/html/rfc6120#section-6.4.1 +type saslMechanisms struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl mechanisms"` + Mechanism []string `xml:"mechanism"` +} + +// StreamManagement +// Reference: XEP-0198 - https://xmpp.org/extensions/xep-0198.html#feature +type streamManagement struct { + XMLName xml.Name `xml:"urn:xmpp:sm:3 sm"` +} + +func (sf *StreamFeatures) DoesStreamManagement() (isSupported bool) { + if sf.StreamManagement.XMLName.Space+" "+sf.StreamManagement.XMLName.Local == "urn:xmpp:sm:3 sm" { + return true + } + return false +} + +// P1 extensions +// Reference: https://docs.ejabberd.im/developer/mobile/core-features/ + +// p1:push support +type p1Push struct { + XMLName xml.Name `xml:"p1:push push"` +} + +// p1:rebind suppor +type p1Rebind struct { + XMLName xml.Name `xml:"p1:rebind rebind"` +} + +// p1:ack support +type p1Ack struct { + XMLName xml.Name `xml:"p1:ack ack"` +} + +// ============================================================================ +// StreamError Packet + +type StreamError struct { + XMLName xml.Name `xml:"http://etherx.jabber.org/streams error"` + Error xml.Name `xml:",any"` + Text string `xml:"urn:ietf:params:xml:ns:xmpp-streams text"` +} + +func (StreamError) Name() string { + return "stream:error" +} + +type streamErrorDecoder struct{} + +var streamError streamErrorDecoder + +func (streamErrorDecoder) decode(p *xml.Decoder, se xml.StartElement) (StreamError, error) { + var packet StreamError + err := p.DecodeElement(&packet, &se) + return packet, err +} diff --git a/stanza/stream_test.go b/stanza/stream_features_test.go similarity index 100% rename from stanza/stream_test.go rename to stanza/stream_features_test.go diff --git a/stream_logger.go b/stream_logger.go index 8bbecdf..154dc87 100644 --- a/stream_logger.go +++ b/stream_logger.go @@ -2,17 +2,16 @@ package xmpp import ( "io" - "os" ) // Mediated Read / Write on socket // Used if logFile from Config is not nil type streamLogger struct { socket io.ReadWriter // Actual connection - logFile *os.File + logFile io.Writer } -func newStreamLogger(conn io.ReadWriter, logFile *os.File) io.ReadWriter { +func newStreamLogger(conn io.ReadWriter, logFile io.Writer) io.ReadWriter { if logFile == nil { return conn } else { diff --git a/transport.go b/transport.go index 6c4b8e0..daf66a1 100644 --- a/transport.go +++ b/transport.go @@ -2,7 +2,9 @@ package xmpp import ( "crypto/tls" + "encoding/xml" "errors" + "io" "strings" ) @@ -12,17 +14,22 @@ type TransportConfiguration struct { // Address is the XMPP Host and port to connect to. Host is of // the form 'serverhost:port' i.e "localhost:8888" Address string + Domain string ConnectTimeout int // Client timeout in seconds. Default to 15 // tls.Config must not be modified after having been passed to NewClient. Any // changes made after connecting are ignored. - TLSConfig *tls.Config + TLSConfig *tls.Config + CharsetReader func(charset string, input io.Reader) (io.Reader, error) // passed to xml decoder } type Transport interface { - Connect() error + Connect() (string, error) DoesStartTLS() bool - StartTLS(domain string) error + StartTLS() error + LogTraffic(logFile io.Writer) + + GetDecoder() *xml.Decoder IsSecure() bool Ping() error diff --git a/websocket_transport.go b/websocket_transport.go index 690bc1d..391c012 100644 --- a/websocket_transport.go +++ b/websocket_transport.go @@ -2,11 +2,15 @@ package xmpp import ( "context" + "encoding/xml" "errors" + "fmt" + "io" "net" "strings" "time" + "gosrc.io/xmpp/stanza" "nhooyr.io/websocket" ) @@ -16,35 +20,60 @@ var ServerDoesNotSupportXmppOverWebsocket = errors.New("The websocket server doe type WebsocketTransport struct { Config TransportConfiguration + decoder *xml.Decoder wsConn *websocket.Conn netConn net.Conn - ctx context.Context + logFile io.Writer } -func (t *WebsocketTransport) Connect() error { - t.ctx = context.Background() +func (t *WebsocketTransport) Connect() (string, error) { + ctx := context.Background() if t.Config.ConnectTimeout > 0 { - ctx, cancel := context.WithTimeout(t.ctx, time.Duration(t.Config.ConnectTimeout)*time.Second) - t.ctx = ctx + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, time.Duration(t.Config.ConnectTimeout)*time.Second) defer cancel() } - wsConn, response, err := websocket.Dial(t.ctx, t.Config.Address, &websocket.DialOptions{ + wsConn, response, err := websocket.Dial(ctx, t.Config.Address, &websocket.DialOptions{ Subprotocols: []string{"xmpp"}, }) if err != nil { - return NewConnError(err, true) + return "", NewConnError(err, true) } if response.Header.Get("Sec-WebSocket-Protocol") != "xmpp" { - return ServerDoesNotSupportXmppOverWebsocket + _ = wsConn.Close(websocket.StatusBadGateway, "Could not negotiate XMPP subprotocol") + return "", NewConnError(ServerDoesNotSupportXmppOverWebsocket, true) } + t.wsConn = wsConn - t.netConn = websocket.NetConn(t.ctx, t.wsConn, websocket.MessageText) - return nil + t.netConn = websocket.NetConn(ctx, t.wsConn, websocket.MessageText) + + handshake := fmt.Sprintf("", t.Config.Domain) + if _, err = t.Write([]byte(handshake)); err != nil { + _ = wsConn.Close(websocket.StatusBadGateway, "XMPP handshake error") + return "", NewConnError(err, false) + } + + handshakeResponse := make([]byte, 2048) + if _, err = t.Read(handshakeResponse); err != nil { + _ = wsConn.Close(websocket.StatusBadGateway, "XMPP handshake error") + return "", NewConnError(err, false) + } + + var openResponse = stanza.WebsocketOpen{} + if err = xml.Unmarshal(handshakeResponse, &openResponse); err != nil { + _ = wsConn.Close(websocket.StatusBadGateway, "XMPP handshake error") + return "", NewConnError(err, false) + } + + t.decoder = xml.NewDecoder(t) + t.decoder.CharsetReader = t.Config.CharsetReader + + return openResponse.Id, nil } -func (t WebsocketTransport) StartTLS(domain string) error { +func (t WebsocketTransport) StartTLS() error { return TLSNotSupported } @@ -52,6 +81,10 @@ func (t WebsocketTransport) DoesStartTLS() bool { return false } +func (t WebsocketTransport) GetDecoder() *xml.Decoder { + return t.decoder +} + func (t WebsocketTransport) IsSecure() bool { return strings.HasPrefix(t.Config.Address, "wss:") } @@ -59,19 +92,29 @@ func (t WebsocketTransport) IsSecure() bool { func (t WebsocketTransport) Ping() error { ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) defer cancel() - // Note that we do not use wsConn.Ping(), because not all websocket servers - // (ejabberd for example) implement ping frames - return t.wsConn.Write(ctx, websocket.MessageText, []byte(" ")) + return t.wsConn.Ping(ctx) } -func (t WebsocketTransport) Read(p []byte) (n int, err error) { - return t.netConn.Read(p) +func (t *WebsocketTransport) Read(p []byte) (n int, err error) { + n, err = t.netConn.Read(p) + if t.logFile != nil && n > 0 { + _, _ = fmt.Fprintf(t.logFile, "RECV:\n%s\n\n", p) + } + return } func (t WebsocketTransport) Write(p []byte) (n int, err error) { + if t.logFile != nil { + _, _ = fmt.Fprintf(t.logFile, "SEND:\n%s\n\n", p) + } return t.netConn.Write(p) } func (t WebsocketTransport) Close() error { + t.Write([]byte("")) return t.netConn.Close() } + +func (t *WebsocketTransport) LogTraffic(logFile io.Writer) { + t.logFile = logFile +} diff --git a/xmpp_transport.go b/xmpp_transport.go index 614a76d..64f2b9a 100644 --- a/xmpp_transport.go +++ b/xmpp_transport.go @@ -2,39 +2,65 @@ package xmpp import ( "crypto/tls" + "encoding/xml" "errors" + "fmt" + "io" "net" "time" + + "gosrc.io/xmpp/stanza" ) // XMPPTransport implements the XMPP native TCP transport type XMPPTransport struct { - Config TransportConfiguration - TLSConfig *tls.Config - // TCP level connection / can be replaced by a TLS session after starttls - conn net.Conn - isSecure bool + Config TransportConfiguration + TLSConfig *tls.Config + decoder *xml.Decoder + conn net.Conn + readWriter io.ReadWriter + isSecure bool } -func (t *XMPPTransport) Connect() error { +const xmppStreamOpen = "" + +func (t *XMPPTransport) Connect() (string, error) { var err error t.conn, err = net.DialTimeout("tcp", t.Config.Address, time.Duration(t.Config.ConnectTimeout)*time.Second) if err != nil { - return NewConnError(err, true) + return "", NewConnError(err, true) } - return nil + + if _, err = fmt.Fprintf(t.conn, xmppStreamOpen, t.Config.Domain, stanza.NSClient, stanza.NSStream); err != nil { + t.conn.Close() + return "", NewConnError(err, true) + } + + t.decoder = xml.NewDecoder(t.readWriter) + t.decoder.CharsetReader = t.Config.CharsetReader + sessionId, err := stanza.InitStream(t.decoder) + if err != nil { + t.conn.Close() + return "", NewConnError(err, false) + } + t.readWriter = t.conn + return sessionId, nil } func (t XMPPTransport) DoesStartTLS() bool { return true } +func (t XMPPTransport) GetDecoder() *xml.Decoder { + return t.decoder +} + func (t XMPPTransport) IsSecure() bool { return t.isSecure } -func (t *XMPPTransport) StartTLS(domain string) error { +func (t *XMPPTransport) StartTLS() error { if t.Config.TLSConfig == nil { t.TLSConfig = &tls.Config{} } else { @@ -42,7 +68,7 @@ func (t *XMPPTransport) StartTLS(domain string) error { } if t.TLSConfig.ServerName == "" { - t.TLSConfig.ServerName = domain + t.TLSConfig.ServerName = t.Config.Domain } tlsConn := tls.Client(t.conn, t.TLSConfig) // We convert existing connection to TLS @@ -51,7 +77,7 @@ func (t *XMPPTransport) StartTLS(domain string) error { } if !t.TLSConfig.InsecureSkipVerify { - if err := tlsConn.VerifyHostname(domain); err != nil { + if err := tlsConn.VerifyHostname(t.Config.Domain); err != nil { return err } } @@ -72,13 +98,18 @@ func (t XMPPTransport) Ping() error { } func (t XMPPTransport) Read(p []byte) (n int, err error) { - return t.conn.Read(p) + return t.readWriter.Read(p) } func (t XMPPTransport) Write(p []byte) (n int, err error) { - return t.conn.Write(p) + return t.readWriter.Write(p) } func (t XMPPTransport) Close() error { + _, _ = t.readWriter.Write([]byte("")) return t.conn.Close() } + +func (t *XMPPTransport) LogTraffic(logFile io.Writer) { + t.readWriter = &streamLogger{t.conn, logFile} +}