Introduce Transport interface

This commit is contained in:
Wichert Akkerman 2019-10-06 19:37:56 +02:00 committed by Mickaël Rémond
parent 2781563ea7
commit a3c62e515e
4 changed files with 115 additions and 57 deletions

View File

@ -90,7 +90,7 @@ type Client struct {
// Session gather data that can be accessed by users of this library // Session gather data that can be accessed by users of this library
Session *Session Session *Session
// TCP level connection / can be replaced by a TLS session after starttls // TCP level connection / can be replaced by a TLS session after starttls
conn net.Conn transport Transport
// Router is used to dispatch packets // Router is used to dispatch packets
router *Router router *Router
// Track and broadcast connection state // Track and broadcast connection state
@ -139,6 +139,7 @@ func NewClient(config Config, r *Router) (c *Client, err error) {
c = new(Client) c = new(Client)
c.config = config c.config = config
c.router = r c.router = r
c.transport = &XMPPTransport{}
if c.config.ConnectTimeout == 0 { if c.config.ConnectTimeout == 0 {
c.config.ConnectTimeout = 15 // 15 second as default c.config.ConnectTimeout = 15 // 15 second as default
@ -159,21 +160,21 @@ func (c *Client) Connect() error {
func (c *Client) Resume(state SMState) error { func (c *Client) Resume(state SMState) error {
var err error var err error
c.conn, err = net.DialTimeout("tcp", c.config.Address, time.Duration(c.config.ConnectTimeout)*time.Second) err = c.transport.Connect(c.config.Address, c.config)
if err != nil { if err != nil {
return err return err
} }
c.updateState(StateConnected) c.updateState(StateConnected)
// Client is ok, we now open XMPP session // Client is ok, we now open XMPP session
if c.conn, c.Session, err = NewSession(c.conn, c.config, state); err != nil { if c.Session, err = NewSession(c.transport, c.config, state); err != nil {
return err return err
} }
c.updateState(StateSessionEstablished) c.updateState(StateSessionEstablished)
// Start the keepalive go routine // Start the keepalive go routine
keepaliveQuit := make(chan struct{}) keepaliveQuit := make(chan struct{})
go keepalive(c.conn, keepaliveQuit) go keepalive(c.transport, keepaliveQuit)
// Start the receiver go routine // Start the receiver go routine
state = c.Session.SMState state = c.Session.SMState
go c.recv(state, keepaliveQuit) go c.recv(state, keepaliveQuit)
@ -190,7 +191,7 @@ func (c *Client) Resume(state SMState) error {
func (c *Client) Disconnect() { func (c *Client) Disconnect() {
_ = c.SendRaw("</stream:stream>") _ = c.SendRaw("</stream:stream>")
// TODO: Add a way to wait for stream close acknowledgement from the server for clean disconnect // TODO: Add a way to wait for stream close acknowledgement from the server for clean disconnect
conn := c.conn conn := c.transport
if conn != nil { if conn != nil {
_ = conn.Close() _ = conn.Close()
} }
@ -202,7 +203,7 @@ func (c *Client) SetHandler(handler EventHandler) {
// Send marshals XMPP stanza and sends it to the server. // Send marshals XMPP stanza and sends it to the server.
func (c *Client) Send(packet stanza.Packet) error { func (c *Client) Send(packet stanza.Packet) error {
conn := c.conn conn := c.transport
if conn == nil { if conn == nil {
return errors.New("client is not connected") return errors.New("client is not connected")
} }
@ -220,7 +221,7 @@ func (c *Client) Send(packet stanza.Packet) error {
// disconnect the client. It is up to the user of this method to // disconnect the client. It is up to the user of this method to
// carefully craft the XML content to produce valid XMPP. // carefully craft the XML content to produce valid XMPP.
func (c *Client) SendRaw(packet string) error { func (c *Client) SendRaw(packet string) error {
conn := c.conn conn := c.transport
if conn == nil { if conn == nil {
return errors.New("client is not connected") return errors.New("client is not connected")
} }
@ -272,16 +273,16 @@ func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}) (err error)
// Loop: send whitespace keepalive to server // Loop: send whitespace keepalive to server
// This is use to keep the connection open, but also to detect connection loss // This is use to keep the connection open, but also to detect connection loss
// and trigger proper client connection shutdown. // and trigger proper client connection shutdown.
func keepalive(conn net.Conn, quit <-chan struct{}) { func keepalive(transport Transport, quit <-chan struct{}) {
// TODO: Make keepalive interval configurable // TODO: Make keepalive interval configurable
ticker := time.NewTicker(30 * time.Second) ticker := time.NewTicker(30 * time.Second)
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
if n, err := fmt.Fprintf(conn, "\n"); err != nil || n != 1 { if n, err := fmt.Fprintf(transport, "\n"); err != nil || n != 1 {
// When keep alive fails, we force close the connection. In all cases, the recv will also fail. // When keep alive fails, we force close the transportection. In all cases, the recv will also fail.
ticker.Stop() ticker.Stop()
_ = conn.Close() _ = transport.Close()
return return
} }
case <-quit: case <-quit:

View File

@ -14,8 +14,8 @@ type Config struct {
StreamLogger *os.File // Used for debugging StreamLogger *os.File // Used for debugging
Lang string // TODO: should default to 'en' Lang string // TODO: should default to 'en'
ConnectTimeout int // Client timeout in seconds. Default to 15 ConnectTimeout int // Client timeout in seconds. Default to 15
// tls.Config must not be modified after having been passed to NewClient. The // tls.Config must not be modified after having been passed to NewClient. Any
// Client connect method may override the tls.Config.ServerName if it was not set. // changes made after connecting are ignored.
TLSConfig *tls.Config TLSConfig *tls.Config
// Insecure can be set to true to allow to open a session without TLS. If TLS // Insecure can be set to true to allow to open a session without TLS. If TLS
// is supported on the server, we will still try to use it. // is supported on the server, we will still try to use it.

View File

@ -1,12 +1,10 @@
package xmpp package xmpp
import ( import (
"crypto/tls"
"encoding/xml" "encoding/xml"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net"
"gosrc.io/xmpp/stanza" "gosrc.io/xmpp/stanza"
) )
@ -30,35 +28,33 @@ type Session struct {
err error err error
} }
func NewSession(conn net.Conn, o Config, state SMState) (net.Conn, *Session, error) { func NewSession(transport Transport, o Config, state SMState) (*Session, error) {
s := new(Session) s := new(Session)
s.SMState = state s.SMState = state
s.init(conn, o) s.init(transport, o)
// starttls s.startTlsIfSupported(transport, o.parsedJid.Domain, o)
var tlsConn net.Conn
tlsConn = s.startTlsIfSupported(conn, o.parsedJid.Domain, o)
if s.err != nil { if s.err != nil {
return nil, nil, NewConnError(s.err, true) return nil, NewConnError(s.err, true)
} }
if !s.TlsEnabled && !o.Insecure { if !s.TlsEnabled && !o.Insecure {
err := fmt.Errorf("failed to negotiate TLS session : %s", s.err) err := fmt.Errorf("failed to negotiate TLS session : %s", s.err)
return nil, nil, NewConnError(err, true) return nil, NewConnError(err, true)
} }
if s.TlsEnabled { if s.TlsEnabled {
s.reset(conn, tlsConn, o) s.reset(transport, o)
} }
// auth // auth
s.auth(o) s.auth(o)
s.reset(tlsConn, tlsConn, o) s.reset(transport, o)
// attempt resumption // attempt resumption
if s.resume(o) { if s.resume(o) {
return tlsConn, s, s.err return s, s.err
} }
// otherwise, bind resource and 'start' XMPP session // otherwise, bind resource and 'start' XMPP session
@ -68,7 +64,7 @@ func NewSession(conn net.Conn, o Config, state SMState) (net.Conn, *Session, err
// Enable stream management if supported // Enable stream management if supported
s.EnableStreamManagement(o) s.EnableStreamManagement(o)
return tlsConn, s, s.err return s, s.err
} }
func (s *Session) PacketId() string { func (s *Session) PacketId() string {
@ -76,24 +72,22 @@ func (s *Session) PacketId() string {
return fmt.Sprintf("%x", s.lastPacketId) return fmt.Sprintf("%x", s.lastPacketId)
} }
func (s *Session) init(conn net.Conn, o Config) { func (s *Session) init(transport Transport, o Config) {
s.setStreamLogger(nil, conn, o) s.setStreamLogger(transport, o)
s.Features = s.open(o.parsedJid.Domain) s.Features = s.open(o.parsedJid.Domain)
} }
func (s *Session) reset(conn net.Conn, newConn net.Conn, o Config) { func (s *Session) reset(transport Transport, o Config) {
if s.err != nil { if s.err != nil {
return return
} }
s.setStreamLogger(conn, newConn, o) s.setStreamLogger(transport, o)
s.Features = s.open(o.parsedJid.Domain) s.Features = s.open(o.parsedJid.Domain)
} }
func (s *Session) setStreamLogger(conn net.Conn, newConn net.Conn, o Config) { func (s *Session) setStreamLogger(transport Transport, o Config) {
if newConn != conn { s.streamLogger = newStreamLogger(transport, o.StreamLogger)
s.streamLogger = newStreamLogger(newConn, o.StreamLogger)
}
s.decoder = xml.NewDecoder(s.streamLogger) s.decoder = xml.NewDecoder(s.streamLogger)
s.decoder.CharsetReader = o.CharsetReader s.decoder.CharsetReader = o.CharsetReader
} }
@ -117,9 +111,16 @@ func (s *Session) open(domain string) (f stanza.StreamFeatures) {
return return
} }
func (s *Session) startTlsIfSupported(conn net.Conn, domain string, o Config) net.Conn { func (s *Session) startTlsIfSupported(transport Transport, domain string, o Config) {
if s.err != nil { if s.err != nil {
return conn return
}
if !transport.DoesStartTLS() {
if !o.Insecure {
s.err = errors.New("Transport does not support starttls")
}
return
} }
if _, ok := s.Features.DoesStartTLS(); ok { if _, ok := s.Features.DoesStartTLS(); ok {
@ -128,39 +129,21 @@ func (s *Session) startTlsIfSupported(conn net.Conn, domain string, o Config) ne
var k stanza.TLSProceed var k stanza.TLSProceed
if s.err = s.decoder.DecodeElement(&k, nil); s.err != nil { if s.err = s.decoder.DecodeElement(&k, nil); s.err != nil {
s.err = errors.New("expecting starttls proceed: " + s.err.Error()) s.err = errors.New("expecting starttls proceed: " + s.err.Error())
return conn return
} }
if o.TLSConfig == nil { s.err = transport.StartTLS(domain, o)
o.TLSConfig = &tls.Config{}
}
if o.TLSConfig.ServerName == "" {
o.TLSConfig.ServerName = domain
}
tlsConn := tls.Client(conn, o.TLSConfig)
// We convert existing connection to TLS
if s.err = tlsConn.Handshake(); s.err != nil {
return tlsConn
}
if !o.TLSConfig.InsecureSkipVerify {
s.err = tlsConn.VerifyHostname(domain)
}
if s.err == nil { if s.err == nil {
s.TlsEnabled = true s.TlsEnabled = true
} }
return tlsConn return
} }
// If we do not allow cleartext connections, make it explicit that server do not support starttls // If we do not allow cleartext connections, make it explicit that server do not support starttls
if !o.Insecure { if !o.Insecure {
s.err = errors.New("XMPP server does not advertise support for starttls") s.err = errors.New("XMPP server does not advertise support for starttls")
} }
// starttls is not supported => we do not upgrade the connection:
return conn
} }
func (s *Session) auth(o Config) { func (s *Session) auth(o Config) {

74
transport.go Normal file
View File

@ -0,0 +1,74 @@
package xmpp
import (
"crypto/tls"
"net"
"time"
)
type Transport interface {
Connect(address string, c Config) error
DoesStartTLS() bool
StartTLS(domain string, c Config) error
Read(p []byte) (n int, err error)
Write(p []byte) (n int, err error)
Close() error
}
// XMPPTransport implements the XMPP native TCP transport
type XMPPTransport struct {
TLSConfig *tls.Config
// TCP level connection / can be replaced by a TLS session after starttls
conn net.Conn
}
func (t *XMPPTransport) Connect(address string, c Config) error {
var err error
t.conn, err = net.DialTimeout("tcp", address, time.Duration(c.ConnectTimeout)*time.Second)
return err
}
func (t XMPPTransport) DoesStartTLS() bool {
return true
}
func (t *XMPPTransport) StartTLS(domain string, c Config) error {
if t.TLSConfig == nil {
if c.TLSConfig != nil {
t.TLSConfig = c.TLSConfig
} else {
t.TLSConfig = &tls.Config{}
}
}
if t.TLSConfig.ServerName == "" {
t.TLSConfig.ServerName = domain
}
tlsConn := tls.Client(t.conn, t.TLSConfig)
// We convert existing connection to TLS
if err := tlsConn.Handshake(); err != nil {
return err
}
if !t.TLSConfig.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(domain); err != nil {
return err
}
}
return nil
}
func (t XMPPTransport) Read(p []byte) (n int, err error) {
return t.conn.Read(p)
}
func (t XMPPTransport) Write(p []byte) (n int, err error) {
return t.conn.Write(p)
}
func (t XMPPTransport) Close() error {
return t.conn.Close()
}