forked from jshiffer/go-xmpp
add scram auth (#147)
* Fix syntax errors. * gofmt * Add SCRAM-SHA-1, SCRAM-SHA-256 and SCRAM-SHA-512 auth
This commit is contained in:
parent
9129a110df
commit
bef3e549f7
288
xmpp.go
288
xmpp.go
@ -15,22 +15,30 @@ package xmpp
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/hmac"
|
||||
"crypto/md5"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/xml"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -364,7 +372,8 @@ func (c *Client) init(o *Options) error {
|
||||
if f, err = c.startTLSIfRequired(f, o, domain); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var mechanism string
|
||||
var serverSignature []byte
|
||||
if o.User == "" && o.Password == "" {
|
||||
foundAnonymous := false
|
||||
for _, m := range f.Mechanisms.Mechanism {
|
||||
@ -385,77 +394,213 @@ func (c *Client) init(o *Options) error {
|
||||
return errors.New("refusing to authenticate over unencrypted TCP connection")
|
||||
}
|
||||
|
||||
mechanism := ""
|
||||
mechanism = ""
|
||||
for _, m := range f.Mechanisms.Mechanism {
|
||||
if m == "X-OAUTH2" && o.OAuthToken != "" && o.OAuthScope != "" {
|
||||
switch m {
|
||||
case "SCRAM-SHA-512":
|
||||
mechanism = m
|
||||
// Oauth authentication: send base64-encoded \x00 user \x00 token.
|
||||
raw := "\x00" + user + "\x00" + o.OAuthToken
|
||||
enc := make([]byte, base64.StdEncoding.EncodedLen(len(raw)))
|
||||
base64.StdEncoding.Encode(enc, []byte(raw))
|
||||
fmt.Fprintf(c.conn, "<auth xmlns='%s' mechanism='X-OAUTH2' auth:service='oauth2' "+
|
||||
"xmlns:auth='%s'>%s</auth>\n", nsSASL, o.OAuthXmlNs, enc)
|
||||
break
|
||||
}
|
||||
if m == "PLAIN" {
|
||||
mechanism = m
|
||||
// Plain authentication: send base64-encoded \x00 user \x00 password.
|
||||
raw := "\x00" + user + "\x00" + o.Password
|
||||
enc := make([]byte, base64.StdEncoding.EncodedLen(len(raw)))
|
||||
base64.StdEncoding.Encode(enc, []byte(raw))
|
||||
fmt.Fprintf(c.conn, "<auth xmlns='%s' mechanism='PLAIN'>%s</auth>\n", nsSASL, enc)
|
||||
break
|
||||
}
|
||||
if m == "DIGEST-MD5" {
|
||||
mechanism = m
|
||||
// Digest-MD5 authentication
|
||||
fmt.Fprintf(c.conn, "<auth xmlns='%s' mechanism='DIGEST-MD5'/>\n", nsSASL)
|
||||
var ch saslChallenge
|
||||
if err = c.p.DecodeElement(&ch, nil); err != nil {
|
||||
return errors.New("unmarshal <challenge>: " + err.Error())
|
||||
case "SCRAM-SHA-256":
|
||||
if mechanism != "SCRAM-SHA-512" {
|
||||
mechanism = m
|
||||
}
|
||||
b, err := base64.StdEncoding.DecodeString(string(ch))
|
||||
if err != nil {
|
||||
return err
|
||||
case "SCRAM-SHA-1":
|
||||
if mechanism != "SCRAM-SHA-512" &&
|
||||
mechanism != "SCRAM-SHA-256" {
|
||||
mechanism = m
|
||||
}
|
||||
tokens := map[string]string{}
|
||||
for _, token := range strings.Split(string(b), ",") {
|
||||
kv := strings.SplitN(strings.TrimSpace(token), "=", 2)
|
||||
if len(kv) == 2 {
|
||||
if kv[1][0] == '"' && kv[1][len(kv[1])-1] == '"' {
|
||||
kv[1] = kv[1][1 : len(kv[1])-1]
|
||||
}
|
||||
tokens[kv[0]] = kv[1]
|
||||
case "X-OAUTH2":
|
||||
if mechanism == "" {
|
||||
mechanism = m
|
||||
}
|
||||
case "PLAIN":
|
||||
if mechanism == "" {
|
||||
mechanism = m
|
||||
}
|
||||
case "DIGEST-MD5":
|
||||
if mechanism == "" {
|
||||
mechanism = m
|
||||
}
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(mechanism, "SCRAM-SHA") {
|
||||
var shaNewFn func() hash.Hash
|
||||
switch mechanism {
|
||||
case "SCRAM-SHA-512":
|
||||
shaNewFn = sha512.New
|
||||
case "SCRAM-SHA-256":
|
||||
shaNewFn = sha256.New
|
||||
case "SCRAM-SHA-1":
|
||||
shaNewFn = sha1.New
|
||||
default:
|
||||
return errors.New("unsupported auth mechanism")
|
||||
}
|
||||
clientNonce := cnonce()
|
||||
clientFirstMessage := "n=" + user + ",r=" + clientNonce
|
||||
fmt.Fprintf(StanzaWriter, "<auth xmlns='%s' mechanism='%s'>%s</auth>",
|
||||
nsSASL, mechanism, base64.StdEncoding.EncodeToString([]byte("n,,"+
|
||||
clientFirstMessage)))
|
||||
var sfm string
|
||||
if err = c.p.DecodeElement(&sfm, nil); err != nil {
|
||||
return errors.New("unmarshal <challenge>: " + err.Error())
|
||||
}
|
||||
b, err := base64.StdEncoding.DecodeString(string(sfm))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var serverNonce string
|
||||
var salt []byte
|
||||
var iterations int
|
||||
for _, serverReply := range strings.Split(string(b), ",") {
|
||||
switch {
|
||||
case strings.HasPrefix(serverReply, "r="):
|
||||
serverNonce = strings.SplitN(serverReply, "=", 2)[1]
|
||||
if !strings.HasPrefix(serverNonce, clientNonce) {
|
||||
return errors.New("SCRAM: server nonce didn't start with client nonce")
|
||||
}
|
||||
case strings.HasPrefix(serverReply, "s="):
|
||||
salt, err = base64.StdEncoding.DecodeString(strings.SplitN(serverReply, "=", 2)[1])
|
||||
if string(salt) == "" {
|
||||
return errors.New("SCRAM: server sent empty salt")
|
||||
}
|
||||
case strings.HasPrefix(serverReply, "i="):
|
||||
iterations, err = strconv.Atoi(strings.SplitN(serverReply,
|
||||
"=", 2)[1])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return errors.New("unexpected conted in SCRAM challenge")
|
||||
}
|
||||
realm, _ := tokens["realm"]
|
||||
nonce, _ := tokens["nonce"]
|
||||
qop, _ := tokens["qop"]
|
||||
charset, _ := tokens["charset"]
|
||||
cnonceStr := cnonce()
|
||||
digestURI := "xmpp/" + domain
|
||||
nonceCount := fmt.Sprintf("%08x", 1)
|
||||
digest := saslDigestResponse(user, realm, o.Password, nonce, cnonceStr, "AUTHENTICATE", digestURI, nonceCount)
|
||||
message := "username=\"" + user + "\", realm=\"" + realm + "\", nonce=\"" + nonce + "\", cnonce=\"" + cnonceStr +
|
||||
"\", nc=" + nonceCount + ", qop=" + qop + ", digest-uri=\"" + digestURI + "\", response=" + digest + ", charset=" + charset
|
||||
|
||||
fmt.Fprintf(c.conn, "<response xmlns='%s'>%s</response>\n", nsSASL, base64.StdEncoding.EncodeToString([]byte(message)))
|
||||
|
||||
var rspauth saslRspAuth
|
||||
if err = c.p.DecodeElement(&rspauth, nil); err != nil {
|
||||
return errors.New("unmarshal <challenge>: " + err.Error())
|
||||
}
|
||||
b, err = base64.StdEncoding.DecodeString(string(rspauth))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Fprintf(c.conn, "<response xmlns='%s'/>\n", nsSASL)
|
||||
break
|
||||
}
|
||||
clientFinalMessageBare := "c=biws,r=" + serverNonce
|
||||
saltedPassword := pbkdf2.Key([]byte(o.Password), salt,
|
||||
iterations, shaNewFn().Size(), shaNewFn)
|
||||
h := hmac.New(shaNewFn, saltedPassword)
|
||||
_, err = h.Write([]byte("Client Key"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
clientKey := h.Sum(nil)
|
||||
h.Reset()
|
||||
var storedKey []byte
|
||||
switch mechanism {
|
||||
case "SCRAM-SHA-512":
|
||||
storedKey512 := sha512.Sum512(clientKey)
|
||||
storedKey = storedKey512[:]
|
||||
case "SCRAM-SHA-256":
|
||||
storedKey256 := sha256.Sum256(clientKey)
|
||||
storedKey = storedKey256[:]
|
||||
case "SCRAM-SHA-1":
|
||||
storedKey1 := sha1.Sum(clientKey)
|
||||
storedKey = storedKey1[:]
|
||||
}
|
||||
_, err = h.Write([]byte("Server Key"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
serverFirstMessage, err := base64.StdEncoding.DecodeString(sfm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
authMessage := clientFirstMessage + "," + string(serverFirstMessage) +
|
||||
"," + clientFinalMessageBare
|
||||
h = hmac.New(shaNewFn, storedKey[:])
|
||||
_, err = h.Write([]byte(authMessage))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
clientSignature := h.Sum(nil)
|
||||
h.Reset()
|
||||
if len(clientKey) != len(clientSignature) {
|
||||
return errors.New("SCRAM: client key and signature length mismatch")
|
||||
}
|
||||
clientProof := make([]byte, len(clientKey))
|
||||
for i := range clientKey {
|
||||
clientProof[i] = clientKey[i] ^ clientSignature[i]
|
||||
}
|
||||
h = hmac.New(shaNewFn, saltedPassword)
|
||||
_, err = h.Write([]byte("Server Key"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
serverKey := h.Sum(nil)
|
||||
h.Reset()
|
||||
h = hmac.New(shaNewFn, serverKey)
|
||||
_, err = h.Write([]byte(authMessage))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
serverSignature = h.Sum(nil)
|
||||
if string(serverSignature) == "" {
|
||||
return errors.New("SCRAM: calculated an empty server signature")
|
||||
}
|
||||
clientFinalMessage := base64.StdEncoding.EncodeToString([]byte(clientFinalMessageBare +
|
||||
",p=" + base64.StdEncoding.EncodeToString(clientProof)))
|
||||
fmt.Fprintf(StanzaWriter, "<response xmlns='%s'>%s</response>", nsSASL,
|
||||
clientFinalMessage)
|
||||
}
|
||||
if mechanism == "" {
|
||||
return fmt.Errorf("PLAIN authentication is not an option: %v", f.Mechanisms.Mechanism)
|
||||
if mechanism == "X-OAUTH2" && o.OAuthToken != "" && o.OAuthScope != "" {
|
||||
// Oauth authentication: send base64-encoded \x00 user \x00 token.
|
||||
raw := "\x00" + user + "\x00" + o.OAuthToken
|
||||
enc := make([]byte, base64.StdEncoding.EncodedLen(len(raw)))
|
||||
base64.StdEncoding.Encode(enc, []byte(raw))
|
||||
fmt.Fprintf(StanzaWriter, "<auth xmlns='%s' mechanism='X-OAUTH2' auth:service='oauth2' "+
|
||||
"xmlns:auth='%s'>%s</auth>\n", nsSASL, o.OAuthXmlNs, enc)
|
||||
}
|
||||
if mechanism == "PLAIN" {
|
||||
// Plain authentication: send base64-encoded \x00 user \x00 password.
|
||||
raw := "\x00" + user + "\x00" + o.Password
|
||||
enc := make([]byte, base64.StdEncoding.EncodedLen(len(raw)))
|
||||
base64.StdEncoding.Encode(enc, []byte(raw))
|
||||
fmt.Fprintf(c.conn, "<auth xmlns='%s' mechanism='PLAIN'>%s</auth>\n", nsSASL, enc)
|
||||
}
|
||||
if mechanism == "DIGEST-MD5" {
|
||||
// Digest-MD5 authentication
|
||||
fmt.Fprintf(StanzaWriter, "<auth xmlns='%s' mechanism='DIGEST-MD5'/>\n", nsSASL)
|
||||
var ch saslChallenge
|
||||
if err = c.p.DecodeElement(&ch, nil); err != nil {
|
||||
return errors.New("unmarshal <challenge>: " + err.Error())
|
||||
}
|
||||
b, err := base64.StdEncoding.DecodeString(string(ch))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tokens := map[string]string{}
|
||||
for _, token := range strings.Split(string(b), ",") {
|
||||
kv := strings.SplitN(strings.TrimSpace(token), "=", 2)
|
||||
if len(kv) == 2 {
|
||||
if kv[1][0] == '"' && kv[1][len(kv[1])-1] == '"' {
|
||||
kv[1] = kv[1][1 : len(kv[1])-1]
|
||||
}
|
||||
tokens[kv[0]] = kv[1]
|
||||
}
|
||||
}
|
||||
realm, _ := tokens["realm"]
|
||||
nonce, _ := tokens["nonce"]
|
||||
qop, _ := tokens["qop"]
|
||||
charset, _ := tokens["charset"]
|
||||
cnonceStr := cnonce()
|
||||
digestURI := "xmpp/" + domain
|
||||
nonceCount := fmt.Sprintf("%08x", 1)
|
||||
digest := saslDigestResponse(user, realm, o.Password, nonce, cnonceStr, "AUTHENTICATE", digestURI, nonceCount)
|
||||
message := "username=\"" + user + "\", realm=\"" + realm + "\", nonce=\"" + nonce + "\", cnonce=\"" + cnonceStr +
|
||||
"\", nc=" + nonceCount + ", qop=" + qop + ", digest-uri=\"" + digestURI + "\", response=" + digest + ", charset=" + charset
|
||||
|
||||
fmt.Fprintf(StanzaWriter, "<response xmlns='%s'>%s</response>\n", nsSASL, base64.StdEncoding.EncodeToString([]byte(message)))
|
||||
|
||||
var rspauth saslRspAuth
|
||||
if err = c.p.DecodeElement(&rspauth, nil); err != nil {
|
||||
return errors.New("unmarshal <challenge>: " + err.Error())
|
||||
}
|
||||
b, err = base64.StdEncoding.DecodeString(string(rspauth))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Fprintf(StanzaWriter, "<response xmlns='%s'/>\n", nsSASL)
|
||||
}
|
||||
}
|
||||
if mechanism == "" {
|
||||
return fmt.Errorf("PLAIN authentication is not an option: %v", f.Mechanisms.Mechanism)
|
||||
}
|
||||
// Next message should be either success or failure.
|
||||
name, val, err := next(c.p)
|
||||
@ -464,6 +609,23 @@ func (c *Client) init(o *Options) error {
|
||||
}
|
||||
switch v := val.(type) {
|
||||
case *saslSuccess:
|
||||
if strings.HasPrefix(mechanism, "SCRAM-SHA") {
|
||||
successMsg, err := base64.StdEncoding.DecodeString(v.Text)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !strings.HasPrefix(string(successMsg), "v=") {
|
||||
return errors.New("server sent unexpected content in SCRAM success message")
|
||||
}
|
||||
serverSignatureReply := strings.SplitN(string(successMsg), "v=", 2)[1]
|
||||
serverSignatureRemote, err := base64.StdEncoding.DecodeString(serverSignatureReply)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if string(serverSignature) != string(serverSignatureRemote) {
|
||||
return errors.New("SCRAM: server signature mismatch")
|
||||
}
|
||||
}
|
||||
case *saslFailure:
|
||||
errorMessage := v.Text
|
||||
if errorMessage == "" {
|
||||
|
Loading…
Reference in New Issue
Block a user