diff --git a/xmpp.go b/xmpp.go index 77fc909..f11f154 100644 --- a/xmpp.go +++ b/xmpp.go @@ -22,6 +22,7 @@ import ( "crypto/sha256" "crypto/sha512" "crypto/tls" + "crypto/x509" "encoding/base64" "encoding/binary" "encoding/xml" @@ -379,7 +380,8 @@ func (c *Client) init(o *Options) error { } var mechanism, channelBinding, clientFirstMessage, clientFinalMessageBare, authMessage string var serverSignature, keyingMaterial []byte - var scramPlus, ok, tlsConnOK, tls13 bool + var scramPlus, ok, tlsConnOK, tls13, serverEndPoint bool + var cbsSlice []string var tlsConn *tls.Conn if o.User == "" && o.Password == "" { foundAnonymous := false @@ -435,25 +437,59 @@ func (c *Client) init(o *Options) error { scramPlus = true } if scramPlus { + for _, cbs := range f.ChannelBindings.ChannelBinding { + cbsSlice = append(cbsSlice, cbs.Type) + } tlsState := tlsConn.ConnectionState() switch tlsState.Version { case tls.VersionTLS13: tls13 = true - keyingMaterial, err = tlsState.ExportKeyingMaterial("EXPORTER-Channel-Binding", nil, 32) - if err != nil { - return err + if slices.Contains(cbsSlice, "tls-server-end-point") && !slices.Contains(cbsSlice, "tls-exporter") { + serverEndPoint = true + } else { + keyingMaterial, err = tlsState.ExportKeyingMaterial("EXPORTER-Channel-Binding", nil, 32) + if err != nil { + return err + } } case tls.VersionTLS10, tls.VersionTLS11, tls.VersionTLS12: - keyingMaterial = tlsState.TLSUnique + if slices.Contains(cbsSlice, "tls-server-end-point") && !slices.Contains(cbsSlice, "tls-unique") { + serverEndPoint = true + } else { + keyingMaterial = tlsState.TLSUnique + } default: return errors.New(mechanism + ": unknown TLS version") } + if serverEndPoint { + switch tlsState.PeerCertificates[0].SignatureAlgorithm { + case x509.SHA1WithRSA, x509.SHA256WithRSA, x509.ECDSAWithSHA1, + x509.ECDSAWithSHA256, x509.SHA256WithRSAPSS: + h := sha256.New() + h.Write(tlsState.PeerCertificates[0].Raw) + keyingMaterial = h.Sum(nil) + h.Reset() + case x509.SHA384WithRSA, x509.ECDSAWithSHA384, x509.SHA384WithRSAPSS: + h := sha512.New384() + h.Write(tlsState.PeerCertificates[0].Raw) + keyingMaterial = h.Sum(nil) + h.Reset() + case x509.SHA512WithRSA, x509.ECDSAWithSHA512, x509.SHA512WithRSAPSS: + h := sha512.New() + h.Write(tlsState.PeerCertificates[0].Raw) + keyingMaterial = h.Sum(nil) + h.Reset() + } + } if len(keyingMaterial) == 0 { return errors.New(mechanism + ": no keying material") } - if tls13 { + switch { + case tls13 && !serverEndPoint: channelBinding = base64.StdEncoding.EncodeToString(append([]byte("p=tls-exporter,,"), keyingMaterial[:]...)) - } else { + case serverEndPoint: + channelBinding = base64.StdEncoding.EncodeToString(append([]byte("p=tls-server-end-point,,"), keyingMaterial[:]...)) + default: channelBinding = base64.StdEncoding.EncodeToString(append([]byte("p=tls-unique,,"), keyingMaterial[:]...)) } } @@ -470,9 +506,12 @@ func (c *Client) init(o *Options) error { } clientNonce := cnonce() if scramPlus { - if tls13 { + switch { + case tls13 && !serverEndPoint: clientFirstMessage = "p=tls-exporter,,n=" + user + ",r=" + clientNonce - } else { + case serverEndPoint: + clientFirstMessage = "p=tls-server-end-point,,n=" + user + ",r=" + clientNonce + default: clientFirstMessage = "p=tls-unique,,n=" + user + ",r=" + clientNonce } } else { @@ -523,10 +562,6 @@ func (c *Client) init(o *Options) error { } } dgProtect = dgProtect + "|" - var cbsSlice []string - for _, cbs := range f.ChannelBindings.ChannelBinding { - cbsSlice = append(cbsSlice, cbs.Type) - } slices.Sort(cbsSlice) for i, cb := range cbsSlice { if i == 0 {