Fix race condition for nextStart and nextEnd.

This commit is contained in:
Martin Dosch 2024-03-16 19:04:09 +01:00
parent 73f06c9f3d
commit 8ab32d885f

26
xmpp.go
View File

@ -38,6 +38,7 @@ import (
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"golang.org/x/crypto/pbkdf2" "golang.org/x/crypto/pbkdf2"
@ -60,6 +61,9 @@ var DefaultConfig = &tls.Config{}
// DebugWriter is the writer used to write debugging output to. // DebugWriter is the writer used to write debugging output to.
var DebugWriter io.Writer = os.Stderr var DebugWriter io.Writer = os.Stderr
// Mutex to prevent multiple access to xml.Decoder
var nextMutex sync.Mutex
// Cookie is a unique XMPP session identifier // Cookie is a unique XMPP session identifier
type Cookie uint64 type Cookie uint64
@ -917,7 +921,6 @@ type IQ struct {
} }
// Recv waits to receive the next XMPP stanza. // Recv waits to receive the next XMPP stanza.
// Return type is either a presence notification or a chat message.
func (c *Client) Recv() (stanza interface{}, err error) { func (c *Client) Recv() (stanza interface{}, err error) {
for { for {
_, val, err := next(c.p) _, val, err := next(c.p)
@ -1461,29 +1464,46 @@ type rosterItem struct {
// Scan XML token stream to find next StartElement. // Scan XML token stream to find next StartElement.
func nextStart(p *xml.Decoder) (xml.StartElement, error) { func nextStart(p *xml.Decoder) (xml.StartElement, error) {
for { for {
nextMutex.Lock()
t, err := p.Token() t, err := p.Token()
if err != nil || t == nil { if err != nil || t == nil {
nextMutex.Unlock()
return xml.StartElement{}, err return xml.StartElement{}, err
} }
switch t := t.(type) { switch t := t.(type) {
case xml.StartElement: case xml.StartElement:
nextMutex.Unlock()
return t, nil return t, nil
// Also check for stream end element and stop waiting
// for new start elements if we received a closing stream
// element.
case xml.EndElement:
if t.Name.Local == "stream" {
nextMutex.Unlock()
return xml.StartElement{}, nil
} }
} }
nextMutex.Unlock()
}
} }
// Scan XML token stream to find next EndElement // Scan XML token stream to find next EndElement
func nextEnd(p *xml.Decoder) (xml.EndElement, error) { func nextEnd(p *xml.Decoder) (xml.EndElement, error) {
p.Strict = false p.Strict = false
for { for {
t, err := p.RawToken() nextMutex.Lock()
if err != nil || t == nil { to, err := p.RawToken()
if err != nil || to == nil {
nextMutex.Unlock()
return xml.EndElement{}, err return xml.EndElement{}, err
} }
t := xml.CopyToken(to)
switch t := t.(type) { switch t := t.(type) {
case xml.EndElement: case xml.EndElement:
nextMutex.Unlock()
return t, nil return t, nil
} }
nextMutex.Unlock()
} }
} }