Authentication improvements. (#165)

* Add XEP-0474 support.
* Add missing error handling.
This commit is contained in:
Martin 2024-01-09 10:24:56 +01:00 committed by GitHub
parent 31c7eb6919
commit 39f5b80375
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

54
xmpp.go
View File

@ -34,6 +34,7 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
"slices"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -46,6 +47,7 @@ const (
nsTLS = "urn:ietf:params:xml:ns:xmpp-tls" nsTLS = "urn:ietf:params:xml:ns:xmpp-tls"
nsSASL = "urn:ietf:params:xml:ns:xmpp-sasl" nsSASL = "urn:ietf:params:xml:ns:xmpp-sasl"
nsBind = "urn:ietf:params:xml:ns:xmpp-bind" nsBind = "urn:ietf:params:xml:ns:xmpp-bind"
nsSASLCB = "urn:xmpp:sasl-cb:0"
nsClient = "jabber:client" nsClient = "jabber:client"
nsSession = "urn:ietf:params:xml:ns:xmpp-session" nsSession = "urn:ietf:params:xml:ns:xmpp-session"
) )
@ -521,7 +523,7 @@ func (c *Client) init(o *Options) error {
if err != nil { if err != nil {
return err return err
} }
var serverNonce string var serverNonce, dgProtect string
var salt []byte var salt []byte
var iterations int var iterations int
for _, serverReply := range strings.Split(string(b), ",") { for _, serverReply := range strings.Split(string(b), ",") {
@ -533,6 +535,9 @@ func (c *Client) init(o *Options) error {
} }
case strings.HasPrefix(serverReply, "s="): case strings.HasPrefix(serverReply, "s="):
salt, err = base64.StdEncoding.DecodeString(strings.SplitN(serverReply, "=", 2)[1]) salt, err = base64.StdEncoding.DecodeString(strings.SplitN(serverReply, "=", 2)[1])
if err != nil {
return err
}
if string(salt) == "" { if string(salt) == "" {
return errors.New("SCRAM: server sent empty salt") return errors.New("SCRAM: server sent empty salt")
} }
@ -542,6 +547,37 @@ func (c *Client) init(o *Options) error {
if err != nil { if err != nil {
return err return err
} }
case strings.HasPrefix(serverReply, "d="):
serverDgProtectHash := strings.SplitN(serverReply, "=", 2)[1]
slices.Sort(f.Mechanisms.Mechanism)
for _, mech := range f.Mechanisms.Mechanism {
if dgProtect == "" {
dgProtect = mech
} else {
dgProtect = dgProtect + "," + mech
}
}
dgProtect = dgProtect + "|"
var cbsSlice []string
for _, cbs := range f.ChannelBindings.ChannelBinding {
cbsSlice = append(cbsSlice, cbs.Type)
}
slices.Sort(cbsSlice)
for i, cb := range cbsSlice {
if i == 0 {
dgProtect = dgProtect + cb
} else {
dgProtect = dgProtect + "," + cb
}
}
dgh := shaNewFn()
dgh.Write([]byte(dgProtect))
dHash := dgh.Sum(nil)
dHashb64 := base64.StdEncoding.EncodeToString(dHash)
if dHashb64 != serverDgProtectHash {
return errors.New("SCRAM: downgrade protection hash mismatch")
}
dgh.Reset()
default: default:
return errors.New("unexpected content in SCRAM challenge") return errors.New("unexpected content in SCRAM challenge")
} }
@ -1194,12 +1230,12 @@ func (c *Client) Roster() error {
// RFC 3920 C.1 Streams name space // RFC 3920 C.1 Streams name space
type streamFeatures struct { type streamFeatures struct {
XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"` XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"`
StartTLS *tlsStartTLS StartTLS *tlsStartTLS
Mechanisms saslMechanisms Mechanisms saslMechanisms
ChannelBinding saslChannelBinding ChannelBindings saslChannelBindings
Bind bindBind Bind bindBind
Session bool Session bool
} }
type streamError struct { type streamError struct {
@ -1233,7 +1269,7 @@ type saslAuth struct {
Mechanism string `xml:",attr"` Mechanism string `xml:",attr"`
} }
type saslChannelBinding struct { type saslChannelBindings struct {
XMLName xml.Name `xml:"sasl-channel-binding"` XMLName xml.Name `xml:"sasl-channel-binding"`
Text string `xml:",chardata"` Text string `xml:",chardata"`
Xmlns string `xml:"xmlns,attr"` Xmlns string `xml:"xmlns,attr"`
@ -1436,6 +1472,8 @@ func next(p *xml.Decoder) (xml.Name, interface{}, error) {
nv = &saslSuccess{} nv = &saslSuccess{}
case nsSASL + " failure": case nsSASL + " failure":
nv = &saslFailure{} nv = &saslFailure{}
case nsSASLCB + " sasl-channel-binding":
nv = &saslChannelBindings{}
case nsBind + " bind": case nsBind + " bind":
nv = &bindBind{} nv = &bindBind{}
case nsClient + " message": case nsClient + " message":