mirror of
https://github.com/FluuxIO/go-xmpp.git
synced 2024-11-21 10:02:00 -08:00
Clean up and fix StartTLS feature discovery
Required field was never set to true
This commit is contained in:
parent
44568fcf2b
commit
709a95129e
5
auth.go
5
auth.go
@ -50,11 +50,6 @@ func authPlain(socket io.ReadWriter, decoder *xml.Decoder, user string, password
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
type saslMechanisms struct {
|
|
||||||
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl mechanisms"`
|
|
||||||
Mechanism []string `xml:"mechanism"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
// SASLSuccess
|
// SASLSuccess
|
||||||
|
|
||||||
|
@ -76,8 +76,7 @@ func (c *ServerCheck) Check() error {
|
|||||||
return errors.New("expected packet received while expecting features, got " + p.Name())
|
return errors.New("expected packet received while expecting features, got " + p.Name())
|
||||||
}
|
}
|
||||||
|
|
||||||
startTLSFeature := f.StartTLS.XMLName.Space + " " + f.StartTLS.XMLName.Local
|
if _, ok := f.DoesStartTLS(); ok {
|
||||||
if startTLSFeature == nsTLS+" starttls" {
|
|
||||||
fmt.Fprintf(tcpconn, "<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>")
|
fmt.Fprintf(tcpconn, "<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>")
|
||||||
|
|
||||||
var k tlsProceed
|
var k tlsProceed
|
||||||
|
@ -60,6 +60,30 @@ func TestClient_NoInsecure(t *testing.T) {
|
|||||||
mock.Stop()
|
mock.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check that the client is properly tracking features, as session negotiation progresses.
|
||||||
|
func TestClient_FeaturesTracking(t *testing.T) {
|
||||||
|
// Setup Mock server
|
||||||
|
mock := ServerMock{}
|
||||||
|
mock.Start(t, testXMPPAddress, handlerAbortTLS)
|
||||||
|
|
||||||
|
// Test / Check result
|
||||||
|
config := Config{Address: testXMPPAddress, Jid: "test@localhost", Password: "test"}
|
||||||
|
|
||||||
|
var client *Client
|
||||||
|
var err error
|
||||||
|
if client, err = NewClient(config); err != nil {
|
||||||
|
t.Errorf("cannot create XMPP client: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = client.Connect(); err == nil {
|
||||||
|
// When insecure is not allowed:
|
||||||
|
t.Errorf("should fail as insecure connection is not allowed and server does not support TLS")
|
||||||
|
}
|
||||||
|
|
||||||
|
mock.Stop()
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
//=============================================================================
|
//=============================================================================
|
||||||
// Basic XMPP Server Mock Handlers.
|
// Basic XMPP Server Mock Handlers.
|
||||||
|
|
||||||
|
@ -109,7 +109,7 @@ func (s *Session) startTlsIfSupported(conn net.Conn, domain string) net.Conn {
|
|||||||
return conn
|
return conn
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.Features.StartTLS.XMLName.Space+" "+s.Features.StartTLS.XMLName.Local == nsTLS+" starttls" {
|
if _, ok := s.Features.DoesStartTLS(); ok {
|
||||||
fmt.Fprintf(s.socketProxy, "<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>")
|
fmt.Fprintf(s.socketProxy, "<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>")
|
||||||
|
|
||||||
var k tlsProceed
|
var k tlsProceed
|
||||||
|
@ -7,12 +7,7 @@ import (
|
|||||||
|
|
||||||
var DefaultTlsConfig tls.Config
|
var DefaultTlsConfig tls.Config
|
||||||
|
|
||||||
// XMPP Packet Parsing
|
// Used during stream initiation / session establishment
|
||||||
type tlsStartTLS struct {
|
|
||||||
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls starttls"`
|
|
||||||
Required bool
|
|
||||||
}
|
|
||||||
|
|
||||||
type tlsProceed struct {
|
type tlsProceed struct {
|
||||||
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls proceed"`
|
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls proceed"`
|
||||||
}
|
}
|
||||||
|
86
stream.go
86
stream.go
@ -6,11 +6,14 @@ import (
|
|||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
// StreamFeatures Packet
|
// StreamFeatures Packet
|
||||||
|
// Reference: https://xmpp.org/registrar/stream-features.html
|
||||||
|
|
||||||
type StreamFeatures struct {
|
type StreamFeatures struct {
|
||||||
XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"`
|
XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"`
|
||||||
StartTLS tlsStartTLS
|
// Server capabilities hash
|
||||||
Caps Caps
|
Caps Caps
|
||||||
|
// Stream features
|
||||||
|
StartTLS tlsStartTLS
|
||||||
Mechanisms saslMechanisms
|
Mechanisms saslMechanisms
|
||||||
Bind BindBind
|
Bind BindBind
|
||||||
Session sessionSession
|
Session sessionSession
|
||||||
@ -31,6 +34,76 @@ func (streamFeatureDecoder) decode(p *xml.Decoder, se xml.StartElement) (StreamF
|
|||||||
return packet, err
|
return packet, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Capabilities
|
||||||
|
// Reference: https://xmpp.org/extensions/xep-0115.html#stream
|
||||||
|
// "A server MAY include its entity capabilities in a stream feature element so that connecting clients
|
||||||
|
// and peer servers do not need to send service discovery requests each time they connect."
|
||||||
|
// This is not a stream feature but a way to let client cache server disco info.
|
||||||
|
type Caps struct {
|
||||||
|
XMLName xml.Name `xml:"http://jabber.org/protocol/caps c"`
|
||||||
|
Hash string `xml:"hash,attr"`
|
||||||
|
Node string `xml:"node,attr"`
|
||||||
|
Ver string `xml:"ver,attr"`
|
||||||
|
Ext string `xml:"ext,attr,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Supported Stream Features
|
||||||
|
|
||||||
|
// StartTLS feature
|
||||||
|
// Reference: RFC 6120 - https://tools.ietf.org/html/rfc6120#section-5.4
|
||||||
|
type tlsStartTLS struct {
|
||||||
|
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls starttls"`
|
||||||
|
Required bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalXML implements custom parsing startTLS required flag
|
||||||
|
func (stls *tlsStartTLS) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
|
||||||
|
stls.XMLName = start.Name
|
||||||
|
|
||||||
|
// Check subelements to extract required field as boolean
|
||||||
|
for {
|
||||||
|
t, err := d.Token()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch tt := t.(type) {
|
||||||
|
|
||||||
|
case xml.StartElement:
|
||||||
|
elt := new(Node)
|
||||||
|
|
||||||
|
err = d.DecodeElement(elt, &tt)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if elt.XMLName.Local == "required" {
|
||||||
|
stls.Required = true
|
||||||
|
}
|
||||||
|
|
||||||
|
case xml.EndElement:
|
||||||
|
if tt == start.End() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sf *StreamFeatures) DoesStartTLS() (feature tlsStartTLS, isSupported bool) {
|
||||||
|
if sf.StartTLS.XMLName.Space+" "+sf.StartTLS.XMLName.Local == nsTLS+" starttls" {
|
||||||
|
return sf.StartTLS, true
|
||||||
|
}
|
||||||
|
return feature, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mechanisms
|
||||||
|
// Reference: RFC 6120 - https://tools.ietf.org/html/rfc6120#section-6.4.1
|
||||||
|
type saslMechanisms struct {
|
||||||
|
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl mechanisms"`
|
||||||
|
Mechanism []string `xml:"mechanism"`
|
||||||
|
}
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
// StreamError Packet
|
// StreamError Packet
|
||||||
|
|
||||||
@ -53,14 +126,3 @@ func (streamErrorDecoder) decode(p *xml.Decoder, se xml.StartElement) (StreamErr
|
|||||||
err := p.DecodeElement(&packet, &se)
|
err := p.DecodeElement(&packet, &se)
|
||||||
return packet, err
|
return packet, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============================================================================
|
|
||||||
// Caps subElement
|
|
||||||
|
|
||||||
type Caps struct {
|
|
||||||
XMLName xml.Name `xml:"http://jabber.org/protocol/caps c"`
|
|
||||||
Hash string `xml:"hash,attr"`
|
|
||||||
Node string `xml:"node,attr"`
|
|
||||||
Ver string `xml:"ver,attr"`
|
|
||||||
Ext string `xml:"ext,attr,omitempty"`
|
|
||||||
}
|
|
||||||
|
47
stream_test.go
Normal file
47
stream_test.go
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
package xmpp_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/xml"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gosrc.io/xmpp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNoStartTLS(t *testing.T) {
|
||||||
|
streamFeatures := `<stream:features xmlns:stream='http://etherx.jabber.org/streams'>
|
||||||
|
</stream:features>`
|
||||||
|
|
||||||
|
var parsedSF xmpp.StreamFeatures
|
||||||
|
if err := xml.Unmarshal([]byte(streamFeatures), &parsedSF); err != nil {
|
||||||
|
t.Errorf("Unmarshal(%s) returned error: %v", streamFeatures, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
startTLS, ok := parsedSF.DoesStartTLS()
|
||||||
|
if ok {
|
||||||
|
t.Error("StartTLS feature should not be enabled")
|
||||||
|
}
|
||||||
|
if startTLS.Required {
|
||||||
|
t.Error("StartTLS cannot be required as default")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStartTLS(t *testing.T) {
|
||||||
|
streamFeatures := `<stream:features xmlns:stream='http://etherx.jabber.org/streams'>
|
||||||
|
<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'>
|
||||||
|
<required/>
|
||||||
|
</starttls>
|
||||||
|
</stream:features>`
|
||||||
|
|
||||||
|
var parsedSF xmpp.StreamFeatures
|
||||||
|
if err := xml.Unmarshal([]byte(streamFeatures), &parsedSF); err != nil {
|
||||||
|
t.Errorf("Unmarshal(%s) returned error: %v", streamFeatures, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
startTLS, ok := parsedSF.DoesStartTLS()
|
||||||
|
if !ok {
|
||||||
|
t.Error("StartTLS feature should be enabled")
|
||||||
|
}
|
||||||
|
if !startTLS.Required {
|
||||||
|
t.Error("StartTLS feature should be required")
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user