forked from jshiffer/go-xmpp
SCRAM: Add support for tls-server-end-point channel binding. (#177)
This commit is contained in:
parent
34d683d25a
commit
f4c732fdc7
61
xmpp.go
61
xmpp.go
@ -22,6 +22,7 @@ import (
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/xml"
|
||||
@ -379,7 +380,8 @@ func (c *Client) init(o *Options) error {
|
||||
}
|
||||
var mechanism, channelBinding, clientFirstMessage, clientFinalMessageBare, authMessage string
|
||||
var serverSignature, keyingMaterial []byte
|
||||
var scramPlus, ok, tlsConnOK, tls13 bool
|
||||
var scramPlus, ok, tlsConnOK, tls13, serverEndPoint bool
|
||||
var cbsSlice []string
|
||||
var tlsConn *tls.Conn
|
||||
if o.User == "" && o.Password == "" {
|
||||
foundAnonymous := false
|
||||
@ -435,25 +437,59 @@ func (c *Client) init(o *Options) error {
|
||||
scramPlus = true
|
||||
}
|
||||
if scramPlus {
|
||||
for _, cbs := range f.ChannelBindings.ChannelBinding {
|
||||
cbsSlice = append(cbsSlice, cbs.Type)
|
||||
}
|
||||
tlsState := tlsConn.ConnectionState()
|
||||
switch tlsState.Version {
|
||||
case tls.VersionTLS13:
|
||||
tls13 = true
|
||||
keyingMaterial, err = tlsState.ExportKeyingMaterial("EXPORTER-Channel-Binding", nil, 32)
|
||||
if err != nil {
|
||||
return err
|
||||
if slices.Contains(cbsSlice, "tls-server-end-point") && !slices.Contains(cbsSlice, "tls-exporter") {
|
||||
serverEndPoint = true
|
||||
} else {
|
||||
keyingMaterial, err = tlsState.ExportKeyingMaterial("EXPORTER-Channel-Binding", nil, 32)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case tls.VersionTLS10, tls.VersionTLS11, tls.VersionTLS12:
|
||||
keyingMaterial = tlsState.TLSUnique
|
||||
if slices.Contains(cbsSlice, "tls-server-end-point") && !slices.Contains(cbsSlice, "tls-unique") {
|
||||
serverEndPoint = true
|
||||
} else {
|
||||
keyingMaterial = tlsState.TLSUnique
|
||||
}
|
||||
default:
|
||||
return errors.New(mechanism + ": unknown TLS version")
|
||||
}
|
||||
if serverEndPoint {
|
||||
switch tlsState.PeerCertificates[0].SignatureAlgorithm {
|
||||
case x509.SHA1WithRSA, x509.SHA256WithRSA, x509.ECDSAWithSHA1,
|
||||
x509.ECDSAWithSHA256, x509.SHA256WithRSAPSS:
|
||||
h := sha256.New()
|
||||
h.Write(tlsState.PeerCertificates[0].Raw)
|
||||
keyingMaterial = h.Sum(nil)
|
||||
h.Reset()
|
||||
case x509.SHA384WithRSA, x509.ECDSAWithSHA384, x509.SHA384WithRSAPSS:
|
||||
h := sha512.New384()
|
||||
h.Write(tlsState.PeerCertificates[0].Raw)
|
||||
keyingMaterial = h.Sum(nil)
|
||||
h.Reset()
|
||||
case x509.SHA512WithRSA, x509.ECDSAWithSHA512, x509.SHA512WithRSAPSS:
|
||||
h := sha512.New()
|
||||
h.Write(tlsState.PeerCertificates[0].Raw)
|
||||
keyingMaterial = h.Sum(nil)
|
||||
h.Reset()
|
||||
}
|
||||
}
|
||||
if len(keyingMaterial) == 0 {
|
||||
return errors.New(mechanism + ": no keying material")
|
||||
}
|
||||
if tls13 {
|
||||
switch {
|
||||
case tls13 && !serverEndPoint:
|
||||
channelBinding = base64.StdEncoding.EncodeToString(append([]byte("p=tls-exporter,,"), keyingMaterial[:]...))
|
||||
} else {
|
||||
case serverEndPoint:
|
||||
channelBinding = base64.StdEncoding.EncodeToString(append([]byte("p=tls-server-end-point,,"), keyingMaterial[:]...))
|
||||
default:
|
||||
channelBinding = base64.StdEncoding.EncodeToString(append([]byte("p=tls-unique,,"), keyingMaterial[:]...))
|
||||
}
|
||||
}
|
||||
@ -470,9 +506,12 @@ func (c *Client) init(o *Options) error {
|
||||
}
|
||||
clientNonce := cnonce()
|
||||
if scramPlus {
|
||||
if tls13 {
|
||||
switch {
|
||||
case tls13 && !serverEndPoint:
|
||||
clientFirstMessage = "p=tls-exporter,,n=" + user + ",r=" + clientNonce
|
||||
} else {
|
||||
case serverEndPoint:
|
||||
clientFirstMessage = "p=tls-server-end-point,,n=" + user + ",r=" + clientNonce
|
||||
default:
|
||||
clientFirstMessage = "p=tls-unique,,n=" + user + ",r=" + clientNonce
|
||||
}
|
||||
} else {
|
||||
@ -523,10 +562,6 @@ func (c *Client) init(o *Options) error {
|
||||
}
|
||||
}
|
||||
dgProtect = dgProtect + "|"
|
||||
var cbsSlice []string
|
||||
for _, cbs := range f.ChannelBindings.ChannelBinding {
|
||||
cbsSlice = append(cbsSlice, cbs.Type)
|
||||
}
|
||||
slices.Sort(cbsSlice)
|
||||
for i, cb := range cbsSlice {
|
||||
if i == 0 {
|
||||
|
Loading…
Reference in New Issue
Block a user