Introduce Transport interface

This commit is contained in:
Wichert Akkerman
2019-10-06 19:37:56 +02:00
committed by Mickaël Rémond
parent 2781563ea7
commit a3c62e515e
4 changed files with 115 additions and 57 deletions
+27 -44
View File
@@ -1,12 +1,10 @@
package xmpp
import (
"crypto/tls"
"encoding/xml"
"errors"
"fmt"
"io"
"net"
"gosrc.io/xmpp/stanza"
)
@@ -30,35 +28,33 @@ type Session struct {
err error
}
func NewSession(conn net.Conn, o Config, state SMState) (net.Conn, *Session, error) {
func NewSession(transport Transport, o Config, state SMState) (*Session, error) {
s := new(Session)
s.SMState = state
s.init(conn, o)
s.init(transport, o)
// starttls
var tlsConn net.Conn
tlsConn = s.startTlsIfSupported(conn, o.parsedJid.Domain, o)
s.startTlsIfSupported(transport, o.parsedJid.Domain, o)
if s.err != nil {
return nil, nil, NewConnError(s.err, true)
return nil, NewConnError(s.err, true)
}
if !s.TlsEnabled && !o.Insecure {
err := fmt.Errorf("failed to negotiate TLS session : %s", s.err)
return nil, nil, NewConnError(err, true)
return nil, NewConnError(err, true)
}
if s.TlsEnabled {
s.reset(conn, tlsConn, o)
s.reset(transport, o)
}
// auth
s.auth(o)
s.reset(tlsConn, tlsConn, o)
s.reset(transport, o)
// attempt resumption
if s.resume(o) {
return tlsConn, s, s.err
return s, s.err
}
// otherwise, bind resource and 'start' XMPP session
@@ -68,7 +64,7 @@ func NewSession(conn net.Conn, o Config, state SMState) (net.Conn, *Session, err
// Enable stream management if supported
s.EnableStreamManagement(o)
return tlsConn, s, s.err
return s, s.err
}
func (s *Session) PacketId() string {
@@ -76,24 +72,22 @@ func (s *Session) PacketId() string {
return fmt.Sprintf("%x", s.lastPacketId)
}
func (s *Session) init(conn net.Conn, o Config) {
s.setStreamLogger(nil, conn, o)
func (s *Session) init(transport Transport, o Config) {
s.setStreamLogger(transport, o)
s.Features = s.open(o.parsedJid.Domain)
}
func (s *Session) reset(conn net.Conn, newConn net.Conn, o Config) {
func (s *Session) reset(transport Transport, o Config) {
if s.err != nil {
return
}
s.setStreamLogger(conn, newConn, o)
s.setStreamLogger(transport, o)
s.Features = s.open(o.parsedJid.Domain)
}
func (s *Session) setStreamLogger(conn net.Conn, newConn net.Conn, o Config) {
if newConn != conn {
s.streamLogger = newStreamLogger(newConn, o.StreamLogger)
}
func (s *Session) setStreamLogger(transport Transport, o Config) {
s.streamLogger = newStreamLogger(transport, o.StreamLogger)
s.decoder = xml.NewDecoder(s.streamLogger)
s.decoder.CharsetReader = o.CharsetReader
}
@@ -117,9 +111,16 @@ func (s *Session) open(domain string) (f stanza.StreamFeatures) {
return
}
func (s *Session) startTlsIfSupported(conn net.Conn, domain string, o Config) net.Conn {
func (s *Session) startTlsIfSupported(transport Transport, domain string, o Config) {
if s.err != nil {
return conn
return
}
if !transport.DoesStartTLS() {
if !o.Insecure {
s.err = errors.New("Transport does not support starttls")
}
return
}
if _, ok := s.Features.DoesStartTLS(); ok {
@@ -128,39 +129,21 @@ func (s *Session) startTlsIfSupported(conn net.Conn, domain string, o Config) ne
var k stanza.TLSProceed
if s.err = s.decoder.DecodeElement(&k, nil); s.err != nil {
s.err = errors.New("expecting starttls proceed: " + s.err.Error())
return conn
return
}
if o.TLSConfig == nil {
o.TLSConfig = &tls.Config{}
}
if o.TLSConfig.ServerName == "" {
o.TLSConfig.ServerName = domain
}
tlsConn := tls.Client(conn, o.TLSConfig)
// We convert existing connection to TLS
if s.err = tlsConn.Handshake(); s.err != nil {
return tlsConn
}
if !o.TLSConfig.InsecureSkipVerify {
s.err = tlsConn.VerifyHostname(domain)
}
s.err = transport.StartTLS(domain, o)
if s.err == nil {
s.TlsEnabled = true
}
return tlsConn
return
}
// If we do not allow cleartext connections, make it explicit that server do not support starttls
if !o.Insecure {
s.err = errors.New("XMPP server does not advertise support for starttls")
}
// starttls is not supported => we do not upgrade the connection:
return conn
}
func (s *Session) auth(o Config) {