diff --git a/go.mod b/go.mod
new file mode 100644
index 0000000..4adb9c6
--- /dev/null
+++ b/go.mod
@@ -0,0 +1,5 @@
+module github.com/mattn/go-xmpp
+
+go 1.20
+
+require golang.org/x/crypto v0.15.0
diff --git a/go.sum b/go.sum
new file mode 100644
index 0000000..fcac726
--- /dev/null
+++ b/go.sum
@@ -0,0 +1,2 @@
+golang.org/x/crypto v0.15.0 h1:frVn1TEaCEaZcn3Tmd7Y2b5KKPaZ+I32Q2OA3kYp5TA=
+golang.org/x/crypto v0.15.0/go.mod h1:4ChreQoLWfG3xLDer1WdlH5NdlQ3+mwnQq1YTKY+72g=
diff --git a/xmpp.go b/xmpp.go
index 1f9109d..051c11a 100644
--- a/xmpp.go
+++ b/xmpp.go
@@ -74,6 +74,7 @@ type Client struct {
domain string
p *xml.Decoder
stanzaWriter io.Writer
+ Mechanism string
}
func (c *Client) JID() string {
@@ -207,6 +208,9 @@ type Options struct {
// Status message
StatusMessage string
+
+ // Auth mechanism to use
+ Mechanism string
}
// NewClient establishes a new Client connection based on a set of Options.
@@ -344,7 +348,6 @@ func cnonce() string {
}
func (c *Client) init(o *Options) error {
-
var domain string
var user string
a := strings.SplitN(o.User, "@", 2)
@@ -372,8 +375,10 @@ func (c *Client) init(o *Options) error {
if f, err = c.startTLSIfRequired(f, o, domain); err != nil {
return err
}
- var mechanism string
- var serverSignature []byte
+ var mechanism, channelBinding, clientFirstMessage, clientFinalMessageBare, authMessage string
+ var serverSignature, keyingMaterial []byte
+ var scramPlus, ok, tlsConnOK, tls13 bool
+ var tlsConn *tls.Conn
if o.User == "" && o.Password == "" {
foundAnonymous := false
for _, m := range f.Mechanisms.Mechanism {
@@ -394,51 +399,120 @@ func (c *Client) init(o *Options) error {
return errors.New("refusing to authenticate over unencrypted TCP connection")
}
+ tlsConn, ok = c.conn.(*tls.Conn)
+ if ok {
+ tlsConnOK = true
+ }
mechanism = ""
- for _, m := range f.Mechanisms.Mechanism {
- switch m {
- case "SCRAM-SHA-512":
- mechanism = m
- case "SCRAM-SHA-256":
- if mechanism != "SCRAM-SHA-512" {
- mechanism = m
+ if o.Mechanism == "" {
+ for _, m := range f.Mechanisms.Mechanism {
+ switch m {
+ case "SCRAM-SHA-512-PLUS":
+ if tlsConnOK {
+ mechanism = m
+ }
+ case "SCRAM-SHA-256-PLUS":
+ if mechanism != "SCRAM-SHA-512-PLUS" && tlsConnOK {
+ mechanism = m
+ }
+ case "SCRAM-SHA-1-PLUS":
+ if mechanism != "SCRAM-SHA-512-PLUS" &&
+ mechanism != "SCRAM-SHA-256-PLUS" &&
+ tlsConnOK {
+ mechanism = m
+ }
+ case "SCRAM-SHA-512":
+ if mechanism != "SCRAM-SHA-512-PLUS" &&
+ mechanism != "SCRAM-SHA-256-PLUS" &&
+ mechanism != "SCRAM-SHA-1-PLUS" {
+ mechanism = m
+ }
+ case "SCRAM-SHA-256":
+ if mechanism != "SCRAM-SHA-512-PLUS" &&
+ mechanism != "SCRAM-SHA-256-PLUS" &&
+ mechanism != "SCRAM-SHA-1-PLUS" &&
+ mechanism != "SCRAM-SHA-512" {
+ mechanism = m
+ }
+ case "SCRAM-SHA-1":
+ if mechanism != "SCRAM-SHA-512-PLUS" &&
+ mechanism != "SCRAM-SHA-256-PLUS" &&
+ mechanism != "SCRAM-SHA-1-PLUS" &&
+ mechanism != "SCRAM-SHA-512" &&
+ mechanism != "SCRAM-SHA-256" {
+ mechanism = m
+ }
+ case "X-OAUTH2":
+ if mechanism == "" {
+ mechanism = m
+ }
+ case "PLAIN":
+ if mechanism == "" {
+ mechanism = m
+ }
+ case "DIGEST-MD5":
+ if mechanism == "" {
+ mechanism = m
+ }
}
- case "SCRAM-SHA-1":
- if mechanism != "SCRAM-SHA-512" &&
- mechanism != "SCRAM-SHA-256" {
- mechanism = m
- }
- case "X-OAUTH2":
- if mechanism == "" {
- mechanism = m
- }
- case "PLAIN":
- if mechanism == "" {
- mechanism = m
- }
- case "DIGEST-MD5":
- if mechanism == "" {
- mechanism = m
+ }
+ } else {
+ for _, m := range f.Mechanisms.Mechanism {
+ if m == o.Mechanism {
+ mechanism = o.Mechanism
}
}
}
if strings.HasPrefix(mechanism, "SCRAM-SHA") {
+ if strings.HasSuffix(mechanism, "PLUS") {
+ scramPlus = true
+ }
+ if scramPlus {
+ tlsState := tlsConn.ConnectionState()
+ switch tlsState.Version {
+ case tls.VersionTLS13:
+ tls13 = true
+ keyingMaterial, err = tlsState.ExportKeyingMaterial("EXPORTER-Channel-Binding", nil, 32)
+ if err != nil {
+ return err
+ }
+ case tls.VersionTLS10, tls.VersionTLS11, tls.VersionTLS12:
+ keyingMaterial = tlsState.TLSUnique
+ default:
+ return errors.New(mechanism + ": unknown TLS version")
+ }
+ if len(keyingMaterial) == 0 {
+ return errors.New(mechanism + ": no keying material")
+ }
+ if tls13 {
+ channelBinding = base64.StdEncoding.EncodeToString(append([]byte("p=tls-exporter,,"), keyingMaterial[:]...))
+ } else {
+ channelBinding = base64.StdEncoding.EncodeToString(append([]byte("p=tls-unique,,"), keyingMaterial[:]...))
+ }
+ }
var shaNewFn func() hash.Hash
switch mechanism {
- case "SCRAM-SHA-512":
+ case "SCRAM-SHA-512", "SCRAM-SHA-512-PLUS":
shaNewFn = sha512.New
- case "SCRAM-SHA-256":
+ case "SCRAM-SHA-256", "SCRAM-SHA-256-PLUS":
shaNewFn = sha256.New
- case "SCRAM-SHA-1":
+ case "SCRAM-SHA-1", "SCRAM-SHA-1-PLUS":
shaNewFn = sha1.New
default:
return errors.New("unsupported auth mechanism")
}
clientNonce := cnonce()
- clientFirstMessage := "n=" + user + ",r=" + clientNonce
+ if scramPlus {
+ if tls13 {
+ clientFirstMessage = "p=tls-exporter,,n=" + user + ",r=" + clientNonce
+ } else {
+ clientFirstMessage = "p=tls-unique,,n=" + user + ",r=" + clientNonce
+ }
+ } else {
+ clientFirstMessage = "n,,n=" + user + ",r=" + clientNonce
+ }
fmt.Fprintf(c.stanzaWriter, "%s\n",
- nsSASL, mechanism, base64.StdEncoding.EncodeToString([]byte("n,,"+
- clientFirstMessage)))
+ nsSASL, mechanism, base64.StdEncoding.EncodeToString([]byte(clientFirstMessage)))
var sfm string
if err = c.p.DecodeElement(&sfm, nil); err != nil {
return errors.New("unmarshal : " + err.Error())
@@ -469,10 +543,14 @@ func (c *Client) init(o *Options) error {
return err
}
default:
- return errors.New("unexpected conted in SCRAM challenge")
+ return errors.New("unexpected content in SCRAM challenge")
}
}
- clientFinalMessageBare := "c=biws,r=" + serverNonce
+ if scramPlus {
+ clientFinalMessageBare = "c=" + channelBinding + ",r=" + serverNonce
+ } else {
+ clientFinalMessageBare = "c=biws,r=" + serverNonce
+ }
saltedPassword := pbkdf2.Key([]byte(o.Password), salt,
iterations, shaNewFn().Size(), shaNewFn)
h := hmac.New(shaNewFn, saltedPassword)
@@ -484,13 +562,13 @@ func (c *Client) init(o *Options) error {
h.Reset()
var storedKey []byte
switch mechanism {
- case "SCRAM-SHA-512":
+ case "SCRAM-SHA-512", "SCRAM-SHA-512-PLUS":
storedKey512 := sha512.Sum512(clientKey)
storedKey = storedKey512[:]
- case "SCRAM-SHA-256":
+ case "SCRAM-SHA-256", "SCRAM-SH-256-PLUS":
storedKey256 := sha256.Sum256(clientKey)
storedKey = storedKey256[:]
- case "SCRAM-SHA-1":
+ case "SCRAM-SHA-1", "SCRAM-SHA-1-PLUS":
storedKey1 := sha1.Sum(clientKey)
storedKey = storedKey1[:]
}
@@ -502,8 +580,8 @@ func (c *Client) init(o *Options) error {
if err != nil {
return err
}
- authMessage := clientFirstMessage + "," + string(serverFirstMessage) +
- "," + clientFinalMessageBare
+ authMessage = strings.SplitAfter(clientFirstMessage, ",,")[1] + "," +
+ string(serverFirstMessage) + "," + clientFinalMessageBare
h = hmac.New(shaNewFn, storedKey[:])
_, err = h.Write([]byte(authMessage))
if err != nil {
@@ -625,6 +703,7 @@ func (c *Client) init(o *Options) error {
if string(serverSignature) != string(serverSignatureRemote) {
return errors.New("SCRAM: server signature mismatch")
}
+ c.Mechanism = mechanism
}
case *saslFailure:
errorMessage := v.Text
@@ -664,7 +743,7 @@ func (c *Client) init(o *Options) error {
c.domain = domain
if o.Session {
- //if server support session, open it
+ // if server support session, open it
fmt.Fprintf(c.stanzaWriter, "\n", xmlEscape(domain), cookie, nsSession)
}
@@ -700,7 +779,7 @@ func (c *Client) startTLSIfRequired(f *streamFeatures, o *Options, domain string
tc := o.TLSConfig
if tc == nil {
tc = DefaultConfig.Clone()
- //TODO(scott): we should consider using the server's address or reverse lookup
+ // TODO(scott): we should consider using the server's address or reverse lookup
tc.ServerName = domain
}
t := tls.Client(c.conn, tc)
@@ -892,8 +971,10 @@ func (c *Client) Recv() (stanza interface{}, err error) {
return Chat{}, err
}
- return IQ{ID: v.ID, From: v.From, To: v.To, Type: v.Type,
- Query: res}, nil
+ return IQ{
+ ID: v.ID, From: v.From, To: v.To, Type: v.Type,
+ Query: res,
+ }, nil
}
case v.Type == "result":
switch v.ID {
@@ -1021,8 +1102,10 @@ func (c *Client) Recv() (stanza interface{}, err error) {
return Chat{}, err
}
- return IQ{ID: v.ID, From: v.From, To: v.To, Type: v.Type,
- Query: res}, nil
+ return IQ{
+ ID: v.ID, From: v.From, To: v.To, Type: v.Type,
+ Query: res,
+ }, nil
}
case v.Query.XMLName.Local == "":
return IQ{ID: v.ID, From: v.From, To: v.To, Type: v.Type}, nil
@@ -1032,8 +1115,10 @@ func (c *Client) Recv() (stanza interface{}, err error) {
return Chat{}, err
}
- return IQ{ID: v.ID, From: v.From, To: v.To, Type: v.Type,
- Query: res}, nil
+ return IQ{
+ ID: v.ID, From: v.From, To: v.To, Type: v.Type,
+ Query: res,
+ }, nil
}
}
}
@@ -1109,11 +1194,12 @@ func (c *Client) Roster() error {
// RFC 3920 C.1 Streams name space
type streamFeatures struct {
- XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"`
- StartTLS *tlsStartTLS
- Mechanisms saslMechanisms
- Bind bindBind
- Session bool
+ XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"`
+ StartTLS *tlsStartTLS
+ Mechanisms saslMechanisms
+ ChannelBinding saslChannelBinding
+ Bind bindBind
+ Session bool
}
type streamError struct {
@@ -1147,6 +1233,16 @@ type saslAuth struct {
Mechanism string `xml:",attr"`
}
+type saslChannelBinding struct {
+ XMLName xml.Name `xml:"sasl-channel-binding"`
+ Text string `xml:",chardata"`
+ Xmlns string `xml:"xmlns,attr"`
+ ChannelBinding []struct {
+ Text string `xml:",chardata"`
+ Type string `xml:"type,attr"`
+ } `xml:"channel-binding"`
+}
+
type saslChallenge string
type saslRspAuth string