forked from jshiffer/go-xmpp
Refactor tests
This commit is contained in:
+141
-40
@@ -1,10 +1,11 @@
|
||||
package xmpp
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"gosrc.io/xmpp/stanza"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
@@ -23,44 +24,67 @@ type Session struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func NewSession(transport Transport, o Config, state SMState) (*Session, error) {
|
||||
s := new(Session)
|
||||
s.transport = transport
|
||||
s.SMState = state
|
||||
s.init(o)
|
||||
func NewSession(c *Client, state SMState) (*Session, error) {
|
||||
var s *Session
|
||||
if c.Session == nil {
|
||||
s = new(Session)
|
||||
s.transport = c.transport
|
||||
s.SMState = state
|
||||
s.init()
|
||||
} else {
|
||||
s = c.Session
|
||||
// We keep information about the previously set session, like the session ID, but we read server provided
|
||||
// info again in case it changed between session break and resume, such as features.
|
||||
s.init()
|
||||
}
|
||||
|
||||
if s.err != nil {
|
||||
return nil, NewConnError(s.err, true)
|
||||
}
|
||||
|
||||
if !transport.IsSecure() {
|
||||
s.startTlsIfSupported(o)
|
||||
if !c.transport.IsSecure() {
|
||||
s.startTlsIfSupported(c.config)
|
||||
}
|
||||
|
||||
if !transport.IsSecure() && !o.Insecure {
|
||||
if !c.transport.IsSecure() && !c.config.Insecure {
|
||||
err := fmt.Errorf("failed to negotiate TLS session : %s", s.err)
|
||||
return nil, NewConnError(err, true)
|
||||
}
|
||||
|
||||
if s.TlsEnabled {
|
||||
s.reset(o)
|
||||
s.reset()
|
||||
}
|
||||
|
||||
// auth
|
||||
s.auth(o)
|
||||
s.reset(o)
|
||||
s.auth(c.config)
|
||||
if s.err != nil {
|
||||
return s, s.err
|
||||
}
|
||||
s.reset()
|
||||
if s.err != nil {
|
||||
return s, s.err
|
||||
}
|
||||
|
||||
// attempt resumption
|
||||
if s.resume(o) {
|
||||
if s.resume(c.config) {
|
||||
return s, s.err
|
||||
}
|
||||
|
||||
// otherwise, bind resource and 'start' XMPP session
|
||||
s.bind(o)
|
||||
s.rfc3921Session(o)
|
||||
s.bind(c.config)
|
||||
if s.err != nil {
|
||||
return s, s.err
|
||||
}
|
||||
s.rfc3921Session()
|
||||
if s.err != nil {
|
||||
return s, s.err
|
||||
}
|
||||
|
||||
// Enable stream management if supported
|
||||
s.EnableStreamManagement(o)
|
||||
s.EnableStreamManagement(c.config)
|
||||
if s.err != nil {
|
||||
return s, s.err
|
||||
}
|
||||
|
||||
return s, s.err
|
||||
}
|
||||
@@ -70,19 +94,20 @@ func (s *Session) PacketId() string {
|
||||
return fmt.Sprintf("%x", s.lastPacketId)
|
||||
}
|
||||
|
||||
func (s *Session) init(o Config) {
|
||||
s.Features = s.open(o.parsedJid.Domain)
|
||||
// init gathers information on the session such as stream features from the server.
|
||||
func (s *Session) init() {
|
||||
s.Features = s.extractStreamFeatures()
|
||||
}
|
||||
|
||||
func (s *Session) reset(o Config) {
|
||||
func (s *Session) reset() {
|
||||
if s.StreamId, s.err = s.transport.StartStream(); s.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.Features = s.open(o.parsedJid.Domain)
|
||||
s.Features = s.extractStreamFeatures()
|
||||
}
|
||||
|
||||
func (s *Session) open(domain string) (f stanza.StreamFeatures) {
|
||||
func (s *Session) extractStreamFeatures() (f stanza.StreamFeatures) {
|
||||
// extract stream features
|
||||
if s.err = s.transport.GetDecoder().Decode(&f); s.err != nil {
|
||||
s.err = errors.New("stream open decode features: " + s.err.Error())
|
||||
@@ -90,7 +115,7 @@ func (s *Session) open(domain string) (f stanza.StreamFeatures) {
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Session) startTlsIfSupported(o Config) {
|
||||
func (s *Session) startTlsIfSupported(o *Config) {
|
||||
if s.err != nil {
|
||||
return
|
||||
}
|
||||
@@ -125,7 +150,7 @@ func (s *Session) startTlsIfSupported(o Config) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) auth(o Config) {
|
||||
func (s *Session) auth(o *Config) {
|
||||
if s.err != nil {
|
||||
return
|
||||
}
|
||||
@@ -134,7 +159,7 @@ func (s *Session) auth(o Config) {
|
||||
}
|
||||
|
||||
// Attempt to resume session using stream management
|
||||
func (s *Session) resume(o Config) bool {
|
||||
func (s *Session) resume(o *Config) bool {
|
||||
if !s.Features.DoesStreamManagement() {
|
||||
return false
|
||||
}
|
||||
@@ -142,9 +167,16 @@ func (s *Session) resume(o Config) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
fmt.Fprintf(s.transport, "<resume xmlns='%s' h='%d' previd='%s'/>",
|
||||
stanza.NSStreamManagement, s.SMState.Inbound, s.SMState.Id)
|
||||
rsm := stanza.SMResume{
|
||||
PrevId: s.SMState.Id,
|
||||
H: &s.SMState.Inbound,
|
||||
}
|
||||
data, err := xml.Marshal(rsm)
|
||||
|
||||
_, err = s.transport.Write(data)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
var packet stanza.Packet
|
||||
packet, s.err = stanza.NextPacket(s.transport.GetDecoder())
|
||||
if s.err == nil {
|
||||
@@ -165,20 +197,48 @@ func (s *Session) resume(o Config) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *Session) bind(o Config) {
|
||||
func (s *Session) bind(o *Config) {
|
||||
if s.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Send IQ message asking to bind to the local user name.
|
||||
var resource = o.parsedJid.Resource
|
||||
if resource != "" {
|
||||
fmt.Fprintf(s.transport, "<iq type='set' id='%s'><bind xmlns='%s'><resource>%s</resource></bind></iq>",
|
||||
s.PacketId(), stanza.NSBind, resource)
|
||||
} else {
|
||||
fmt.Fprintf(s.transport, "<iq type='set' id='%s'><bind xmlns='%s'/></iq>", s.PacketId(), stanza.NSBind)
|
||||
iqB, err := stanza.NewIQ(stanza.Attrs{
|
||||
Type: stanza.IQTypeSet,
|
||||
Id: s.PacketId(),
|
||||
})
|
||||
if err != nil {
|
||||
s.err = err
|
||||
return
|
||||
}
|
||||
|
||||
// Check if we already have a resource name, and include it in the request if so
|
||||
if resource != "" {
|
||||
iqB.Payload = &stanza.Bind{
|
||||
Resource: resource,
|
||||
}
|
||||
} else {
|
||||
iqB.Payload = &stanza.Bind{}
|
||||
|
||||
}
|
||||
|
||||
// Send the bind request IQ
|
||||
data, err := xml.Marshal(iqB)
|
||||
if err != nil {
|
||||
s.err = err
|
||||
return
|
||||
}
|
||||
n, err := s.transport.Write(data)
|
||||
if err != nil {
|
||||
s.err = err
|
||||
return
|
||||
} else if n == 0 {
|
||||
s.err = errors.New("failed to write bind iq stanza to the server : wrote 0 bytes")
|
||||
return
|
||||
}
|
||||
|
||||
// Check the server response
|
||||
var iq stanza.IQ
|
||||
if s.err = s.transport.GetDecoder().Decode(&iq); s.err != nil {
|
||||
s.err = errors.New("error decoding iq bind result: " + s.err.Error())
|
||||
@@ -197,7 +257,7 @@ func (s *Session) bind(o Config) {
|
||||
}
|
||||
|
||||
// After the bind, if the session is not optional (as per old RFC 3921), we send the session open iq.
|
||||
func (s *Session) rfc3921Session(o Config) {
|
||||
func (s *Session) rfc3921Session() {
|
||||
if s.err != nil {
|
||||
return
|
||||
}
|
||||
@@ -205,7 +265,29 @@ func (s *Session) rfc3921Session(o Config) {
|
||||
var iq stanza.IQ
|
||||
// We only negotiate session binding if it is mandatory, we skip it when optional.
|
||||
if !s.Features.Session.IsOptional() {
|
||||
fmt.Fprintf(s.transport, "<iq type='set' id='%s'><session xmlns='%s'/></iq>", s.PacketId(), stanza.NSSession)
|
||||
se, err := stanza.NewIQ(stanza.Attrs{
|
||||
Type: stanza.IQTypeSet,
|
||||
Id: s.PacketId(),
|
||||
})
|
||||
if err != nil {
|
||||
s.err = err
|
||||
return
|
||||
}
|
||||
se.Payload = &stanza.StreamSession{}
|
||||
data, err := xml.Marshal(se)
|
||||
if err != nil {
|
||||
s.err = err
|
||||
return
|
||||
}
|
||||
n, err := s.transport.Write(data)
|
||||
if err != nil {
|
||||
s.err = err
|
||||
return
|
||||
} else if n == 0 {
|
||||
s.err = errors.New("there was a problem marshaling the session IQ : wrote 0 bytes to server")
|
||||
return
|
||||
}
|
||||
|
||||
if s.err = s.transport.GetDecoder().Decode(&iq); s.err != nil {
|
||||
s.err = errors.New("expecting iq result after session open: " + s.err.Error())
|
||||
return
|
||||
@@ -214,28 +296,47 @@ func (s *Session) rfc3921Session(o Config) {
|
||||
}
|
||||
|
||||
// Enable stream management, with session resumption, if supported.
|
||||
func (s *Session) EnableStreamManagement(o Config) {
|
||||
func (s *Session) EnableStreamManagement(o *Config) {
|
||||
if s.err != nil {
|
||||
return
|
||||
}
|
||||
if !s.Features.DoesStreamManagement() {
|
||||
if !s.Features.DoesStreamManagement() || !o.StreamManagementEnable {
|
||||
return
|
||||
}
|
||||
q := stanza.NewUnAckQueue()
|
||||
ebleNonza := stanza.SMEnable{Resume: &o.streamManagementResume}
|
||||
pktStr, err := xml.Marshal(ebleNonza)
|
||||
if err != nil {
|
||||
s.err = err
|
||||
return
|
||||
}
|
||||
_, err = s.transport.Write(pktStr)
|
||||
if err != nil {
|
||||
s.err = err
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Fprintf(s.transport, "<enable xmlns='%s' resume='true'/>", stanza.NSStreamManagement)
|
||||
|
||||
var packet stanza.Packet
|
||||
packet, s.err = stanza.NextPacket(s.transport.GetDecoder())
|
||||
if s.err == nil {
|
||||
switch p := packet.(type) {
|
||||
case stanza.SMEnabled:
|
||||
s.SMState = SMState{Id: p.Id}
|
||||
// Server allows resumption or not using SMEnabled attribute "resume". We must read the server response
|
||||
// and update config accordingly
|
||||
b, err := strconv.ParseBool(p.Resume)
|
||||
if err != nil || !b {
|
||||
o.StreamManagementEnable = false
|
||||
}
|
||||
s.SMState = SMState{Id: p.Id, preferredReconAddr: p.Location}
|
||||
s.SMState.UnAckQueue = q
|
||||
case stanza.SMFailed:
|
||||
// TODO: Store error in SMState, for later inspection
|
||||
s.SMState = SMState{StreamErrorGroup: p.StreamErrorGroup}
|
||||
s.SMState.UnAckQueue = q
|
||||
s.err = errors.New("failed to establish session : " + s.SMState.StreamErrorGroup.GroupErrorName())
|
||||
default:
|
||||
s.err = errors.New("unexpected reply to SM enable")
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user