Merge pull request #186 from vcabbage/move-mutex

move nextMutex to Client to prevent blocking separate Clients
This commit is contained in:
Martin 2024-03-26 19:18:05 +01:00 committed by GitHub
commit 94ab540b80
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

50
xmpp.go
View File

@ -61,9 +61,6 @@ var DefaultConfig = &tls.Config{}
// DebugWriter is the writer used to write debugging output to. // DebugWriter is the writer used to write debugging output to.
var DebugWriter io.Writer = os.Stderr var DebugWriter io.Writer = os.Stderr
// Mutex to prevent multiple access to xml.Decoder
var nextMutex sync.Mutex
// Cookie is a unique XMPP session identifier // Cookie is a unique XMPP session identifier
type Cookie uint64 type Cookie uint64
@ -80,6 +77,7 @@ type Client struct {
conn net.Conn // connection to server conn net.Conn // connection to server
jid string // Jabber ID for our connection jid string // Jabber ID for our connection
domain string domain string
nextMutex sync.Mutex // Mutex to prevent multiple access to xml.Decoder
p *xml.Decoder p *xml.Decoder
stanzaWriter io.Writer stanzaWriter io.Writer
Mechanism string Mechanism string
@ -341,7 +339,7 @@ func (c *Client) Close() error {
case <-time.After(10 * time.Second): case <-time.After(10 * time.Second):
break break
default: default:
ee, err := nextEnd(c.p) ee, err := c.nextEnd()
// If the server already closed the stream it is // If the server already closed the stream it is
// likely to receive an error when trying to parse // likely to receive an error when trying to parse
// the stream. Therefore the connection is also closed // the stream. Therefore the connection is also closed
@ -540,7 +538,7 @@ func (c *Client) init(o *Options) error {
fmt.Fprintf(c.stanzaWriter, "<auth xmlns='%s' mechanism='%s'>%s</auth>\n", fmt.Fprintf(c.stanzaWriter, "<auth xmlns='%s' mechanism='%s'>%s</auth>\n",
nsSASL, mechanism, base64.StdEncoding.EncodeToString([]byte(clientFirstMessage))) nsSASL, mechanism, base64.StdEncoding.EncodeToString([]byte(clientFirstMessage)))
var sfm string var sfm string
_, val, err := next(c.p) _, val, err := c.next()
if err != nil { if err != nil {
return err return err
} }
@ -706,7 +704,7 @@ func (c *Client) init(o *Options) error {
return fmt.Errorf("no viable authentication method available: %v", f.Mechanisms.Mechanism) return fmt.Errorf("no viable authentication method available: %v", f.Mechanisms.Mechanism)
} }
// Next message should be either success or failure. // Next message should be either success or failure.
name, val, err := next(c.p) name, val, err := c.next()
if err != nil { if err != nil {
return err return err
} }
@ -757,7 +755,7 @@ func (c *Client) init(o *Options) error {
} else { } else {
fmt.Fprintf(c.stanzaWriter, "<iq type='set' id='%x'><bind xmlns='%s'><resource>%s</resource></bind></iq>\n", cookie, nsBind, o.Resource) fmt.Fprintf(c.stanzaWriter, "<iq type='set' id='%x'><bind xmlns='%s'><resource>%s</resource></bind></iq>\n", cookie, nsBind, o.Resource)
} }
_, val, err = next(c.p) _, val, err = c.next()
if err != nil { if err != nil {
return err return err
} }
@ -854,7 +852,7 @@ func (c *Client) startStream(o *Options, domain string) (*streamFeatures, error)
} }
// We expect the server to start a <stream>. // We expect the server to start a <stream>.
se, err := nextStart(c.p) se, err := c.nextStart()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -923,7 +921,7 @@ type IQ struct {
// Recv waits to receive the next XMPP stanza. // Recv waits to receive the next XMPP stanza.
func (c *Client) Recv() (stanza interface{}, err error) { func (c *Client) Recv() (stanza interface{}, err error) {
for { for {
_, val, err := next(c.p) _, val, err := c.next()
if err != nil { if err != nil {
return Chat{}, err return Chat{}, err
} }
@ -1462,57 +1460,57 @@ type rosterItem struct {
} }
// Scan XML token stream to find next StartElement. // Scan XML token stream to find next StartElement.
func nextStart(p *xml.Decoder) (xml.StartElement, error) { func (c *Client) nextStart() (xml.StartElement, error) {
for { for {
nextMutex.Lock() c.nextMutex.Lock()
t, err := p.Token() t, err := c.p.Token()
if err != nil || t == nil { if err != nil || t == nil {
nextMutex.Unlock() c.nextMutex.Unlock()
return xml.StartElement{}, err return xml.StartElement{}, err
} }
switch t := t.(type) { switch t := t.(type) {
case xml.StartElement: case xml.StartElement:
nextMutex.Unlock() c.nextMutex.Unlock()
return t, nil return t, nil
// Also check for stream end element and stop waiting // Also check for stream end element and stop waiting
// for new start elements if we received a closing stream // for new start elements if we received a closing stream
// element. // element.
case xml.EndElement: case xml.EndElement:
if t.Name.Local == "stream" { if t.Name.Local == "stream" {
nextMutex.Unlock() c.nextMutex.Unlock()
return xml.StartElement{}, nil return xml.StartElement{}, nil
} }
} }
nextMutex.Unlock() c.nextMutex.Unlock()
} }
} }
// Scan XML token stream to find next EndElement // Scan XML token stream to find next EndElement
func nextEnd(p *xml.Decoder) (xml.EndElement, error) { func (c *Client) nextEnd() (xml.EndElement, error) {
p.Strict = false c.p.Strict = false
for { for {
nextMutex.Lock() c.nextMutex.Lock()
to, err := p.RawToken() to, err := c.p.RawToken()
if err != nil || to == nil { if err != nil || to == nil {
nextMutex.Unlock() c.nextMutex.Unlock()
return xml.EndElement{}, err return xml.EndElement{}, err
} }
t := xml.CopyToken(to) t := xml.CopyToken(to)
switch t := t.(type) { switch t := t.(type) {
case xml.EndElement: case xml.EndElement:
nextMutex.Unlock() c.nextMutex.Unlock()
return t, nil return t, nil
} }
nextMutex.Unlock() c.nextMutex.Unlock()
} }
} }
// Scan XML token stream for next element and save into val. // Scan XML token stream for next element and save into val.
// If val == nil, allocate new element based on proto map. // If val == nil, allocate new element based on proto map.
// Either way, return val. // Either way, return val.
func next(p *xml.Decoder) (xml.Name, interface{}, error) { func (c *Client) next() (xml.Name, interface{}, error) {
// Read start element to find out what type we want. // Read start element to find out what type we want.
se, err := nextStart(p) se, err := c.nextStart()
if err != nil { if err != nil {
return xml.Name{}, nil, err return xml.Name{}, nil, err
} }
@ -1560,7 +1558,7 @@ func next(p *xml.Decoder) (xml.Name, interface{}, error) {
} }
// Unmarshal into that storage. // Unmarshal into that storage.
if err = p.DecodeElement(nv, &se); err != nil { if err = c.p.DecodeElement(nv, &se); err != nil {
return xml.Name{}, nil, err return xml.Name{}, nil, err
} }