forked from jshiffer/go-xmpp
Tests for Component and code style fixes (#129)
* Tests for Component and code style fixes
This commit is contained in:
parent
7d89353156
commit
1822089db6
@ -58,7 +58,7 @@ func handleMessage(_ xmpp.Sender, p stanza.Packet) {
|
|||||||
func discoInfo(c xmpp.Sender, p stanza.Packet, opts xmpp.ComponentOptions) {
|
func discoInfo(c xmpp.Sender, p stanza.Packet, opts xmpp.ComponentOptions) {
|
||||||
// Type conversion & sanity checks
|
// Type conversion & sanity checks
|
||||||
iq, ok := p.(stanza.IQ)
|
iq, ok := p.(stanza.IQ)
|
||||||
if !ok || iq.Type != "get" {
|
if !ok || iq.Type != stanza.IQTypeGet {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -73,7 +73,7 @@ func discoInfo(c xmpp.Sender, p stanza.Packet, opts xmpp.ComponentOptions) {
|
|||||||
func discoItems(c xmpp.Sender, p stanza.Packet) {
|
func discoItems(c xmpp.Sender, p stanza.Packet) {
|
||||||
// Type conversion & sanity checks
|
// Type conversion & sanity checks
|
||||||
iq, ok := p.(stanza.IQ)
|
iq, ok := p.(stanza.IQ)
|
||||||
if !ok || iq.Type != "get" {
|
if !ok || iq.Type != stanza.IQTypeGet {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -106,7 +106,7 @@ func handleIQ(s xmpp.Sender, p stanza.Packet, player *mpg123.Player) {
|
|||||||
|
|
||||||
func sendUserTune(s xmpp.Sender, artist string, title string) {
|
func sendUserTune(s xmpp.Sender, artist string, title string) {
|
||||||
tune := stanza.Tune{Artist: artist, Title: title}
|
tune := stanza.Tune{Artist: artist, Title: title}
|
||||||
iq := stanza.NewIQ(stanza.Attrs{Type: "set", Id: "usertune-1", Lang: "en"})
|
iq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeSet, Id: "usertune-1", Lang: "en"})
|
||||||
payload := stanza.PubSub{Publish: &stanza.Publish{Node: "http://jabber.org/protocol/tune", Item: stanza.Item{Tune: &tune}}}
|
payload := stanza.PubSub{Publish: &stanza.Publish{Node: "http://jabber.org/protocol/tune", Item: stanza.Item{Tune: &tune}}}
|
||||||
iq.Payload = &payload
|
iq.Payload = &payload
|
||||||
_ = s.Send(iq)
|
_ = s.Send(iq)
|
||||||
|
5
auth.go
5
auth.go
@ -60,7 +60,10 @@ func authPlain(socket io.ReadWriter, decoder *xml.Decoder, mech string, user str
|
|||||||
raw := "\x00" + user + "\x00" + secret
|
raw := "\x00" + user + "\x00" + secret
|
||||||
enc := make([]byte, base64.StdEncoding.EncodedLen(len(raw)))
|
enc := make([]byte, base64.StdEncoding.EncodedLen(len(raw)))
|
||||||
base64.StdEncoding.Encode(enc, []byte(raw))
|
base64.StdEncoding.Encode(enc, []byte(raw))
|
||||||
fmt.Fprintf(socket, "<auth xmlns='%s' mechanism='%s'>%s</auth>", stanza.NSSASL, mech, enc)
|
_, err := fmt.Fprintf(socket, "<auth xmlns='%s' mechanism='%s'>%s</auth>", stanza.NSSASL, mech, enc)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// Next message should be either success or failure.
|
// Next message should be either success or failure.
|
||||||
val, err := stanza.NextPacket(decoder)
|
val, err := stanza.NextPacket(decoder)
|
||||||
|
@ -79,7 +79,10 @@ func (c *ServerCheck) Check() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := f.DoesStartTLS(); ok {
|
if _, ok := f.DoesStartTLS(); ok {
|
||||||
fmt.Fprintf(tcpconn, "<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>")
|
_, err = fmt.Fprintf(tcpconn, "<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
var k stanza.TLSProceed
|
var k stanza.TLSProceed
|
||||||
if err = decoder.DecodeElement(&k, nil); err != nil {
|
if err = decoder.DecodeElement(&k, nil); err != nil {
|
||||||
|
26
client.go
26
client.go
@ -50,7 +50,7 @@ type SMState struct {
|
|||||||
|
|
||||||
// EventHandler is use to pass events about state of the connection to
|
// EventHandler is use to pass events about state of the connection to
|
||||||
// client implementation.
|
// client implementation.
|
||||||
type EventHandler func(Event)
|
type EventHandler func(Event) error
|
||||||
|
|
||||||
type EventManager struct {
|
type EventManager struct {
|
||||||
// Store current state
|
// Store current state
|
||||||
@ -188,13 +188,16 @@ func (c *Client) Resume(state SMState) error {
|
|||||||
go keepalive(c.transport, keepaliveQuit)
|
go keepalive(c.transport, keepaliveQuit)
|
||||||
// Start the receiver go routine
|
// Start the receiver go routine
|
||||||
state = c.Session.SMState
|
state = c.Session.SMState
|
||||||
go c.recv(state, keepaliveQuit)
|
// Leaving this channel here for later. Not used atm. We should return this instead of an error because right
|
||||||
|
// now the returned error is lost in limbo.
|
||||||
|
errChan := make(chan error)
|
||||||
|
go c.recv(state, keepaliveQuit, errChan)
|
||||||
|
|
||||||
// We're connected and can now receive and send messages.
|
// We're connected and can now receive and send messages.
|
||||||
//fmt.Fprintf(client.conn, "<presence xml:lang='en'><show>%s</show><status>%s</status></presence>", "chat", "Online")
|
//fmt.Fprintf(client.conn, "<presence xml:lang='en'><show>%s</show><status>%s</status></presence>", "chat", "Online")
|
||||||
// TODO: Do we always want to send initial presence automatically ?
|
// TODO: Do we always want to send initial presence automatically ?
|
||||||
// Do we need an option to avoid that or do we rely on client to send the presence itself ?
|
// Do we need an option to avoid that or do we rely on client to send the presence itself ?
|
||||||
fmt.Fprintf(c.transport, "<presence/>")
|
_, err = fmt.Fprintf(c.transport, "<presence/>")
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -235,7 +238,7 @@ func (c *Client) Send(packet stanza.Packet) error {
|
|||||||
// result := <- client.SendIQ(ctx, iq)
|
// result := <- client.SendIQ(ctx, iq)
|
||||||
//
|
//
|
||||||
func (c *Client) SendIQ(ctx context.Context, iq stanza.IQ) (chan stanza.IQ, error) {
|
func (c *Client) SendIQ(ctx context.Context, iq stanza.IQ) (chan stanza.IQ, error) {
|
||||||
if iq.Attrs.Type != "set" && iq.Attrs.Type != "get" {
|
if iq.Attrs.Type != stanza.IQTypeSet && iq.Attrs.Type != stanza.IQTypeGet {
|
||||||
return nil, ErrCanOnlySendGetOrSetIq
|
return nil, ErrCanOnlySendGetOrSetIq
|
||||||
}
|
}
|
||||||
if err := c.Send(iq); err != nil {
|
if err := c.Send(iq); err != nil {
|
||||||
@ -267,13 +270,14 @@ func (c *Client) sendWithWriter(writer io.Writer, packet []byte) error {
|
|||||||
// Go routines
|
// Go routines
|
||||||
|
|
||||||
// Loop: Receive data from server
|
// Loop: Receive data from server
|
||||||
func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}) (err error) {
|
func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}, errChan chan<- error) {
|
||||||
for {
|
for {
|
||||||
val, err := stanza.NextPacket(c.transport.GetDecoder())
|
val, err := stanza.NextPacket(c.transport.GetDecoder())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
errChan <- err
|
||||||
close(keepaliveQuit)
|
close(keepaliveQuit)
|
||||||
c.disconnected(state)
|
c.disconnected(state)
|
||||||
return err
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle stream errors
|
// Handle stream errors
|
||||||
@ -282,18 +286,22 @@ func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}) (err error)
|
|||||||
c.router.route(c, val)
|
c.router.route(c, val)
|
||||||
close(keepaliveQuit)
|
close(keepaliveQuit)
|
||||||
c.streamError(packet.Error.Local, packet.Text)
|
c.streamError(packet.Error.Local, packet.Text)
|
||||||
return errors.New("stream error: " + packet.Error.Local)
|
errChan <- errors.New("stream error: " + packet.Error.Local)
|
||||||
|
return
|
||||||
// Process Stream management nonzas
|
// Process Stream management nonzas
|
||||||
case stanza.SMRequest:
|
case stanza.SMRequest:
|
||||||
answer := stanza.SMAnswer{XMLName: xml.Name{
|
answer := stanza.SMAnswer{XMLName: xml.Name{
|
||||||
Space: stanza.NSStreamManagement,
|
Space: stanza.NSStreamManagement,
|
||||||
Local: "a",
|
Local: "a",
|
||||||
}, H: state.Inbound}
|
}, H: state.Inbound}
|
||||||
c.Send(answer)
|
err = c.Send(answer)
|
||||||
|
if err != nil {
|
||||||
|
errChan <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
state.Inbound++
|
state.Inbound++
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do normal route processing in a go-routine so we can immediately
|
// Do normal route processing in a go-routine so we can immediately
|
||||||
// start receiving other stanzas. This also allows route handlers to
|
// start receiving other stanzas. This also allows route handlers to
|
||||||
// send and receive more stanzas.
|
// send and receive more stanzas.
|
||||||
|
23
component.go
23
component.go
@ -72,11 +72,13 @@ func (c *Component) Resume(sm SMState) error {
|
|||||||
c.transport, err = NewComponentTransport(c.ComponentOptions.TransportConfiguration)
|
c.transport, err = NewComponentTransport(c.ComponentOptions.TransportConfiguration)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.updateState(StatePermanentError)
|
c.updateState(StatePermanentError)
|
||||||
|
|
||||||
return NewConnError(err, true)
|
return NewConnError(err, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
if streamId, err = c.transport.Connect(); err != nil {
|
if streamId, err = c.transport.Connect(); err != nil {
|
||||||
c.updateState(StatePermanentError)
|
c.updateState(StatePermanentError)
|
||||||
|
|
||||||
return NewConnError(err, true)
|
return NewConnError(err, true)
|
||||||
}
|
}
|
||||||
c.updateState(StateConnected)
|
c.updateState(StateConnected)
|
||||||
@ -84,6 +86,7 @@ func (c *Component) Resume(sm SMState) error {
|
|||||||
// Authentication
|
// Authentication
|
||||||
if _, err := fmt.Fprintf(c.transport, "<handshake>%s</handshake>", c.handshake(streamId)); err != nil {
|
if _, err := fmt.Fprintf(c.transport, "<handshake>%s</handshake>", c.handshake(streamId)); err != nil {
|
||||||
c.updateState(StateStreamError)
|
c.updateState(StateStreamError)
|
||||||
|
|
||||||
return NewConnError(errors.New("cannot send handshake "+err.Error()), false)
|
return NewConnError(errors.New("cannot send handshake "+err.Error()), false)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -101,12 +104,16 @@ func (c *Component) Resume(sm SMState) error {
|
|||||||
case stanza.Handshake:
|
case stanza.Handshake:
|
||||||
// Start the receiver go routine
|
// Start the receiver go routine
|
||||||
c.updateState(StateSessionEstablished)
|
c.updateState(StateSessionEstablished)
|
||||||
go c.recv()
|
// Leaving this channel here for later. Not used atm. We should return this instead of an error because right
|
||||||
return nil
|
// now the returned error is lost in limbo.
|
||||||
|
errChan := make(chan error)
|
||||||
|
go c.recv(errChan) // Sends to errChan
|
||||||
|
return err // Should be empty at this point
|
||||||
default:
|
default:
|
||||||
c.updateState(StatePermanentError)
|
c.updateState(StatePermanentError)
|
||||||
return NewConnError(errors.New("expecting handshake result, got "+v.Name()), true)
|
return NewConnError(errors.New("expecting handshake result, got "+v.Name()), true)
|
||||||
}
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Component) Disconnect() {
|
func (c *Component) Disconnect() {
|
||||||
@ -121,20 +128,22 @@ func (c *Component) SetHandler(handler EventHandler) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Receiver Go routine receiver
|
// Receiver Go routine receiver
|
||||||
func (c *Component) recv() (err error) {
|
func (c *Component) recv(errChan chan<- error) {
|
||||||
|
|
||||||
for {
|
for {
|
||||||
val, err := stanza.NextPacket(c.transport.GetDecoder())
|
val, err := stanza.NextPacket(c.transport.GetDecoder())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.updateState(StateDisconnected)
|
c.updateState(StateDisconnected)
|
||||||
return err
|
errChan <- err
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle stream errors
|
// Handle stream errors
|
||||||
switch p := val.(type) {
|
switch p := val.(type) {
|
||||||
case stanza.StreamError:
|
case stanza.StreamError:
|
||||||
c.router.route(c, val)
|
c.router.route(c, val)
|
||||||
c.streamError(p.Error.Local, p.Text)
|
c.streamError(p.Error.Local, p.Text)
|
||||||
return errors.New("stream error: " + p.Error.Local)
|
errChan <- errors.New("stream error: " + p.Error.Local)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
c.router.route(c, val)
|
c.router.route(c, val)
|
||||||
}
|
}
|
||||||
@ -168,7 +177,7 @@ func (c *Component) Send(packet stanza.Packet) error {
|
|||||||
// result := <- client.SendIQ(ctx, iq)
|
// result := <- client.SendIQ(ctx, iq)
|
||||||
//
|
//
|
||||||
func (c *Component) SendIQ(ctx context.Context, iq stanza.IQ) (chan stanza.IQ, error) {
|
func (c *Component) SendIQ(ctx context.Context, iq stanza.IQ) (chan stanza.IQ, error) {
|
||||||
if iq.Attrs.Type != "set" && iq.Attrs.Type != "get" {
|
if iq.Attrs.Type != stanza.IQTypeSet && iq.Attrs.Type != stanza.IQTypeGet {
|
||||||
return nil, ErrCanOnlySendGetOrSetIq
|
return nil, ErrCanOnlySendGetOrSetIq
|
||||||
}
|
}
|
||||||
if err := c.Send(iq); err != nil {
|
if err := c.Send(iq); err != nil {
|
||||||
|
@ -1,12 +1,34 @@
|
|||||||
package xmpp
|
package xmpp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/xml"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"gosrc.io/xmpp/stanza"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
const testComponentDomain = "localhost"
|
// Tests are ran in parallel, so each test creating a server must use a different port so we do not get any
|
||||||
const testComponentPort = "15222"
|
// conflict. Using iota for this should do the trick.
|
||||||
|
const (
|
||||||
|
testComponentDomain = "localhost"
|
||||||
|
defaultServerName = "testServer"
|
||||||
|
defaultStreamID = "91bd0bba-012f-4d92-bb17-5fc41e6fe545"
|
||||||
|
defaultComponentName = "Test Component"
|
||||||
|
|
||||||
|
// Default port is not standard XMPP port to avoid interfering
|
||||||
|
// with local running XMPP server
|
||||||
|
testHandshakePort = iota + 15222
|
||||||
|
testDecoderPort
|
||||||
|
testSendIqPort
|
||||||
|
testSendRawPort
|
||||||
|
testDisconnectPort
|
||||||
|
testSManDisconnectPort
|
||||||
|
)
|
||||||
|
|
||||||
func TestHandshake(t *testing.T) {
|
func TestHandshake(t *testing.T) {
|
||||||
opts := ComponentOptions{
|
opts := ComponentOptions{
|
||||||
@ -24,25 +46,43 @@ func TestHandshake(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Tests connection process with a handshake exchange
|
||||||
|
// Tests multiple session IDs. All connections should generate a unique stream ID
|
||||||
func TestGenerateHandshake(t *testing.T) {
|
func TestGenerateHandshake(t *testing.T) {
|
||||||
// TODO
|
// Using this array with a channel to make a queue of values to test
|
||||||
|
// These are stream IDs that will be used to test the connection process, mixing them with the "secret" to generate
|
||||||
|
// some handshake value
|
||||||
|
var uuidsArray = [5]string{
|
||||||
|
"cc9b3249-9582-4780-825f-4311b42f9b0e",
|
||||||
|
"bba8be3c-d98e-4e26-b9bb-9ed34578a503",
|
||||||
|
"dae72822-80e8-496b-b763-ab685f53a188",
|
||||||
|
"a45d6c06-de49-4bb0-935b-1a2201b71028",
|
||||||
|
"7dc6924f-0eca-4237-9898-18654b8d891e",
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test that NewStreamManager can accept a Component.
|
// Channel to pass stream IDs as a queue
|
||||||
//
|
var uchan = make(chan string, len(uuidsArray))
|
||||||
// This validates that Component conforms to StreamClient interface.
|
// Populate test channel
|
||||||
func TestStreamManager(t *testing.T) {
|
for _, elt := range uuidsArray {
|
||||||
NewStreamManager(&Component{}, nil)
|
uchan <- elt
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tests that the decoder is properly initialized when connecting a component to a server.
|
// Performs a Component connection with a handshake. It expects to have an ID sent its way through the "uchan"
|
||||||
// The decoder is expected to be built after a valid connection
|
// channel of this file. Otherwise it will hang for ever.
|
||||||
// Based on the xmpp_component example.
|
h := func(t *testing.T, c net.Conn) {
|
||||||
func TestDecoder(t *testing.T) {
|
decoder := xml.NewDecoder(c)
|
||||||
testComponentAddess := fmt.Sprintf("%s:%s", testComponentDomain, testComponentPort)
|
checkOpenStreamHandshakeID(t, c, decoder, <-uchan)
|
||||||
|
readHandshakeComponent(t, decoder)
|
||||||
|
fmt.Fprintln(c, "<handshake/>") // That's all the server needs to return (see xep-0114)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Init mock server
|
||||||
|
testComponentAddess := fmt.Sprintf("%s:%d", testComponentDomain, testHandshakePort)
|
||||||
mock := ServerMock{}
|
mock := ServerMock{}
|
||||||
mock.Start(t, testComponentAddess, handlerConnectSuccess)
|
mock.Start(t, testComponentAddess, h)
|
||||||
|
|
||||||
|
// Init component
|
||||||
opts := ComponentOptions{
|
opts := ComponentOptions{
|
||||||
TransportConfiguration: TransportConfiguration{
|
TransportConfiguration: TransportConfiguration{
|
||||||
Address: testComponentAddess,
|
Address: testComponentAddess,
|
||||||
@ -63,12 +103,352 @@ func TestDecoder(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("%+v", err)
|
t.Errorf("%+v", err)
|
||||||
}
|
}
|
||||||
_, err = c.transport.Connect()
|
|
||||||
if err != nil {
|
// Try connecting, and storing the resulting streamID in a map.
|
||||||
t.Errorf("%+v", err)
|
m := make(map[string]bool)
|
||||||
|
for _, _ = range uuidsArray {
|
||||||
|
streamId, _ := c.transport.Connect()
|
||||||
|
m[c.handshake(streamId)] = true
|
||||||
}
|
}
|
||||||
|
if len(uuidsArray) != len(m) {
|
||||||
|
t.Errorf("Handshake does not produce a unique id. Expected: %d unique ids, got: %d", len(uuidsArray), len(m))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that NewStreamManager can accept a Component.
|
||||||
|
//
|
||||||
|
// This validates that Component conforms to StreamClient interface.
|
||||||
|
func TestStreamManager(t *testing.T) {
|
||||||
|
NewStreamManager(&Component{}, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tests that the decoder is properly initialized when connecting a component to a server.
|
||||||
|
// The decoder is expected to be built after a valid connection
|
||||||
|
// Based on the xmpp_component example.
|
||||||
|
func TestDecoder(t *testing.T) {
|
||||||
|
c, _ := mockConnection(t, testDecoderPort, handlerForComponentHandshakeDefaultID)
|
||||||
if c.transport.GetDecoder() == nil {
|
if c.transport.GetDecoder() == nil {
|
||||||
t.Errorf("Failed to initialize decoder. Decoder is nil.")
|
t.Errorf("Failed to initialize decoder. Decoder is nil.")
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tests sending an IQ to the server, and getting the response
|
||||||
|
func TestSendIq(t *testing.T) {
|
||||||
|
//Connecting to a mock server, initialized with given port and handler function
|
||||||
|
c, m := mockConnection(t, testSendIqPort, handlerForComponentIQSend)
|
||||||
|
|
||||||
|
ctx, _ := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
iqReq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "test1@localhost/mremond-mbp", To: defaultServerName, Id: defaultStreamID, Lang: "en"})
|
||||||
|
disco := iqReq.DiscoInfo()
|
||||||
|
iqReq.Payload = disco
|
||||||
|
|
||||||
|
var res chan stanza.IQ
|
||||||
|
res, _ = c.SendIQ(ctx, iqReq)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-res:
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Errorf("Failed to receive response, to sent IQ, from mock server")
|
||||||
|
}
|
||||||
|
|
||||||
|
m.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tests sending raw xml to the mock server.
|
||||||
|
// TODO : check the server response client side ?
|
||||||
|
// Right now, the server response is not checked and an err is passed in a channel if the test is supposed to err.
|
||||||
|
// In this test, we use IQs
|
||||||
|
func TestSendRaw(t *testing.T) {
|
||||||
|
// Error channel for the handler
|
||||||
|
errChan := make(chan error)
|
||||||
|
// Handler for the mock server
|
||||||
|
h := func(t *testing.T, c net.Conn) {
|
||||||
|
// Completes the connection by exchanging handshakes
|
||||||
|
handlerForComponentHandshakeDefaultID(t, c)
|
||||||
|
receiveRawIq(t, c, errChan)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type testCase struct {
|
||||||
|
req string
|
||||||
|
shouldErr bool
|
||||||
|
}
|
||||||
|
testRequests := make(map[string]testCase)
|
||||||
|
// Sending a correct IQ of type get. Not supposed to err
|
||||||
|
testRequests["Correct IQ"] = testCase{
|
||||||
|
req: `<iq type="get" id="91bd0bba-012f-4d92-bb17-5fc41e6fe545" from="test1@localhost/mremond-mbp" to="testServer" lang="en"><query xmlns="http://jabber.org/protocol/disco#info"></query></iq>`,
|
||||||
|
shouldErr: false,
|
||||||
|
}
|
||||||
|
// Sending an IQ with a missing ID. Should err
|
||||||
|
testRequests["IQ with missing ID"] = testCase{
|
||||||
|
req: `<iq type="get" from="test1@localhost/mremond-mbp" to="testServer" lang="en"><query xmlns="http://jabber.org/protocol/disco#info"></query></iq>`,
|
||||||
|
shouldErr: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tests for all the IQs
|
||||||
|
for name, tcase := range testRequests {
|
||||||
|
t.Run(name, func(st *testing.T) {
|
||||||
|
//Connecting to a mock server, initialized with given port and handler function
|
||||||
|
c, m := mockConnection(t, testSendRawPort, h)
|
||||||
|
|
||||||
|
// Sending raw xml from test case
|
||||||
|
err := c.SendRaw(tcase.req)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error sending Raw string")
|
||||||
|
}
|
||||||
|
// Just wait a little so the message has time to arrive
|
||||||
|
select {
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
case err = <-errChan:
|
||||||
|
if err == nil && tcase.shouldErr {
|
||||||
|
t.Errorf("Failed to get closing stream err")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.transport.Close()
|
||||||
|
m.Stop()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tests the Disconnect method for Components
|
||||||
|
func TestDisconnect(t *testing.T) {
|
||||||
|
c, m := mockConnection(t, testDisconnectPort, handlerForComponentHandshakeDefaultID)
|
||||||
|
err := c.transport.Ping()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Could not ping but not disconnected yet")
|
||||||
|
}
|
||||||
|
c.Disconnect()
|
||||||
|
err = c.transport.Ping()
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Did not disconnect properly")
|
||||||
|
}
|
||||||
|
m.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tests that a streamManager successfully disconnects when a handshake fails between the component and the server.
|
||||||
|
func TestStreamManagerDisconnect(t *testing.T) {
|
||||||
|
// Init mock server
|
||||||
|
testComponentAddress := fmt.Sprintf("%s:%d", testComponentDomain, testSManDisconnectPort)
|
||||||
|
mock := ServerMock{}
|
||||||
|
// Handler fails the handshake, which is currently the only option to disconnect completely when using a streamManager
|
||||||
|
// a failed handshake being a permanent error, except for a "conflict"
|
||||||
|
mock.Start(t, testComponentAddress, handlerComponentFailedHandshakeDefaultID)
|
||||||
|
|
||||||
|
//==================================
|
||||||
|
// Create Component to connect to it
|
||||||
|
c := makeBasicComponent(defaultComponentName, testComponentAddress, t)
|
||||||
|
|
||||||
|
//========================================
|
||||||
|
// Connect the new Component to the server
|
||||||
|
cm := NewStreamManager(c, nil)
|
||||||
|
errChan := make(chan error)
|
||||||
|
runSMan := func(errChan chan error) {
|
||||||
|
errChan <- cm.Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
go runSMan(errChan)
|
||||||
|
select {
|
||||||
|
case <-errChan:
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Errorf("The component and server seem to still be connected while they should not.")
|
||||||
|
}
|
||||||
|
mock.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
//=============================================================================
|
||||||
|
// Basic XMPP Server Mock Handlers.
|
||||||
|
// Performs a Component connection with a handshake. It uses a default ID defined in this file as a constant.
|
||||||
|
// Used in the mock server as a Handler
|
||||||
|
func handlerForComponentHandshakeDefaultID(t *testing.T, c net.Conn) {
|
||||||
|
decoder := xml.NewDecoder(c)
|
||||||
|
checkOpenStreamHandshakeDefaultID(t, c, decoder)
|
||||||
|
readHandshakeComponent(t, decoder)
|
||||||
|
fmt.Fprintln(c, "<handshake/>") // That's all the server needs to return (see xep-0114)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Performs a Component connection with a handshake. It uses a default ID defined in this file as a constant.
|
||||||
|
// This handler is supposed to fail by sending a "message" stanza instead of a <handshake/> stanza to finalize the handshake.
|
||||||
|
func handlerComponentFailedHandshakeDefaultID(t *testing.T, c net.Conn) {
|
||||||
|
decoder := xml.NewDecoder(c)
|
||||||
|
checkOpenStreamHandshakeDefaultID(t, c, decoder)
|
||||||
|
readHandshakeComponent(t, decoder)
|
||||||
|
|
||||||
|
// Send a message, instead of a "<handshake/>" tag, to fail the handshake process dans disconnect the client.
|
||||||
|
me := stanza.Message{
|
||||||
|
Attrs: stanza.Attrs{Type: stanza.MessageTypeChat, From: defaultServerName, To: defaultComponentName, Lang: "en"},
|
||||||
|
Body: "Fail my handshake.",
|
||||||
|
}
|
||||||
|
s, _ := xml.Marshal(me)
|
||||||
|
fmt.Fprintln(c, string(s))
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reads from the connection with the Component. Expects a handshake request, and returns the <handshake/> tag.
|
||||||
|
func readHandshakeComponent(t *testing.T, decoder *xml.Decoder) {
|
||||||
|
se, err := stanza.NextStart(decoder)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("cannot read auth: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
nv := &stanza.Handshake{}
|
||||||
|
// Decode element into pointer storage
|
||||||
|
if err = decoder.DecodeElement(nv, &se); err != nil {
|
||||||
|
t.Errorf("cannot decode handshake: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(strings.TrimSpace(nv.Value)) == 0 {
|
||||||
|
t.Errorf("did not receive handshake ID")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkOpenStreamHandshakeDefaultID(t *testing.T, c net.Conn, decoder *xml.Decoder) {
|
||||||
|
checkOpenStreamHandshakeID(t, c, decoder, defaultStreamID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Used for ID and handshake related tests
|
||||||
|
func checkOpenStreamHandshakeID(t *testing.T, c net.Conn, decoder *xml.Decoder, streamID string) {
|
||||||
|
c.SetDeadline(time.Now().Add(defaultTimeout))
|
||||||
|
defer c.SetDeadline(time.Time{})
|
||||||
|
|
||||||
|
for { // TODO clean up. That for loop is not elegant and I prefer bounded recursion.
|
||||||
|
token, err := decoder.Token()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("cannot read next token: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch elem := token.(type) {
|
||||||
|
// Wait for first startElement
|
||||||
|
case xml.StartElement:
|
||||||
|
if elem.Name.Space != stanza.NSStream || elem.Name.Local != "stream" {
|
||||||
|
err = errors.New("xmpp: expected <stream> but got <" + elem.Name.Local + "> in " + elem.Name.Space)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err := fmt.Fprintf(c, serverStreamOpen, "localhost", streamID, stanza.NSComponent, stanza.NSStream); err != nil {
|
||||||
|
t.Errorf("cannot write server stream open: %s", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//=============================================================================
|
||||||
|
// Sends IQ response to Component request.
|
||||||
|
// No parsing of the request here. We just check that it's valid, and send the default response.
|
||||||
|
func handlerForComponentIQSend(t *testing.T, c net.Conn) {
|
||||||
|
// Completes the connection by exchanging handshakes
|
||||||
|
handlerForComponentHandshakeDefaultID(t, c)
|
||||||
|
|
||||||
|
// Decoder to parse the request
|
||||||
|
decoder := xml.NewDecoder(c)
|
||||||
|
|
||||||
|
iqReq, err := receiveIq(t, c, decoder)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error receiving the IQ stanza : %v", err)
|
||||||
|
} else if !iqReq.IsValid() {
|
||||||
|
t.Errorf("server received an IQ stanza : %v", iqReq)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Crafting response
|
||||||
|
iqResp := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeResult, From: iqReq.To, To: iqReq.From, Id: iqReq.Id, Lang: "en"})
|
||||||
|
disco := iqResp.DiscoInfo()
|
||||||
|
disco.AddFeatures("vcard-temp",
|
||||||
|
`http://jabber.org/protocol/address`)
|
||||||
|
|
||||||
|
disco.AddIdentity("Multicast", "service", "multicast")
|
||||||
|
iqResp.Payload = disco
|
||||||
|
|
||||||
|
// Sending response to the Component
|
||||||
|
mResp, err := xml.Marshal(iqResp)
|
||||||
|
_, err = fmt.Fprintln(c, string(mResp))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Could not send response stanza : %s", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reads next request coming from the Component. Expecting it to be an IQ request
|
||||||
|
func receiveIq(t *testing.T, c net.Conn, decoder *xml.Decoder) (stanza.IQ, error) {
|
||||||
|
c.SetDeadline(time.Now().Add(defaultTimeout))
|
||||||
|
defer c.SetDeadline(time.Time{})
|
||||||
|
var iqStz stanza.IQ
|
||||||
|
err := decoder.Decode(&iqStz)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("cannot read the received IQ stanza: %s", err)
|
||||||
|
}
|
||||||
|
if !iqStz.IsValid() {
|
||||||
|
t.Errorf("received IQ stanza is invalid : %s", err)
|
||||||
|
}
|
||||||
|
return iqStz, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func receiveRawIq(t *testing.T, c net.Conn, errChan chan error) {
|
||||||
|
c.SetDeadline(time.Now().Add(defaultTimeout))
|
||||||
|
defer c.SetDeadline(time.Time{})
|
||||||
|
decoder := xml.NewDecoder(c)
|
||||||
|
var iq stanza.IQ
|
||||||
|
err := decoder.Decode(&iq)
|
||||||
|
if err != nil || !iq.IsValid() {
|
||||||
|
s := stanza.StreamError{
|
||||||
|
XMLName: xml.Name{Local: "stream:error"},
|
||||||
|
Error: xml.Name{Local: "xml-not-well-formed"},
|
||||||
|
Text: `XML was not well-formed`,
|
||||||
|
}
|
||||||
|
raw, _ := xml.Marshal(s)
|
||||||
|
fmt.Fprintln(c, string(raw))
|
||||||
|
fmt.Fprintln(c, `</stream:stream>`) // TODO : check this client side
|
||||||
|
errChan <- fmt.Errorf("invalid xml")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
errChan <- nil
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
//===============================
|
||||||
|
// Init mock server and connection
|
||||||
|
// Creating a mock server and connecting a Component to it. Initialized with given port and handler function
|
||||||
|
// The Component and mock are both returned
|
||||||
|
func mockConnection(t *testing.T, port int, handler func(t *testing.T, c net.Conn)) (*Component, *ServerMock) {
|
||||||
|
// Init mock server
|
||||||
|
testComponentAddress := fmt.Sprintf("%s:%d", testComponentDomain, port)
|
||||||
|
mock := ServerMock{}
|
||||||
|
mock.Start(t, testComponentAddress, handler)
|
||||||
|
|
||||||
|
//==================================
|
||||||
|
// Create Component to connect to it
|
||||||
|
c := makeBasicComponent(defaultComponentName, testComponentAddress, t)
|
||||||
|
|
||||||
|
//========================================
|
||||||
|
// Connect the new Component to the server
|
||||||
|
err := c.Connect()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%+v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c, &mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeBasicComponent(name string, mockServerAddr string, t *testing.T) *Component {
|
||||||
|
opts := ComponentOptions{
|
||||||
|
TransportConfiguration: TransportConfiguration{
|
||||||
|
Address: mockServerAddr,
|
||||||
|
Domain: "localhost",
|
||||||
|
},
|
||||||
|
Domain: testComponentDomain,
|
||||||
|
Secret: "mypass",
|
||||||
|
Name: name,
|
||||||
|
Category: "gateway",
|
||||||
|
Type: "service",
|
||||||
|
}
|
||||||
|
router := NewRouter()
|
||||||
|
c, err := NewComponent(opts, router)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%+v", err)
|
||||||
|
}
|
||||||
|
c.transport, err = NewComponentTransport(c.ComponentOptions.TransportConfiguration)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%+v", err)
|
||||||
|
}
|
||||||
|
return c
|
||||||
}
|
}
|
||||||
|
@ -23,7 +23,7 @@ func ensurePort(addr string, port int) string {
|
|||||||
// This is IPV4 without port
|
// This is IPV4 without port
|
||||||
return addr + ":" + strconv.Itoa(port)
|
return addr + ":" + strconv.Itoa(port)
|
||||||
case 1:
|
case 1:
|
||||||
// This is IPV$ with port
|
// This is IPV6 with port
|
||||||
return addr
|
return addr
|
||||||
default:
|
default:
|
||||||
// This is IPV6 without port, as you need to use bracket with port in IPV6
|
// This is IPV6 without port, as you need to use bracket with port in IPV6
|
||||||
|
@ -1,12 +1,10 @@
|
|||||||
package xmpp
|
package xmpp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
type params struct {
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseAddr(t *testing.T) {
|
func TestParseAddr(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@ -33,3 +31,36 @@ func TestParseAddr(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEnsurePort(t *testing.T) {
|
||||||
|
testAddresses := []string{
|
||||||
|
"1ca3:6c07:ee3a:89ca:e065:9a70:71d:daad",
|
||||||
|
"1ca3:6c07:ee3a:89ca:e065:9a70:71d:daad:5252",
|
||||||
|
"[::1]",
|
||||||
|
"127.0.0.1:5555",
|
||||||
|
"127.0.0.1",
|
||||||
|
"[::1]:5555",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, oldAddr := range testAddresses {
|
||||||
|
t.Run(oldAddr, func(st *testing.T) {
|
||||||
|
newAddr := ensurePort(oldAddr, 5222)
|
||||||
|
|
||||||
|
if len(newAddr) < len(oldAddr) {
|
||||||
|
st.Errorf("incorrect Result: transformed address is shorter than input : %v (old) > %v (new)", newAddr, oldAddr)
|
||||||
|
}
|
||||||
|
// If IPv6, the new address needs brackets to specify a port, like so : [2001:db8:85a3:0:0:8a2e:370:7334]:5222
|
||||||
|
if strings.Count(newAddr, "[") < strings.Count(oldAddr, "[") ||
|
||||||
|
strings.Count(newAddr, "]") < strings.Count(oldAddr, "]") {
|
||||||
|
|
||||||
|
st.Errorf("incorrect Result. Transformed address seems to not have correct brakets : %v => %v", oldAddr, newAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if we messed up the colons, or didn't properly add a port
|
||||||
|
if strings.Count(newAddr, ":") < strings.Count(oldAddr, ":") {
|
||||||
|
st.Errorf("incorrect Result: transformed address doesn't seem to have a port %v (=> %v, no port ?)", oldAddr, newAddr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
@ -146,7 +146,7 @@ func TestTypeMatcher(t *testing.T) {
|
|||||||
|
|
||||||
// We do not match on other types
|
// We do not match on other types
|
||||||
conn = NewSenderMock()
|
conn = NewSenderMock()
|
||||||
iqVersion := stanza.NewIQ(stanza.Attrs{Type: "get", From: "service.localhost", To: "test@localhost", Id: "1"})
|
iqVersion := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "service.localhost", To: "test@localhost", Id: "1"})
|
||||||
iqVersion.Payload = &stanza.DiscoInfo{
|
iqVersion.Payload = &stanza.DiscoInfo{
|
||||||
XMLName: xml.Name{
|
XMLName: xml.Name{
|
||||||
Space: "jabber:iq:version",
|
Space: "jabber:iq:version",
|
||||||
@ -163,27 +163,27 @@ func TestCompositeMatcher(t *testing.T) {
|
|||||||
router := NewRouter()
|
router := NewRouter()
|
||||||
router.NewRoute().
|
router.NewRoute().
|
||||||
IQNamespaces("jabber:iq:version").
|
IQNamespaces("jabber:iq:version").
|
||||||
StanzaType("get").
|
StanzaType(string(stanza.IQTypeGet)).
|
||||||
HandlerFunc(func(s Sender, p stanza.Packet) {
|
HandlerFunc(func(s Sender, p stanza.Packet) {
|
||||||
_ = s.SendRaw(successFlag)
|
_ = s.SendRaw(successFlag)
|
||||||
})
|
})
|
||||||
|
|
||||||
// Data set
|
// Data set
|
||||||
getVersionIq := stanza.NewIQ(stanza.Attrs{Type: "get", From: "service.localhost", To: "test@localhost", Id: "1"})
|
getVersionIq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "service.localhost", To: "test@localhost", Id: "1"})
|
||||||
getVersionIq.Payload = &stanza.Version{
|
getVersionIq.Payload = &stanza.Version{
|
||||||
XMLName: xml.Name{
|
XMLName: xml.Name{
|
||||||
Space: "jabber:iq:version",
|
Space: "jabber:iq:version",
|
||||||
Local: "query",
|
Local: "query",
|
||||||
}}
|
}}
|
||||||
|
|
||||||
setVersionIq := stanza.NewIQ(stanza.Attrs{Type: "set", From: "service.localhost", To: "test@localhost", Id: "1"})
|
setVersionIq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeSet, From: "service.localhost", To: "test@localhost", Id: "1"})
|
||||||
setVersionIq.Payload = &stanza.Version{
|
setVersionIq.Payload = &stanza.Version{
|
||||||
XMLName: xml.Name{
|
XMLName: xml.Name{
|
||||||
Space: "jabber:iq:version",
|
Space: "jabber:iq:version",
|
||||||
Local: "query",
|
Local: "query",
|
||||||
}}
|
}}
|
||||||
|
|
||||||
GetDiscoIq := stanza.NewIQ(stanza.Attrs{Type: "get", From: "service.localhost", To: "test@localhost", Id: "1"})
|
GetDiscoIq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "service.localhost", To: "test@localhost", Id: "1"})
|
||||||
GetDiscoIq.Payload = &stanza.DiscoInfo{
|
GetDiscoIq.Payload = &stanza.DiscoInfo{
|
||||||
XMLName: xml.Name{
|
XMLName: xml.Name{
|
||||||
Space: "http://jabber.org/protocol/disco#info",
|
Space: "http://jabber.org/protocol/disco#info",
|
||||||
@ -238,7 +238,7 @@ func TestCatchallMatcher(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
conn = NewSenderMock()
|
conn = NewSenderMock()
|
||||||
iqVersion := stanza.NewIQ(stanza.Attrs{Type: "get", From: "service.localhost", To: "test@localhost", Id: "1"})
|
iqVersion := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "service.localhost", To: "test@localhost", Id: "1"})
|
||||||
iqVersion.Payload = &stanza.DiscoInfo{
|
iqVersion.Payload = &stanza.DiscoInfo{
|
||||||
XMLName: xml.Name{
|
XMLName: xml.Name{
|
||||||
Space: "jabber:iq:version",
|
Space: "jabber:iq:version",
|
||||||
|
@ -12,7 +12,7 @@ import (
|
|||||||
type Handshake struct {
|
type Handshake struct {
|
||||||
XMLName xml.Name `xml:"jabber:component:accept handshake"`
|
XMLName xml.Name `xml:"jabber:component:accept handshake"`
|
||||||
// TODO Add handshake value with test for proper serialization
|
// TODO Add handshake value with test for proper serialization
|
||||||
// Value string `xml:",innerxml"`
|
Value string `xml:",innerxml"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (Handshake) Name() string {
|
func (Handshake) Name() string {
|
||||||
|
@ -54,7 +54,7 @@ func (x *Err) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
|
|||||||
|
|
||||||
textName := xml.Name{Space: "urn:ietf:params:xml:ns:xmpp-stanzas", Local: "text"}
|
textName := xml.Name{Space: "urn:ietf:params:xml:ns:xmpp-stanzas", Local: "text"}
|
||||||
if elt.XMLName == textName {
|
if elt.XMLName == textName {
|
||||||
x.Text = string(elt.Content)
|
x.Text = elt.Content
|
||||||
} else if elt.XMLName.Space == "urn:ietf:params:xml:ns:xmpp-stanzas" {
|
} else if elt.XMLName.Space == "urn:ietf:params:xml:ns:xmpp-stanzas" {
|
||||||
x.Reason = elt.XMLName.Local
|
x.Reason = elt.XMLName.Local
|
||||||
}
|
}
|
||||||
@ -94,16 +94,32 @@ func (x Err) MarshalXML(e *xml.Encoder, start xml.StartElement) (err error) {
|
|||||||
// Reason
|
// Reason
|
||||||
if x.Reason != "" {
|
if x.Reason != "" {
|
||||||
reason := xml.Name{Space: "urn:ietf:params:xml:ns:xmpp-stanzas", Local: x.Reason}
|
reason := xml.Name{Space: "urn:ietf:params:xml:ns:xmpp-stanzas", Local: x.Reason}
|
||||||
e.EncodeToken(xml.StartElement{Name: reason})
|
err = e.EncodeToken(xml.StartElement{Name: reason})
|
||||||
e.EncodeToken(xml.EndElement{Name: reason})
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = e.EncodeToken(xml.EndElement{Name: reason})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Text
|
// Text
|
||||||
if x.Text != "" {
|
if x.Text != "" {
|
||||||
text := xml.Name{Space: "urn:ietf:params:xml:ns:xmpp-stanzas", Local: "text"}
|
text := xml.Name{Space: "urn:ietf:params:xml:ns:xmpp-stanzas", Local: "text"}
|
||||||
e.EncodeToken(xml.StartElement{Name: text})
|
err = e.EncodeToken(xml.StartElement{Name: text})
|
||||||
e.EncodeToken(xml.CharData(x.Text))
|
if err != nil {
|
||||||
e.EncodeToken(xml.EndElement{Name: text})
|
return err
|
||||||
|
}
|
||||||
|
err = e.EncodeToken(xml.CharData(x.Text))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = e.EncodeToken(xml.EndElement{Name: text})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return e.EncodeToken(xml.EndElement{Name: start.Name})
|
return e.EncodeToken(xml.EndElement{Name: start.Name})
|
||||||
|
43
stanza/iq.go
43
stanza/iq.go
@ -2,6 +2,7 @@ package stanza
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/xml"
|
"encoding/xml"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
@ -23,7 +24,7 @@ type IQ struct { // Info/Query
|
|||||||
// child element, which specifies the semantics of the particular
|
// child element, which specifies the semantics of the particular
|
||||||
// request."
|
// request."
|
||||||
Payload IQPayload `xml:",omitempty"`
|
Payload IQPayload `xml:",omitempty"`
|
||||||
Error Err `xml:"error,omitempty"`
|
Error *Err `xml:"error,omitempty"`
|
||||||
// Any is used to decode unknown payload as a generic structure
|
// Any is used to decode unknown payload as a generic structure
|
||||||
Any *Node `xml:",any"`
|
Any *Node `xml:",any"`
|
||||||
}
|
}
|
||||||
@ -52,7 +53,7 @@ func (iq IQ) MakeError(xerror Err) IQ {
|
|||||||
iq.Type = "error"
|
iq.Type = "error"
|
||||||
iq.From = to
|
iq.From = to
|
||||||
iq.To = from
|
iq.To = from
|
||||||
iq.Error = xerror
|
iq.Error = &xerror
|
||||||
|
|
||||||
return iq
|
return iq
|
||||||
}
|
}
|
||||||
@ -106,7 +107,7 @@ func (iq *IQ) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
iq.Error = xmppError
|
iq.Error = &xmppError
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if iqExt := TypeRegistry.GetIQExtension(tt.Name); iqExt != nil {
|
if iqExt := TypeRegistry.GetIQExtension(tt.Name); iqExt != nil {
|
||||||
@ -132,3 +133,39 @@ func (iq *IQ) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Following RFC-3920 for IQs
|
||||||
|
func (iq *IQ) IsValid() bool {
|
||||||
|
// ID is required
|
||||||
|
if len(strings.TrimSpace(iq.Id)) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type is required
|
||||||
|
if iq.Type.IsEmpty() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type get and set must contain one and only one child element that specifies the semantics
|
||||||
|
if iq.Type == IQTypeGet || iq.Type == IQTypeSet {
|
||||||
|
if iq.Payload == nil && iq.Any == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// A result must include zero or one child element
|
||||||
|
if iq.Type == IQTypeResult {
|
||||||
|
if iq.Payload != nil && iq.Any != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//Error type must contain an "error" child element
|
||||||
|
if iq.Type == IQTypeError {
|
||||||
|
if iq.Error == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
@ -187,3 +187,38 @@ func TestUnknownPayload(t *testing.T) {
|
|||||||
t.Errorf("could not extract namespace: '%s'", parsedIQ.Any.XMLName.Space)
|
t.Errorf("could not extract namespace: '%s'", parsedIQ.Any.XMLName.Space)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIsValid(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
iq string
|
||||||
|
shouldErr bool
|
||||||
|
}
|
||||||
|
testIQs := make(map[string]testCase)
|
||||||
|
testIQs["Valid IQ"] = testCase{
|
||||||
|
`<iq type="get" to="service.localhost" id="1" >
|
||||||
|
<query xmlns="unknown:ns"/>
|
||||||
|
</iq>`,
|
||||||
|
false,
|
||||||
|
}
|
||||||
|
testIQs["Invalid IQ"] = testCase{
|
||||||
|
`<iq type="get" to="service.localhost">
|
||||||
|
<query xmlns="unknown:ns"/>
|
||||||
|
</iq>`,
|
||||||
|
true,
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, tcase := range testIQs {
|
||||||
|
t.Run(name, func(st *testing.T) {
|
||||||
|
parsedIQ := stanza.IQ{}
|
||||||
|
err := xml.Unmarshal([]byte(tcase.iq), &parsedIQ)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unmarshal error: %#v (%s)", err, tcase.iq)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !parsedIQ.IsValid() && !tcase.shouldErr {
|
||||||
|
t.Errorf("failed iq validation for : %s", tcase.iq)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
@ -46,9 +46,18 @@ func (n Node) MarshalXML(e *xml.Encoder, start xml.StartElement) (err error) {
|
|||||||
start.Name = n.XMLName
|
start.Name = n.XMLName
|
||||||
|
|
||||||
err = e.EncodeToken(start)
|
err = e.EncodeToken(start)
|
||||||
e.EncodeElement(n.Nodes, xml.StartElement{Name: n.XMLName})
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = e.EncodeElement(n.Nodes, xml.StartElement{Name: n.XMLName})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
if n.Content != "" {
|
if n.Content != "" {
|
||||||
e.EncodeToken(xml.CharData(n.Content))
|
err = e.EncodeToken(xml.CharData(n.Content))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return e.EncodeToken(xml.EndElement{Name: start.Name})
|
return e.EncodeToken(xml.EndElement{Name: start.Name})
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
package stanza
|
package stanza
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
type StanzaType string
|
type StanzaType string
|
||||||
|
|
||||||
// RFC 6120: part of A.5 Client Namespace and A.6 Server Namespace
|
// RFC 6120: part of A.5 Client Namespace and A.6 Server Namespace
|
||||||
@ -23,3 +25,7 @@ const (
|
|||||||
PresenceTypeUnsubscribe StanzaType = "unsubscribe"
|
PresenceTypeUnsubscribe StanzaType = "unsubscribe"
|
||||||
PresenceTypeUnsubscribed StanzaType = "unsubscribed"
|
PresenceTypeUnsubscribed StanzaType = "unsubscribed"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func (s StanzaType) IsEmpty() bool {
|
||||||
|
return len(strings.TrimSpace(string(s))) == 0
|
||||||
|
}
|
||||||
|
@ -107,6 +107,6 @@ func (s *StreamSession) IsOptional() bool {
|
|||||||
// Registry init
|
// Registry init
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
TypeRegistry.MapExtension(PKTIQ, xml.Name{"urn:ietf:params:xml:ns:xmpp-bind", "bind"}, Bind{})
|
TypeRegistry.MapExtension(PKTIQ, xml.Name{Space: "urn:ietf:params:xml:ns:xmpp-bind", Local: "bind"}, Bind{})
|
||||||
TypeRegistry.MapExtension(PKTIQ, xml.Name{"urn:ietf:params:xml:ns:xmpp-session", "session"}, StreamSession{})
|
TypeRegistry.MapExtension(PKTIQ, xml.Name{Space: "urn:ietf:params:xml:ns:xmpp-session", Local: "session"}, StreamSession{})
|
||||||
}
|
}
|
||||||
|
@ -15,7 +15,7 @@ type StreamFeatures struct {
|
|||||||
// Server capabilities hash
|
// Server capabilities hash
|
||||||
Caps Caps
|
Caps Caps
|
||||||
// Stream features
|
// Stream features
|
||||||
StartTLS tlsStartTLS
|
StartTLS TlsStartTLS
|
||||||
Mechanisms saslMechanisms
|
Mechanisms saslMechanisms
|
||||||
Bind Bind
|
Bind Bind
|
||||||
StreamManagement streamManagement
|
StreamManagement streamManagement
|
||||||
@ -60,13 +60,13 @@ type Caps struct {
|
|||||||
|
|
||||||
// StartTLS feature
|
// StartTLS feature
|
||||||
// Reference: RFC 6120 - https://tools.ietf.org/html/rfc6120#section-5.4
|
// Reference: RFC 6120 - https://tools.ietf.org/html/rfc6120#section-5.4
|
||||||
type tlsStartTLS struct {
|
type TlsStartTLS struct {
|
||||||
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls starttls"`
|
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls starttls"`
|
||||||
Required bool
|
Required bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnmarshalXML implements custom parsing startTLS required flag
|
// UnmarshalXML implements custom parsing startTLS required flag
|
||||||
func (stls *tlsStartTLS) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
|
func (stls *TlsStartTLS) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
|
||||||
stls.XMLName = start.Name
|
stls.XMLName = start.Name
|
||||||
|
|
||||||
// Check subelements to extract required field as boolean
|
// Check subelements to extract required field as boolean
|
||||||
@ -98,7 +98,7 @@ func (stls *tlsStartTLS) UnmarshalXML(d *xml.Decoder, start xml.StartElement) er
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sf *StreamFeatures) DoesStartTLS() (feature tlsStartTLS, isSupported bool) {
|
func (sf *StreamFeatures) DoesStartTLS() (feature TlsStartTLS, isSupported bool) {
|
||||||
if sf.StartTLS.XMLName.Space+" "+sf.StartTLS.XMLName.Local == nsTLS+" starttls" {
|
if sf.StartTLS.XMLName.Space+" "+sf.StartTLS.XMLName.Local == nsTLS+" starttls" {
|
||||||
return sf.StartTLS, true
|
return sf.StartTLS, true
|
||||||
}
|
}
|
||||||
|
@ -74,7 +74,7 @@ func (sm *StreamManager) Run() error {
|
|||||||
return errors.New("missing stream client")
|
return errors.New("missing stream client")
|
||||||
}
|
}
|
||||||
|
|
||||||
handler := func(e Event) {
|
handler := func(e Event) error {
|
||||||
switch e.State {
|
switch e.State {
|
||||||
case StateConnected:
|
case StateConnected:
|
||||||
sm.Metrics.setConnectTime()
|
sm.Metrics.setConnectTime()
|
||||||
@ -82,17 +82,18 @@ func (sm *StreamManager) Run() error {
|
|||||||
sm.Metrics.setLoginTime()
|
sm.Metrics.setLoginTime()
|
||||||
case StateDisconnected:
|
case StateDisconnected:
|
||||||
// Reconnect on disconnection
|
// Reconnect on disconnection
|
||||||
sm.resume(e.SMState)
|
return sm.resume(e.SMState)
|
||||||
case StateStreamError:
|
case StateStreamError:
|
||||||
sm.client.Disconnect()
|
sm.client.Disconnect()
|
||||||
// Only try reconnecting if we have not been kicked by another session to avoid connection loop.
|
// Only try reconnecting if we have not been kicked by another session to avoid connection loop.
|
||||||
// TODO: Make this conflict exception a permanent error
|
// TODO: Make this conflict exception a permanent error
|
||||||
if e.StreamError != "conflict" {
|
if e.StreamError != "conflict" {
|
||||||
sm.connect()
|
return sm.connect()
|
||||||
}
|
}
|
||||||
case StatePermanentError:
|
case StatePermanentError:
|
||||||
// Do not attempt to reconnect
|
// Do not attempt to reconnect
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
sm.client.SetHandler(handler)
|
sm.client.SetHandler(handler)
|
||||||
|
|
||||||
|
2
test.sh
2
test.sh
@ -5,7 +5,7 @@ export GO111MODULE=on
|
|||||||
echo "" > coverage.txt
|
echo "" > coverage.txt
|
||||||
|
|
||||||
for d in $(go list ./... | grep -v vendor); do
|
for d in $(go list ./... | grep -v vendor); do
|
||||||
go test -race -coverprofile=profile.out -covermode=atomic ${d}
|
go test -race -coverprofile=profile.out -covermode=atomic "${d}"
|
||||||
if [ -f profile.out ]; then
|
if [ -f profile.out ]; then
|
||||||
cat profile.out >> coverage.txt
|
cat profile.out >> coverage.txt
|
||||||
rm profile.out
|
rm profile.out
|
||||||
|
Loading…
Reference in New Issue
Block a user