forked from jshiffer/go-xmpp
add scram auth (#147)
* Fix syntax errors. * gofmt * Add SCRAM-SHA-1, SCRAM-SHA-256 and SCRAM-SHA-512 auth
This commit is contained in:
parent
9129a110df
commit
bef3e549f7
192
xmpp.go
192
xmpp.go
@ -15,22 +15,30 @@ package xmpp
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"crypto/hmac"
|
||||||
"crypto/md5"
|
"crypto/md5"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
"crypto/sha1"
|
||||||
|
"crypto/sha256"
|
||||||
|
"crypto/sha512"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/xml"
|
"encoding/xml"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"hash"
|
||||||
"io"
|
"io"
|
||||||
"math/big"
|
"math/big"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/pbkdf2"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -364,7 +372,8 @@ func (c *Client) init(o *Options) error {
|
|||||||
if f, err = c.startTLSIfRequired(f, o, domain); err != nil {
|
if f, err = c.startTLSIfRequired(f, o, domain); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
var mechanism string
|
||||||
|
var serverSignature []byte
|
||||||
if o.User == "" && o.Password == "" {
|
if o.User == "" && o.Password == "" {
|
||||||
foundAnonymous := false
|
foundAnonymous := false
|
||||||
for _, m := range f.Mechanisms.Mechanism {
|
for _, m := range f.Mechanisms.Mechanism {
|
||||||
@ -385,31 +394,169 @@ func (c *Client) init(o *Options) error {
|
|||||||
return errors.New("refusing to authenticate over unencrypted TCP connection")
|
return errors.New("refusing to authenticate over unencrypted TCP connection")
|
||||||
}
|
}
|
||||||
|
|
||||||
mechanism := ""
|
mechanism = ""
|
||||||
for _, m := range f.Mechanisms.Mechanism {
|
for _, m := range f.Mechanisms.Mechanism {
|
||||||
if m == "X-OAUTH2" && o.OAuthToken != "" && o.OAuthScope != "" {
|
switch m {
|
||||||
|
case "SCRAM-SHA-512":
|
||||||
mechanism = m
|
mechanism = m
|
||||||
|
case "SCRAM-SHA-256":
|
||||||
|
if mechanism != "SCRAM-SHA-512" {
|
||||||
|
mechanism = m
|
||||||
|
}
|
||||||
|
case "SCRAM-SHA-1":
|
||||||
|
if mechanism != "SCRAM-SHA-512" &&
|
||||||
|
mechanism != "SCRAM-SHA-256" {
|
||||||
|
mechanism = m
|
||||||
|
}
|
||||||
|
case "X-OAUTH2":
|
||||||
|
if mechanism == "" {
|
||||||
|
mechanism = m
|
||||||
|
}
|
||||||
|
case "PLAIN":
|
||||||
|
if mechanism == "" {
|
||||||
|
mechanism = m
|
||||||
|
}
|
||||||
|
case "DIGEST-MD5":
|
||||||
|
if mechanism == "" {
|
||||||
|
mechanism = m
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(mechanism, "SCRAM-SHA") {
|
||||||
|
var shaNewFn func() hash.Hash
|
||||||
|
switch mechanism {
|
||||||
|
case "SCRAM-SHA-512":
|
||||||
|
shaNewFn = sha512.New
|
||||||
|
case "SCRAM-SHA-256":
|
||||||
|
shaNewFn = sha256.New
|
||||||
|
case "SCRAM-SHA-1":
|
||||||
|
shaNewFn = sha1.New
|
||||||
|
default:
|
||||||
|
return errors.New("unsupported auth mechanism")
|
||||||
|
}
|
||||||
|
clientNonce := cnonce()
|
||||||
|
clientFirstMessage := "n=" + user + ",r=" + clientNonce
|
||||||
|
fmt.Fprintf(StanzaWriter, "<auth xmlns='%s' mechanism='%s'>%s</auth>",
|
||||||
|
nsSASL, mechanism, base64.StdEncoding.EncodeToString([]byte("n,,"+
|
||||||
|
clientFirstMessage)))
|
||||||
|
var sfm string
|
||||||
|
if err = c.p.DecodeElement(&sfm, nil); err != nil {
|
||||||
|
return errors.New("unmarshal <challenge>: " + err.Error())
|
||||||
|
}
|
||||||
|
b, err := base64.StdEncoding.DecodeString(string(sfm))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var serverNonce string
|
||||||
|
var salt []byte
|
||||||
|
var iterations int
|
||||||
|
for _, serverReply := range strings.Split(string(b), ",") {
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(serverReply, "r="):
|
||||||
|
serverNonce = strings.SplitN(serverReply, "=", 2)[1]
|
||||||
|
if !strings.HasPrefix(serverNonce, clientNonce) {
|
||||||
|
return errors.New("SCRAM: server nonce didn't start with client nonce")
|
||||||
|
}
|
||||||
|
case strings.HasPrefix(serverReply, "s="):
|
||||||
|
salt, err = base64.StdEncoding.DecodeString(strings.SplitN(serverReply, "=", 2)[1])
|
||||||
|
if string(salt) == "" {
|
||||||
|
return errors.New("SCRAM: server sent empty salt")
|
||||||
|
}
|
||||||
|
case strings.HasPrefix(serverReply, "i="):
|
||||||
|
iterations, err = strconv.Atoi(strings.SplitN(serverReply,
|
||||||
|
"=", 2)[1])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return errors.New("unexpected conted in SCRAM challenge")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
clientFinalMessageBare := "c=biws,r=" + serverNonce
|
||||||
|
saltedPassword := pbkdf2.Key([]byte(o.Password), salt,
|
||||||
|
iterations, shaNewFn().Size(), shaNewFn)
|
||||||
|
h := hmac.New(shaNewFn, saltedPassword)
|
||||||
|
_, err = h.Write([]byte("Client Key"))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
clientKey := h.Sum(nil)
|
||||||
|
h.Reset()
|
||||||
|
var storedKey []byte
|
||||||
|
switch mechanism {
|
||||||
|
case "SCRAM-SHA-512":
|
||||||
|
storedKey512 := sha512.Sum512(clientKey)
|
||||||
|
storedKey = storedKey512[:]
|
||||||
|
case "SCRAM-SHA-256":
|
||||||
|
storedKey256 := sha256.Sum256(clientKey)
|
||||||
|
storedKey = storedKey256[:]
|
||||||
|
case "SCRAM-SHA-1":
|
||||||
|
storedKey1 := sha1.Sum(clientKey)
|
||||||
|
storedKey = storedKey1[:]
|
||||||
|
}
|
||||||
|
_, err = h.Write([]byte("Server Key"))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
serverFirstMessage, err := base64.StdEncoding.DecodeString(sfm)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
authMessage := clientFirstMessage + "," + string(serverFirstMessage) +
|
||||||
|
"," + clientFinalMessageBare
|
||||||
|
h = hmac.New(shaNewFn, storedKey[:])
|
||||||
|
_, err = h.Write([]byte(authMessage))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
clientSignature := h.Sum(nil)
|
||||||
|
h.Reset()
|
||||||
|
if len(clientKey) != len(clientSignature) {
|
||||||
|
return errors.New("SCRAM: client key and signature length mismatch")
|
||||||
|
}
|
||||||
|
clientProof := make([]byte, len(clientKey))
|
||||||
|
for i := range clientKey {
|
||||||
|
clientProof[i] = clientKey[i] ^ clientSignature[i]
|
||||||
|
}
|
||||||
|
h = hmac.New(shaNewFn, saltedPassword)
|
||||||
|
_, err = h.Write([]byte("Server Key"))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
serverKey := h.Sum(nil)
|
||||||
|
h.Reset()
|
||||||
|
h = hmac.New(shaNewFn, serverKey)
|
||||||
|
_, err = h.Write([]byte(authMessage))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
serverSignature = h.Sum(nil)
|
||||||
|
if string(serverSignature) == "" {
|
||||||
|
return errors.New("SCRAM: calculated an empty server signature")
|
||||||
|
}
|
||||||
|
clientFinalMessage := base64.StdEncoding.EncodeToString([]byte(clientFinalMessageBare +
|
||||||
|
",p=" + base64.StdEncoding.EncodeToString(clientProof)))
|
||||||
|
fmt.Fprintf(StanzaWriter, "<response xmlns='%s'>%s</response>", nsSASL,
|
||||||
|
clientFinalMessage)
|
||||||
|
}
|
||||||
|
if mechanism == "X-OAUTH2" && o.OAuthToken != "" && o.OAuthScope != "" {
|
||||||
// Oauth authentication: send base64-encoded \x00 user \x00 token.
|
// Oauth authentication: send base64-encoded \x00 user \x00 token.
|
||||||
raw := "\x00" + user + "\x00" + o.OAuthToken
|
raw := "\x00" + user + "\x00" + o.OAuthToken
|
||||||
enc := make([]byte, base64.StdEncoding.EncodedLen(len(raw)))
|
enc := make([]byte, base64.StdEncoding.EncodedLen(len(raw)))
|
||||||
base64.StdEncoding.Encode(enc, []byte(raw))
|
base64.StdEncoding.Encode(enc, []byte(raw))
|
||||||
fmt.Fprintf(c.conn, "<auth xmlns='%s' mechanism='X-OAUTH2' auth:service='oauth2' "+
|
fmt.Fprintf(StanzaWriter, "<auth xmlns='%s' mechanism='X-OAUTH2' auth:service='oauth2' "+
|
||||||
"xmlns:auth='%s'>%s</auth>\n", nsSASL, o.OAuthXmlNs, enc)
|
"xmlns:auth='%s'>%s</auth>\n", nsSASL, o.OAuthXmlNs, enc)
|
||||||
break
|
|
||||||
}
|
}
|
||||||
if m == "PLAIN" {
|
if mechanism == "PLAIN" {
|
||||||
mechanism = m
|
|
||||||
// Plain authentication: send base64-encoded \x00 user \x00 password.
|
// Plain authentication: send base64-encoded \x00 user \x00 password.
|
||||||
raw := "\x00" + user + "\x00" + o.Password
|
raw := "\x00" + user + "\x00" + o.Password
|
||||||
enc := make([]byte, base64.StdEncoding.EncodedLen(len(raw)))
|
enc := make([]byte, base64.StdEncoding.EncodedLen(len(raw)))
|
||||||
base64.StdEncoding.Encode(enc, []byte(raw))
|
base64.StdEncoding.Encode(enc, []byte(raw))
|
||||||
fmt.Fprintf(c.conn, "<auth xmlns='%s' mechanism='PLAIN'>%s</auth>\n", nsSASL, enc)
|
fmt.Fprintf(c.conn, "<auth xmlns='%s' mechanism='PLAIN'>%s</auth>\n", nsSASL, enc)
|
||||||
break
|
|
||||||
}
|
}
|
||||||
if m == "DIGEST-MD5" {
|
if mechanism == "DIGEST-MD5" {
|
||||||
mechanism = m
|
|
||||||
// Digest-MD5 authentication
|
// Digest-MD5 authentication
|
||||||
fmt.Fprintf(c.conn, "<auth xmlns='%s' mechanism='DIGEST-MD5'/>\n", nsSASL)
|
fmt.Fprintf(StanzaWriter, "<auth xmlns='%s' mechanism='DIGEST-MD5'/>\n", nsSASL)
|
||||||
var ch saslChallenge
|
var ch saslChallenge
|
||||||
if err = c.p.DecodeElement(&ch, nil); err != nil {
|
if err = c.p.DecodeElement(&ch, nil); err != nil {
|
||||||
return errors.New("unmarshal <challenge>: " + err.Error())
|
return errors.New("unmarshal <challenge>: " + err.Error())
|
||||||
@ -439,7 +586,7 @@ func (c *Client) init(o *Options) error {
|
|||||||
message := "username=\"" + user + "\", realm=\"" + realm + "\", nonce=\"" + nonce + "\", cnonce=\"" + cnonceStr +
|
message := "username=\"" + user + "\", realm=\"" + realm + "\", nonce=\"" + nonce + "\", cnonce=\"" + cnonceStr +
|
||||||
"\", nc=" + nonceCount + ", qop=" + qop + ", digest-uri=\"" + digestURI + "\", response=" + digest + ", charset=" + charset
|
"\", nc=" + nonceCount + ", qop=" + qop + ", digest-uri=\"" + digestURI + "\", response=" + digest + ", charset=" + charset
|
||||||
|
|
||||||
fmt.Fprintf(c.conn, "<response xmlns='%s'>%s</response>\n", nsSASL, base64.StdEncoding.EncodeToString([]byte(message)))
|
fmt.Fprintf(StanzaWriter, "<response xmlns='%s'>%s</response>\n", nsSASL, base64.StdEncoding.EncodeToString([]byte(message)))
|
||||||
|
|
||||||
var rspauth saslRspAuth
|
var rspauth saslRspAuth
|
||||||
if err = c.p.DecodeElement(&rspauth, nil); err != nil {
|
if err = c.p.DecodeElement(&rspauth, nil); err != nil {
|
||||||
@ -449,14 +596,12 @@ func (c *Client) init(o *Options) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
fmt.Fprintf(c.conn, "<response xmlns='%s'/>\n", nsSASL)
|
fmt.Fprintf(StanzaWriter, "<response xmlns='%s'/>\n", nsSASL)
|
||||||
break
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if mechanism == "" {
|
if mechanism == "" {
|
||||||
return fmt.Errorf("PLAIN authentication is not an option: %v", f.Mechanisms.Mechanism)
|
return fmt.Errorf("PLAIN authentication is not an option: %v", f.Mechanisms.Mechanism)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
// Next message should be either success or failure.
|
// Next message should be either success or failure.
|
||||||
name, val, err := next(c.p)
|
name, val, err := next(c.p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -464,6 +609,23 @@ func (c *Client) init(o *Options) error {
|
|||||||
}
|
}
|
||||||
switch v := val.(type) {
|
switch v := val.(type) {
|
||||||
case *saslSuccess:
|
case *saslSuccess:
|
||||||
|
if strings.HasPrefix(mechanism, "SCRAM-SHA") {
|
||||||
|
successMsg, err := base64.StdEncoding.DecodeString(v.Text)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(string(successMsg), "v=") {
|
||||||
|
return errors.New("server sent unexpected content in SCRAM success message")
|
||||||
|
}
|
||||||
|
serverSignatureReply := strings.SplitN(string(successMsg), "v=", 2)[1]
|
||||||
|
serverSignatureRemote, err := base64.StdEncoding.DecodeString(serverSignatureReply)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if string(serverSignature) != string(serverSignatureRemote) {
|
||||||
|
return errors.New("SCRAM: server signature mismatch")
|
||||||
|
}
|
||||||
|
}
|
||||||
case *saslFailure:
|
case *saslFailure:
|
||||||
errorMessage := v.Text
|
errorMessage := v.Text
|
||||||
if errorMessage == "" {
|
if errorMessage == "" {
|
||||||
|
Loading…
Reference in New Issue
Block a user