diff --git a/xmpp.go b/xmpp.go index 117262e..9af89a2 100644 --- a/xmpp.go +++ b/xmpp.go @@ -15,22 +15,30 @@ package xmpp import ( "bufio" "bytes" + "crypto/hmac" "crypto/md5" "crypto/rand" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" "crypto/tls" "encoding/base64" "encoding/binary" "encoding/xml" "errors" "fmt" + "hash" "io" "math/big" "net" "net/http" "net/url" "os" + "strconv" "strings" "time" + + "golang.org/x/crypto/pbkdf2" ) const ( @@ -364,7 +372,8 @@ func (c *Client) init(o *Options) error { if f, err = c.startTLSIfRequired(f, o, domain); err != nil { return err } - + var mechanism string + var serverSignature []byte if o.User == "" && o.Password == "" { foundAnonymous := false for _, m := range f.Mechanisms.Mechanism { @@ -385,77 +394,213 @@ func (c *Client) init(o *Options) error { return errors.New("refusing to authenticate over unencrypted TCP connection") } - mechanism := "" + mechanism = "" for _, m := range f.Mechanisms.Mechanism { - if m == "X-OAUTH2" && o.OAuthToken != "" && o.OAuthScope != "" { + switch m { + case "SCRAM-SHA-512": mechanism = m - // Oauth authentication: send base64-encoded \x00 user \x00 token. - raw := "\x00" + user + "\x00" + o.OAuthToken - enc := make([]byte, base64.StdEncoding.EncodedLen(len(raw))) - base64.StdEncoding.Encode(enc, []byte(raw)) - fmt.Fprintf(c.conn, "%s\n", nsSASL, o.OAuthXmlNs, enc) - break - } - if m == "PLAIN" { - mechanism = m - // Plain authentication: send base64-encoded \x00 user \x00 password. - raw := "\x00" + user + "\x00" + o.Password - enc := make([]byte, base64.StdEncoding.EncodedLen(len(raw))) - base64.StdEncoding.Encode(enc, []byte(raw)) - fmt.Fprintf(c.conn, "%s\n", nsSASL, enc) - break - } - if m == "DIGEST-MD5" { - mechanism = m - // Digest-MD5 authentication - fmt.Fprintf(c.conn, "\n", nsSASL) - var ch saslChallenge - if err = c.p.DecodeElement(&ch, nil); err != nil { - return errors.New("unmarshal : " + err.Error()) + case "SCRAM-SHA-256": + if mechanism != "SCRAM-SHA-512" { + mechanism = m } - b, err := base64.StdEncoding.DecodeString(string(ch)) - if err != nil { - return err + case "SCRAM-SHA-1": + if mechanism != "SCRAM-SHA-512" && + mechanism != "SCRAM-SHA-256" { + mechanism = m } - tokens := map[string]string{} - for _, token := range strings.Split(string(b), ",") { - kv := strings.SplitN(strings.TrimSpace(token), "=", 2) - if len(kv) == 2 { - if kv[1][0] == '"' && kv[1][len(kv[1])-1] == '"' { - kv[1] = kv[1][1 : len(kv[1])-1] - } - tokens[kv[0]] = kv[1] + case "X-OAUTH2": + if mechanism == "" { + mechanism = m + } + case "PLAIN": + if mechanism == "" { + mechanism = m + } + case "DIGEST-MD5": + if mechanism == "" { + mechanism = m + } + } + } + if strings.HasPrefix(mechanism, "SCRAM-SHA") { + var shaNewFn func() hash.Hash + switch mechanism { + case "SCRAM-SHA-512": + shaNewFn = sha512.New + case "SCRAM-SHA-256": + shaNewFn = sha256.New + case "SCRAM-SHA-1": + shaNewFn = sha1.New + default: + return errors.New("unsupported auth mechanism") + } + clientNonce := cnonce() + clientFirstMessage := "n=" + user + ",r=" + clientNonce + fmt.Fprintf(StanzaWriter, "%s", + nsSASL, mechanism, base64.StdEncoding.EncodeToString([]byte("n,,"+ + clientFirstMessage))) + var sfm string + if err = c.p.DecodeElement(&sfm, nil); err != nil { + return errors.New("unmarshal : " + err.Error()) + } + b, err := base64.StdEncoding.DecodeString(string(sfm)) + if err != nil { + return err + } + var serverNonce string + var salt []byte + var iterations int + for _, serverReply := range strings.Split(string(b), ",") { + switch { + case strings.HasPrefix(serverReply, "r="): + serverNonce = strings.SplitN(serverReply, "=", 2)[1] + if !strings.HasPrefix(serverNonce, clientNonce) { + return errors.New("SCRAM: server nonce didn't start with client nonce") } + case strings.HasPrefix(serverReply, "s="): + salt, err = base64.StdEncoding.DecodeString(strings.SplitN(serverReply, "=", 2)[1]) + if string(salt) == "" { + return errors.New("SCRAM: server sent empty salt") + } + case strings.HasPrefix(serverReply, "i="): + iterations, err = strconv.Atoi(strings.SplitN(serverReply, + "=", 2)[1]) + if err != nil { + return err + } + default: + return errors.New("unexpected conted in SCRAM challenge") } - realm, _ := tokens["realm"] - nonce, _ := tokens["nonce"] - qop, _ := tokens["qop"] - charset, _ := tokens["charset"] - cnonceStr := cnonce() - digestURI := "xmpp/" + domain - nonceCount := fmt.Sprintf("%08x", 1) - digest := saslDigestResponse(user, realm, o.Password, nonce, cnonceStr, "AUTHENTICATE", digestURI, nonceCount) - message := "username=\"" + user + "\", realm=\"" + realm + "\", nonce=\"" + nonce + "\", cnonce=\"" + cnonceStr + - "\", nc=" + nonceCount + ", qop=" + qop + ", digest-uri=\"" + digestURI + "\", response=" + digest + ", charset=" + charset - - fmt.Fprintf(c.conn, "%s\n", nsSASL, base64.StdEncoding.EncodeToString([]byte(message))) - - var rspauth saslRspAuth - if err = c.p.DecodeElement(&rspauth, nil); err != nil { - return errors.New("unmarshal : " + err.Error()) - } - b, err = base64.StdEncoding.DecodeString(string(rspauth)) - if err != nil { - return err - } - fmt.Fprintf(c.conn, "\n", nsSASL) - break } + clientFinalMessageBare := "c=biws,r=" + serverNonce + saltedPassword := pbkdf2.Key([]byte(o.Password), salt, + iterations, shaNewFn().Size(), shaNewFn) + h := hmac.New(shaNewFn, saltedPassword) + _, err = h.Write([]byte("Client Key")) + if err != nil { + return err + } + clientKey := h.Sum(nil) + h.Reset() + var storedKey []byte + switch mechanism { + case "SCRAM-SHA-512": + storedKey512 := sha512.Sum512(clientKey) + storedKey = storedKey512[:] + case "SCRAM-SHA-256": + storedKey256 := sha256.Sum256(clientKey) + storedKey = storedKey256[:] + case "SCRAM-SHA-1": + storedKey1 := sha1.Sum(clientKey) + storedKey = storedKey1[:] + } + _, err = h.Write([]byte("Server Key")) + if err != nil { + return err + } + serverFirstMessage, err := base64.StdEncoding.DecodeString(sfm) + if err != nil { + return err + } + authMessage := clientFirstMessage + "," + string(serverFirstMessage) + + "," + clientFinalMessageBare + h = hmac.New(shaNewFn, storedKey[:]) + _, err = h.Write([]byte(authMessage)) + if err != nil { + return err + } + clientSignature := h.Sum(nil) + h.Reset() + if len(clientKey) != len(clientSignature) { + return errors.New("SCRAM: client key and signature length mismatch") + } + clientProof := make([]byte, len(clientKey)) + for i := range clientKey { + clientProof[i] = clientKey[i] ^ clientSignature[i] + } + h = hmac.New(shaNewFn, saltedPassword) + _, err = h.Write([]byte("Server Key")) + if err != nil { + return err + } + serverKey := h.Sum(nil) + h.Reset() + h = hmac.New(shaNewFn, serverKey) + _, err = h.Write([]byte(authMessage)) + if err != nil { + return err + } + serverSignature = h.Sum(nil) + if string(serverSignature) == "" { + return errors.New("SCRAM: calculated an empty server signature") + } + clientFinalMessage := base64.StdEncoding.EncodeToString([]byte(clientFinalMessageBare + + ",p=" + base64.StdEncoding.EncodeToString(clientProof))) + fmt.Fprintf(StanzaWriter, "%s", nsSASL, + clientFinalMessage) } - if mechanism == "" { - return fmt.Errorf("PLAIN authentication is not an option: %v", f.Mechanisms.Mechanism) + if mechanism == "X-OAUTH2" && o.OAuthToken != "" && o.OAuthScope != "" { + // Oauth authentication: send base64-encoded \x00 user \x00 token. + raw := "\x00" + user + "\x00" + o.OAuthToken + enc := make([]byte, base64.StdEncoding.EncodedLen(len(raw))) + base64.StdEncoding.Encode(enc, []byte(raw)) + fmt.Fprintf(StanzaWriter, "%s\n", nsSASL, o.OAuthXmlNs, enc) } + if mechanism == "PLAIN" { + // Plain authentication: send base64-encoded \x00 user \x00 password. + raw := "\x00" + user + "\x00" + o.Password + enc := make([]byte, base64.StdEncoding.EncodedLen(len(raw))) + base64.StdEncoding.Encode(enc, []byte(raw)) + fmt.Fprintf(c.conn, "%s\n", nsSASL, enc) + } + if mechanism == "DIGEST-MD5" { + // Digest-MD5 authentication + fmt.Fprintf(StanzaWriter, "\n", nsSASL) + var ch saslChallenge + if err = c.p.DecodeElement(&ch, nil); err != nil { + return errors.New("unmarshal : " + err.Error()) + } + b, err := base64.StdEncoding.DecodeString(string(ch)) + if err != nil { + return err + } + tokens := map[string]string{} + for _, token := range strings.Split(string(b), ",") { + kv := strings.SplitN(strings.TrimSpace(token), "=", 2) + if len(kv) == 2 { + if kv[1][0] == '"' && kv[1][len(kv[1])-1] == '"' { + kv[1] = kv[1][1 : len(kv[1])-1] + } + tokens[kv[0]] = kv[1] + } + } + realm, _ := tokens["realm"] + nonce, _ := tokens["nonce"] + qop, _ := tokens["qop"] + charset, _ := tokens["charset"] + cnonceStr := cnonce() + digestURI := "xmpp/" + domain + nonceCount := fmt.Sprintf("%08x", 1) + digest := saslDigestResponse(user, realm, o.Password, nonce, cnonceStr, "AUTHENTICATE", digestURI, nonceCount) + message := "username=\"" + user + "\", realm=\"" + realm + "\", nonce=\"" + nonce + "\", cnonce=\"" + cnonceStr + + "\", nc=" + nonceCount + ", qop=" + qop + ", digest-uri=\"" + digestURI + "\", response=" + digest + ", charset=" + charset + + fmt.Fprintf(StanzaWriter, "%s\n", nsSASL, base64.StdEncoding.EncodeToString([]byte(message))) + + var rspauth saslRspAuth + if err = c.p.DecodeElement(&rspauth, nil); err != nil { + return errors.New("unmarshal : " + err.Error()) + } + b, err = base64.StdEncoding.DecodeString(string(rspauth)) + if err != nil { + return err + } + fmt.Fprintf(StanzaWriter, "\n", nsSASL) + } + } + if mechanism == "" { + return fmt.Errorf("PLAIN authentication is not an option: %v", f.Mechanisms.Mechanism) } // Next message should be either success or failure. name, val, err := next(c.p) @@ -464,6 +609,23 @@ func (c *Client) init(o *Options) error { } switch v := val.(type) { case *saslSuccess: + if strings.HasPrefix(mechanism, "SCRAM-SHA") { + successMsg, err := base64.StdEncoding.DecodeString(v.Text) + if err != nil { + return err + } + if !strings.HasPrefix(string(successMsg), "v=") { + return errors.New("server sent unexpected content in SCRAM success message") + } + serverSignatureReply := strings.SplitN(string(successMsg), "v=", 2)[1] + serverSignatureRemote, err := base64.StdEncoding.DecodeString(serverSignatureReply) + if err != nil { + return err + } + if string(serverSignature) != string(serverSignatureRemote) { + return errors.New("SCRAM: server signature mismatch") + } + } case *saslFailure: errorMessage := v.Text if errorMessage == "" {