Transports need to handle open/close stanzas

XMPP and WebSocket transports require different open and close stanzas. To
handle this the responsibility handling those and creating the XML decoder is
moved to the Transport.
This commit is contained in:
Wichert Akkerman 2019-10-18 20:29:54 +02:00 committed by Mickaël Rémond
parent 25fd476328
commit 92329b48e6
15 changed files with 356 additions and 274 deletions

View File

@ -13,6 +13,7 @@ func main() {
opts := xmpp.ComponentOptions{
TransportConfiguration: xmpp.TransportConfiguration{
Address: "localhost:9999",
Domain: "service.localhost",
},
Domain: "service.localhost",
Secret: "mypass",

View File

@ -12,6 +12,7 @@ func main() {
opts := xmpp.ComponentOptions{
TransportConfiguration: xmpp.TransportConfiguration{
Address: "localhost:8888",
Domain: "service2.localhost",
},
Domain: "service2.localhost",
Secret: "mypass",

View File

@ -16,7 +16,8 @@ import (
func main() {
config := xmpp.Config{
TransportConfiguration: xmpp.TransportConfiguration{
Address: "localhost:5222",
// Address: "localhost:5222",
Address: "ws://127.0.0.1:5280/xmpp",
},
Jid: "test@localhost",
Credential: xmpp.Password("test"),

View File

@ -141,8 +141,15 @@ func NewClient(config Config, r *Router) (c *Client, err error) {
c.config.ConnectTimeout = 15 // 15 second as default
}
if config.TransportConfiguration.Domain == "" {
config.TransportConfiguration.Domain = config.parsedJid.Domain
}
c.transport = NewTransport(config.TransportConfiguration)
if config.StreamLogger != nil {
c.transport.LogTraffic(config.StreamLogger)
}
return
}
@ -158,7 +165,7 @@ func (c *Client) Connect() error {
func (c *Client) Resume(state SMState) error {
var err error
err = c.transport.Connect()
streamId, err := c.transport.Connect()
if err != nil {
return err
}
@ -168,6 +175,7 @@ func (c *Client) Resume(state SMState) error {
if c.Session, err = NewSession(c.transport, c.config, state); err != nil {
return err
}
c.Session.StreamId = streamId
c.updateState(StateSessionEstablished)
// Start the keepalive go routine
@ -181,13 +189,12 @@ func (c *Client) Resume(state SMState) error {
//fmt.Fprintf(client.conn, "<presence xml:lang='en'><show>%s</show><status>%s</status></presence>", "chat", "Online")
// TODO: Do we always want to send initial presence automatically ?
// Do we need an option to avoid that or do we rely on client to send the presence itself ?
fmt.Fprintf(c.Session.streamLogger, "<presence/>")
fmt.Fprintf(c.transport, "<presence/>")
return err
}
func (c *Client) Disconnect() {
_ = c.SendRaw("</stream:stream>")
// TODO: Add a way to wait for stream close acknowledgement from the server for clean disconnect
if c.transport != nil {
_ = c.transport.Close()
@ -210,7 +217,7 @@ func (c *Client) Send(packet stanza.Packet) error {
return errors.New("cannot marshal packet " + err.Error())
}
return c.sendWithWriter(c.Session.streamLogger, data)
return c.sendWithWriter(c.transport, data)
}
// SendRaw sends an XMPP stanza as a string to the server.
@ -223,7 +230,7 @@ func (c *Client) SendRaw(packet string) error {
return errors.New("client is not connected")
}
return c.sendWithWriter(c.Session.streamLogger, []byte(packet))
return c.sendWithWriter(c.transport, []byte(packet))
}
func (c *Client) sendWithWriter(writer io.Writer, packet []byte) error {
@ -238,7 +245,7 @@ func (c *Client) sendWithWriter(writer io.Writer, packet []byte) error {
// Loop: Receive data from server
func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}) (err error) {
for {
val, err := stanza.NextPacket(c.Session.decoder)
val, err := stanza.NextPacket(c.transport.GetDecoder())
if err != nil {
close(keepaliveQuit)
c.disconnected(state)

View File

@ -67,33 +67,25 @@ func (c *Component) Connect() error {
}
func (c *Component) Resume(sm SMState) error {
var err error
var streamId string
if c.ComponentOptions.TransportConfiguration.Domain == "" {
c.ComponentOptions.TransportConfiguration.Domain = c.ComponentOptions.Domain
}
c.transport = NewTransport(c.ComponentOptions.TransportConfiguration)
if err = c.transport.Connect(); err != nil {
if streamId, err = c.transport.Connect(); err != nil {
c.updateState(StateStreamError)
return err
}
c.updateState(StateConnected)
// 1. Send stream open tag
if _, err := fmt.Fprintf(c.transport, componentStreamOpen, c.Domain, stanza.NSComponent, stanza.NSStream); err != nil {
c.updateState(StateStreamError)
return NewConnError(errors.New("cannot send stream open "+err.Error()), false)
}
c.decoder = xml.NewDecoder(c.transport)
// 2. Initialize xml decoder and extract streamID from reply
streamId, err := stanza.InitStream(c.decoder)
if err != nil {
c.updateState(StateStreamError)
return NewConnError(errors.New("cannot init decoder "+err.Error()), false)
}
// 3. Authentication
// Authentication
if _, err := fmt.Fprintf(c.transport, "<handshake>%s</handshake>", c.handshake(streamId)); err != nil {
c.updateState(StateStreamError)
return NewConnError(errors.New("cannot send handshake "+err.Error()), false)
}
// 4. Check server response for authentication
// Check server response for authentication
val, err := stanza.NextPacket(c.decoder)
if err != nil {
c.updateState(StateDisconnected)
@ -116,7 +108,6 @@ func (c *Component) Resume(sm SMState) error {
}
func (c *Component) Disconnect() {
_ = c.SendRaw("</stream:stream>")
// TODO: Add a way to wait for stream close acknowledgement from the server for clean disconnect
if c.transport != nil {
_ = c.transport.Close()

View File

@ -1,7 +1,6 @@
package xmpp
import (
"io"
"os"
)
@ -19,5 +18,4 @@ type Config struct {
// Insecure can be set to true to allow to open a session without TLS. If TLS
// is supported on the server, we will still try to use it.
Insecure bool
CharsetReader func(charset string, input io.Reader) (io.Reader, error) // passed to xml decoder
}

View File

@ -1,16 +1,12 @@
package xmpp
import (
"encoding/xml"
"errors"
"fmt"
"io"
"gosrc.io/xmpp/stanza"
)
const xmppStreamOpen = "<?xml version='1.0'?><stream:stream to='%s' xmlns='%s' xmlns:stream='%s' version='1.0'>"
type Session struct {
// Session info
BindJid string // Jabber ID as provided by XMPP server
@ -21,8 +17,7 @@ type Session struct {
lastPacketId int
// read / write
streamLogger io.ReadWriter
decoder *xml.Decoder
transport Transport
// error management
err error
@ -30,10 +25,11 @@ type Session struct {
func NewSession(transport Transport, o Config, state SMState) (*Session, error) {
s := new(Session)
s.transport = transport
s.SMState = state
s.init(transport, o)
s.init(o)
s.startTlsIfSupported(transport, o.parsedJid.Domain, o)
s.startTlsIfSupported(o)
if s.err != nil {
return nil, NewConnError(s.err, true)
@ -45,12 +41,12 @@ func NewSession(transport Transport, o Config, state SMState) (*Session, error)
}
if s.TlsEnabled {
s.reset(transport, o)
s.reset(o)
}
// auth
s.auth(o)
s.reset(transport, o)
s.reset(o)
// attempt resumption
if s.resume(o) {
@ -72,51 +68,31 @@ func (s *Session) PacketId() string {
return fmt.Sprintf("%x", s.lastPacketId)
}
func (s *Session) init(transport Transport, o Config) {
s.setStreamLogger(transport, o)
func (s *Session) init(o Config) {
s.Features = s.open(o.parsedJid.Domain)
}
func (s *Session) reset(transport Transport, o Config) {
func (s *Session) reset(o Config) {
if s.err != nil {
return
}
s.setStreamLogger(transport, o)
s.Features = s.open(o.parsedJid.Domain)
}
func (s *Session) setStreamLogger(transport Transport, o Config) {
s.streamLogger = newStreamLogger(transport, o.StreamLogger)
s.decoder = xml.NewDecoder(s.streamLogger)
s.decoder.CharsetReader = o.CharsetReader
}
func (s *Session) open(domain string) (f stanza.StreamFeatures) {
// Send stream open tag
if _, s.err = fmt.Fprintf(s.streamLogger, xmppStreamOpen, domain, stanza.NSClient, stanza.NSStream); s.err != nil {
return
}
// Set xml decoder and extract streamID from reply
s.StreamId, s.err = stanza.InitStream(s.decoder) // TODO refactor / rename
if s.err != nil {
return
}
// extract stream features
if s.err = s.decoder.Decode(&f); s.err != nil {
if s.err = s.transport.GetDecoder().Decode(&f); s.err != nil {
s.err = errors.New("stream open decode features: " + s.err.Error())
}
return
}
func (s *Session) startTlsIfSupported(transport Transport, domain string, o Config) {
func (s *Session) startTlsIfSupported(o Config) {
if s.err != nil {
return
}
if !transport.DoesStartTLS() {
if !s.transport.DoesStartTLS() {
if !o.Insecure {
s.err = errors.New("Transport does not support starttls")
}
@ -124,15 +100,15 @@ func (s *Session) startTlsIfSupported(transport Transport, domain string, o Conf
}
if _, ok := s.Features.DoesStartTLS(); ok {
fmt.Fprintf(s.streamLogger, "<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>")
fmt.Fprintf(s.transport, "<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>")
var k stanza.TLSProceed
if s.err = s.decoder.DecodeElement(&k, nil); s.err != nil {
if s.err = s.transport.GetDecoder().DecodeElement(&k, nil); s.err != nil {
s.err = errors.New("expecting starttls proceed: " + s.err.Error())
return
}
s.err = transport.StartTLS(domain)
s.err = s.transport.StartTLS()
if s.err == nil {
s.TlsEnabled = true
@ -151,7 +127,7 @@ func (s *Session) auth(o Config) {
return
}
s.err = authSASL(s.streamLogger, s.decoder, s.Features, o.parsedJid.Node, o.Credential)
s.err = authSASL(s.transport, s.transport.GetDecoder(), s.Features, o.parsedJid.Node, o.Credential)
}
// Attempt to resume session using stream management
@ -163,11 +139,11 @@ func (s *Session) resume(o Config) bool {
return false
}
fmt.Fprintf(s.streamLogger, "<resume xmlns='%s' h='%d' previd='%s'/>",
fmt.Fprintf(s.transport, "<resume xmlns='%s' h='%d' previd='%s'/>",
stanza.NSStreamManagement, s.SMState.Inbound, s.SMState.Id)
var packet stanza.Packet
packet, s.err = stanza.NextPacket(s.decoder)
packet, s.err = stanza.NextPacket(s.transport.GetDecoder())
if s.err == nil {
switch p := packet.(type) {
case stanza.SMResumed:
@ -194,14 +170,14 @@ func (s *Session) bind(o Config) {
// Send IQ message asking to bind to the local user name.
var resource = o.parsedJid.Resource
if resource != "" {
fmt.Fprintf(s.streamLogger, "<iq type='set' id='%s'><bind xmlns='%s'><resource>%s</resource></bind></iq>",
fmt.Fprintf(s.transport, "<iq type='set' id='%s'><bind xmlns='%s'><resource>%s</resource></bind></iq>",
s.PacketId(), stanza.NSBind, resource)
} else {
fmt.Fprintf(s.streamLogger, "<iq type='set' id='%s'><bind xmlns='%s'/></iq>", s.PacketId(), stanza.NSBind)
fmt.Fprintf(s.transport, "<iq type='set' id='%s'><bind xmlns='%s'/></iq>", s.PacketId(), stanza.NSBind)
}
var iq stanza.IQ
if s.err = s.decoder.Decode(&iq); s.err != nil {
if s.err = s.transport.GetDecoder().Decode(&iq); s.err != nil {
s.err = errors.New("error decoding iq bind result: " + s.err.Error())
return
}
@ -226,8 +202,8 @@ func (s *Session) rfc3921Session(o Config) {
var iq stanza.IQ
// We only negotiate session binding if it is mandatory, we skip it when optional.
if !s.Features.Session.IsOptional() {
fmt.Fprintf(s.streamLogger, "<iq type='set' id='%s'><session xmlns='%s'/></iq>", s.PacketId(), stanza.NSSession)
if s.err = s.decoder.Decode(&iq); s.err != nil {
fmt.Fprintf(s.transport, "<iq type='set' id='%s'><session xmlns='%s'/></iq>", s.PacketId(), stanza.NSSession)
if s.err = s.transport.GetDecoder().Decode(&iq); s.err != nil {
s.err = errors.New("expecting iq result after session open: " + s.err.Error())
return
}
@ -243,10 +219,10 @@ func (s *Session) EnableStreamManagement(o Config) {
return
}
fmt.Fprintf(s.streamLogger, "<enable xmlns='%s' resume='true'/>", stanza.NSStreamManagement)
fmt.Fprintf(s.transport, "<enable xmlns='%s' resume='true'/>", stanza.NSStreamManagement)
var packet stanza.Packet
packet, s.err = stanza.NextPacket(s.decoder)
packet, s.err = stanza.NextPacket(s.transport.GetDecoder())
if s.err == nil {
switch p := packet.(type) {
case stanza.SMEnabled:

13
stanza/open.go Normal file
View File

@ -0,0 +1,13 @@
package stanza
import "encoding/xml"
// Open Packet
// Reference: WebSocket connections must start with this element
// https://tools.ietf.org/html/rfc7395#section-3.4
type WebsocketOpen struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-framing open"`
From string `xml:"from,attr"`
Id string `xml:"id,attr"`
Version string `xml:"version,attr"`
}

View File

@ -1,167 +1,14 @@
package stanza
import (
"encoding/xml"
)
import "encoding/xml"
// ============================================================================
// StreamFeatures Packet
// Reference: The active stream features are published on
// https://xmpp.org/registrar/stream-features.html
// Note: That page misses draft and experimental XEP (i.e CSI, etc)
type StreamFeatures struct {
XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"`
// Server capabilities hash
Caps Caps
// Stream features
StartTLS tlsStartTLS
Mechanisms saslMechanisms
Bind Bind
StreamManagement streamManagement
// Obsolete
Session StreamSession
// ProcessOne Stream Features
P1Push p1Push
P1Rebind p1Rebind
p1Ack p1Ack
Any []xml.Name `xml:",any"`
}
func (StreamFeatures) Name() string {
return "stream:features"
}
type streamFeatureDecoder struct{}
var streamFeatures streamFeatureDecoder
func (streamFeatureDecoder) decode(p *xml.Decoder, se xml.StartElement) (StreamFeatures, error) {
var packet StreamFeatures
err := p.DecodeElement(&packet, &se)
return packet, err
}
// Capabilities
// Reference: https://xmpp.org/extensions/xep-0115.html#stream
// "A server MAY include its entity capabilities in a stream feature element so that connecting clients
// and peer servers do not need to send service discovery requests each time they connect."
// This is not a stream feature but a way to let client cache server disco info.
type Caps struct {
XMLName xml.Name `xml:"http://jabber.org/protocol/caps c"`
Hash string `xml:"hash,attr"`
Node string `xml:"node,attr"`
Ver string `xml:"ver,attr"`
Ext string `xml:"ext,attr,omitempty"`
}
// ============================================================================
// Supported Stream Features
// StartTLS feature
// Reference: RFC 6120 - https://tools.ietf.org/html/rfc6120#section-5.4
type tlsStartTLS struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls starttls"`
Required bool
}
// UnmarshalXML implements custom parsing startTLS required flag
func (stls *tlsStartTLS) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
stls.XMLName = start.Name
// Check subelements to extract required field as boolean
for {
t, err := d.Token()
if err != nil {
return err
}
switch tt := t.(type) {
case xml.StartElement:
elt := new(Node)
err = d.DecodeElement(elt, &tt)
if err != nil {
return err
}
if elt.XMLName.Local == "required" {
stls.Required = true
}
case xml.EndElement:
if tt == start.End() {
return nil
}
}
}
}
func (sf *StreamFeatures) DoesStartTLS() (feature tlsStartTLS, isSupported bool) {
if sf.StartTLS.XMLName.Space+" "+sf.StartTLS.XMLName.Local == nsTLS+" starttls" {
return sf.StartTLS, true
}
return feature, false
}
// Mechanisms
// Reference: RFC 6120 - https://tools.ietf.org/html/rfc6120#section-6.4.1
type saslMechanisms struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl mechanisms"`
Mechanism []string `xml:"mechanism"`
}
// StreamManagement
// Reference: XEP-0198 - https://xmpp.org/extensions/xep-0198.html#feature
type streamManagement struct {
XMLName xml.Name `xml:"urn:xmpp:sm:3 sm"`
}
func (sf *StreamFeatures) DoesStreamManagement() (isSupported bool) {
if sf.StreamManagement.XMLName.Space+" "+sf.StreamManagement.XMLName.Local == "urn:xmpp:sm:3 sm" {
return true
}
return false
}
// P1 extensions
// Reference: https://docs.ejabberd.im/developer/mobile/core-features/
// p1:push support
type p1Push struct {
XMLName xml.Name `xml:"p1:push push"`
}
// p1:rebind suppor
type p1Rebind struct {
XMLName xml.Name `xml:"p1:rebind rebind"`
}
// p1:ack support
type p1Ack struct {
XMLName xml.Name `xml:"p1:ack ack"`
}
// ============================================================================
// StreamError Packet
type StreamError struct {
XMLName xml.Name `xml:"http://etherx.jabber.org/streams error"`
Error xml.Name `xml:",any"`
Text string `xml:"urn:ietf:params:xml:ns:xmpp-streams text"`
}
func (StreamError) Name() string {
return "stream:error"
}
type streamErrorDecoder struct{}
var streamError streamErrorDecoder
func (streamErrorDecoder) decode(p *xml.Decoder, se xml.StartElement) (StreamError, error) {
var packet StreamError
err := p.DecodeElement(&packet, &se)
return packet, err
// Start of stream
// Reference: XMPP Core stream open
// https://tools.ietf.org/html/rfc6120#section-4.2
type Stream struct {
XMLName xml.Name `xml:"http://etherx.jabber.org/streams stream"`
From string `xml:"from,attr"`
To string `xml:"to,attr"`
Id string `xml:"id,attr"`
Version string `xml:"version,attr"`
}

167
stanza/stream_features.go Normal file
View File

@ -0,0 +1,167 @@
package stanza
import (
"encoding/xml"
)
// ============================================================================
// StreamFeatures Packet
// Reference: The active stream features are published on
// https://xmpp.org/registrar/stream-features.html
// Note: That page misses draft and experimental XEP (i.e CSI, etc)
type StreamFeatures struct {
XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"`
// Server capabilities hash
Caps Caps
// Stream features
StartTLS tlsStartTLS
Mechanisms saslMechanisms
Bind Bind
StreamManagement streamManagement
// Obsolete
Session StreamSession
// ProcessOne Stream Features
P1Push p1Push
P1Rebind p1Rebind
p1Ack p1Ack
Any []xml.Name `xml:",any"`
}
func (StreamFeatures) Name() string {
return "stream:features"
}
type streamFeatureDecoder struct{}
var streamFeatures streamFeatureDecoder
func (streamFeatureDecoder) decode(p *xml.Decoder, se xml.StartElement) (StreamFeatures, error) {
var packet StreamFeatures
err := p.DecodeElement(&packet, &se)
return packet, err
}
// Capabilities
// Reference: https://xmpp.org/extensions/xep-0115.html#stream
// "A server MAY include its entity capabilities in a stream feature element so that connecting clients
// and peer servers do not need to send service discovery requests each time they connect."
// This is not a stream feature but a way to let client cache server disco info.
type Caps struct {
XMLName xml.Name `xml:"http://jabber.org/protocol/caps c"`
Hash string `xml:"hash,attr"`
Node string `xml:"node,attr"`
Ver string `xml:"ver,attr"`
Ext string `xml:"ext,attr,omitempty"`
}
// ============================================================================
// Supported Stream Features
// StartTLS feature
// Reference: RFC 6120 - https://tools.ietf.org/html/rfc6120#section-5.4
type tlsStartTLS struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls starttls"`
Required bool
}
// UnmarshalXML implements custom parsing startTLS required flag
func (stls *tlsStartTLS) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
stls.XMLName = start.Name
// Check subelements to extract required field as boolean
for {
t, err := d.Token()
if err != nil {
return err
}
switch tt := t.(type) {
case xml.StartElement:
elt := new(Node)
err = d.DecodeElement(elt, &tt)
if err != nil {
return err
}
if elt.XMLName.Local == "required" {
stls.Required = true
}
case xml.EndElement:
if tt == start.End() {
return nil
}
}
}
}
func (sf *StreamFeatures) DoesStartTLS() (feature tlsStartTLS, isSupported bool) {
if sf.StartTLS.XMLName.Space+" "+sf.StartTLS.XMLName.Local == nsTLS+" starttls" {
return sf.StartTLS, true
}
return feature, false
}
// Mechanisms
// Reference: RFC 6120 - https://tools.ietf.org/html/rfc6120#section-6.4.1
type saslMechanisms struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl mechanisms"`
Mechanism []string `xml:"mechanism"`
}
// StreamManagement
// Reference: XEP-0198 - https://xmpp.org/extensions/xep-0198.html#feature
type streamManagement struct {
XMLName xml.Name `xml:"urn:xmpp:sm:3 sm"`
}
func (sf *StreamFeatures) DoesStreamManagement() (isSupported bool) {
if sf.StreamManagement.XMLName.Space+" "+sf.StreamManagement.XMLName.Local == "urn:xmpp:sm:3 sm" {
return true
}
return false
}
// P1 extensions
// Reference: https://docs.ejabberd.im/developer/mobile/core-features/
// p1:push support
type p1Push struct {
XMLName xml.Name `xml:"p1:push push"`
}
// p1:rebind suppor
type p1Rebind struct {
XMLName xml.Name `xml:"p1:rebind rebind"`
}
// p1:ack support
type p1Ack struct {
XMLName xml.Name `xml:"p1:ack ack"`
}
// ============================================================================
// StreamError Packet
type StreamError struct {
XMLName xml.Name `xml:"http://etherx.jabber.org/streams error"`
Error xml.Name `xml:",any"`
Text string `xml:"urn:ietf:params:xml:ns:xmpp-streams text"`
}
func (StreamError) Name() string {
return "stream:error"
}
type streamErrorDecoder struct{}
var streamError streamErrorDecoder
func (streamErrorDecoder) decode(p *xml.Decoder, se xml.StartElement) (StreamError, error) {
var packet StreamError
err := p.DecodeElement(&packet, &se)
return packet, err
}

View File

@ -2,17 +2,16 @@ package xmpp
import (
"io"
"os"
)
// Mediated Read / Write on socket
// Used if logFile from Config is not nil
type streamLogger struct {
socket io.ReadWriter // Actual connection
logFile *os.File
logFile io.Writer
}
func newStreamLogger(conn io.ReadWriter, logFile *os.File) io.ReadWriter {
func newStreamLogger(conn io.ReadWriter, logFile io.Writer) io.ReadWriter {
if logFile == nil {
return conn
} else {

View File

@ -2,7 +2,9 @@ package xmpp
import (
"crypto/tls"
"encoding/xml"
"errors"
"io"
"strings"
)
@ -12,17 +14,22 @@ type TransportConfiguration struct {
// Address is the XMPP Host and port to connect to. Host is of
// the form 'serverhost:port' i.e "localhost:8888"
Address string
Domain string
ConnectTimeout int // Client timeout in seconds. Default to 15
// tls.Config must not be modified after having been passed to NewClient. Any
// changes made after connecting are ignored.
TLSConfig *tls.Config
CharsetReader func(charset string, input io.Reader) (io.Reader, error) // passed to xml decoder
}
type Transport interface {
Connect() error
Connect() (string, error)
DoesStartTLS() bool
StartTLS(domain string) error
StartTLS() error
LogTraffic(logFile io.Writer)
GetDecoder() *xml.Decoder
IsSecure() bool
Ping() error

View File

@ -2,11 +2,15 @@ package xmpp
import (
"context"
"encoding/xml"
"errors"
"fmt"
"io"
"net"
"strings"
"time"
"gosrc.io/xmpp/stanza"
"nhooyr.io/websocket"
)
@ -16,35 +20,60 @@ var ServerDoesNotSupportXmppOverWebsocket = errors.New("The websocket server doe
type WebsocketTransport struct {
Config TransportConfiguration
decoder *xml.Decoder
wsConn *websocket.Conn
netConn net.Conn
ctx context.Context
logFile io.Writer
}
func (t *WebsocketTransport) Connect() error {
t.ctx = context.Background()
func (t *WebsocketTransport) Connect() (string, error) {
ctx := context.Background()
if t.Config.ConnectTimeout > 0 {
ctx, cancel := context.WithTimeout(t.ctx, time.Duration(t.Config.ConnectTimeout)*time.Second)
t.ctx = ctx
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, time.Duration(t.Config.ConnectTimeout)*time.Second)
defer cancel()
}
wsConn, response, err := websocket.Dial(t.ctx, t.Config.Address, &websocket.DialOptions{
wsConn, response, err := websocket.Dial(ctx, t.Config.Address, &websocket.DialOptions{
Subprotocols: []string{"xmpp"},
})
if err != nil {
return NewConnError(err, true)
return "", NewConnError(err, true)
}
if response.Header.Get("Sec-WebSocket-Protocol") != "xmpp" {
return ServerDoesNotSupportXmppOverWebsocket
_ = wsConn.Close(websocket.StatusBadGateway, "Could not negotiate XMPP subprotocol")
return "", NewConnError(ServerDoesNotSupportXmppOverWebsocket, true)
}
t.wsConn = wsConn
t.netConn = websocket.NetConn(t.ctx, t.wsConn, websocket.MessageText)
return nil
t.netConn = websocket.NetConn(ctx, t.wsConn, websocket.MessageText)
handshake := fmt.Sprintf("<open xmlns=\"urn:ietf:params:xml:ns:xmpp-framing\" to=\"%s\" version=\"1.0\" />", t.Config.Domain)
if _, err = t.Write([]byte(handshake)); err != nil {
_ = wsConn.Close(websocket.StatusBadGateway, "XMPP handshake error")
return "", NewConnError(err, false)
}
handshakeResponse := make([]byte, 2048)
if _, err = t.Read(handshakeResponse); err != nil {
_ = wsConn.Close(websocket.StatusBadGateway, "XMPP handshake error")
return "", NewConnError(err, false)
}
var openResponse = stanza.WebsocketOpen{}
if err = xml.Unmarshal(handshakeResponse, &openResponse); err != nil {
_ = wsConn.Close(websocket.StatusBadGateway, "XMPP handshake error")
return "", NewConnError(err, false)
}
t.decoder = xml.NewDecoder(t)
t.decoder.CharsetReader = t.Config.CharsetReader
return openResponse.Id, nil
}
func (t WebsocketTransport) StartTLS(domain string) error {
func (t WebsocketTransport) StartTLS() error {
return TLSNotSupported
}
@ -52,6 +81,10 @@ func (t WebsocketTransport) DoesStartTLS() bool {
return false
}
func (t WebsocketTransport) GetDecoder() *xml.Decoder {
return t.decoder
}
func (t WebsocketTransport) IsSecure() bool {
return strings.HasPrefix(t.Config.Address, "wss:")
}
@ -59,19 +92,29 @@ func (t WebsocketTransport) IsSecure() bool {
func (t WebsocketTransport) Ping() error {
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
defer cancel()
// Note that we do not use wsConn.Ping(), because not all websocket servers
// (ejabberd for example) implement ping frames
return t.wsConn.Write(ctx, websocket.MessageText, []byte(" "))
return t.wsConn.Ping(ctx)
}
func (t WebsocketTransport) Read(p []byte) (n int, err error) {
return t.netConn.Read(p)
func (t *WebsocketTransport) Read(p []byte) (n int, err error) {
n, err = t.netConn.Read(p)
if t.logFile != nil && n > 0 {
_, _ = fmt.Fprintf(t.logFile, "RECV:\n%s\n\n", p)
}
return
}
func (t WebsocketTransport) Write(p []byte) (n int, err error) {
if t.logFile != nil {
_, _ = fmt.Fprintf(t.logFile, "SEND:\n%s\n\n", p)
}
return t.netConn.Write(p)
}
func (t WebsocketTransport) Close() error {
t.Write([]byte("<close xmlns=\"urn:ietf:params:xml:ns:xmpp-framing\" />"))
return t.netConn.Close()
}
func (t *WebsocketTransport) LogTraffic(logFile io.Writer) {
t.logFile = logFile
}

View File

@ -2,39 +2,65 @@ package xmpp
import (
"crypto/tls"
"encoding/xml"
"errors"
"fmt"
"io"
"net"
"time"
"gosrc.io/xmpp/stanza"
)
// XMPPTransport implements the XMPP native TCP transport
type XMPPTransport struct {
Config TransportConfiguration
TLSConfig *tls.Config
// TCP level connection / can be replaced by a TLS session after starttls
decoder *xml.Decoder
conn net.Conn
readWriter io.ReadWriter
isSecure bool
}
func (t *XMPPTransport) Connect() error {
const xmppStreamOpen = "<?xml version='1.0'?><stream:stream to='%s' xmlns='%s' xmlns:stream='%s' version='1.0'>"
func (t *XMPPTransport) Connect() (string, error) {
var err error
t.conn, err = net.DialTimeout("tcp", t.Config.Address, time.Duration(t.Config.ConnectTimeout)*time.Second)
if err != nil {
return NewConnError(err, true)
return "", NewConnError(err, true)
}
return nil
if _, err = fmt.Fprintf(t.conn, xmppStreamOpen, t.Config.Domain, stanza.NSClient, stanza.NSStream); err != nil {
t.conn.Close()
return "", NewConnError(err, true)
}
t.decoder = xml.NewDecoder(t.readWriter)
t.decoder.CharsetReader = t.Config.CharsetReader
sessionId, err := stanza.InitStream(t.decoder)
if err != nil {
t.conn.Close()
return "", NewConnError(err, false)
}
t.readWriter = t.conn
return sessionId, nil
}
func (t XMPPTransport) DoesStartTLS() bool {
return true
}
func (t XMPPTransport) GetDecoder() *xml.Decoder {
return t.decoder
}
func (t XMPPTransport) IsSecure() bool {
return t.isSecure
}
func (t *XMPPTransport) StartTLS(domain string) error {
func (t *XMPPTransport) StartTLS() error {
if t.Config.TLSConfig == nil {
t.TLSConfig = &tls.Config{}
} else {
@ -42,7 +68,7 @@ func (t *XMPPTransport) StartTLS(domain string) error {
}
if t.TLSConfig.ServerName == "" {
t.TLSConfig.ServerName = domain
t.TLSConfig.ServerName = t.Config.Domain
}
tlsConn := tls.Client(t.conn, t.TLSConfig)
// We convert existing connection to TLS
@ -51,7 +77,7 @@ func (t *XMPPTransport) StartTLS(domain string) error {
}
if !t.TLSConfig.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(domain); err != nil {
if err := tlsConn.VerifyHostname(t.Config.Domain); err != nil {
return err
}
}
@ -72,13 +98,18 @@ func (t XMPPTransport) Ping() error {
}
func (t XMPPTransport) Read(p []byte) (n int, err error) {
return t.conn.Read(p)
return t.readWriter.Read(p)
}
func (t XMPPTransport) Write(p []byte) (n int, err error) {
return t.conn.Write(p)
return t.readWriter.Write(p)
}
func (t XMPPTransport) Close() error {
_, _ = t.readWriter.Write([]byte("</stream:stream>"))
return t.conn.Close()
}
func (t *XMPPTransport) LogTraffic(logFile io.Writer) {
t.readWriter = &streamLogger{t.conn, logFile}
}