Tests for Component and code style fixes (#129)

* Tests for Component and code style fixes
This commit is contained in:
remicorniere
2019-11-28 16:15:15 +00:00
committed by Jérôme Sautret
parent 7d89353156
commit 1822089db6
21 changed files with 612 additions and 74 deletions

View File

@@ -12,7 +12,7 @@ import (
type Handshake struct {
XMLName xml.Name `xml:"jabber:component:accept handshake"`
// TODO Add handshake value with test for proper serialization
// Value string `xml:",innerxml"`
Value string `xml:",innerxml"`
}
func (Handshake) Name() string {

View File

@@ -54,7 +54,7 @@ func (x *Err) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
textName := xml.Name{Space: "urn:ietf:params:xml:ns:xmpp-stanzas", Local: "text"}
if elt.XMLName == textName {
x.Text = string(elt.Content)
x.Text = elt.Content
} else if elt.XMLName.Space == "urn:ietf:params:xml:ns:xmpp-stanzas" {
x.Reason = elt.XMLName.Local
}
@@ -94,16 +94,32 @@ func (x Err) MarshalXML(e *xml.Encoder, start xml.StartElement) (err error) {
// Reason
if x.Reason != "" {
reason := xml.Name{Space: "urn:ietf:params:xml:ns:xmpp-stanzas", Local: x.Reason}
e.EncodeToken(xml.StartElement{Name: reason})
e.EncodeToken(xml.EndElement{Name: reason})
err = e.EncodeToken(xml.StartElement{Name: reason})
if err != nil {
return err
}
err = e.EncodeToken(xml.EndElement{Name: reason})
if err != nil {
return err
}
}
// Text
if x.Text != "" {
text := xml.Name{Space: "urn:ietf:params:xml:ns:xmpp-stanzas", Local: "text"}
e.EncodeToken(xml.StartElement{Name: text})
e.EncodeToken(xml.CharData(x.Text))
e.EncodeToken(xml.EndElement{Name: text})
err = e.EncodeToken(xml.StartElement{Name: text})
if err != nil {
return err
}
err = e.EncodeToken(xml.CharData(x.Text))
if err != nil {
return err
}
err = e.EncodeToken(xml.EndElement{Name: text})
if err != nil {
return err
}
}
return e.EncodeToken(xml.EndElement{Name: start.Name})

View File

@@ -2,6 +2,7 @@ package stanza
import (
"encoding/xml"
"strings"
"github.com/google/uuid"
)
@@ -23,7 +24,7 @@ type IQ struct { // Info/Query
// child element, which specifies the semantics of the particular
// request."
Payload IQPayload `xml:",omitempty"`
Error Err `xml:"error,omitempty"`
Error *Err `xml:"error,omitempty"`
// Any is used to decode unknown payload as a generic structure
Any *Node `xml:",any"`
}
@@ -52,7 +53,7 @@ func (iq IQ) MakeError(xerror Err) IQ {
iq.Type = "error"
iq.From = to
iq.To = from
iq.Error = xerror
iq.Error = &xerror
return iq
}
@@ -106,7 +107,7 @@ func (iq *IQ) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
if err != nil {
return err
}
iq.Error = xmppError
iq.Error = &xmppError
continue
}
if iqExt := TypeRegistry.GetIQExtension(tt.Name); iqExt != nil {
@@ -132,3 +133,39 @@ func (iq *IQ) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
}
}
}
// Following RFC-3920 for IQs
func (iq *IQ) IsValid() bool {
// ID is required
if len(strings.TrimSpace(iq.Id)) == 0 {
return false
}
// Type is required
if iq.Type.IsEmpty() {
return false
}
// Type get and set must contain one and only one child element that specifies the semantics
if iq.Type == IQTypeGet || iq.Type == IQTypeSet {
if iq.Payload == nil && iq.Any == nil {
return false
}
}
// A result must include zero or one child element
if iq.Type == IQTypeResult {
if iq.Payload != nil && iq.Any != nil {
return false
}
}
//Error type must contain an "error" child element
if iq.Type == IQTypeError {
if iq.Error == nil {
return false
}
}
return true
}

View File

@@ -187,3 +187,38 @@ func TestUnknownPayload(t *testing.T) {
t.Errorf("could not extract namespace: '%s'", parsedIQ.Any.XMLName.Space)
}
}
func TestIsValid(t *testing.T) {
type testCase struct {
iq string
shouldErr bool
}
testIQs := make(map[string]testCase)
testIQs["Valid IQ"] = testCase{
`<iq type="get" to="service.localhost" id="1" >
<query xmlns="unknown:ns"/>
</iq>`,
false,
}
testIQs["Invalid IQ"] = testCase{
`<iq type="get" to="service.localhost">
<query xmlns="unknown:ns"/>
</iq>`,
true,
}
for name, tcase := range testIQs {
t.Run(name, func(st *testing.T) {
parsedIQ := stanza.IQ{}
err := xml.Unmarshal([]byte(tcase.iq), &parsedIQ)
if err != nil {
t.Errorf("Unmarshal error: %#v (%s)", err, tcase.iq)
return
}
if !parsedIQ.IsValid() && !tcase.shouldErr {
t.Errorf("failed iq validation for : %s", tcase.iq)
}
})
}
}

