Clean up and fix StartTLS feature discovery

Required field was never set to true
This commit is contained in:
Mickael Remond 2019-06-10 16:27:52 +02:00
parent 44568fcf2b
commit 709a95129e
No known key found for this signature in database
GPG Key ID: E6F6045D79965AA3
7 changed files with 149 additions and 27 deletions

View File

@ -50,11 +50,6 @@ func authPlain(socket io.ReadWriter, decoder *xml.Decoder, user string, password
return err return err
} }
type saslMechanisms struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl mechanisms"`
Mechanism []string `xml:"mechanism"`
}
// ============================================================================ // ============================================================================
// SASLSuccess // SASLSuccess

View File

@ -76,8 +76,7 @@ func (c *ServerCheck) Check() error {
return errors.New("expected packet received while expecting features, got " + p.Name()) return errors.New("expected packet received while expecting features, got " + p.Name())
} }
startTLSFeature := f.StartTLS.XMLName.Space + " " + f.StartTLS.XMLName.Local if _, ok := f.DoesStartTLS(); ok {
if startTLSFeature == nsTLS+" starttls" {
fmt.Fprintf(tcpconn, "<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>") fmt.Fprintf(tcpconn, "<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>")
var k tlsProceed var k tlsProceed

View File

@ -60,6 +60,30 @@ func TestClient_NoInsecure(t *testing.T) {
mock.Stop() mock.Stop()
} }
// Check that the client is properly tracking features, as session negotiation progresses.
func TestClient_FeaturesTracking(t *testing.T) {
// Setup Mock server
mock := ServerMock{}
mock.Start(t, testXMPPAddress, handlerAbortTLS)
// Test / Check result
config := Config{Address: testXMPPAddress, Jid: "test@localhost", Password: "test"}
var client *Client
var err error
if client, err = NewClient(config); err != nil {
t.Errorf("cannot create XMPP client: %s", err)
}
if err = client.Connect(); err == nil {
// When insecure is not allowed:
t.Errorf("should fail as insecure connection is not allowed and server does not support TLS")
}
mock.Stop()
}
//============================================================================= //=============================================================================
// Basic XMPP Server Mock Handlers. // Basic XMPP Server Mock Handlers.

View File

@ -109,7 +109,7 @@ func (s *Session) startTlsIfSupported(conn net.Conn, domain string) net.Conn {
return conn return conn
} }
if s.Features.StartTLS.XMLName.Space+" "+s.Features.StartTLS.XMLName.Local == nsTLS+" starttls" { if _, ok := s.Features.DoesStartTLS(); ok {
fmt.Fprintf(s.socketProxy, "<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>") fmt.Fprintf(s.socketProxy, "<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>")
var k tlsProceed var k tlsProceed

View File

@ -7,12 +7,7 @@ import (
var DefaultTlsConfig tls.Config var DefaultTlsConfig tls.Config
// XMPP Packet Parsing // Used during stream initiation / session establishment
type tlsStartTLS struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls starttls"`
Required bool
}
type tlsProceed struct { type tlsProceed struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls proceed"` XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls proceed"`
} }

View File

@ -6,11 +6,14 @@ import (
// ============================================================================ // ============================================================================
// StreamFeatures Packet // StreamFeatures Packet
// Reference: https://xmpp.org/registrar/stream-features.html
type StreamFeatures struct { type StreamFeatures struct {
XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"` XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"`
// Server capabilities hash
Caps Caps
// Stream features
StartTLS tlsStartTLS StartTLS tlsStartTLS
Caps Caps
Mechanisms saslMechanisms Mechanisms saslMechanisms
Bind BindBind Bind BindBind
Session sessionSession Session sessionSession
@ -31,6 +34,76 @@ func (streamFeatureDecoder) decode(p *xml.Decoder, se xml.StartElement) (StreamF
return packet, err return packet, err
} }
// Capabilities
// Reference: https://xmpp.org/extensions/xep-0115.html#stream
// "A server MAY include its entity capabilities in a stream feature element so that connecting clients
// and peer servers do not need to send service discovery requests each time they connect."
// This is not a stream feature but a way to let client cache server disco info.
type Caps struct {
XMLName xml.Name `xml:"http://jabber.org/protocol/caps c"`
Hash string `xml:"hash,attr"`
Node string `xml:"node,attr"`
Ver string `xml:"ver,attr"`
Ext string `xml:"ext,attr,omitempty"`
}
// ============================================================================
// Supported Stream Features
// StartTLS feature
// Reference: RFC 6120 - https://tools.ietf.org/html/rfc6120#section-5.4
type tlsStartTLS struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls starttls"`
Required bool
}
// UnmarshalXML implements custom parsing startTLS required flag
func (stls *tlsStartTLS) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
stls.XMLName = start.Name
// Check subelements to extract required field as boolean
for {
t, err := d.Token()
if err != nil {
return err
}
switch tt := t.(type) {
case xml.StartElement:
elt := new(Node)
err = d.DecodeElement(elt, &tt)
if err != nil {
return err
}
if elt.XMLName.Local == "required" {
stls.Required = true
}
case xml.EndElement:
if tt == start.End() {
return nil
}
}
}
}
func (sf *StreamFeatures) DoesStartTLS() (feature tlsStartTLS, isSupported bool) {
if sf.StartTLS.XMLName.Space+" "+sf.StartTLS.XMLName.Local == nsTLS+" starttls" {
return sf.StartTLS, true
}
return feature, false
}
// Mechanisms
// Reference: RFC 6120 - https://tools.ietf.org/html/rfc6120#section-6.4.1
type saslMechanisms struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl mechanisms"`
Mechanism []string `xml:"mechanism"`
}
// ============================================================================ // ============================================================================
// StreamError Packet // StreamError Packet
@ -53,14 +126,3 @@ func (streamErrorDecoder) decode(p *xml.Decoder, se xml.StartElement) (StreamErr
err := p.DecodeElement(&packet, &se) err := p.DecodeElement(&packet, &se)
return packet, err return packet, err
} }
// ============================================================================
// Caps subElement
type Caps struct {
XMLName xml.Name `xml:"http://jabber.org/protocol/caps c"`
Hash string `xml:"hash,attr"`
Node string `xml:"node,attr"`
Ver string `xml:"ver,attr"`
Ext string `xml:"ext,attr,omitempty"`
}

