From 4c385a334c606e8bc387f0a3d4d84975802b3984 Mon Sep 17 00:00:00 2001 From: Martin Date: Sat, 11 Nov 2023 13:08:17 +0100 Subject: [PATCH] Add SCRAM PLUS variants. (#163) --- xmpp.go | 216 ++++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 156 insertions(+), 60 deletions(-) diff --git a/xmpp.go b/xmpp.go index e14f84d..6b58995 100644 --- a/xmpp.go +++ b/xmpp.go @@ -74,6 +74,7 @@ type Client struct { domain string p *xml.Decoder stanzaWriter io.Writer + Mechanism string } func (c *Client) JID() string { @@ -207,6 +208,9 @@ type Options struct { // Status message StatusMessage string + + // Auth mechanism to use + Mechanism string } // NewClient establishes a new Client connection based on a set of Options. @@ -344,7 +348,6 @@ func cnonce() string { } func (c *Client) init(o *Options) error { - var domain string var user string a := strings.SplitN(o.User, "@", 2) @@ -372,8 +375,10 @@ func (c *Client) init(o *Options) error { if f, err = c.startTLSIfRequired(f, o, domain); err != nil { return err } - var mechanism string - var serverSignature []byte + var mechanism, channelBinding, clientFirstMessage, clientFinalMessageBare, authMessage string + var serverSignature, keyingMaterial []byte + var scramPlus, ok, tlsConnOK, tls13 bool + var tlsConn *tls.Conn if o.User == "" && o.Password == "" { foundAnonymous := false for _, m := range f.Mechanisms.Mechanism { @@ -394,51 +399,120 @@ func (c *Client) init(o *Options) error { return errors.New("refusing to authenticate over unencrypted TCP connection") } + tlsConn, ok = c.conn.(*tls.Conn) + if ok { + tlsConnOK = true + } mechanism = "" - for _, m := range f.Mechanisms.Mechanism { - switch m { - case "SCRAM-SHA-512": - mechanism = m - case "SCRAM-SHA-256": - if mechanism != "SCRAM-SHA-512" { - mechanism = m + if o.Mechanism == "" { + for _, m := range f.Mechanisms.Mechanism { + switch m { + case "SCRAM-SHA-512-PLUS": + if tlsConnOK { + mechanism = m + } + case "SCRAM-SHA-256-PLUS": + if mechanism != "SCRAM-SHA-512-PLUS" && tlsConnOK { + mechanism = m + } + case "SCRAM-SHA-1-PLUS": + if mechanism != "SCRAM-SHA-512-PLUS" && + mechanism != "SCRAM-SHA-256-PLUS" && + tlsConnOK { + mechanism = m + } + case "SCRAM-SHA-512": + if mechanism != "SCRAM-SHA-512-PLUS" && + mechanism != "SCRAM-SHA-256-PLUS" && + mechanism != "SCRAM-SHA-1-PLUS" { + mechanism = m + } + case "SCRAM-SHA-256": + if mechanism != "SCRAM-SHA-512-PLUS" && + mechanism != "SCRAM-SHA-256-PLUS" && + mechanism != "SCRAM-SHA-1-PLUS" && + mechanism != "SCRAM-SHA-512" { + mechanism = m + } + case "SCRAM-SHA-1": + if mechanism != "SCRAM-SHA-512-PLUS" && + mechanism != "SCRAM-SHA-256-PLUS" && + mechanism != "SCRAM-SHA-1-PLUS" && + mechanism != "SCRAM-SHA-512" && + mechanism != "SCRAM-SHA-256" { + mechanism = m + } + case "X-OAUTH2": + if mechanism == "" { + mechanism = m + } + case "PLAIN": + if mechanism == "" { + mechanism = m + } + case "DIGEST-MD5": + if mechanism == "" { + mechanism = m + } } - case "SCRAM-SHA-1": - if mechanism != "SCRAM-SHA-512" && - mechanism != "SCRAM-SHA-256" { - mechanism = m - } - case "X-OAUTH2": - if mechanism == "" { - mechanism = m - } - case "PLAIN": - if mechanism == "" { - mechanism = m - } - case "DIGEST-MD5": - if mechanism == "" { - mechanism = m + } + } else { + for _, m := range f.Mechanisms.Mechanism { + if m == o.Mechanism { + mechanism = o.Mechanism } } } if strings.HasPrefix(mechanism, "SCRAM-SHA") { + if strings.HasSuffix(mechanism, "PLUS") { + scramPlus = true + } + if scramPlus { + tlsState := tlsConn.ConnectionState() + switch tlsState.Version { + case tls.VersionTLS13: + tls13 = true + keyingMaterial, err = tlsState.ExportKeyingMaterial("EXPORTER-Channel-Binding", nil, 32) + if err != nil { + return err + } + case tls.VersionTLS10, tls.VersionTLS11, tls.VersionTLS12: + keyingMaterial = tlsState.TLSUnique + default: + return errors.New(mechanism + ": unknown TLS version") + } + if len(keyingMaterial) == 0 { + return errors.New(mechanism + ": no keying material") + } + if tls13 { + channelBinding = base64.StdEncoding.EncodeToString(append([]byte("p=tls-exporter,,"), keyingMaterial[:]...)) + } else { + channelBinding = base64.StdEncoding.EncodeToString(append([]byte("p=tls-unique,,"), keyingMaterial[:]...)) + } + } var shaNewFn func() hash.Hash switch mechanism { - case "SCRAM-SHA-512": + case "SCRAM-SHA-512", "SCRAM-SHA-512-PLUS": shaNewFn = sha512.New - case "SCRAM-SHA-256": + case "SCRAM-SHA-256", "SCRAM-SHA-256-PLUS": shaNewFn = sha256.New - case "SCRAM-SHA-1": + case "SCRAM-SHA-1", "SCRAM-SHA-1-PLUS": shaNewFn = sha1.New default: return errors.New("unsupported auth mechanism") } clientNonce := cnonce() - clientFirstMessage := "n=" + user + ",r=" + clientNonce + if scramPlus { + if tls13 { + clientFirstMessage = "p=tls-exporter,,n=" + user + ",r=" + clientNonce + } else { + clientFirstMessage = "p=tls-unique,,n=" + user + ",r=" + clientNonce + } + } else { + clientFirstMessage = "n,,n=" + user + ",r=" + clientNonce + } fmt.Fprintf(c.stanzaWriter, "%s", - nsSASL, mechanism, base64.StdEncoding.EncodeToString([]byte("n,,"+ - clientFirstMessage))) + nsSASL, mechanism, base64.StdEncoding.EncodeToString([]byte(clientFirstMessage))) var sfm string if err = c.p.DecodeElement(&sfm, nil); err != nil { return errors.New("unmarshal : " + err.Error()) @@ -472,7 +546,11 @@ func (c *Client) init(o *Options) error { return errors.New("unexpected content in SCRAM challenge") } } - clientFinalMessageBare := "c=biws,r=" + serverNonce + if scramPlus { + clientFinalMessageBare = "c=" + channelBinding + ",r=" + serverNonce + } else { + clientFinalMessageBare = "c=biws,r=" + serverNonce + } saltedPassword := pbkdf2.Key([]byte(o.Password), salt, iterations, shaNewFn().Size(), shaNewFn) h := hmac.New(shaNewFn, saltedPassword) @@ -484,13 +562,13 @@ func (c *Client) init(o *Options) error { h.Reset() var storedKey []byte switch mechanism { - case "SCRAM-SHA-512": + case "SCRAM-SHA-512", "SCRAM-SHA-512-PLUS": storedKey512 := sha512.Sum512(clientKey) storedKey = storedKey512[:] - case "SCRAM-SHA-256": + case "SCRAM-SHA-256", "SCRAM-SH-256-PLUS": storedKey256 := sha256.Sum256(clientKey) storedKey = storedKey256[:] - case "SCRAM-SHA-1": + case "SCRAM-SHA-1", "SCRAM-SHA-1-PLUS": storedKey1 := sha1.Sum(clientKey) storedKey = storedKey1[:] } @@ -502,8 +580,8 @@ func (c *Client) init(o *Options) error { if err != nil { return err } - authMessage := clientFirstMessage + "," + string(serverFirstMessage) + - "," + clientFinalMessageBare + authMessage = strings.SplitAfter(clientFirstMessage, ",,")[1] + "," + + string(serverFirstMessage) + "," + clientFinalMessageBare h = hmac.New(shaNewFn, storedKey[:]) _, err = h.Write([]byte(authMessage)) if err != nil { @@ -536,7 +614,7 @@ func (c *Client) init(o *Options) error { } clientFinalMessage := base64.StdEncoding.EncodeToString([]byte(clientFinalMessageBare + ",p=" + base64.StdEncoding.EncodeToString(clientProof))) - fmt.Fprintf(c.stanzaWriter, "%s", nsSASL, + fmt.Fprintf(c.stanzaWriter, "%s\n", nsSASL, clientFinalMessage) } if mechanism == "X-OAUTH2" && o.OAuthToken != "" && o.OAuthScope != "" { @@ -625,6 +703,7 @@ func (c *Client) init(o *Options) error { if string(serverSignature) != string(serverSignatureRemote) { return errors.New("SCRAM: server signature mismatch") } + c.Mechanism = mechanism } case *saslFailure: errorMessage := v.Text @@ -664,12 +743,12 @@ func (c *Client) init(o *Options) error { c.domain = domain if o.Session { - //if server support session, open it - fmt.Fprintf(c.stanzaWriter, "", xmlEscape(domain), cookie, nsSession) + // if server support session, open it + fmt.Fprintf(c.stanzaWriter, "\n", xmlEscape(domain), cookie, nsSession) } // We're connected and can now receive and send messages. - fmt.Fprintf(c.stanzaWriter, "%s%s", o.Status, o.StatusMessage) + fmt.Fprintf(c.stanzaWriter, "%s%s\n", o.Status, o.StatusMessage) return nil } @@ -700,7 +779,7 @@ func (c *Client) startTLSIfRequired(f *streamFeatures, o *Options, domain string tc := o.TLSConfig if tc == nil { tc = DefaultConfig.Clone() - //TODO(scott): we should consider using the server's address or reverse lookup + // TODO(scott): we should consider using the server's address or reverse lookup tc.ServerName = domain } t := tls.Client(c.conn, tc) @@ -732,7 +811,7 @@ func (c *Client) startStream(o *Options, domain string) (*streamFeatures, error) _, err := fmt.Fprintf(c.stanzaWriter, ""+ "", + " xmlns:stream='%s' version='1.0'>\n", xmlEscape(domain), nsClient, nsStream) if err != nil { return nil, err @@ -892,8 +971,10 @@ func (c *Client) Recv() (stanza interface{}, err error) { return Chat{}, err } - return IQ{ID: v.ID, From: v.From, To: v.To, Type: v.Type, - Query: res}, nil + return IQ{ + ID: v.ID, From: v.From, To: v.To, Type: v.Type, + Query: res, + }, nil } case v.Type == "result": switch v.ID { @@ -1021,8 +1102,10 @@ func (c *Client) Recv() (stanza interface{}, err error) { return Chat{}, err } - return IQ{ID: v.ID, From: v.From, To: v.To, Type: v.Type, - Query: res}, nil + return IQ{ + ID: v.ID, From: v.From, To: v.To, Type: v.Type, + Query: res, + }, nil } case v.Query.XMLName.Local == "": return IQ{ID: v.ID, From: v.From, To: v.To, Type: v.Type}, nil @@ -1032,8 +1115,10 @@ func (c *Client) Recv() (stanza interface{}, err error) { return Chat{}, err } - return IQ{ID: v.ID, From: v.From, To: v.To, Type: v.Type, - Query: res}, nil + return IQ{ + ID: v.ID, From: v.From, To: v.To, Type: v.Type, + Query: res, + }, nil } } } @@ -1056,7 +1141,7 @@ func (c *Client) Send(chat Chat) (n int, err error) { oobtext += `` } - stanza := "" + subtext + "%s" + oobtext + thdtext + "" + stanza := "" + subtext + "%s" + oobtext + thdtext + "\n" return fmt.Fprintf(c.stanzaWriter, stanza, xmlEscape(chat.Remote), xmlEscape(chat.Type), cnonce(), xmlEscape(chat.Text)) @@ -1075,17 +1160,17 @@ func (c *Client) SendOOB(chat Chat) (n int, err error) { } oobtext += `` } - return fmt.Fprintf(c.stanzaWriter, ""+oobtext+thdtext+"", + return fmt.Fprintf(c.stanzaWriter, ""+oobtext+thdtext+"\n", xmlEscape(chat.Remote), xmlEscape(chat.Type), cnonce()) } // SendOrg sends the original text without being wrapped in an XMPP message stanza. func (c *Client) SendOrg(org string) (n int, err error) { - return fmt.Fprint(c.stanzaWriter, org) + return fmt.Fprint(c.stanzaWriter, org+"\n") } func (c *Client) SendPresence(presence Presence) (n int, err error) { - return fmt.Fprintf(c.stanzaWriter, "", xmlEscape(presence.From), xmlEscape(presence.To)) + return fmt.Fprintf(c.stanzaWriter, "\n", xmlEscape(presence.From), xmlEscape(presence.To)) } // SendKeepAlive sends a "whitespace keepalive" as described in chapter 4.6.1 of RFC6120. @@ -1097,7 +1182,7 @@ func (c *Client) SendKeepAlive() (n int, err error) { func (c *Client) SendHtml(chat Chat) (n int, err error) { return fmt.Fprintf(c.stanzaWriter, ""+ "%s"+ - "%s", + "%s\n", xmlEscape(chat.Remote), xmlEscape(chat.Type), xmlEscape(chat.Text), chat.Text) } @@ -1109,11 +1194,12 @@ func (c *Client) Roster() error { // RFC 3920 C.1 Streams name space type streamFeatures struct { - XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"` - StartTLS *tlsStartTLS - Mechanisms saslMechanisms - Bind bindBind - Session bool + XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"` + StartTLS *tlsStartTLS + Mechanisms saslMechanisms + ChannelBinding saslChannelBinding + Bind bindBind + Session bool } type streamError struct { @@ -1147,6 +1233,16 @@ type saslAuth struct { Mechanism string `xml:",attr"` } +type saslChannelBinding struct { + XMLName xml.Name `xml:"sasl-channel-binding"` + Text string `xml:",chardata"` + Xmlns string `xml:"xmlns,attr"` + ChannelBinding []struct { + Text string `xml:",chardata"` + Type string `xml:"type,attr"` + } `xml:"channel-binding"` +} + type saslChallenge string type saslRspAuth string