Transports need to handle open/close stanzas

XMPP and WebSocket transports require different open and close stanzas. To
handle this the responsibility handling those and creating the XML decoder is
moved to the Transport.
This commit is contained in:
Wichert Akkerman
2019-10-18 20:29:54 +02:00
committed by Mickaël Rémond
parent 25fd476328
commit 92329b48e6
15 changed files with 356 additions and 274 deletions

View File

@@ -1,16 +1,12 @@
package xmpp
import (
"encoding/xml"
"errors"
"fmt"
"io"
"gosrc.io/xmpp/stanza"
)
const xmppStreamOpen = "<?xml version='1.0'?><stream:stream to='%s' xmlns='%s' xmlns:stream='%s' version='1.0'>"
type Session struct {
// Session info
BindJid string // Jabber ID as provided by XMPP server
@@ -21,8 +17,7 @@ type Session struct {
lastPacketId int
// read / write
streamLogger io.ReadWriter
decoder *xml.Decoder
transport Transport
// error management
err error
@@ -30,10 +25,11 @@ type Session struct {
func NewSession(transport Transport, o Config, state SMState) (*Session, error) {
s := new(Session)
s.transport = transport
s.SMState = state
s.init(transport, o)
s.init(o)
s.startTlsIfSupported(transport, o.parsedJid.Domain, o)
s.startTlsIfSupported(o)
if s.err != nil {
return nil, NewConnError(s.err, true)
@@ -45,12 +41,12 @@ func NewSession(transport Transport, o Config, state SMState) (*Session, error)
}
if s.TlsEnabled {
s.reset(transport, o)
s.reset(o)
}
// auth
s.auth(o)
s.reset(transport, o)
s.reset(o)
// attempt resumption
if s.resume(o) {
@@ -72,51 +68,31 @@ func (s *Session) PacketId() string {
return fmt.Sprintf("%x", s.lastPacketId)
}
func (s *Session) init(transport Transport, o Config) {
s.setStreamLogger(transport, o)
func (s *Session) init(o Config) {
s.Features = s.open(o.parsedJid.Domain)
}
func (s *Session) reset(transport Transport, o Config) {
func (s *Session) reset(o Config) {
if s.err != nil {
return
}
s.setStreamLogger(transport, o)
s.Features = s.open(o.parsedJid.Domain)
}
func (s *Session) setStreamLogger(transport Transport, o Config) {
s.streamLogger = newStreamLogger(transport, o.StreamLogger)
s.decoder = xml.NewDecoder(s.streamLogger)
s.decoder.CharsetReader = o.CharsetReader
}
func (s *Session) open(domain string) (f stanza.StreamFeatures) {
// Send stream open tag
if _, s.err = fmt.Fprintf(s.streamLogger, xmppStreamOpen, domain, stanza.NSClient, stanza.NSStream); s.err != nil {
return
}
// Set xml decoder and extract streamID from reply
s.StreamId, s.err = stanza.InitStream(s.decoder) // TODO refactor / rename
if s.err != nil {
return
}
// extract stream features
if s.err = s.decoder.Decode(&f); s.err != nil {
if s.err = s.transport.GetDecoder().Decode(&f); s.err != nil {
s.err = errors.New("stream open decode features: " + s.err.Error())
}
return
}
func (s *Session) startTlsIfSupported(transport Transport, domain string, o Config) {
func (s *Session) startTlsIfSupported(o Config) {
if s.err != nil {
return
}
if !transport.DoesStartTLS() {
if !s.transport.DoesStartTLS() {
if !o.Insecure {
s.err = errors.New("Transport does not support starttls")
}
@@ -124,15 +100,15 @@ func (s *Session) startTlsIfSupported(transport Transport, domain string, o Conf
}
if _, ok := s.Features.DoesStartTLS(); ok {
fmt.Fprintf(s.streamLogger, "<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>")
fmt.Fprintf(s.transport, "<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>")
var k stanza.TLSProceed
if s.err = s.decoder.DecodeElement(&k, nil); s.err != nil {
if s.err = s.transport.GetDecoder().DecodeElement(&k, nil); s.err != nil {
s.err = errors.New("expecting starttls proceed: " + s.err.Error())
return
}
s.err = transport.StartTLS(domain)
s.err = s.transport.StartTLS()
if s.err == nil {
s.TlsEnabled = true
@@ -151,7 +127,7 @@ func (s *Session) auth(o Config) {
return
}
s.err = authSASL(s.streamLogger, s.decoder, s.Features, o.parsedJid.Node, o.Credential)
s.err = authSASL(s.transport, s.transport.GetDecoder(), s.Features, o.parsedJid.Node, o.Credential)
}
// Attempt to resume session using stream management
@@ -163,11 +139,11 @@ func (s *Session) resume(o Config) bool {
return false
}
fmt.Fprintf(s.streamLogger, "<resume xmlns='%s' h='%d' previd='%s'/>",
fmt.Fprintf(s.transport, "<resume xmlns='%s' h='%d' previd='%s'/>",
stanza.NSStreamManagement, s.SMState.Inbound, s.SMState.Id)
var packet stanza.Packet
packet, s.err = stanza.NextPacket(s.decoder)
packet, s.err = stanza.NextPacket(s.transport.GetDecoder())
if s.err == nil {
switch p := packet.(type) {
case stanza.SMResumed:
@@ -194,14 +170,14 @@ func (s *Session) bind(o Config) {
// Send IQ message asking to bind to the local user name.
var resource = o.parsedJid.Resource
if resource != "" {
fmt.Fprintf(s.streamLogger, "<iq type='set' id='%s'><bind xmlns='%s'><resource>%s</resource></bind></iq>",
fmt.Fprintf(s.transport, "<iq type='set' id='%s'><bind xmlns='%s'><resource>%s</resource></bind></iq>",
s.PacketId(), stanza.NSBind, resource)
} else {
fmt.Fprintf(s.streamLogger, "<iq type='set' id='%s'><bind xmlns='%s'/></iq>", s.PacketId(), stanza.NSBind)
fmt.Fprintf(s.transport, "<iq type='set' id='%s'><bind xmlns='%s'/></iq>", s.PacketId(), stanza.NSBind)
}
var iq stanza.IQ
if s.err = s.decoder.Decode(&iq); s.err != nil {
if s.err = s.transport.GetDecoder().Decode(&iq); s.err != nil {
s.err = errors.New("error decoding iq bind result: " + s.err.Error())
return
}
@@ -226,8 +202,8 @@ func (s *Session) rfc3921Session(o Config) {
var iq stanza.IQ
// We only negotiate session binding if it is mandatory, we skip it when optional.
if !s.Features.Session.IsOptional() {
fmt.Fprintf(s.streamLogger, "<iq type='set' id='%s'><session xmlns='%s'/></iq>", s.PacketId(), stanza.NSSession)
if s.err = s.decoder.Decode(&iq); s.err != nil {
fmt.Fprintf(s.transport, "<iq type='set' id='%s'><session xmlns='%s'/></iq>", s.PacketId(), stanza.NSSession)
if s.err = s.transport.GetDecoder().Decode(&iq); s.err != nil {
s.err = errors.New("expecting iq result after session open: " + s.err.Error())
return
}
@@ -243,10 +219,10 @@ func (s *Session) EnableStreamManagement(o Config) {
return
}
fmt.Fprintf(s.streamLogger, "<enable xmlns='%s' resume='true'/>", stanza.NSStreamManagement)
fmt.Fprintf(s.transport, "<enable xmlns='%s' resume='true'/>", stanza.NSStreamManagement)
var packet stanza.Packet
packet, s.err = stanza.NextPacket(s.decoder)
packet, s.err = stanza.NextPacket(s.transport.GetDecoder())
if s.err == nil {
switch p := packet.(type) {
case stanza.SMEnabled: