diff --git a/xmpp.go b/xmpp.go
index 117262e..9af89a2 100644
--- a/xmpp.go
+++ b/xmpp.go
@@ -15,22 +15,30 @@ package xmpp
import (
"bufio"
"bytes"
+ "crypto/hmac"
"crypto/md5"
"crypto/rand"
+ "crypto/sha1"
+ "crypto/sha256"
+ "crypto/sha512"
"crypto/tls"
"encoding/base64"
"encoding/binary"
"encoding/xml"
"errors"
"fmt"
+ "hash"
"io"
"math/big"
"net"
"net/http"
"net/url"
"os"
+ "strconv"
"strings"
"time"
+
+ "golang.org/x/crypto/pbkdf2"
)
const (
@@ -364,7 +372,8 @@ func (c *Client) init(o *Options) error {
if f, err = c.startTLSIfRequired(f, o, domain); err != nil {
return err
}
-
+ var mechanism string
+ var serverSignature []byte
if o.User == "" && o.Password == "" {
foundAnonymous := false
for _, m := range f.Mechanisms.Mechanism {
@@ -385,77 +394,213 @@ func (c *Client) init(o *Options) error {
return errors.New("refusing to authenticate over unencrypted TCP connection")
}
- mechanism := ""
+ mechanism = ""
for _, m := range f.Mechanisms.Mechanism {
- if m == "X-OAUTH2" && o.OAuthToken != "" && o.OAuthScope != "" {
+ switch m {
+ case "SCRAM-SHA-512":
mechanism = m
- // Oauth authentication: send base64-encoded \x00 user \x00 token.
- raw := "\x00" + user + "\x00" + o.OAuthToken
- enc := make([]byte, base64.StdEncoding.EncodedLen(len(raw)))
- base64.StdEncoding.Encode(enc, []byte(raw))
- fmt.Fprintf(c.conn, "%s\n", nsSASL, o.OAuthXmlNs, enc)
- break
- }
- if m == "PLAIN" {
- mechanism = m
- // Plain authentication: send base64-encoded \x00 user \x00 password.
- raw := "\x00" + user + "\x00" + o.Password
- enc := make([]byte, base64.StdEncoding.EncodedLen(len(raw)))
- base64.StdEncoding.Encode(enc, []byte(raw))
- fmt.Fprintf(c.conn, "%s\n", nsSASL, enc)
- break
- }
- if m == "DIGEST-MD5" {
- mechanism = m
- // Digest-MD5 authentication
- fmt.Fprintf(c.conn, "\n", nsSASL)
- var ch saslChallenge
- if err = c.p.DecodeElement(&ch, nil); err != nil {
- return errors.New("unmarshal : " + err.Error())
+ case "SCRAM-SHA-256":
+ if mechanism != "SCRAM-SHA-512" {
+ mechanism = m
}
- b, err := base64.StdEncoding.DecodeString(string(ch))
- if err != nil {
- return err
+ case "SCRAM-SHA-1":
+ if mechanism != "SCRAM-SHA-512" &&
+ mechanism != "SCRAM-SHA-256" {
+ mechanism = m
}
- tokens := map[string]string{}
- for _, token := range strings.Split(string(b), ",") {
- kv := strings.SplitN(strings.TrimSpace(token), "=", 2)
- if len(kv) == 2 {
- if kv[1][0] == '"' && kv[1][len(kv[1])-1] == '"' {
- kv[1] = kv[1][1 : len(kv[1])-1]
- }
- tokens[kv[0]] = kv[1]
+ case "X-OAUTH2":
+ if mechanism == "" {
+ mechanism = m
+ }
+ case "PLAIN":
+ if mechanism == "" {
+ mechanism = m
+ }
+ case "DIGEST-MD5":
+ if mechanism == "" {
+ mechanism = m
+ }
+ }
+ }
+ if strings.HasPrefix(mechanism, "SCRAM-SHA") {
+ var shaNewFn func() hash.Hash
+ switch mechanism {
+ case "SCRAM-SHA-512":
+ shaNewFn = sha512.New
+ case "SCRAM-SHA-256":
+ shaNewFn = sha256.New
+ case "SCRAM-SHA-1":
+ shaNewFn = sha1.New
+ default:
+ return errors.New("unsupported auth mechanism")
+ }
+ clientNonce := cnonce()
+ clientFirstMessage := "n=" + user + ",r=" + clientNonce
+ fmt.Fprintf(StanzaWriter, "%s",
+ nsSASL, mechanism, base64.StdEncoding.EncodeToString([]byte("n,,"+
+ clientFirstMessage)))
+ var sfm string
+ if err = c.p.DecodeElement(&sfm, nil); err != nil {
+ return errors.New("unmarshal : " + err.Error())
+ }
+ b, err := base64.StdEncoding.DecodeString(string(sfm))
+ if err != nil {
+ return err
+ }
+ var serverNonce string
+ var salt []byte
+ var iterations int
+ for _, serverReply := range strings.Split(string(b), ",") {
+ switch {
+ case strings.HasPrefix(serverReply, "r="):
+ serverNonce = strings.SplitN(serverReply, "=", 2)[1]
+ if !strings.HasPrefix(serverNonce, clientNonce) {
+ return errors.New("SCRAM: server nonce didn't start with client nonce")
}
+ case strings.HasPrefix(serverReply, "s="):
+ salt, err = base64.StdEncoding.DecodeString(strings.SplitN(serverReply, "=", 2)[1])
+ if string(salt) == "" {
+ return errors.New("SCRAM: server sent empty salt")
+ }
+ case strings.HasPrefix(serverReply, "i="):
+ iterations, err = strconv.Atoi(strings.SplitN(serverReply,
+ "=", 2)[1])
+ if err != nil {
+ return err
+ }
+ default:
+ return errors.New("unexpected conted in SCRAM challenge")
}
- realm, _ := tokens["realm"]
- nonce, _ := tokens["nonce"]
- qop, _ := tokens["qop"]
- charset, _ := tokens["charset"]
- cnonceStr := cnonce()
- digestURI := "xmpp/" + domain
- nonceCount := fmt.Sprintf("%08x", 1)
- digest := saslDigestResponse(user, realm, o.Password, nonce, cnonceStr, "AUTHENTICATE", digestURI, nonceCount)
- message := "username=\"" + user + "\", realm=\"" + realm + "\", nonce=\"" + nonce + "\", cnonce=\"" + cnonceStr +
- "\", nc=" + nonceCount + ", qop=" + qop + ", digest-uri=\"" + digestURI + "\", response=" + digest + ", charset=" + charset
-
- fmt.Fprintf(c.conn, "%s\n", nsSASL, base64.StdEncoding.EncodeToString([]byte(message)))
-
- var rspauth saslRspAuth
- if err = c.p.DecodeElement(&rspauth, nil); err != nil {
- return errors.New("unmarshal : " + err.Error())
- }
- b, err = base64.StdEncoding.DecodeString(string(rspauth))
- if err != nil {
- return err
- }
- fmt.Fprintf(c.conn, "\n", nsSASL)
- break
}
+ clientFinalMessageBare := "c=biws,r=" + serverNonce
+ saltedPassword := pbkdf2.Key([]byte(o.Password), salt,
+ iterations, shaNewFn().Size(), shaNewFn)
+ h := hmac.New(shaNewFn, saltedPassword)
+ _, err = h.Write([]byte("Client Key"))
+ if err != nil {
+ return err
+ }
+ clientKey := h.Sum(nil)
+ h.Reset()
+ var storedKey []byte
+ switch mechanism {
+ case "SCRAM-SHA-512":
+ storedKey512 := sha512.Sum512(clientKey)
+ storedKey = storedKey512[:]
+ case "SCRAM-SHA-256":
+ storedKey256 := sha256.Sum256(clientKey)
+ storedKey = storedKey256[:]
+ case "SCRAM-SHA-1":
+ storedKey1 := sha1.Sum(clientKey)
+ storedKey = storedKey1[:]
+ }
+ _, err = h.Write([]byte("Server Key"))
+ if err != nil {
+ return err
+ }
+ serverFirstMessage, err := base64.StdEncoding.DecodeString(sfm)
+ if err != nil {
+ return err
+ }
+ authMessage := clientFirstMessage + "," + string(serverFirstMessage) +
+ "," + clientFinalMessageBare
+ h = hmac.New(shaNewFn, storedKey[:])
+ _, err = h.Write([]byte(authMessage))
+ if err != nil {
+ return err
+ }
+ clientSignature := h.Sum(nil)
+ h.Reset()
+ if len(clientKey) != len(clientSignature) {
+ return errors.New("SCRAM: client key and signature length mismatch")
+ }
+ clientProof := make([]byte, len(clientKey))
+ for i := range clientKey {
+ clientProof[i] = clientKey[i] ^ clientSignature[i]
+ }
+ h = hmac.New(shaNewFn, saltedPassword)
+ _, err = h.Write([]byte("Server Key"))
+ if err != nil {
+ return err
+ }
+ serverKey := h.Sum(nil)
+ h.Reset()
+ h = hmac.New(shaNewFn, serverKey)
+ _, err = h.Write([]byte(authMessage))
+ if err != nil {
+ return err
+ }
+ serverSignature = h.Sum(nil)
+ if string(serverSignature) == "" {
+ return errors.New("SCRAM: calculated an empty server signature")
+ }
+ clientFinalMessage := base64.StdEncoding.EncodeToString([]byte(clientFinalMessageBare +
+ ",p=" + base64.StdEncoding.EncodeToString(clientProof)))
+ fmt.Fprintf(StanzaWriter, "%s", nsSASL,
+ clientFinalMessage)
}
- if mechanism == "" {
- return fmt.Errorf("PLAIN authentication is not an option: %v", f.Mechanisms.Mechanism)
+ if mechanism == "X-OAUTH2" && o.OAuthToken != "" && o.OAuthScope != "" {
+ // Oauth authentication: send base64-encoded \x00 user \x00 token.
+ raw := "\x00" + user + "\x00" + o.OAuthToken
+ enc := make([]byte, base64.StdEncoding.EncodedLen(len(raw)))
+ base64.StdEncoding.Encode(enc, []byte(raw))
+ fmt.Fprintf(StanzaWriter, "%s\n", nsSASL, o.OAuthXmlNs, enc)
}
+ if mechanism == "PLAIN" {
+ // Plain authentication: send base64-encoded \x00 user \x00 password.
+ raw := "\x00" + user + "\x00" + o.Password
+ enc := make([]byte, base64.StdEncoding.EncodedLen(len(raw)))
+ base64.StdEncoding.Encode(enc, []byte(raw))
+ fmt.Fprintf(c.conn, "%s\n", nsSASL, enc)
+ }
+ if mechanism == "DIGEST-MD5" {
+ // Digest-MD5 authentication
+ fmt.Fprintf(StanzaWriter, "\n", nsSASL)
+ var ch saslChallenge
+ if err = c.p.DecodeElement(&ch, nil); err != nil {
+ return errors.New("unmarshal : " + err.Error())
+ }
+ b, err := base64.StdEncoding.DecodeString(string(ch))
+ if err != nil {
+ return err
+ }
+ tokens := map[string]string{}
+ for _, token := range strings.Split(string(b), ",") {
+ kv := strings.SplitN(strings.TrimSpace(token), "=", 2)
+ if len(kv) == 2 {
+ if kv[1][0] == '"' && kv[1][len(kv[1])-1] == '"' {
+ kv[1] = kv[1][1 : len(kv[1])-1]
+ }
+ tokens[kv[0]] = kv[1]
+ }
+ }
+ realm, _ := tokens["realm"]
+ nonce, _ := tokens["nonce"]
+ qop, _ := tokens["qop"]
+ charset, _ := tokens["charset"]
+ cnonceStr := cnonce()
+ digestURI := "xmpp/" + domain
+ nonceCount := fmt.Sprintf("%08x", 1)
+ digest := saslDigestResponse(user, realm, o.Password, nonce, cnonceStr, "AUTHENTICATE", digestURI, nonceCount)
+ message := "username=\"" + user + "\", realm=\"" + realm + "\", nonce=\"" + nonce + "\", cnonce=\"" + cnonceStr +
+ "\", nc=" + nonceCount + ", qop=" + qop + ", digest-uri=\"" + digestURI + "\", response=" + digest + ", charset=" + charset
+
+ fmt.Fprintf(StanzaWriter, "%s\n", nsSASL, base64.StdEncoding.EncodeToString([]byte(message)))
+
+ var rspauth saslRspAuth
+ if err = c.p.DecodeElement(&rspauth, nil); err != nil {
+ return errors.New("unmarshal : " + err.Error())
+ }
+ b, err = base64.StdEncoding.DecodeString(string(rspauth))
+ if err != nil {
+ return err
+ }
+ fmt.Fprintf(StanzaWriter, "\n", nsSASL)
+ }
+ }
+ if mechanism == "" {
+ return fmt.Errorf("PLAIN authentication is not an option: %v", f.Mechanisms.Mechanism)
}
// Next message should be either success or failure.
name, val, err := next(c.p)
@@ -464,6 +609,23 @@ func (c *Client) init(o *Options) error {
}
switch v := val.(type) {
case *saslSuccess:
+ if strings.HasPrefix(mechanism, "SCRAM-SHA") {
+ successMsg, err := base64.StdEncoding.DecodeString(v.Text)
+ if err != nil {
+ return err
+ }
+ if !strings.HasPrefix(string(successMsg), "v=") {
+ return errors.New("server sent unexpected content in SCRAM success message")
+ }
+ serverSignatureReply := strings.SplitN(string(successMsg), "v=", 2)[1]
+ serverSignatureRemote, err := base64.StdEncoding.DecodeString(serverSignatureReply)
+ if err != nil {
+ return err
+ }
+ if string(serverSignature) != string(serverSignatureRemote) {
+ return errors.New("SCRAM: server signature mismatch")
+ }
+ }
case *saslFailure:
errorMessage := v.Text
if errorMessage == "" {