forked from jshiffer/go-xmpp
Introduce Transport interface
This commit is contained in:
committed by
Mickaël Rémond
parent
2781563ea7
commit
a3c62e515e
+27
-44
@@ -1,12 +1,10 @@
|
||||
package xmpp
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/xml"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"gosrc.io/xmpp/stanza"
|
||||
)
|
||||
@@ -30,35 +28,33 @@ type Session struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func NewSession(conn net.Conn, o Config, state SMState) (net.Conn, *Session, error) {
|
||||
func NewSession(transport Transport, o Config, state SMState) (*Session, error) {
|
||||
s := new(Session)
|
||||
s.SMState = state
|
||||
s.init(conn, o)
|
||||
s.init(transport, o)
|
||||
|
||||
// starttls
|
||||
var tlsConn net.Conn
|
||||
tlsConn = s.startTlsIfSupported(conn, o.parsedJid.Domain, o)
|
||||
s.startTlsIfSupported(transport, o.parsedJid.Domain, o)
|
||||
|
||||
if s.err != nil {
|
||||
return nil, nil, NewConnError(s.err, true)
|
||||
return nil, NewConnError(s.err, true)
|
||||
}
|
||||
|
||||
if !s.TlsEnabled && !o.Insecure {
|
||||
err := fmt.Errorf("failed to negotiate TLS session : %s", s.err)
|
||||
return nil, nil, NewConnError(err, true)
|
||||
return nil, NewConnError(err, true)
|
||||
}
|
||||
|
||||
if s.TlsEnabled {
|
||||
s.reset(conn, tlsConn, o)
|
||||
s.reset(transport, o)
|
||||
}
|
||||
|
||||
// auth
|
||||
s.auth(o)
|
||||
s.reset(tlsConn, tlsConn, o)
|
||||
s.reset(transport, o)
|
||||
|
||||
// attempt resumption
|
||||
if s.resume(o) {
|
||||
return tlsConn, s, s.err
|
||||
return s, s.err
|
||||
}
|
||||
|
||||
// otherwise, bind resource and 'start' XMPP session
|
||||
@@ -68,7 +64,7 @@ func NewSession(conn net.Conn, o Config, state SMState) (net.Conn, *Session, err
|
||||
// Enable stream management if supported
|
||||
s.EnableStreamManagement(o)
|
||||
|
||||
return tlsConn, s, s.err
|
||||
return s, s.err
|
||||
}
|
||||
|
||||
func (s *Session) PacketId() string {
|
||||
@@ -76,24 +72,22 @@ func (s *Session) PacketId() string {
|
||||
return fmt.Sprintf("%x", s.lastPacketId)
|
||||
}
|
||||
|
||||
func (s *Session) init(conn net.Conn, o Config) {
|
||||
s.setStreamLogger(nil, conn, o)
|
||||
func (s *Session) init(transport Transport, o Config) {
|
||||
s.setStreamLogger(transport, o)
|
||||
s.Features = s.open(o.parsedJid.Domain)
|
||||
}
|
||||
|
||||
func (s *Session) reset(conn net.Conn, newConn net.Conn, o Config) {
|
||||
func (s *Session) reset(transport Transport, o Config) {
|
||||
if s.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.setStreamLogger(conn, newConn, o)
|
||||
s.setStreamLogger(transport, o)
|
||||
s.Features = s.open(o.parsedJid.Domain)
|
||||
}
|
||||
|
||||
func (s *Session) setStreamLogger(conn net.Conn, newConn net.Conn, o Config) {
|
||||
if newConn != conn {
|
||||
s.streamLogger = newStreamLogger(newConn, o.StreamLogger)
|
||||
}
|
||||
func (s *Session) setStreamLogger(transport Transport, o Config) {
|
||||
s.streamLogger = newStreamLogger(transport, o.StreamLogger)
|
||||
s.decoder = xml.NewDecoder(s.streamLogger)
|
||||
s.decoder.CharsetReader = o.CharsetReader
|
||||
}
|
||||
@@ -117,9 +111,16 @@ func (s *Session) open(domain string) (f stanza.StreamFeatures) {
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Session) startTlsIfSupported(conn net.Conn, domain string, o Config) net.Conn {
|
||||
func (s *Session) startTlsIfSupported(transport Transport, domain string, o Config) {
|
||||
if s.err != nil {
|
||||
return conn
|
||||
return
|
||||
}
|
||||
|
||||
if !transport.DoesStartTLS() {
|
||||
if !o.Insecure {
|
||||
s.err = errors.New("Transport does not support starttls")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := s.Features.DoesStartTLS(); ok {
|
||||
@@ -128,39 +129,21 @@ func (s *Session) startTlsIfSupported(conn net.Conn, domain string, o Config) ne
|
||||
var k stanza.TLSProceed
|
||||
if s.err = s.decoder.DecodeElement(&k, nil); s.err != nil {
|
||||
s.err = errors.New("expecting starttls proceed: " + s.err.Error())
|
||||
return conn
|
||||
return
|
||||
}
|
||||
|
||||
if o.TLSConfig == nil {
|
||||
o.TLSConfig = &tls.Config{}
|
||||
}
|
||||
|
||||
if o.TLSConfig.ServerName == "" {
|
||||
o.TLSConfig.ServerName = domain
|
||||
}
|
||||
tlsConn := tls.Client(conn, o.TLSConfig)
|
||||
// We convert existing connection to TLS
|
||||
if s.err = tlsConn.Handshake(); s.err != nil {
|
||||
return tlsConn
|
||||
}
|
||||
|
||||
if !o.TLSConfig.InsecureSkipVerify {
|
||||
s.err = tlsConn.VerifyHostname(domain)
|
||||
}
|
||||
s.err = transport.StartTLS(domain, o)
|
||||
|
||||
if s.err == nil {
|
||||
s.TlsEnabled = true
|
||||
}
|
||||
return tlsConn
|
||||
return
|
||||
}
|
||||
|
||||
// If we do not allow cleartext connections, make it explicit that server do not support starttls
|
||||
if !o.Insecure {
|
||||
s.err = errors.New("XMPP server does not advertise support for starttls")
|
||||
}
|
||||
|
||||
// starttls is not supported => we do not upgrade the connection:
|
||||
return conn
|
||||
}
|
||||
|
||||
func (s *Session) auth(o Config) {
|
||||
|
||||
Reference in New Issue
Block a user