Add SCRAM PLUS variants. (#163)

This commit is contained in:
Martin 2023-11-11 13:08:17 +01:00 committed by GitHub
parent 24e0f536cb
commit 4c385a334c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

216
xmpp.go
View File

@ -74,6 +74,7 @@ type Client struct {
domain string domain string
p *xml.Decoder p *xml.Decoder
stanzaWriter io.Writer stanzaWriter io.Writer
Mechanism string
} }
func (c *Client) JID() string { func (c *Client) JID() string {
@ -207,6 +208,9 @@ type Options struct {
// Status message // Status message
StatusMessage string StatusMessage string
// Auth mechanism to use
Mechanism string
} }
// NewClient establishes a new Client connection based on a set of Options. // NewClient establishes a new Client connection based on a set of Options.
@ -344,7 +348,6 @@ func cnonce() string {
} }
func (c *Client) init(o *Options) error { func (c *Client) init(o *Options) error {
var domain string var domain string
var user string var user string
a := strings.SplitN(o.User, "@", 2) a := strings.SplitN(o.User, "@", 2)
@ -372,8 +375,10 @@ func (c *Client) init(o *Options) error {
if f, err = c.startTLSIfRequired(f, o, domain); err != nil { if f, err = c.startTLSIfRequired(f, o, domain); err != nil {
return err return err
} }
var mechanism string var mechanism, channelBinding, clientFirstMessage, clientFinalMessageBare, authMessage string
var serverSignature []byte var serverSignature, keyingMaterial []byte
var scramPlus, ok, tlsConnOK, tls13 bool
var tlsConn *tls.Conn
if o.User == "" && o.Password == "" { if o.User == "" && o.Password == "" {
foundAnonymous := false foundAnonymous := false
for _, m := range f.Mechanisms.Mechanism { for _, m := range f.Mechanisms.Mechanism {
@ -394,51 +399,120 @@ func (c *Client) init(o *Options) error {
return errors.New("refusing to authenticate over unencrypted TCP connection") return errors.New("refusing to authenticate over unencrypted TCP connection")
} }
tlsConn, ok = c.conn.(*tls.Conn)
if ok {
tlsConnOK = true
}
mechanism = "" mechanism = ""
for _, m := range f.Mechanisms.Mechanism { if o.Mechanism == "" {
switch m { for _, m := range f.Mechanisms.Mechanism {
case "SCRAM-SHA-512": switch m {
mechanism = m case "SCRAM-SHA-512-PLUS":
case "SCRAM-SHA-256": if tlsConnOK {
if mechanism != "SCRAM-SHA-512" { mechanism = m
mechanism = m }
case "SCRAM-SHA-256-PLUS":
if mechanism != "SCRAM-SHA-512-PLUS" && tlsConnOK {
mechanism = m
}
case "SCRAM-SHA-1-PLUS":
if mechanism != "SCRAM-SHA-512-PLUS" &&
mechanism != "SCRAM-SHA-256-PLUS" &&
tlsConnOK {
mechanism = m
}
case "SCRAM-SHA-512":
if mechanism != "SCRAM-SHA-512-PLUS" &&
mechanism != "SCRAM-SHA-256-PLUS" &&
mechanism != "SCRAM-SHA-1-PLUS" {
mechanism = m
}
case "SCRAM-SHA-256":
if mechanism != "SCRAM-SHA-512-PLUS" &&
mechanism != "SCRAM-SHA-256-PLUS" &&
mechanism != "SCRAM-SHA-1-PLUS" &&
mechanism != "SCRAM-SHA-512" {
mechanism = m
}
case "SCRAM-SHA-1":
if mechanism != "SCRAM-SHA-512-PLUS" &&
mechanism != "SCRAM-SHA-256-PLUS" &&
mechanism != "SCRAM-SHA-1-PLUS" &&
mechanism != "SCRAM-SHA-512" &&
mechanism != "SCRAM-SHA-256" {
mechanism = m
}
case "X-OAUTH2":
if mechanism == "" {
mechanism = m
}
case "PLAIN":
if mechanism == "" {
mechanism = m
}
case "DIGEST-MD5":
if mechanism == "" {
mechanism = m
}
} }
case "SCRAM-SHA-1": }
if mechanism != "SCRAM-SHA-512" && } else {
mechanism != "SCRAM-SHA-256" { for _, m := range f.Mechanisms.Mechanism {
mechanism = m if m == o.Mechanism {
} mechanism = o.Mechanism
case "X-OAUTH2":
if mechanism == "" {
mechanism = m
}
case "PLAIN":
if mechanism == "" {
mechanism = m
}
case "DIGEST-MD5":
if mechanism == "" {
mechanism = m
} }
} }
} }
if strings.HasPrefix(mechanism, "SCRAM-SHA") { if strings.HasPrefix(mechanism, "SCRAM-SHA") {
if strings.HasSuffix(mechanism, "PLUS") {
scramPlus = true
}
if scramPlus {
tlsState := tlsConn.ConnectionState()
switch tlsState.Version {
case tls.VersionTLS13:
tls13 = true
keyingMaterial, err = tlsState.ExportKeyingMaterial("EXPORTER-Channel-Binding", nil, 32)
if err != nil {
return err
}
case tls.VersionTLS10, tls.VersionTLS11, tls.VersionTLS12:
keyingMaterial = tlsState.TLSUnique
default:
return errors.New(mechanism + ": unknown TLS version")
}
if len(keyingMaterial) == 0 {
return errors.New(mechanism + ": no keying material")
}
if tls13 {
channelBinding = base64.StdEncoding.EncodeToString(append([]byte("p=tls-exporter,,"), keyingMaterial[:]...))
} else {
channelBinding = base64.StdEncoding.EncodeToString(append([]byte("p=tls-unique,,"), keyingMaterial[:]...))
}
}
var shaNewFn func() hash.Hash var shaNewFn func() hash.Hash
switch mechanism { switch mechanism {
case "SCRAM-SHA-512": case "SCRAM-SHA-512", "SCRAM-SHA-512-PLUS":
shaNewFn = sha512.New shaNewFn = sha512.New
case "SCRAM-SHA-256": case "SCRAM-SHA-256", "SCRAM-SHA-256-PLUS":
shaNewFn = sha256.New shaNewFn = sha256.New
case "SCRAM-SHA-1": case "SCRAM-SHA-1", "SCRAM-SHA-1-PLUS":
shaNewFn = sha1.New shaNewFn = sha1.New
default: default:
return errors.New("unsupported auth mechanism") return errors.New("unsupported auth mechanism")
} }
clientNonce := cnonce() clientNonce := cnonce()
clientFirstMessage := "n=" + user + ",r=" + clientNonce if scramPlus {
if tls13 {
clientFirstMessage = "p=tls-exporter,,n=" + user + ",r=" + clientNonce
} else {
clientFirstMessage = "p=tls-unique,,n=" + user + ",r=" + clientNonce
}
} else {
clientFirstMessage = "n,,n=" + user + ",r=" + clientNonce
}
fmt.Fprintf(c.stanzaWriter, "<auth xmlns='%s' mechanism='%s'>%s</auth>", fmt.Fprintf(c.stanzaWriter, "<auth xmlns='%s' mechanism='%s'>%s</auth>",
nsSASL, mechanism, base64.StdEncoding.EncodeToString([]byte("n,,"+ nsSASL, mechanism, base64.StdEncoding.EncodeToString([]byte(clientFirstMessage)))
clientFirstMessage)))
var sfm string var sfm string
if err = c.p.DecodeElement(&sfm, nil); err != nil { if err = c.p.DecodeElement(&sfm, nil); err != nil {
return errors.New("unmarshal <challenge>: " + err.Error()) return errors.New("unmarshal <challenge>: " + err.Error())
@ -472,7 +546,11 @@ func (c *Client) init(o *Options) error {
return errors.New("unexpected content in SCRAM challenge") return errors.New("unexpected content in SCRAM challenge")
} }
} }
clientFinalMessageBare := "c=biws,r=" + serverNonce if scramPlus {
clientFinalMessageBare = "c=" + channelBinding + ",r=" + serverNonce
} else {
clientFinalMessageBare = "c=biws,r=" + serverNonce
}
saltedPassword := pbkdf2.Key([]byte(o.Password), salt, saltedPassword := pbkdf2.Key([]byte(o.Password), salt,
iterations, shaNewFn().Size(), shaNewFn) iterations, shaNewFn().Size(), shaNewFn)
h := hmac.New(shaNewFn, saltedPassword) h := hmac.New(shaNewFn, saltedPassword)
@ -484,13 +562,13 @@ func (c *Client) init(o *Options) error {
h.Reset() h.Reset()
var storedKey []byte var storedKey []byte
switch mechanism { switch mechanism {
case "SCRAM-SHA-512": case "SCRAM-SHA-512", "SCRAM-SHA-512-PLUS":
storedKey512 := sha512.Sum512(clientKey) storedKey512 := sha512.Sum512(clientKey)
storedKey = storedKey512[:] storedKey = storedKey512[:]
case "SCRAM-SHA-256": case "SCRAM-SHA-256", "SCRAM-SH-256-PLUS":
storedKey256 := sha256.Sum256(clientKey) storedKey256 := sha256.Sum256(clientKey)
storedKey = storedKey256[:] storedKey = storedKey256[:]
case "SCRAM-SHA-1": case "SCRAM-SHA-1", "SCRAM-SHA-1-PLUS":
storedKey1 := sha1.Sum(clientKey) storedKey1 := sha1.Sum(clientKey)
storedKey = storedKey1[:] storedKey = storedKey1[:]
} }
@ -502,8 +580,8 @@ func (c *Client) init(o *Options) error {
if err != nil { if err != nil {
return err return err
} }
authMessage := clientFirstMessage + "," + string(serverFirstMessage) + authMessage = strings.SplitAfter(clientFirstMessage, ",,")[1] + "," +
"," + clientFinalMessageBare string(serverFirstMessage) + "," + clientFinalMessageBare
h = hmac.New(shaNewFn, storedKey[:]) h = hmac.New(shaNewFn, storedKey[:])
_, err = h.Write([]byte(authMessage)) _, err = h.Write([]byte(authMessage))
if err != nil { if err != nil {
@ -536,7 +614,7 @@ func (c *Client) init(o *Options) error {
} }
clientFinalMessage := base64.StdEncoding.EncodeToString([]byte(clientFinalMessageBare + clientFinalMessage := base64.StdEncoding.EncodeToString([]byte(clientFinalMessageBare +
",p=" + base64.StdEncoding.EncodeToString(clientProof))) ",p=" + base64.StdEncoding.EncodeToString(clientProof)))
fmt.Fprintf(c.stanzaWriter, "<response xmlns='%s'>%s</response>", nsSASL, fmt.Fprintf(c.stanzaWriter, "<response xmlns='%s'>%s</response>\n", nsSASL,
clientFinalMessage) clientFinalMessage)
} }
if mechanism == "X-OAUTH2" && o.OAuthToken != "" && o.OAuthScope != "" { if mechanism == "X-OAUTH2" && o.OAuthToken != "" && o.OAuthScope != "" {
@ -625,6 +703,7 @@ func (c *Client) init(o *Options) error {
if string(serverSignature) != string(serverSignatureRemote) { if string(serverSignature) != string(serverSignatureRemote) {
return errors.New("SCRAM: server signature mismatch") return errors.New("SCRAM: server signature mismatch")
} }
c.Mechanism = mechanism
} }
case *saslFailure: case *saslFailure:
errorMessage := v.Text errorMessage := v.Text
@ -664,12 +743,12 @@ func (c *Client) init(o *Options) error {
c.domain = domain c.domain = domain
if o.Session { if o.Session {
//if server support session, open it // if server support session, open it
fmt.Fprintf(c.stanzaWriter, "<iq to='%s' type='set' id='%x'><session xmlns='%s'/></iq>", xmlEscape(domain), cookie, nsSession) fmt.Fprintf(c.stanzaWriter, "<iq to='%s' type='set' id='%x'><session xmlns='%s'/></iq>\n", xmlEscape(domain), cookie, nsSession)
} }
// We're connected and can now receive and send messages. // We're connected and can now receive and send messages.
fmt.Fprintf(c.stanzaWriter, "<presence xml:lang='en'><show>%s</show><status>%s</status></presence>", o.Status, o.StatusMessage) fmt.Fprintf(c.stanzaWriter, "<presence xml:lang='en'><show>%s</show><status>%s</status></presence>\n", o.Status, o.StatusMessage)
return nil return nil
} }
@ -700,7 +779,7 @@ func (c *Client) startTLSIfRequired(f *streamFeatures, o *Options, domain string
tc := o.TLSConfig tc := o.TLSConfig
if tc == nil { if tc == nil {
tc = DefaultConfig.Clone() tc = DefaultConfig.Clone()
//TODO(scott): we should consider using the server's address or reverse lookup // TODO(scott): we should consider using the server's address or reverse lookup
tc.ServerName = domain tc.ServerName = domain
} }
t := tls.Client(c.conn, tc) t := tls.Client(c.conn, tc)
@ -732,7 +811,7 @@ func (c *Client) startStream(o *Options, domain string) (*streamFeatures, error)
_, err := fmt.Fprintf(c.stanzaWriter, "<?xml version='1.0'?>"+ _, err := fmt.Fprintf(c.stanzaWriter, "<?xml version='1.0'?>"+
"<stream:stream to='%s' xmlns='%s'"+ "<stream:stream to='%s' xmlns='%s'"+
" xmlns:stream='%s' version='1.0'>", " xmlns:stream='%s' version='1.0'>\n",
xmlEscape(domain), nsClient, nsStream) xmlEscape(domain), nsClient, nsStream)
if err != nil { if err != nil {
return nil, err return nil, err
@ -892,8 +971,10 @@ func (c *Client) Recv() (stanza interface{}, err error) {
return Chat{}, err return Chat{}, err
} }
return IQ{ID: v.ID, From: v.From, To: v.To, Type: v.Type, return IQ{
Query: res}, nil ID: v.ID, From: v.From, To: v.To, Type: v.Type,
Query: res,
}, nil
} }
case v.Type == "result": case v.Type == "result":
switch v.ID { switch v.ID {
@ -1021,8 +1102,10 @@ func (c *Client) Recv() (stanza interface{}, err error) {
return Chat{}, err return Chat{}, err
} }
return IQ{ID: v.ID, From: v.From, To: v.To, Type: v.Type, return IQ{
Query: res}, nil ID: v.ID, From: v.From, To: v.To, Type: v.Type,
Query: res,
}, nil
} }
case v.Query.XMLName.Local == "": case v.Query.XMLName.Local == "":
return IQ{ID: v.ID, From: v.From, To: v.To, Type: v.Type}, nil return IQ{ID: v.ID, From: v.From, To: v.To, Type: v.Type}, nil
@ -1032,8 +1115,10 @@ func (c *Client) Recv() (stanza interface{}, err error) {
return Chat{}, err return Chat{}, err
} }
return IQ{ID: v.ID, From: v.From, To: v.To, Type: v.Type, return IQ{
Query: res}, nil ID: v.ID, From: v.From, To: v.To, Type: v.Type,
Query: res,
}, nil
} }
} }
} }
@ -1056,7 +1141,7 @@ func (c *Client) Send(chat Chat) (n int, err error) {
oobtext += `</x>` oobtext += `</x>`
} }
stanza := "<message to='%s' type='%s' id='%s' xml:lang='en'>" + subtext + "<body>%s</body>" + oobtext + thdtext + "</message>" stanza := "<message to='%s' type='%s' id='%s' xml:lang='en'>" + subtext + "<body>%s</body>" + oobtext + thdtext + "</message>\n"
return fmt.Fprintf(c.stanzaWriter, stanza, return fmt.Fprintf(c.stanzaWriter, stanza,
xmlEscape(chat.Remote), xmlEscape(chat.Type), cnonce(), xmlEscape(chat.Text)) xmlEscape(chat.Remote), xmlEscape(chat.Type), cnonce(), xmlEscape(chat.Text))
@ -1075,17 +1160,17 @@ func (c *Client) SendOOB(chat Chat) (n int, err error) {
} }
oobtext += `</x>` oobtext += `</x>`
} }
return fmt.Fprintf(c.stanzaWriter, "<message to='%s' type='%s' id='%s' xml:lang='en'>"+oobtext+thdtext+"</message>", return fmt.Fprintf(c.stanzaWriter, "<message to='%s' type='%s' id='%s' xml:lang='en'>"+oobtext+thdtext+"</message>\n",
xmlEscape(chat.Remote), xmlEscape(chat.Type), cnonce()) xmlEscape(chat.Remote), xmlEscape(chat.Type), cnonce())
} }
// SendOrg sends the original text without being wrapped in an XMPP message stanza. // SendOrg sends the original text without being wrapped in an XMPP message stanza.
func (c *Client) SendOrg(org string) (n int, err error) { func (c *Client) SendOrg(org string) (n int, err error) {
return fmt.Fprint(c.stanzaWriter, org) return fmt.Fprint(c.stanzaWriter, org+"\n")
} }
func (c *Client) SendPresence(presence Presence) (n int, err error) { func (c *Client) SendPresence(presence Presence) (n int, err error) {
return fmt.Fprintf(c.stanzaWriter, "<presence from='%s' to='%s'/>", xmlEscape(presence.From), xmlEscape(presence.To)) return fmt.Fprintf(c.stanzaWriter, "<presence from='%s' to='%s'/>\n", xmlEscape(presence.From), xmlEscape(presence.To))
} }
// SendKeepAlive sends a "whitespace keepalive" as described in chapter 4.6.1 of RFC6120. // SendKeepAlive sends a "whitespace keepalive" as described in chapter 4.6.1 of RFC6120.
@ -1097,7 +1182,7 @@ func (c *Client) SendKeepAlive() (n int, err error) {
func (c *Client) SendHtml(chat Chat) (n int, err error) { func (c *Client) SendHtml(chat Chat) (n int, err error) {
return fmt.Fprintf(c.stanzaWriter, "<message to='%s' type='%s' xml:lang='en'>"+ return fmt.Fprintf(c.stanzaWriter, "<message to='%s' type='%s' xml:lang='en'>"+
"<body>%s</body>"+ "<body>%s</body>"+
"<html xmlns='http://jabber.org/protocol/xhtml-im'><body xmlns='http://www.w3.org/1999/xhtml'>%s</body></html></message>", "<html xmlns='http://jabber.org/protocol/xhtml-im'><body xmlns='http://www.w3.org/1999/xhtml'>%s</body></html></message>\n",
xmlEscape(chat.Remote), xmlEscape(chat.Type), xmlEscape(chat.Text), chat.Text) xmlEscape(chat.Remote), xmlEscape(chat.Type), xmlEscape(chat.Text), chat.Text)
} }
@ -1109,11 +1194,12 @@ func (c *Client) Roster() error {
// RFC 3920 C.1 Streams name space // RFC 3920 C.1 Streams name space
type streamFeatures struct { type streamFeatures struct {
XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"` XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"`
StartTLS *tlsStartTLS StartTLS *tlsStartTLS
Mechanisms saslMechanisms Mechanisms saslMechanisms
Bind bindBind ChannelBinding saslChannelBinding
Session bool Bind bindBind
Session bool
} }
type streamError struct { type streamError struct {
@ -1147,6 +1233,16 @@ type saslAuth struct {
Mechanism string `xml:",attr"` Mechanism string `xml:",attr"`
} }
type saslChannelBinding struct {
XMLName xml.Name `xml:"sasl-channel-binding"`
Text string `xml:",chardata"`
Xmlns string `xml:"xmlns,attr"`
ChannelBinding []struct {
Text string `xml:",chardata"`
Type string `xml:"type,attr"`
} `xml:"channel-binding"`
}
type saslChallenge string type saslChallenge string
type saslRspAuth string type saslRspAuth string