SCRAM: Add support for tls-server-end-point channel binding. (#177)

This commit is contained in:
Martin 2024-01-11 13:33:32 +01:00 committed by GitHub
parent 34d683d25a
commit f4c732fdc7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

53
xmpp.go
View File

@ -22,6 +22,7 @@ import (
"crypto/sha256" "crypto/sha256"
"crypto/sha512" "crypto/sha512"
"crypto/tls" "crypto/tls"
"crypto/x509"
"encoding/base64" "encoding/base64"
"encoding/binary" "encoding/binary"
"encoding/xml" "encoding/xml"
@ -379,7 +380,8 @@ func (c *Client) init(o *Options) error {
} }
var mechanism, channelBinding, clientFirstMessage, clientFinalMessageBare, authMessage string var mechanism, channelBinding, clientFirstMessage, clientFinalMessageBare, authMessage string
var serverSignature, keyingMaterial []byte var serverSignature, keyingMaterial []byte
var scramPlus, ok, tlsConnOK, tls13 bool var scramPlus, ok, tlsConnOK, tls13, serverEndPoint bool
var cbsSlice []string
var tlsConn *tls.Conn var tlsConn *tls.Conn
if o.User == "" && o.Password == "" { if o.User == "" && o.Password == "" {
foundAnonymous := false foundAnonymous := false
@ -435,25 +437,59 @@ func (c *Client) init(o *Options) error {
scramPlus = true scramPlus = true
} }
if scramPlus { if scramPlus {
for _, cbs := range f.ChannelBindings.ChannelBinding {
cbsSlice = append(cbsSlice, cbs.Type)
}
tlsState := tlsConn.ConnectionState() tlsState := tlsConn.ConnectionState()
switch tlsState.Version { switch tlsState.Version {
case tls.VersionTLS13: case tls.VersionTLS13:
tls13 = true tls13 = true
if slices.Contains(cbsSlice, "tls-server-end-point") && !slices.Contains(cbsSlice, "tls-exporter") {
serverEndPoint = true
} else {
keyingMaterial, err = tlsState.ExportKeyingMaterial("EXPORTER-Channel-Binding", nil, 32) keyingMaterial, err = tlsState.ExportKeyingMaterial("EXPORTER-Channel-Binding", nil, 32)
if err != nil { if err != nil {
return err return err
} }
}
case tls.VersionTLS10, tls.VersionTLS11, tls.VersionTLS12: case tls.VersionTLS10, tls.VersionTLS11, tls.VersionTLS12:
if slices.Contains(cbsSlice, "tls-server-end-point") && !slices.Contains(cbsSlice, "tls-unique") {
serverEndPoint = true
} else {
keyingMaterial = tlsState.TLSUnique keyingMaterial = tlsState.TLSUnique
}
default: default:
return errors.New(mechanism + ": unknown TLS version") return errors.New(mechanism + ": unknown TLS version")
} }
if serverEndPoint {
switch tlsState.PeerCertificates[0].SignatureAlgorithm {
case x509.SHA1WithRSA, x509.SHA256WithRSA, x509.ECDSAWithSHA1,
x509.ECDSAWithSHA256, x509.SHA256WithRSAPSS:
h := sha256.New()
h.Write(tlsState.PeerCertificates[0].Raw)
keyingMaterial = h.Sum(nil)
h.Reset()
case x509.SHA384WithRSA, x509.ECDSAWithSHA384, x509.SHA384WithRSAPSS:
h := sha512.New384()
h.Write(tlsState.PeerCertificates[0].Raw)
keyingMaterial = h.Sum(nil)
h.Reset()
case x509.SHA512WithRSA, x509.ECDSAWithSHA512, x509.SHA512WithRSAPSS:
h := sha512.New()
h.Write(tlsState.PeerCertificates[0].Raw)
keyingMaterial = h.Sum(nil)
h.Reset()
}
}
if len(keyingMaterial) == 0 { if len(keyingMaterial) == 0 {
return errors.New(mechanism + ": no keying material") return errors.New(mechanism + ": no keying material")
} }
if tls13 { switch {
case tls13 && !serverEndPoint:
channelBinding = base64.StdEncoding.EncodeToString(append([]byte("p=tls-exporter,,"), keyingMaterial[:]...)) channelBinding = base64.StdEncoding.EncodeToString(append([]byte("p=tls-exporter,,"), keyingMaterial[:]...))
} else { case serverEndPoint:
channelBinding = base64.StdEncoding.EncodeToString(append([]byte("p=tls-server-end-point,,"), keyingMaterial[:]...))
default:
channelBinding = base64.StdEncoding.EncodeToString(append([]byte("p=tls-unique,,"), keyingMaterial[:]...)) channelBinding = base64.StdEncoding.EncodeToString(append([]byte("p=tls-unique,,"), keyingMaterial[:]...))
} }
} }
@ -470,9 +506,12 @@ func (c *Client) init(o *Options) error {
} }
clientNonce := cnonce() clientNonce := cnonce()
if scramPlus { if scramPlus {
if tls13 { switch {
case tls13 && !serverEndPoint:
clientFirstMessage = "p=tls-exporter,,n=" + user + ",r=" + clientNonce clientFirstMessage = "p=tls-exporter,,n=" + user + ",r=" + clientNonce
} else { case serverEndPoint:
clientFirstMessage = "p=tls-server-end-point,,n=" + user + ",r=" + clientNonce
default:
clientFirstMessage = "p=tls-unique,,n=" + user + ",r=" + clientNonce clientFirstMessage = "p=tls-unique,,n=" + user + ",r=" + clientNonce
} }
} else { } else {
@ -523,10 +562,6 @@ func (c *Client) init(o *Options) error {
} }
} }
dgProtect = dgProtect + "|" dgProtect = dgProtect + "|"
var cbsSlice []string
for _, cbs := range f.ChannelBindings.ChannelBinding {
cbsSlice = append(cbsSlice, cbs.Type)
}
slices.Sort(cbsSlice) slices.Sort(cbsSlice)
for i, cb := range cbsSlice { for i, cb := range cbsSlice {
if i == 0 { if i == 0 {