go-xmpp/session.go

343 lines
7.5 KiB
Go
Raw Permalink Normal View History

package xmpp
import (
2020-03-06 07:44:01 -08:00
"encoding/xml"
"errors"
"fmt"
"gosrc.io/xmpp/stanza"
2020-03-06 07:44:01 -08:00
"strconv"
)
type Session struct {
// Session info
BindJid string // Jabber ID as provided by XMPP server
StreamId string
SMState SMState
Features stanza.StreamFeatures
TlsEnabled bool
lastPacketId int
// read / write
transport Transport
// error management
err error
}
2020-03-06 07:44:01 -08:00
func NewSession(c *Client, state SMState) (*Session, error) {
var s *Session
if c.Session == nil {
s = new(Session)
s.transport = c.transport
s.SMState = state
s.init()
} else {
s = c.Session
// We keep information about the previously set session, like the session ID, but we read server provided
// info again in case it changed between session break and resume, such as features.
s.init()
}
if s.err != nil {
2019-10-06 10:37:56 -07:00
return nil, NewConnError(s.err, true)
}
2020-03-06 07:44:01 -08:00
if !c.transport.IsSecure() {
s.startTlsIfSupported(c.config)
}
2020-03-06 07:44:01 -08:00
if !c.transport.IsSecure() && !c.config.Insecure {
err := fmt.Errorf("failed to negotiate TLS session : %s", s.err)
2019-10-06 10:37:56 -07:00
return nil, NewConnError(err, true)
}
if s.TlsEnabled {
2020-03-06 07:44:01 -08:00
s.reset()
}
// auth
2020-03-06 07:44:01 -08:00
s.auth(c.config)
if s.err != nil {
return s, s.err
}
s.reset()
if s.err != nil {
return s, s.err
}
// attempt resumption
2020-03-06 07:44:01 -08:00
if s.resume(c.config) {
2019-10-06 10:37:56 -07:00
return s, s.err
}
// otherwise, bind resource and 'start' XMPP session
2020-03-06 07:44:01 -08:00
s.bind(c.config)
if s.err != nil {
return s, s.err
}
s.rfc3921Session()
if s.err != nil {
return s, s.err
}
// Enable stream management if supported
2020-03-06 07:44:01 -08:00
s.EnableStreamManagement(c.config)
if s.err != nil {
return s, s.err
}
2019-10-06 10:37:56 -07:00
return s, s.err
}
func (s *Session) PacketId() string {
s.lastPacketId++
return fmt.Sprintf("%x", s.lastPacketId)
}
2020-03-06 07:44:01 -08:00
// init gathers information on the session such as stream features from the server.
func (s *Session) init() {
s.Features = s.extractStreamFeatures()
}
2020-03-06 07:44:01 -08:00
func (s *Session) reset() {
if s.StreamId, s.err = s.transport.StartStream(); s.err != nil {
return
}
2020-03-06 07:44:01 -08:00
s.Features = s.extractStreamFeatures()
}
2020-03-06 07:44:01 -08:00
func (s *Session) extractStreamFeatures() (f stanza.StreamFeatures) {
// extract stream features
if s.err = s.transport.GetDecoder().Decode(&f); s.err != nil {
s.err = errors.New("stream open decode features: " + s.err.Error())
}
return
}
2020-03-06 07:44:01 -08:00
func (s *Session) startTlsIfSupported(o *Config) {
if s.err != nil {
2019-10-06 10:37:56 -07:00
return
}
if !s.transport.DoesStartTLS() {
2019-10-06 10:37:56 -07:00
if !o.Insecure {
s.err = errors.New("transport does not support starttls")
2019-10-06 10:37:56 -07:00
}
return
}
if _, ok := s.Features.DoesStartTLS(); ok {
fmt.Fprintf(s.transport, "<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>")
var k stanza.TLSProceed
if s.err = s.transport.GetDecoder().DecodeElement(&k, nil); s.err != nil {
s.err = errors.New("expecting starttls proceed: " + s.err.Error())
2019-10-06 10:37:56 -07:00
return
2019-07-15 15:26:21 -07:00
}
s.err = s.transport.StartTLS()
if s.err == nil {
s.TlsEnabled = true
}
2019-10-06 10:37:56 -07:00
return
}
// If we do not allow cleartext serverConnections, make it explicit that server do not support starttls
if !o.Insecure {
s.err = errors.New("XMPP server does not advertise support for starttls")
}
}
2020-03-06 07:44:01 -08:00
func (s *Session) auth(o *Config) {
if s.err != nil {
return
}
s.err = authSASL(s.transport, s.transport.GetDecoder(), s.Features, o.parsedJid.Node, o.Credential)
}
// Attempt to resume session using stream management
2020-03-06 07:44:01 -08:00
func (s *Session) resume(o *Config) bool {
if !s.Features.DoesStreamManagement() {
return false
}
if s.SMState.Id == "" {
return false
}
2020-03-06 07:44:01 -08:00
rsm := stanza.SMResume{
PrevId: s.SMState.Id,
H: &s.SMState.Inbound,
}
data, err := xml.Marshal(rsm)
2020-03-06 07:44:01 -08:00
_, err = s.transport.Write(data)
if err != nil {
return false
}
var packet stanza.Packet
packet, s.err = stanza.NextPacket(s.transport.GetDecoder())
if s.err == nil {
switch p := packet.(type) {
case stanza.SMResumed:
if p.PrevId != s.SMState.Id {
s.err = errors.New("session resumption: mismatched id")
s.SMState = SMState{}
return false
}
return true
case stanza.SMFailed:
default:
s.err = errors.New("unexpected reply to SM resume")
}
}
s.SMState = SMState{}
return false
}
2020-03-06 07:44:01 -08:00
func (s *Session) bind(o *Config) {
if s.err != nil {
return
}
// Send IQ message asking to bind to the local user name.
var resource = o.parsedJid.Resource
2020-03-06 07:44:01 -08:00
iqB, err := stanza.NewIQ(stanza.Attrs{
Type: stanza.IQTypeSet,
Id: s.PacketId(),
})
if err != nil {
s.err = err
return
}
// Check if we already have a resource name, and include it in the request if so
if resource != "" {
2020-03-06 07:44:01 -08:00
iqB.Payload = &stanza.Bind{
Resource: resource,
}
} else {
2020-03-06 07:44:01 -08:00
iqB.Payload = &stanza.Bind{}
}
// Send the bind request IQ
data, err := xml.Marshal(iqB)
if err != nil {
s.err = err
return
}
n, err := s.transport.Write(data)
if err != nil {
s.err = err
return
} else if n == 0 {
s.err = errors.New("failed to write bind iq stanza to the server : wrote 0 bytes")
return
}
2020-03-06 07:44:01 -08:00
// Check the server response
var iq stanza.IQ
if s.err = s.transport.GetDecoder().Decode(&iq); s.err != nil {
s.err = errors.New("error decoding iq bind result: " + s.err.Error())
return
}
2018-01-15 03:28:34 -08:00
// TODO Check all elements
switch payload := iq.Payload.(type) {
2019-06-29 07:49:54 -07:00
case *stanza.Bind:
s.BindJid = payload.Jid // our local id (with possibly randomly generated resource
default:
s.err = errors.New("iq bind result missing")
}
return
}
2019-06-29 07:49:54 -07:00
// After the bind, if the session is not optional (as per old RFC 3921), we send the session open iq.
2020-03-06 07:44:01 -08:00
func (s *Session) rfc3921Session() {
if s.err != nil {
return
}
var iq stanza.IQ
2019-06-29 08:48:38 -07:00
// We only negotiate session binding if it is mandatory, we skip it when optional.
2019-06-29 08:39:19 -07:00
if !s.Features.Session.IsOptional() {
2020-03-06 07:44:01 -08:00
se, err := stanza.NewIQ(stanza.Attrs{
Type: stanza.IQTypeSet,
Id: s.PacketId(),
})
if err != nil {
s.err = err
return
}
se.Payload = &stanza.StreamSession{}
data, err := xml.Marshal(se)
if err != nil {
s.err = err
return
}
n, err := s.transport.Write(data)
if err != nil {
s.err = err
return
} else if n == 0 {
s.err = errors.New("there was a problem marshaling the session IQ : wrote 0 bytes to server")
return
}
if s.err = s.transport.GetDecoder().Decode(&iq); s.err != nil {
s.err = errors.New("expecting iq result after session open: " + s.err.Error())
return
}
}
}
// Enable stream management, with session resumption, if supported.
2020-03-06 07:44:01 -08:00
func (s *Session) EnableStreamManagement(o *Config) {
if s.err != nil {
return
}
2020-03-06 07:44:01 -08:00
if !s.Features.DoesStreamManagement() || !o.StreamManagementEnable {
return
}
q := stanza.NewUnAckQueue()
ebleNonza := stanza.SMEnable{Resume: &o.streamManagementResume}
pktStr, err := xml.Marshal(ebleNonza)
if err != nil {
s.err = err
return
}
_, err = s.transport.Write(pktStr)
if err != nil {
s.err = err
return
}
var packet stanza.Packet
packet, s.err = stanza.NextPacket(s.transport.GetDecoder())
if s.err == nil {
switch p := packet.(type) {
case stanza.SMEnabled:
2020-03-06 07:44:01 -08:00
// Server allows resumption or not using SMEnabled attribute "resume". We must read the server response
// and update config accordingly
b, err := strconv.ParseBool(p.Resume)
if err != nil || !b {
o.StreamManagementEnable = false
}
s.SMState = SMState{Id: p.Id, preferredReconAddr: p.Location}
s.SMState.UnAckQueue = q
case stanza.SMFailed:
2019-07-31 09:51:16 -07:00
// TODO: Store error in SMState, for later inspection
2020-03-06 07:44:01 -08:00
s.SMState = SMState{StreamErrorGroup: p.StreamErrorGroup}
s.SMState.UnAckQueue = q
s.err = errors.New("failed to establish session : " + s.SMState.StreamErrorGroup.GroupErrorName())
default:
s.err = errors.New("unexpected reply to SM enable")
}
}
return
}