47
stream_test.go Normal file
View File

@ -0,0 +1,47 @@
package xmpp_test
import (
"encoding/xml"
"testing"
"gosrc.io/xmpp"
)
func TestNoStartTLS(t *testing.T) {
streamFeatures := `<stream:features xmlns:stream='http://etherx.jabber.org/streams'>
</stream:features>`
var parsedSF xmpp.StreamFeatures
if err := xml.Unmarshal([]byte(streamFeatures), &parsedSF); err != nil {
t.Errorf("Unmarshal(%s) returned error: %v", streamFeatures, err)
}
startTLS, ok := parsedSF.DoesStartTLS()
if ok {
t.Error("StartTLS feature should not be enabled")
}
if startTLS.Required {
t.Error("StartTLS cannot be required as default")
}
}
func TestStartTLS(t *testing.T) {
streamFeatures := `<stream:features xmlns:stream='http://etherx.jabber.org/streams'>
<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'>
<required/>
</starttls>
</stream:features>`
var parsedSF xmpp.StreamFeatures
if err := xml.Unmarshal([]byte(streamFeatures), &parsedSF); err != nil {
t.Errorf("Unmarshal(%s) returned error: %v", streamFeatures, err)
}
startTLS, ok := parsedSF.DoesStartTLS()
if !ok {
t.Error("StartTLS feature should be enabled")
}
if !startTLS.Required {
t.Error("StartTLS feature should be required")
}
}