View File

@@ -46,9 +46,18 @@ func (n Node) MarshalXML(e *xml.Encoder, start xml.StartElement) (err error) {
start.Name = n.XMLName
err = e.EncodeToken(start)
e.EncodeElement(n.Nodes, xml.StartElement{Name: n.XMLName})
if err != nil {
return err
}
err = e.EncodeElement(n.Nodes, xml.StartElement{Name: n.XMLName})
if err != nil {
return err
}
if n.Content != "" {
e.EncodeToken(xml.CharData(n.Content))
err = e.EncodeToken(xml.CharData(n.Content))
if err != nil {
return err
}
}
return e.EncodeToken(xml.EndElement{Name: start.Name})
}

View File

@@ -1,5 +1,7 @@
package stanza
import "strings"
type StanzaType string
// RFC 6120: part of A.5 Client Namespace and A.6 Server Namespace
@@ -23,3 +25,7 @@ const (
PresenceTypeUnsubscribe StanzaType = "unsubscribe"
PresenceTypeUnsubscribed StanzaType = "unsubscribed"
)
func (s StanzaType) IsEmpty() bool {
return len(strings.TrimSpace(string(s))) == 0
}

View File

@@ -107,6 +107,6 @@ func (s *StreamSession) IsOptional() bool {
// Registry init
func init() {
TypeRegistry.MapExtension(PKTIQ, xml.Name{"urn:ietf:params:xml:ns:xmpp-bind", "bind"}, Bind{})
TypeRegistry.MapExtension(PKTIQ, xml.Name{"urn:ietf:params:xml:ns:xmpp-session", "session"}, StreamSession{})
TypeRegistry.MapExtension(PKTIQ, xml.Name{Space: "urn:ietf:params:xml:ns:xmpp-bind", Local: "bind"}, Bind{})
TypeRegistry.MapExtension(PKTIQ, xml.Name{Space: "urn:ietf:params:xml:ns:xmpp-session", Local: "session"}, StreamSession{})
}

View File

@@ -8,7 +8,7 @@ import "encoding/xml"
type Stream struct {
XMLName xml.Name `xml:"http://etherx.jabber.org/streams stream"`
From string `xml:"from,attr"`
To string `xml:"to,attr"`
To string `xml:"to,attr"`
Id string `xml:"id,attr"`
Version string `xml:"version,attr"`
}

View File

@@ -15,7 +15,7 @@ type StreamFeatures struct {
// Server capabilities hash
Caps Caps
// Stream features
StartTLS tlsStartTLS
StartTLS TlsStartTLS
Mechanisms saslMechanisms
Bind Bind
StreamManagement streamManagement
@@ -60,13 +60,13 @@ type Caps struct {
// StartTLS feature
// Reference: RFC 6120 - https://tools.ietf.org/html/rfc6120#section-5.4
type tlsStartTLS struct {
type TlsStartTLS struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls starttls"`
Required bool
}
// UnmarshalXML implements custom parsing startTLS required flag
func (stls *tlsStartTLS) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
func (stls *TlsStartTLS) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
stls.XMLName = start.Name
// Check subelements to extract required field as boolean
@@ -98,7 +98,7 @@ func (stls *tlsStartTLS) UnmarshalXML(d *xml.Decoder, start xml.StartElement) er
}
}
func (sf *StreamFeatures) DoesStartTLS() (feature tlsStartTLS, isSupported bool) {
func (sf *StreamFeatures) DoesStartTLS() (feature TlsStartTLS, isSupported bool) {
if sf.StartTLS.XMLName.Space+" "+sf.StartTLS.XMLName.Local == nsTLS+" starttls" {
return sf.StartTLS, true
}