go-xmpp/xmpp_transport.go

130 lines
2.9 KiB
Go
Raw Normal View History

2019-10-10 21:32:26 -07:00
package xmpp
import (
"crypto/tls"
"encoding/xml"
"errors"
"fmt"
"io"
2019-10-10 21:32:26 -07:00
"net"
"time"
"gosrc.io/xmpp/stanza"
2019-10-10 21:32:26 -07:00
)
// XMPPTransport implements the XMPP native TCP transport
type XMPPTransport struct {
2019-10-25 06:22:01 -07:00
openStatement string
Config TransportConfiguration
TLSConfig *tls.Config
decoder *xml.Decoder
conn net.Conn
readWriter io.ReadWriter
logFile io.Writer
isSecure bool
2019-10-10 21:32:26 -07:00
}
2019-10-25 06:22:01 -07:00
var componentStreamOpen = fmt.Sprintf("<?xml version='1.0'?><stream:stream to='%%s' xmlns='%s' xmlns:stream='%s'>", stanza.NSComponent, stanza.NSStream)
var clientStreamOpen = fmt.Sprintf("<?xml version='1.0'?><stream:stream to='%%s' xmlns='%s' xmlns:stream='%s' version='1.0'>", stanza.NSClient, stanza.NSStream)
func (t *XMPPTransport) Connect() (string, error) {
2019-10-10 21:32:26 -07:00
var err error
t.conn, err = net.DialTimeout("tcp", t.Config.Address, time.Duration(t.Config.ConnectTimeout)*time.Second)
if err != nil {
return "", NewConnError(err, true)
}
t.readWriter = newStreamLogger(t.conn, t.logFile)
2019-10-25 06:22:01 -07:00
return t.startStream()
}
2019-10-25 06:22:01 -07:00
func (t *XMPPTransport) startStream() (string, error) {
if _, err := fmt.Fprintf(t.readWriter, t.openStatement, t.Config.Domain); err != nil {
t.conn.Close()
return "", NewConnError(err, true)
}
t.decoder = xml.NewDecoder(t.readWriter)
t.decoder.CharsetReader = t.Config.CharsetReader
sessionID, err := stanza.InitStream(t.decoder)
if err != nil {
t.conn.Close()
return "", NewConnError(err, false)
}
return sessionID, nil
2019-10-10 21:32:26 -07:00
}
func (t XMPPTransport) DoesStartTLS() bool {
return true
}
func (t XMPPTransport) GetDecoder() *xml.Decoder {
return t.decoder
}
2019-10-10 22:15:47 -07:00
func (t XMPPTransport) IsSecure() bool {
return t.isSecure
}
2019-10-25 06:22:01 -07:00
func (t *XMPPTransport) StartTLS() (string, error) {
2019-10-10 21:32:26 -07:00
if t.Config.TLSConfig == nil {
2019-10-12 11:47:16 -07:00
t.TLSConfig = &tls.Config{}
} else {
t.TLSConfig = t.Config.TLSConfig.Clone()
2019-10-10 21:32:26 -07:00
}
2019-10-12 11:47:16 -07:00
if t.TLSConfig.ServerName == "" {
t.TLSConfig.ServerName = t.Config.Domain
2019-10-10 21:32:26 -07:00
}
tlsConn := tls.Client(t.conn, t.TLSConfig)
// We convert existing connection to TLS
if err := tlsConn.Handshake(); err != nil {
2019-10-25 06:22:01 -07:00
return "", err
2019-10-10 21:32:26 -07:00
}
t.conn = tlsConn
t.readWriter = newStreamLogger(tlsConn, t.logFile)
t.decoder = xml.NewDecoder(t.readWriter)
t.decoder.CharsetReader = t.Config.CharsetReader
2019-10-10 21:32:26 -07:00
if !t.TLSConfig.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(t.Config.Domain); err != nil {
2019-10-25 06:22:01 -07:00
return "", err
2019-10-10 21:32:26 -07:00
}
}
2019-10-10 22:15:47 -07:00
t.isSecure = true
2019-10-25 06:22:01 -07:00
return t.startStream()
2019-10-10 21:32:26 -07:00
}
func (t XMPPTransport) Ping() error {
n, err := t.conn.Write([]byte("\n"))
if err != nil {
return err
}
if n != 1 {
return errors.New("Could not write ping")
}
return nil
}
2019-10-10 21:32:26 -07:00
func (t XMPPTransport) Read(p []byte) (n int, err error) {
return t.readWriter.Read(p)
2019-10-10 21:32:26 -07:00
}
func (t XMPPTransport) Write(p []byte) (n int, err error) {
return t.readWriter.Write(p)
2019-10-10 21:32:26 -07:00
}
func (t XMPPTransport) Close() error {
_, _ = t.readWriter.Write([]byte("</stream:stream>"))
2019-10-10 21:32:26 -07:00
return t.conn.Close()
}
func (t *XMPPTransport) LogTraffic(logFile io.Writer) {
t.logFile = logFile
}