Add support for SASL2 and BIND2 (#187)

* Add basic support for SASL2 (XEP-0388) and Bind2 (XEP-0386).
This commit is contained in:
Martin 2024-04-09 10:53:38 +02:00 committed by GitHub
parent da2377ecb0
commit 7486b7a363
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 205 additions and 53 deletions

182
xmpp.go
View File

@ -49,7 +49,9 @@ const (
nsStream = "http://etherx.jabber.org/streams" nsStream = "http://etherx.jabber.org/streams"
nsTLS = "urn:ietf:params:xml:ns:xmpp-tls" nsTLS = "urn:ietf:params:xml:ns:xmpp-tls"
nsSASL = "urn:ietf:params:xml:ns:xmpp-sasl" nsSASL = "urn:ietf:params:xml:ns:xmpp-sasl"
nsSASL2 = "urn:xmpp:sasl:2"
nsBind = "urn:ietf:params:xml:ns:xmpp-bind" nsBind = "urn:ietf:params:xml:ns:xmpp-bind"
nsBind2 = "urn:xmpp:bind:0"
nsSASLCB = "urn:xmpp:sasl-cb:0" nsSASLCB = "urn:xmpp:sasl-cb:0"
nsClient = "jabber:client" nsClient = "jabber:client"
nsSession = "urn:ietf:params:xml:ns:xmpp-session" nsSession = "urn:ietf:params:xml:ns:xmpp-session"
@ -237,6 +239,19 @@ type Options struct {
// XEP-0474: SASL SCRAM Downgrade Protection // XEP-0474: SASL SCRAM Downgrade Protection
SSDP bool SSDP bool
// XEP-0388: XEP-0388: Extensible SASL Profile
// Value for software
UserAgentSW string
// XEP-0388: XEP-0388: Extensible SASL Profile
// Value for device
UserAgentDev string
// XEP-0388: XEP-0388: Extensible SASL Profile
// Unique stable identifier for the client installation
// MUST be a valid UUIDv4
UserAgentID string
} }
// NewClient establishes a new Client connection based on a set of Options. // NewClient establishes a new Client connection based on a set of Options.
@ -416,13 +431,25 @@ func (c *Client) init(o *Options) error {
return err return err
} }
var mechanism, channelBinding, clientFirstMessage, clientFinalMessageBare, authMessage string var mechanism, channelBinding, clientFirstMessage, clientFinalMessageBare, authMessage string
var bind2Data, resource, userAgentSW, userAgentDev, userAgentID string
var serverSignature, keyingMaterial []byte var serverSignature, keyingMaterial []byte
var scramPlus, ok, tlsConnOK, tls13, serverEndPoint bool var scramPlus, ok, tlsConnOK, tls13, serverEndPoint, sasl2, bind2 bool
var cbsSlice []string var cbsSlice, mechSlice []string
var tlsConn *tls.Conn var tlsConn *tls.Conn
// Use SASL2 if available
if f.Authentication.Mechanism != nil && c.IsEncrypted() {
sasl2 = true
mechSlice = f.Authentication.Mechanism
// Detect whether bind2 is available
if f.Authentication.Inline.Bind.Xmlns != "" {
bind2 = true
}
} else {
mechSlice = f.Mechanisms.Mechanism
}
if o.User == "" && o.Password == "" { if o.User == "" && o.Password == "" {
foundAnonymous := false foundAnonymous := false
for _, m := range f.Mechanisms.Mechanism { for _, m := range mechSlice {
if m == "ANONYMOUS" { if m == "ANONYMOUS" {
fmt.Fprintf(c.stanzaWriter, "<auth xmlns='%s' mechanism='ANONYMOUS' />\n", nsSASL) fmt.Fprintf(c.stanzaWriter, "<auth xmlns='%s' mechanism='ANONYMOUS' />\n", nsSASL)
foundAnonymous = true foundAnonymous = true
@ -446,26 +473,26 @@ func (c *Client) init(o *Options) error {
} }
mechanism = "" mechanism = ""
if o.Mechanism != "" { if o.Mechanism != "" {
if slices.Contains(f.Mechanisms.Mechanism, o.Mechanism) { if slices.Contains(mechSlice, o.Mechanism) {
mechanism = o.Mechanism mechanism = o.Mechanism
} }
} else { } else {
switch { switch {
case slices.Contains(f.Mechanisms.Mechanism, "SCRAM-SHA-512-PLUS") && tlsConnOK: case slices.Contains(mechSlice, "SCRAM-SHA-512-PLUS") && tlsConnOK:
mechanism = "SCRAM-SHA-512-PLUS" mechanism = "SCRAM-SHA-512-PLUS"
case slices.Contains(f.Mechanisms.Mechanism, "SCRAM-SHA-256-PLUS") && tlsConnOK: case slices.Contains(mechSlice, "SCRAM-SHA-256-PLUS") && tlsConnOK:
mechanism = "SCRAM-SHA-256-PLUS" mechanism = "SCRAM-SHA-256-PLUS"
case slices.Contains(f.Mechanisms.Mechanism, "SCRAM-SHA-1-PLUS") && tlsConnOK: case slices.Contains(mechSlice, "SCRAM-SHA-1-PLUS") && tlsConnOK:
mechanism = "SCRAM-SHA-1-PLUS" mechanism = "SCRAM-SHA-1-PLUS"
case slices.Contains(f.Mechanisms.Mechanism, "SCRAM-SHA-512"): case slices.Contains(mechSlice, "SCRAM-SHA-512"):
mechanism = "SCRAM-SHA-512" mechanism = "SCRAM-SHA-512"
case slices.Contains(f.Mechanisms.Mechanism, "SCRAM-SHA-256"): case slices.Contains(mechSlice, "SCRAM-SHA-256"):
mechanism = "SCRAM-SHA-256" mechanism = "SCRAM-SHA-256"
case slices.Contains(f.Mechanisms.Mechanism, "SCRAM-SHA-1"): case slices.Contains(mechSlice, "SCRAM-SHA-1"):
mechanism = "SCRAM-SHA-1" mechanism = "SCRAM-SHA-1"
case slices.Contains(f.Mechanisms.Mechanism, "X-OAUTH2"): case slices.Contains(mechSlice, "X-OAUTH2"):
mechanism = "X-OAUTH2" mechanism = "X-OAUTH2"
case slices.Contains(f.Mechanisms.Mechanism, "PLAIN") && tlsConnOK: case slices.Contains(mechSlice, "PLAIN") && tlsConnOK:
mechanism = "PLAIN" mechanism = "PLAIN"
} }
} }
@ -556,14 +583,48 @@ func (c *Client) init(o *Options) error {
} else { } else {
clientFirstMessage = "n,,n=" + user + ",r=" + clientNonce clientFirstMessage = "n,,n=" + user + ",r=" + clientNonce
} }
if sasl2 {
if bind2 {
if o.UserAgentSW != "" {
resource = o.UserAgentSW
} else {
resource = "go-xmpp"
}
bind2Data = fmt.Sprintf("<bind xmlns='%s'><tag>%s</tag></bind>",
nsBind2, resource)
}
if o.UserAgentSW != "" {
userAgentSW = fmt.Sprintf("<software>%s</software>", o.UserAgentSW)
} else {
userAgentSW = "<software>go-xmpp</software>"
}
if o.UserAgentDev != "" {
userAgentDev = fmt.Sprintf("<device>%s</device>", o.UserAgentDev)
}
if o.UserAgentID != "" {
userAgentID = fmt.Sprintf(" id='%s'", o.UserAgentID)
}
fmt.Fprintf(c.stanzaWriter,
"<authenticate xmlns='%s' mechanism='%s'><initial-response>%s</initial-response><user-agent%s>%s%s</user-agent>%s</authenticate>\n",
nsSASL2, mechanism, base64.StdEncoding.EncodeToString([]byte(clientFirstMessage)), userAgentID, userAgentSW, userAgentDev, bind2Data)
} else {
fmt.Fprintf(c.stanzaWriter, "<auth xmlns='%s' mechanism='%s'>%s</auth>\n", fmt.Fprintf(c.stanzaWriter, "<auth xmlns='%s' mechanism='%s'>%s</auth>\n",
nsSASL, mechanism, base64.StdEncoding.EncodeToString([]byte(clientFirstMessage))) nsSASL, mechanism, base64.StdEncoding.EncodeToString([]byte(clientFirstMessage)))
}
var sfm string var sfm string
_, val, err := c.next() _, val, err := c.next()
if err != nil { if err != nil {
return err return err
} }
switch v := val.(type) { switch v := val.(type) {
case *sasl2Failure:
errorMessage := v.Text
if errorMessage == "" {
// v.Any is type of sub-element in failure,
// which gives a description of what failed if there was no text element
errorMessage = v.Any.Local
}
return errors.New("auth failure: " + errorMessage)
case *saslFailure: case *saslFailure:
errorMessage := v.Text errorMessage := v.Text
if errorMessage == "" { if errorMessage == "" {
@ -572,6 +633,8 @@ func (c *Client) init(o *Options) error {
errorMessage = v.Any.Local errorMessage = v.Any.Local
} }
return errors.New("auth failure: " + errorMessage) return errors.New("auth failure: " + errorMessage)
case *sasl2Challenge:
sfm = v.Text
case *saslChallenge: case *saslChallenge:
sfm = v.Text sfm = v.Text
} }
@ -702,25 +765,39 @@ func (c *Client) init(o *Options) error {
} }
clientFinalMessage := base64.StdEncoding.EncodeToString([]byte(clientFinalMessageBare + clientFinalMessage := base64.StdEncoding.EncodeToString([]byte(clientFinalMessageBare +
",p=" + base64.StdEncoding.EncodeToString(clientProof))) ",p=" + base64.StdEncoding.EncodeToString(clientProof)))
if sasl2 {
fmt.Fprintf(c.stanzaWriter, "<response xmlns='%s'>%s</response>\n", nsSASL2,
clientFinalMessage)
} else {
fmt.Fprintf(c.stanzaWriter, "<response xmlns='%s'>%s</response>\n", nsSASL, fmt.Fprintf(c.stanzaWriter, "<response xmlns='%s'>%s</response>\n", nsSASL,
clientFinalMessage) clientFinalMessage)
} }
}
if mechanism == "X-OAUTH2" && o.OAuthToken != "" && o.OAuthScope != "" { if mechanism == "X-OAUTH2" && o.OAuthToken != "" && o.OAuthScope != "" {
// Oauth authentication: send base64-encoded \x00 user \x00 token. // Oauth authentication: send base64-encoded \x00 user \x00 token.
raw := "\x00" + user + "\x00" + o.OAuthToken raw := "\x00" + user + "\x00" + o.OAuthToken
enc := make([]byte, base64.StdEncoding.EncodedLen(len(raw))) enc := make([]byte, base64.StdEncoding.EncodedLen(len(raw)))
base64.StdEncoding.Encode(enc, []byte(raw)) base64.StdEncoding.Encode(enc, []byte(raw))
if sasl2 {
fmt.Fprintf(c.stanzaWriter, "<auth xmlns='%s' mechanism='X-OAUTH2' auth:service='oauth2' "+
"xmlns:auth='%s'>%s</auth>\n", nsSASL2, o.OAuthXmlNs, enc)
} else {
fmt.Fprintf(c.stanzaWriter, "<auth xmlns='%s' mechanism='X-OAUTH2' auth:service='oauth2' "+ fmt.Fprintf(c.stanzaWriter, "<auth xmlns='%s' mechanism='X-OAUTH2' auth:service='oauth2' "+
"xmlns:auth='%s'>%s</auth>\n", nsSASL, o.OAuthXmlNs, enc) "xmlns:auth='%s'>%s</auth>\n", nsSASL, o.OAuthXmlNs, enc)
} }
}
if mechanism == "PLAIN" { if mechanism == "PLAIN" {
// Plain authentication: send base64-encoded \x00 user \x00 password. // Plain authentication: send base64-encoded \x00 user \x00 password.
raw := "\x00" + user + "\x00" + o.Password raw := "\x00" + user + "\x00" + o.Password
enc := make([]byte, base64.StdEncoding.EncodedLen(len(raw))) enc := make([]byte, base64.StdEncoding.EncodedLen(len(raw)))
base64.StdEncoding.Encode(enc, []byte(raw)) base64.StdEncoding.Encode(enc, []byte(raw))
if sasl2 {
fmt.Fprintf(c.conn, "<auth xmlns='%s' mechanism='PLAIN'>%s</auth>\n", nsSASL2, enc)
} else {
fmt.Fprintf(c.conn, "<auth xmlns='%s' mechanism='PLAIN'>%s</auth>\n", nsSASL, enc) fmt.Fprintf(c.conn, "<auth xmlns='%s' mechanism='PLAIN'>%s</auth>\n", nsSASL, enc)
} }
} }
}
if mechanism == "" { if mechanism == "" {
return fmt.Errorf("no viable authentication method available: %v", f.Mechanisms.Mechanism) return fmt.Errorf("no viable authentication method available: %v", f.Mechanisms.Mechanism)
} }
@ -730,6 +807,28 @@ func (c *Client) init(o *Options) error {
return err return err
} }
switch v := val.(type) { switch v := val.(type) {
case *sasl2Success:
if strings.HasPrefix(mechanism, "SCRAM-SHA") {
successMsg, err := base64.StdEncoding.DecodeString(v.AdditionalData)
if err != nil {
return err
}
if !strings.HasPrefix(string(successMsg), "v=") {
return errors.New("server sent unexpected content in SCRAM success message")
}
serverSignatureReply := strings.SplitN(string(successMsg), "v=", 2)[1]
serverSignatureRemote, err := base64.StdEncoding.DecodeString(serverSignatureReply)
if err != nil {
return err
}
if string(serverSignature) != string(serverSignatureRemote) {
return errors.New("SCRAM: server signature mismatch")
}
c.Mechanism = mechanism
}
if bind2 {
c.jid = v.AuthorizationIdentifier
}
case *saslSuccess: case *saslSuccess:
if strings.HasPrefix(mechanism, "SCRAM-SHA") { if strings.HasPrefix(mechanism, "SCRAM-SHA") {
successMsg, err := base64.StdEncoding.DecodeString(v.Text) successMsg, err := base64.StdEncoding.DecodeString(v.Text)
@ -749,6 +848,14 @@ func (c *Client) init(o *Options) error {
} }
c.Mechanism = mechanism c.Mechanism = mechanism
} }
case *sasl2Failure:
errorMessage := v.Text
if errorMessage == "" {
// v.Any is type of sub-element in failure,
// which gives a description of what failed if there was no text element
errorMessage = v.Any.Local
}
return errors.New("auth failure: " + errorMessage)
case *saslFailure: case *saslFailure:
errorMessage := v.Text errorMessage := v.Text
if errorMessage == "" { if errorMessage == "" {
@ -761,11 +868,13 @@ func (c *Client) init(o *Options) error {
return errors.New("expected <success> or <failure>, got <" + name.Local + "> in " + name.Space) return errors.New("expected <success> or <failure>, got <" + name.Local + "> in " + name.Space)
} }
if !sasl2 {
// Now that we're authenticated, we're supposed to start the stream over again. // Now that we're authenticated, we're supposed to start the stream over again.
// Declare intent to be a jabber client. // Declare intent to be a jabber client.
if f, err = c.startStream(o, domain); err != nil { if f, err = c.startStream(o, domain); err != nil {
return err return err
} }
}
// Make the max. stanza size limit available. // Make the max. stanza size limit available.
if f.Limits.MaxBytes != "" { if f.Limits.MaxBytes != "" {
c.LimitMaxBytes, err = strconv.Atoi(f.Limits.MaxBytes) c.LimitMaxBytes, err = strconv.Atoi(f.Limits.MaxBytes)
@ -781,6 +890,7 @@ func (c *Client) init(o *Options) error {
} }
} }
if !bind2 {
// Generate a unique cookie // Generate a unique cookie
cookie := getCookie() cookie := getCookie()
@ -811,9 +921,10 @@ func (c *Client) init(o *Options) error {
return errors.New("bind: unexpected reply to xmpp-bind IQ") return errors.New("bind: unexpected reply to xmpp-bind IQ")
} }
} }
}
if o.Session { if o.Session {
// if server support session, open it // if server support session, open it
cookie = getCookie() // generate new id value for session cookie := getCookie() // generate new id value for session
fmt.Fprintf(c.stanzaWriter, "<iq to='%s' type='set' id='%x'><session xmlns='%s'/></iq>\n", xmlEscape(domain), cookie, nsSession) fmt.Fprintf(c.stanzaWriter, "<iq to='%s' type='set' id='%x'><session xmlns='%s'/></iq>\n", xmlEscape(domain), cookie, nsSession)
} }
@ -880,9 +991,9 @@ func (c *Client) startStream(o *Options, domain string) (*streamFeatures, error)
} }
_, err := fmt.Fprintf(c.stanzaWriter, "<?xml version='1.0'?>"+ _, err := fmt.Fprintf(c.stanzaWriter, "<?xml version='1.0'?>"+
"<stream:stream to='%s' xmlns='%s'"+ "<stream:stream from='%s' to='%s' xmlns='%s'"+
" xmlns:stream='%s' version='1.0'>\n", " xmlns:stream='%s' version='1.0'>\n",
xmlEscape(domain), nsClient, nsStream) xmlEscape(o.User), xmlEscape(domain), nsClient, nsStream)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1338,6 +1449,7 @@ func (c *Client) Roster() error {
// RFC 3920 C.1 Streams name space // RFC 3920 C.1 Streams name space
type streamFeatures struct { type streamFeatures struct {
XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"` XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"`
Authentication sasl2Authentication
StartTLS *tlsStartTLS StartTLS *tlsStartTLS
Mechanisms saslMechanisms Mechanisms saslMechanisms
ChannelBindings saslChannelBindings ChannelBindings saslChannelBindings
@ -1370,6 +1482,18 @@ type tlsFailure struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls failure"` XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls failure"`
} }
type sasl2Authentication struct {
XMLName xml.Name `xml:"urn:xmpp:sasl:2 authentication"`
Mechanism []string `xml:"mechanism"`
Inline struct {
Text string `xml:",chardata"`
Bind struct {
Text string `xml:",chardata"`
Xmlns string `xml:"xmlns,attr"`
} `xml:"bind"`
} `xml:"inline"`
}
// RFC 3920 C.4 SASL name space // RFC 3920 C.4 SASL name space
type saslMechanisms struct { type saslMechanisms struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl mechanisms"` XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl mechanisms"`
@ -1390,17 +1514,39 @@ type saslAbort struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl abort"` XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl abort"`
} }
type sasl2Success struct {
XMLName xml.Name `xml:"urn:xmpp:sasl:2 success"`
Text string `xml:",chardata"`
AdditionalData string `xml:"additional-data"`
AuthorizationIdentifier string `xml:"authorization-identifier"`
Bound struct {
Text string `xml:",chardata"`
Xmlns string `xml:"xmlns,attr"`
} `xml:"bound"`
}
type saslSuccess struct { type saslSuccess struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl success"` XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl success"`
Text string `xml:",chardata"` Text string `xml:",chardata"`
} }
type sasl2Failure struct {
XMLName xml.Name `xml:"urn:xmpp:sasl:2 failure"`
Any xml.Name `xml:",any"`
Text string `xml:"text"`
}
type saslFailure struct { type saslFailure struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl failure"` XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl failure"`
Any xml.Name `xml:",any"` Any xml.Name `xml:",any"`
Text string `xml:"text"` Text string `xml:"text"`
} }
type sasl2Challenge struct {
XMLName xml.Name `xml:"urn:xmpp:sasl:2 challenge"`
Text string `xml:",chardata"`
}
type saslChallenge struct { type saslChallenge struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl challenge"` XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl challenge"`
Text string `xml:",chardata"` Text string `xml:",chardata"`
@ -1607,14 +1753,20 @@ func (c *Client) next() (xml.Name, interface{}, error) {
nv = &tlsFailure{} nv = &tlsFailure{}
case nsSASL + " mechanisms": case nsSASL + " mechanisms":
nv = &saslMechanisms{} nv = &saslMechanisms{}
case nsSASL2 + " challenge":
nv = &sasl2Challenge{}
case nsSASL + " challenge": case nsSASL + " challenge":
nv = &saslChallenge{} nv = &saslChallenge{}
case nsSASL + " response": case nsSASL + " response":
nv = "" nv = ""
case nsSASL + " abort": case nsSASL + " abort":
nv = &saslAbort{} nv = &saslAbort{}
case nsSASL2 + " success":
nv = &sasl2Success{}
case nsSASL + " success": case nsSASL + " success":
nv = &saslSuccess{} nv = &saslSuccess{}
case nsSASL2 + " failure":
nv = &sasl2Failure{}
case nsSASL + " failure": case nsSASL + " failure":
nv = &saslFailure{} nv = &saslFailure{}
case nsSASLCB + " sasl-channel-binding": case nsSASLCB + " sasl-channel-binding":

View File

@ -12,7 +12,7 @@ const (
) )
func (c *Client) Discovery() (string, error) { func (c *Client) Discovery() (string, error) {
// use getCookie for a pseudo random id. // use UUIDv4 for a pseudo random id.
reqID := strconv.FormatUint(uint64(getCookie()), 10) reqID := strconv.FormatUint(uint64(getCookie()), 10)
return c.RawInformationQuery(c.jid, c.domain, reqID, IQTypeGet, XMPPNS_DISCO_ITEMS, "") return c.RawInformationQuery(c.jid, c.domain, reqID, IQTypeGet, XMPPNS_DISCO_ITEMS, "")
} }