add scram auth (#147)

* Fix syntax errors.

* gofmt

* Add SCRAM-SHA-1, SCRAM-SHA-256 and SCRAM-SHA-512 auth
This commit is contained in:
Martin 2023-05-21 09:26:59 +02:00 committed by GitHub
parent 9129a110df
commit bef3e549f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

192
xmpp.go
View File

@ -15,22 +15,30 @@ package xmpp
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"crypto/hmac"
"crypto/md5" "crypto/md5"
"crypto/rand" "crypto/rand"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"crypto/tls" "crypto/tls"
"encoding/base64" "encoding/base64"
"encoding/binary" "encoding/binary"
"encoding/xml" "encoding/xml"
"errors" "errors"
"fmt" "fmt"
"hash"
"io" "io"
"math/big" "math/big"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
"strconv"
"strings" "strings"
"time" "time"
"golang.org/x/crypto/pbkdf2"
) )
const ( const (
@ -364,7 +372,8 @@ func (c *Client) init(o *Options) error {
if f, err = c.startTLSIfRequired(f, o, domain); err != nil { if f, err = c.startTLSIfRequired(f, o, domain); err != nil {
return err return err
} }
var mechanism string
var serverSignature []byte
if o.User == "" && o.Password == "" { if o.User == "" && o.Password == "" {
foundAnonymous := false foundAnonymous := false
for _, m := range f.Mechanisms.Mechanism { for _, m := range f.Mechanisms.Mechanism {
@ -385,31 +394,169 @@ func (c *Client) init(o *Options) error {
return errors.New("refusing to authenticate over unencrypted TCP connection") return errors.New("refusing to authenticate over unencrypted TCP connection")
} }
mechanism := "" mechanism = ""
for _, m := range f.Mechanisms.Mechanism { for _, m := range f.Mechanisms.Mechanism {
if m == "X-OAUTH2" && o.OAuthToken != "" && o.OAuthScope != "" { switch m {
case "SCRAM-SHA-512":
mechanism = m mechanism = m
case "SCRAM-SHA-256":
if mechanism != "SCRAM-SHA-512" {
mechanism = m
}
case "SCRAM-SHA-1":
if mechanism != "SCRAM-SHA-512" &&
mechanism != "SCRAM-SHA-256" {
mechanism = m
}
case "X-OAUTH2":
if mechanism == "" {
mechanism = m
}
case "PLAIN":
if mechanism == "" {
mechanism = m
}
case "DIGEST-MD5":
if mechanism == "" {
mechanism = m
}
}
}
if strings.HasPrefix(mechanism, "SCRAM-SHA") {
var shaNewFn func() hash.Hash
switch mechanism {
case "SCRAM-SHA-512":
shaNewFn = sha512.New
case "SCRAM-SHA-256":
shaNewFn = sha256.New
case "SCRAM-SHA-1":
shaNewFn = sha1.New
default:
return errors.New("unsupported auth mechanism")
}
clientNonce := cnonce()
clientFirstMessage := "n=" + user + ",r=" + clientNonce
fmt.Fprintf(StanzaWriter, "<auth xmlns='%s' mechanism='%s'>%s</auth>",
nsSASL, mechanism, base64.StdEncoding.EncodeToString([]byte("n,,"+
clientFirstMessage)))
var sfm string
if err = c.p.DecodeElement(&sfm, nil); err != nil {
return errors.New("unmarshal <challenge>: " + err.Error())
}
b, err := base64.StdEncoding.DecodeString(string(sfm))
if err != nil {
return err
}
var serverNonce string
var salt []byte
var iterations int
for _, serverReply := range strings.Split(string(b), ",") {
switch {
case strings.HasPrefix(serverReply, "r="):
serverNonce = strings.SplitN(serverReply, "=", 2)[1]
if !strings.HasPrefix(serverNonce, clientNonce) {
return errors.New("SCRAM: server nonce didn't start with client nonce")
}
case strings.HasPrefix(serverReply, "s="):
salt, err = base64.StdEncoding.DecodeString(strings.SplitN(serverReply, "=", 2)[1])
if string(salt) == "" {
return errors.New("SCRAM: server sent empty salt")
}
case strings.HasPrefix(serverReply, "i="):
iterations, err = strconv.Atoi(strings.SplitN(serverReply,
"=", 2)[1])
if err != nil {
return err
}
default:
return errors.New("unexpected conted in SCRAM challenge")
}
}
clientFinalMessageBare := "c=biws,r=" + serverNonce
saltedPassword := pbkdf2.Key([]byte(o.Password), salt,
iterations, shaNewFn().Size(), shaNewFn)
h := hmac.New(shaNewFn, saltedPassword)
_, err = h.Write([]byte("Client Key"))
if err != nil {
return err
}
clientKey := h.Sum(nil)
h.Reset()
var storedKey []byte
switch mechanism {
case "SCRAM-SHA-512":
storedKey512 := sha512.Sum512(clientKey)
storedKey = storedKey512[:]
case "SCRAM-SHA-256":
storedKey256 := sha256.Sum256(clientKey)
storedKey = storedKey256[:]
case "SCRAM-SHA-1":
storedKey1 := sha1.Sum(clientKey)
storedKey = storedKey1[:]
}
_, err = h.Write([]byte("Server Key"))
if err != nil {
return err
}
serverFirstMessage, err := base64.StdEncoding.DecodeString(sfm)
if err != nil {
return err
}
authMessage := clientFirstMessage + "," + string(serverFirstMessage) +
"," + clientFinalMessageBare
h = hmac.New(shaNewFn, storedKey[:])
_, err = h.Write([]byte(authMessage))
if err != nil {
return err
}
clientSignature := h.Sum(nil)
h.Reset()
if len(clientKey) != len(clientSignature) {
return errors.New("SCRAM: client key and signature length mismatch")
}
clientProof := make([]byte, len(clientKey))
for i := range clientKey {
clientProof[i] = clientKey[i] ^ clientSignature[i]
}
h = hmac.New(shaNewFn, saltedPassword)
_, err = h.Write([]byte("Server Key"))
if err != nil {
return err
}
serverKey := h.Sum(nil)
h.Reset()
h = hmac.New(shaNewFn, serverKey)
_, err = h.Write([]byte(authMessage))
if err != nil {
return err
}
serverSignature = h.Sum(nil)
if string(serverSignature) == "" {
return errors.New("SCRAM: calculated an empty server signature")
}
clientFinalMessage := base64.StdEncoding.EncodeToString([]byte(clientFinalMessageBare +
",p=" + base64.StdEncoding.EncodeToString(clientProof)))
fmt.Fprintf(StanzaWriter, "<response xmlns='%s'>%s</response>", nsSASL,
clientFinalMessage)
}
if mechanism == "X-OAUTH2" && o.OAuthToken != "" && o.OAuthScope != "" {
// Oauth authentication: send base64-encoded \x00 user \x00 token. // Oauth authentication: send base64-encoded \x00 user \x00 token.
raw := "\x00" + user + "\x00" + o.OAuthToken raw := "\x00" + user + "\x00" + o.OAuthToken
enc := make([]byte, base64.StdEncoding.EncodedLen(len(raw))) enc := make([]byte, base64.StdEncoding.EncodedLen(len(raw)))
base64.StdEncoding.Encode(enc, []byte(raw)) base64.StdEncoding.Encode(enc, []byte(raw))
fmt.Fprintf(c.conn, "<auth xmlns='%s' mechanism='X-OAUTH2' auth:service='oauth2' "+ fmt.Fprintf(StanzaWriter, "<auth xmlns='%s' mechanism='X-OAUTH2' auth:service='oauth2' "+
"xmlns:auth='%s'>%s</auth>\n", nsSASL, o.OAuthXmlNs, enc) "xmlns:auth='%s'>%s</auth>\n", nsSASL, o.OAuthXmlNs, enc)
break
} }
if m == "PLAIN" { if mechanism == "PLAIN" {
mechanism = m
// Plain authentication: send base64-encoded \x00 user \x00 password. // Plain authentication: send base64-encoded \x00 user \x00 password.
raw := "\x00" + user + "\x00" + o.Password raw := "\x00" + user + "\x00" + o.Password
enc := make([]byte, base64.StdEncoding.EncodedLen(len(raw))) enc := make([]byte, base64.StdEncoding.EncodedLen(len(raw)))
base64.StdEncoding.Encode(enc, []byte(raw)) base64.StdEncoding.Encode(enc, []byte(raw))
fmt.Fprintf(c.conn, "<auth xmlns='%s' mechanism='PLAIN'>%s</auth>\n", nsSASL, enc) fmt.Fprintf(c.conn, "<auth xmlns='%s' mechanism='PLAIN'>%s</auth>\n", nsSASL, enc)
break
} }
if m == "DIGEST-MD5" { if mechanism == "DIGEST-MD5" {
mechanism = m
// Digest-MD5 authentication // Digest-MD5 authentication
fmt.Fprintf(c.conn, "<auth xmlns='%s' mechanism='DIGEST-MD5'/>\n", nsSASL) fmt.Fprintf(StanzaWriter, "<auth xmlns='%s' mechanism='DIGEST-MD5'/>\n", nsSASL)
var ch saslChallenge var ch saslChallenge
if err = c.p.DecodeElement(&ch, nil); err != nil { if err = c.p.DecodeElement(&ch, nil); err != nil {
return errors.New("unmarshal <challenge>: " + err.Error()) return errors.New("unmarshal <challenge>: " + err.Error())
@ -439,7 +586,7 @@ func (c *Client) init(o *Options) error {
message := "username=\"" + user + "\", realm=\"" + realm + "\", nonce=\"" + nonce + "\", cnonce=\"" + cnonceStr + message := "username=\"" + user + "\", realm=\"" + realm + "\", nonce=\"" + nonce + "\", cnonce=\"" + cnonceStr +
"\", nc=" + nonceCount + ", qop=" + qop + ", digest-uri=\"" + digestURI + "\", response=" + digest + ", charset=" + charset "\", nc=" + nonceCount + ", qop=" + qop + ", digest-uri=\"" + digestURI + "\", response=" + digest + ", charset=" + charset
fmt.Fprintf(c.conn, "<response xmlns='%s'>%s</response>\n", nsSASL, base64.StdEncoding.EncodeToString([]byte(message))) fmt.Fprintf(StanzaWriter, "<response xmlns='%s'>%s</response>\n", nsSASL, base64.StdEncoding.EncodeToString([]byte(message)))
var rspauth saslRspAuth var rspauth saslRspAuth
if err = c.p.DecodeElement(&rspauth, nil); err != nil { if err = c.p.DecodeElement(&rspauth, nil); err != nil {
@ -449,14 +596,12 @@ func (c *Client) init(o *Options) error {
if err != nil { if err != nil {
return err return err
} }
fmt.Fprintf(c.conn, "<response xmlns='%s'/>\n", nsSASL) fmt.Fprintf(StanzaWriter, "<response xmlns='%s'/>\n", nsSASL)
break
} }
} }
if mechanism == "" { if mechanism == "" {
return fmt.Errorf("PLAIN authentication is not an option: %v", f.Mechanisms.Mechanism) return fmt.Errorf("PLAIN authentication is not an option: %v", f.Mechanisms.Mechanism)
} }
}
// Next message should be either success or failure. // Next message should be either success or failure.
name, val, err := next(c.p) name, val, err := next(c.p)
if err != nil { if err != nil {
@ -464,6 +609,23 @@ func (c *Client) init(o *Options) error {
} }
switch v := val.(type) { switch v := val.(type) {
case *saslSuccess: case *saslSuccess:
if strings.HasPrefix(mechanism, "SCRAM-SHA") {
successMsg, err := base64.StdEncoding.DecodeString(v.Text)
if err != nil {
return err
}
if !strings.HasPrefix(string(successMsg), "v=") {
return errors.New("server sent unexpected content in SCRAM success message")
}
serverSignatureReply := strings.SplitN(string(successMsg), "v=", 2)[1]
serverSignatureRemote, err := base64.StdEncoding.DecodeString(serverSignatureReply)
if err != nil {
return err
}
if string(serverSignature) != string(serverSignatureRemote) {
return errors.New("SCRAM: server signature mismatch")
}
}
case *saslFailure: case *saslFailure:
errorMessage := v.Text errorMessage := v.Text
if errorMessage == "" { if errorMessage == "" {